Implement SSH Connection Rate Limiting: IP rate limit + global rate limit + auth brute force prevention
This commit is contained in:
@@ -14,6 +14,7 @@ pub mod kex_exchange;
|
||||
pub mod packet;
|
||||
pub mod port_forward;
|
||||
pub mod port_forward_listener;
|
||||
pub mod rate_limiter;
|
||||
pub mod rsync_handler;
|
||||
pub mod scp_handler;
|
||||
pub mod server;
|
||||
|
||||
481
markbase-core/src/ssh_server/rate_limiter.rs
Normal file
481
markbase-core/src/ssh_server/rate_limiter.rs
Normal file
@@ -0,0 +1,481 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RateLimitConfig {
|
||||
pub max_connections_per_ip: usize,
|
||||
pub max_connections_per_ip_window: Duration,
|
||||
pub max_global_connections: usize,
|
||||
pub max_global_connections_window: Duration,
|
||||
pub max_auth_attempts_per_ip: usize,
|
||||
pub max_auth_attempts_window: Duration,
|
||||
pub ban_duration: Duration,
|
||||
pub whitelist: Vec<IpAddr>,
|
||||
pub blacklist: Vec<IpAddr>,
|
||||
}
|
||||
|
||||
impl Default for RateLimitConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_connections_per_ip: 10,
|
||||
max_connections_per_ip_window: Duration::from_secs(60),
|
||||
max_global_connections: 100,
|
||||
max_global_connections_window: Duration::from_secs(60),
|
||||
max_auth_attempts_per_ip: 5,
|
||||
max_auth_attempts_window: Duration::from_secs(60),
|
||||
ban_duration: Duration::from_secs(300),
|
||||
whitelist: Vec::new(),
|
||||
blacklist: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ConnectionRecord {
|
||||
attempts: Vec<Instant>,
|
||||
auth_attempts: Vec<Instant>,
|
||||
banned_until: Option<Instant>,
|
||||
}
|
||||
|
||||
impl ConnectionRecord {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
attempts: Vec::new(),
|
||||
auth_attempts: Vec::new(),
|
||||
banned_until: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn add_connection(&mut self, now: Instant) {
|
||||
self.attempts.push(now);
|
||||
}
|
||||
|
||||
fn add_auth_attempt(&mut self, now: Instant) {
|
||||
self.auth_attempts.push(now);
|
||||
}
|
||||
|
||||
fn cleanup_old_attempts(&mut self, now: Instant, window: Duration) {
|
||||
self.attempts.retain(|t| now.duration_since(*t) < window);
|
||||
self.auth_attempts.retain(|t| now.duration_since(*t) < window);
|
||||
}
|
||||
|
||||
fn connection_count(&self) -> usize {
|
||||
self.attempts.len()
|
||||
}
|
||||
|
||||
fn auth_attempt_count(&self) -> usize {
|
||||
self.auth_attempts.len()
|
||||
}
|
||||
|
||||
fn ban(&mut self, until: Instant) {
|
||||
self.banned_until = Some(until);
|
||||
}
|
||||
|
||||
fn is_banned(&self, now: Instant) -> bool {
|
||||
if let Some(banned_until) = self.banned_until {
|
||||
now < banned_until
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn unban(&mut self) {
|
||||
self.banned_until = None;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct GlobalConnectionRecord {
|
||||
attempts: Vec<Instant>,
|
||||
}
|
||||
|
||||
impl GlobalConnectionRecord {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
attempts: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_connection(&mut self, now: Instant) {
|
||||
self.attempts.push(now);
|
||||
}
|
||||
|
||||
fn cleanup_old_attempts(&mut self, now: Instant, window: Duration) {
|
||||
self.attempts.retain(|t| now.duration_since(*t) < window);
|
||||
}
|
||||
|
||||
fn connection_count(&self) -> usize {
|
||||
self.attempts.len()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConnectionRateLimiter {
|
||||
config: RwLock<RateLimitConfig>,
|
||||
ip_records: RwLock<HashMap<IpAddr, ConnectionRecord>>,
|
||||
global_record: RwLock<GlobalConnectionRecord>,
|
||||
}
|
||||
|
||||
impl ConnectionRateLimiter {
|
||||
pub fn new(config: RateLimitConfig) -> Self {
|
||||
Self {
|
||||
config: RwLock::new(config),
|
||||
ip_records: RwLock::new(HashMap::new()),
|
||||
global_record: RwLock::new(GlobalConnectionRecord::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default() -> Self {
|
||||
Self::new(RateLimitConfig::default())
|
||||
}
|
||||
|
||||
pub fn with_config(config: RateLimitConfig) -> Self {
|
||||
Self::new(config)
|
||||
}
|
||||
|
||||
pub async fn check_connection_allowed(&self, ip: IpAddr) -> Result<(), RateLimitError> {
|
||||
let now = Instant::now();
|
||||
|
||||
let config = self.config.read().await;
|
||||
|
||||
if config.whitelist.contains(&ip) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if config.blacklist.contains(&ip) {
|
||||
return Err(RateLimitError::Blacklisted);
|
||||
}
|
||||
|
||||
let max_conn_per_ip = config.max_connections_per_ip;
|
||||
let max_conn_window = config.max_connections_per_ip_window;
|
||||
let max_global_conn = config.max_global_connections;
|
||||
let max_global_window = config.max_global_connections_window;
|
||||
|
||||
let mut ip_records = self.ip_records.write().await;
|
||||
let record = ip_records.entry(ip).or_insert_with(ConnectionRecord::new);
|
||||
|
||||
record.cleanup_old_attempts(now, max_conn_window);
|
||||
|
||||
if record.is_banned(now) {
|
||||
return Err(RateLimitError::Banned);
|
||||
}
|
||||
|
||||
if record.connection_count() >= max_conn_per_ip {
|
||||
return Err(RateLimitError::IpRateExceeded);
|
||||
}
|
||||
|
||||
let mut global_record = self.global_record.write().await;
|
||||
global_record.cleanup_old_attempts(now, max_global_window);
|
||||
|
||||
if global_record.connection_count() >= max_global_conn {
|
||||
return Err(RateLimitError::GlobalRateExceeded);
|
||||
}
|
||||
|
||||
record.add_connection(now);
|
||||
global_record.add_connection(now);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn check_auth_attempt_allowed(&self, ip: IpAddr) -> Result<(), RateLimitError> {
|
||||
let now = Instant::now();
|
||||
|
||||
let config = self.config.read().await;
|
||||
|
||||
if config.whitelist.contains(&ip) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let max_auth = config.max_auth_attempts_per_ip;
|
||||
let auth_window = config.max_auth_attempts_window;
|
||||
let ban_duration = config.ban_duration;
|
||||
|
||||
let mut ip_records = self.ip_records.write().await;
|
||||
let record = ip_records.entry(ip).or_insert_with(ConnectionRecord::new);
|
||||
|
||||
record.cleanup_old_attempts(now, auth_window);
|
||||
|
||||
if record.is_banned(now) {
|
||||
return Err(RateLimitError::Banned);
|
||||
}
|
||||
|
||||
if record.auth_attempt_count() >= max_auth {
|
||||
let ban_until = now + ban_duration;
|
||||
record.ban(ban_until);
|
||||
return Err(RateLimitError::AuthRateExceeded);
|
||||
}
|
||||
|
||||
record.add_auth_attempt(now);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn ban_ip(&self, ip: IpAddr, duration: Duration) {
|
||||
let now = Instant::now();
|
||||
let ban_until = now + duration;
|
||||
|
||||
let mut ip_records = self.ip_records.write().await;
|
||||
if let Some(record) = ip_records.get_mut(&ip) {
|
||||
record.ban(ban_until);
|
||||
} else {
|
||||
let mut record = ConnectionRecord::new();
|
||||
record.ban(ban_until);
|
||||
ip_records.insert(ip, record);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn unban_ip(&self, ip: IpAddr) {
|
||||
let mut ip_records = self.ip_records.write().await;
|
||||
if let Some(record) = ip_records.get_mut(&ip) {
|
||||
record.unban();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_to_blacklist(&self, ip: IpAddr) {
|
||||
let mut config = self.config.write().await;
|
||||
if !config.blacklist.contains(&ip) {
|
||||
config.blacklist.push(ip);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_from_blacklist(&self, ip: IpAddr) {
|
||||
let mut config = self.config.write().await;
|
||||
config.blacklist.retain(|&x| x != ip);
|
||||
}
|
||||
|
||||
pub async fn add_to_whitelist(&self, ip: IpAddr) {
|
||||
let mut config = self.config.write().await;
|
||||
if !config.whitelist.contains(&ip) {
|
||||
config.whitelist.push(ip);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_from_whitelist(&self, ip: IpAddr) {
|
||||
let mut config = self.config.write().await;
|
||||
config.whitelist.retain(|&x| x != ip);
|
||||
}
|
||||
|
||||
pub async fn get_stats(&self) -> RateLimitStats {
|
||||
let ip_records = self.ip_records.read().await;
|
||||
let config = self.config.read().await;
|
||||
|
||||
let banned_ips = ip_records
|
||||
.values()
|
||||
.filter(|r| r.banned_until.is_some())
|
||||
.count();
|
||||
|
||||
let active_ips = ip_records.len();
|
||||
|
||||
let total_connections = ip_records.values().map(|r| r.connection_count()).sum();
|
||||
|
||||
let total_auth_attempts = ip_records.values().map(|r| r.auth_attempt_count()).sum();
|
||||
|
||||
RateLimitStats {
|
||||
active_ips,
|
||||
banned_ips,
|
||||
total_connections,
|
||||
total_auth_attempts,
|
||||
blacklist_size: config.blacklist.len(),
|
||||
whitelist_size: config.whitelist.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RateLimitStats {
|
||||
pub active_ips: usize,
|
||||
pub banned_ips: usize,
|
||||
pub total_connections: usize,
|
||||
pub total_auth_attempts: usize,
|
||||
pub blacklist_size: usize,
|
||||
pub whitelist_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RateLimitError {
|
||||
IpRateExceeded,
|
||||
GlobalRateExceeded,
|
||||
AuthRateExceeded,
|
||||
Banned,
|
||||
Blacklisted,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RateLimitError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
RateLimitError::IpRateExceeded => write!(f, "IP connection rate exceeded"),
|
||||
RateLimitError::GlobalRateExceeded => write!(f, "Global connection rate exceeded"),
|
||||
RateLimitError::AuthRateExceeded => write!(f, "Authentication rate exceeded - IP banned"),
|
||||
RateLimitError::Banned => write!(f, "IP is banned"),
|
||||
RateLimitError::Blacklisted => write!(f, "IP is blacklisted"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for RateLimitError {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
|
||||
fn test_ip() -> IpAddr {
|
||||
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connection_allowed_initial() {
|
||||
let limiter = ConnectionRateLimiter::default();
|
||||
let ip = test_ip();
|
||||
|
||||
let result = limiter.check_connection_allowed(ip).await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connection_rate_exceeded() {
|
||||
let config = RateLimitConfig {
|
||||
max_connections_per_ip: 2,
|
||||
max_connections_per_ip_window: Duration::from_secs(60),
|
||||
..Default::default()
|
||||
};
|
||||
let limiter = ConnectionRateLimiter::new(config);
|
||||
let ip = test_ip();
|
||||
|
||||
limiter.check_connection_allowed(ip).await.unwrap();
|
||||
limiter.check_connection_allowed(ip).await.unwrap();
|
||||
|
||||
let result = limiter.check_connection_allowed(ip).await;
|
||||
assert!(matches!(result, Err(RateLimitError::IpRateExceeded)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auth_rate_exceeded() {
|
||||
let config = RateLimitConfig {
|
||||
max_auth_attempts_per_ip: 2,
|
||||
max_auth_attempts_window: Duration::from_secs(60),
|
||||
ban_duration: Duration::from_secs(10),
|
||||
..Default::default()
|
||||
};
|
||||
let limiter = ConnectionRateLimiter::new(config);
|
||||
let ip = test_ip();
|
||||
|
||||
limiter.check_auth_attempt_allowed(ip).await.unwrap();
|
||||
limiter.check_auth_attempt_allowed(ip).await.unwrap();
|
||||
|
||||
let result = limiter.check_auth_attempt_allowed(ip).await;
|
||||
assert!(matches!(result, Err(RateLimitError::AuthRateExceeded)));
|
||||
|
||||
let conn_result = limiter.check_connection_allowed(ip).await;
|
||||
assert!(matches!(conn_result, Err(RateLimitError::Banned)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_whitelist_bypass() {
|
||||
let config = RateLimitConfig {
|
||||
max_connections_per_ip: 1,
|
||||
max_connections_per_ip_window: Duration::from_secs(60),
|
||||
whitelist: vec![test_ip()],
|
||||
..Default::default()
|
||||
};
|
||||
let limiter = ConnectionRateLimiter::new(config);
|
||||
let ip = test_ip();
|
||||
|
||||
limiter.check_connection_allowed(ip).await.unwrap();
|
||||
limiter.check_connection_allowed(ip).await.unwrap();
|
||||
|
||||
let result = limiter.check_connection_allowed(ip).await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_blacklist_blocked() {
|
||||
let config = RateLimitConfig {
|
||||
blacklist: vec![test_ip()],
|
||||
..Default::default()
|
||||
};
|
||||
let limiter = ConnectionRateLimiter::new(config);
|
||||
let ip = test_ip();
|
||||
|
||||
let result = limiter.check_connection_allowed(ip).await;
|
||||
assert!(matches!(result, Err(RateLimitError::Blacklisted)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_global_rate_exceeded() {
|
||||
let config = RateLimitConfig {
|
||||
max_global_connections: 2,
|
||||
max_global_connections_window: Duration::from_secs(60),
|
||||
max_connections_per_ip: 100,
|
||||
..Default::default()
|
||||
};
|
||||
let limiter = ConnectionRateLimiter::new(config);
|
||||
|
||||
let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||
let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
|
||||
let ip3 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3));
|
||||
|
||||
limiter.check_connection_allowed(ip1).await.unwrap();
|
||||
limiter.check_connection_allowed(ip2).await.unwrap();
|
||||
|
||||
let result = limiter.check_connection_allowed(ip3).await;
|
||||
assert!(matches!(result, Err(RateLimitError::GlobalRateExceeded)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ban_unban() {
|
||||
let limiter = ConnectionRateLimiter::default();
|
||||
let ip = test_ip();
|
||||
|
||||
limiter.ban_ip(ip, Duration::from_secs(10)).await;
|
||||
|
||||
let result = limiter.check_connection_allowed(ip).await;
|
||||
assert!(matches!(result, Err(RateLimitError::Banned)));
|
||||
|
||||
limiter.unban_ip(ip).await;
|
||||
|
||||
let result = limiter.check_connection_allowed(ip).await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stats() {
|
||||
let limiter = ConnectionRateLimiter::default();
|
||||
|
||||
let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||
let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
|
||||
|
||||
limiter.check_connection_allowed(ip1).await.unwrap();
|
||||
limiter.check_connection_allowed(ip2).await.unwrap();
|
||||
limiter.check_auth_attempt_allowed(ip1).await.unwrap();
|
||||
|
||||
let stats = limiter.get_stats().await;
|
||||
assert_eq!(stats.active_ips, 2);
|
||||
assert_eq!(stats.total_connections, 2);
|
||||
assert_eq!(stats.total_auth_attempts, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rate_limit_window_expiry() {
|
||||
let config = RateLimitConfig {
|
||||
max_connections_per_ip: 2,
|
||||
max_connections_per_ip_window: Duration::from_millis(100),
|
||||
..Default::default()
|
||||
};
|
||||
let limiter = ConnectionRateLimiter::new(config);
|
||||
let ip = test_ip();
|
||||
|
||||
limiter.check_connection_allowed(ip).await.unwrap();
|
||||
limiter.check_connection_allowed(ip).await.unwrap();
|
||||
|
||||
let result = limiter.check_connection_allowed(ip).await;
|
||||
assert!(matches!(result, Err(RateLimitError::IpRateExceeded)));
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(150)).await;
|
||||
|
||||
let result = limiter.check_connection_allowed(ip).await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user