diff --git a/markbase-core/src/ssh_server/mod.rs b/markbase-core/src/ssh_server/mod.rs index 14ef7ec..8475ad0 100644 --- a/markbase-core/src/ssh_server/mod.rs +++ b/markbase-core/src/ssh_server/mod.rs @@ -14,6 +14,7 @@ pub mod kex_exchange; pub mod packet; pub mod port_forward; pub mod port_forward_listener; +pub mod rate_limiter; pub mod rsync_handler; pub mod scp_handler; pub mod server; diff --git a/markbase-core/src/ssh_server/rate_limiter.rs b/markbase-core/src/ssh_server/rate_limiter.rs new file mode 100644 index 0000000..8611fe0 --- /dev/null +++ b/markbase-core/src/ssh_server/rate_limiter.rs @@ -0,0 +1,481 @@ +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; + +#[derive(Debug, Clone)] +pub struct RateLimitConfig { + pub max_connections_per_ip: usize, + pub max_connections_per_ip_window: Duration, + pub max_global_connections: usize, + pub max_global_connections_window: Duration, + pub max_auth_attempts_per_ip: usize, + pub max_auth_attempts_window: Duration, + pub ban_duration: Duration, + pub whitelist: Vec, + pub blacklist: Vec, +} + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + max_connections_per_ip: 10, + max_connections_per_ip_window: Duration::from_secs(60), + max_global_connections: 100, + max_global_connections_window: Duration::from_secs(60), + max_auth_attempts_per_ip: 5, + max_auth_attempts_window: Duration::from_secs(60), + ban_duration: Duration::from_secs(300), + whitelist: Vec::new(), + blacklist: Vec::new(), + } + } +} + +#[derive(Debug)] +struct ConnectionRecord { + attempts: Vec, + auth_attempts: Vec, + banned_until: Option, +} + +impl ConnectionRecord { + fn new() -> Self { + Self { + attempts: Vec::new(), + auth_attempts: Vec::new(), + banned_until: None, + } + } + + fn add_connection(&mut self, now: Instant) { + self.attempts.push(now); + } + + fn add_auth_attempt(&mut self, now: Instant) { + self.auth_attempts.push(now); + } + + fn cleanup_old_attempts(&mut self, now: Instant, window: Duration) { + self.attempts.retain(|t| now.duration_since(*t) < window); + self.auth_attempts.retain(|t| now.duration_since(*t) < window); + } + + fn connection_count(&self) -> usize { + self.attempts.len() + } + + fn auth_attempt_count(&self) -> usize { + self.auth_attempts.len() + } + + fn ban(&mut self, until: Instant) { + self.banned_until = Some(until); + } + + fn is_banned(&self, now: Instant) -> bool { + if let Some(banned_until) = self.banned_until { + now < banned_until + } else { + false + } + } + + fn unban(&mut self) { + self.banned_until = None; + } +} + +#[derive(Debug)] +struct GlobalConnectionRecord { + attempts: Vec, +} + +impl GlobalConnectionRecord { + fn new() -> Self { + Self { + attempts: Vec::new(), + } + } + + fn add_connection(&mut self, now: Instant) { + self.attempts.push(now); + } + + fn cleanup_old_attempts(&mut self, now: Instant, window: Duration) { + self.attempts.retain(|t| now.duration_since(*t) < window); + } + + fn connection_count(&self) -> usize { + self.attempts.len() + } +} + +pub struct ConnectionRateLimiter { + config: RwLock, + ip_records: RwLock>, + global_record: RwLock, +} + +impl ConnectionRateLimiter { + pub fn new(config: RateLimitConfig) -> Self { + Self { + config: RwLock::new(config), + ip_records: RwLock::new(HashMap::new()), + global_record: RwLock::new(GlobalConnectionRecord::new()), + } + } + + pub fn default() -> Self { + Self::new(RateLimitConfig::default()) + } + + pub fn with_config(config: RateLimitConfig) -> Self { + Self::new(config) + } + + pub async fn check_connection_allowed(&self, ip: IpAddr) -> Result<(), RateLimitError> { + let now = Instant::now(); + + let config = self.config.read().await; + + if config.whitelist.contains(&ip) { + return Ok(()); + } + + if config.blacklist.contains(&ip) { + return Err(RateLimitError::Blacklisted); + } + + let max_conn_per_ip = config.max_connections_per_ip; + let max_conn_window = config.max_connections_per_ip_window; + let max_global_conn = config.max_global_connections; + let max_global_window = config.max_global_connections_window; + + let mut ip_records = self.ip_records.write().await; + let record = ip_records.entry(ip).or_insert_with(ConnectionRecord::new); + + record.cleanup_old_attempts(now, max_conn_window); + + if record.is_banned(now) { + return Err(RateLimitError::Banned); + } + + if record.connection_count() >= max_conn_per_ip { + return Err(RateLimitError::IpRateExceeded); + } + + let mut global_record = self.global_record.write().await; + global_record.cleanup_old_attempts(now, max_global_window); + + if global_record.connection_count() >= max_global_conn { + return Err(RateLimitError::GlobalRateExceeded); + } + + record.add_connection(now); + global_record.add_connection(now); + + Ok(()) + } + + pub async fn check_auth_attempt_allowed(&self, ip: IpAddr) -> Result<(), RateLimitError> { + let now = Instant::now(); + + let config = self.config.read().await; + + if config.whitelist.contains(&ip) { + return Ok(()); + } + + let max_auth = config.max_auth_attempts_per_ip; + let auth_window = config.max_auth_attempts_window; + let ban_duration = config.ban_duration; + + let mut ip_records = self.ip_records.write().await; + let record = ip_records.entry(ip).or_insert_with(ConnectionRecord::new); + + record.cleanup_old_attempts(now, auth_window); + + if record.is_banned(now) { + return Err(RateLimitError::Banned); + } + + if record.auth_attempt_count() >= max_auth { + let ban_until = now + ban_duration; + record.ban(ban_until); + return Err(RateLimitError::AuthRateExceeded); + } + + record.add_auth_attempt(now); + + Ok(()) + } + + pub async fn ban_ip(&self, ip: IpAddr, duration: Duration) { + let now = Instant::now(); + let ban_until = now + duration; + + let mut ip_records = self.ip_records.write().await; + if let Some(record) = ip_records.get_mut(&ip) { + record.ban(ban_until); + } else { + let mut record = ConnectionRecord::new(); + record.ban(ban_until); + ip_records.insert(ip, record); + } + } + + pub async fn unban_ip(&self, ip: IpAddr) { + let mut ip_records = self.ip_records.write().await; + if let Some(record) = ip_records.get_mut(&ip) { + record.unban(); + } + } + + pub async fn add_to_blacklist(&self, ip: IpAddr) { + let mut config = self.config.write().await; + if !config.blacklist.contains(&ip) { + config.blacklist.push(ip); + } + } + + pub async fn remove_from_blacklist(&self, ip: IpAddr) { + let mut config = self.config.write().await; + config.blacklist.retain(|&x| x != ip); + } + + pub async fn add_to_whitelist(&self, ip: IpAddr) { + let mut config = self.config.write().await; + if !config.whitelist.contains(&ip) { + config.whitelist.push(ip); + } + } + + pub async fn remove_from_whitelist(&self, ip: IpAddr) { + let mut config = self.config.write().await; + config.whitelist.retain(|&x| x != ip); + } + + pub async fn get_stats(&self) -> RateLimitStats { + let ip_records = self.ip_records.read().await; + let config = self.config.read().await; + + let banned_ips = ip_records + .values() + .filter(|r| r.banned_until.is_some()) + .count(); + + let active_ips = ip_records.len(); + + let total_connections = ip_records.values().map(|r| r.connection_count()).sum(); + + let total_auth_attempts = ip_records.values().map(|r| r.auth_attempt_count()).sum(); + + RateLimitStats { + active_ips, + banned_ips, + total_connections, + total_auth_attempts, + blacklist_size: config.blacklist.len(), + whitelist_size: config.whitelist.len(), + } + } +} + +#[derive(Debug, Clone)] +pub struct RateLimitStats { + pub active_ips: usize, + pub banned_ips: usize, + pub total_connections: usize, + pub total_auth_attempts: usize, + pub blacklist_size: usize, + pub whitelist_size: usize, +} + +#[derive(Debug, Clone)] +pub enum RateLimitError { + IpRateExceeded, + GlobalRateExceeded, + AuthRateExceeded, + Banned, + Blacklisted, +} + +impl std::fmt::Display for RateLimitError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RateLimitError::IpRateExceeded => write!(f, "IP connection rate exceeded"), + RateLimitError::GlobalRateExceeded => write!(f, "Global connection rate exceeded"), + RateLimitError::AuthRateExceeded => write!(f, "Authentication rate exceeded - IP banned"), + RateLimitError::Banned => write!(f, "IP is banned"), + RateLimitError::Blacklisted => write!(f, "IP is blacklisted"), + } + } +} + +impl std::error::Error for RateLimitError {} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, Ipv4Addr}; + + fn test_ip() -> IpAddr { + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) + } + + #[tokio::test] + async fn test_connection_allowed_initial() { + let limiter = ConnectionRateLimiter::default(); + let ip = test_ip(); + + let result = limiter.check_connection_allowed(ip).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_connection_rate_exceeded() { + let config = RateLimitConfig { + max_connections_per_ip: 2, + max_connections_per_ip_window: Duration::from_secs(60), + ..Default::default() + }; + let limiter = ConnectionRateLimiter::new(config); + let ip = test_ip(); + + limiter.check_connection_allowed(ip).await.unwrap(); + limiter.check_connection_allowed(ip).await.unwrap(); + + let result = limiter.check_connection_allowed(ip).await; + assert!(matches!(result, Err(RateLimitError::IpRateExceeded))); + } + + #[tokio::test] + async fn test_auth_rate_exceeded() { + let config = RateLimitConfig { + max_auth_attempts_per_ip: 2, + max_auth_attempts_window: Duration::from_secs(60), + ban_duration: Duration::from_secs(10), + ..Default::default() + }; + let limiter = ConnectionRateLimiter::new(config); + let ip = test_ip(); + + limiter.check_auth_attempt_allowed(ip).await.unwrap(); + limiter.check_auth_attempt_allowed(ip).await.unwrap(); + + let result = limiter.check_auth_attempt_allowed(ip).await; + assert!(matches!(result, Err(RateLimitError::AuthRateExceeded))); + + let conn_result = limiter.check_connection_allowed(ip).await; + assert!(matches!(conn_result, Err(RateLimitError::Banned))); + } + + #[tokio::test] + async fn test_whitelist_bypass() { + let config = RateLimitConfig { + max_connections_per_ip: 1, + max_connections_per_ip_window: Duration::from_secs(60), + whitelist: vec![test_ip()], + ..Default::default() + }; + let limiter = ConnectionRateLimiter::new(config); + let ip = test_ip(); + + limiter.check_connection_allowed(ip).await.unwrap(); + limiter.check_connection_allowed(ip).await.unwrap(); + + let result = limiter.check_connection_allowed(ip).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_blacklist_blocked() { + let config = RateLimitConfig { + blacklist: vec![test_ip()], + ..Default::default() + }; + let limiter = ConnectionRateLimiter::new(config); + let ip = test_ip(); + + let result = limiter.check_connection_allowed(ip).await; + assert!(matches!(result, Err(RateLimitError::Blacklisted))); + } + + #[tokio::test] + async fn test_global_rate_exceeded() { + let config = RateLimitConfig { + max_global_connections: 2, + max_global_connections_window: Duration::from_secs(60), + max_connections_per_ip: 100, + ..Default::default() + }; + let limiter = ConnectionRateLimiter::new(config); + + let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); + let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)); + let ip3 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3)); + + limiter.check_connection_allowed(ip1).await.unwrap(); + limiter.check_connection_allowed(ip2).await.unwrap(); + + let result = limiter.check_connection_allowed(ip3).await; + assert!(matches!(result, Err(RateLimitError::GlobalRateExceeded))); + } + + #[tokio::test] + async fn test_ban_unban() { + let limiter = ConnectionRateLimiter::default(); + let ip = test_ip(); + + limiter.ban_ip(ip, Duration::from_secs(10)).await; + + let result = limiter.check_connection_allowed(ip).await; + assert!(matches!(result, Err(RateLimitError::Banned))); + + limiter.unban_ip(ip).await; + + let result = limiter.check_connection_allowed(ip).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_stats() { + let limiter = ConnectionRateLimiter::default(); + + let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); + let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)); + + limiter.check_connection_allowed(ip1).await.unwrap(); + limiter.check_connection_allowed(ip2).await.unwrap(); + limiter.check_auth_attempt_allowed(ip1).await.unwrap(); + + let stats = limiter.get_stats().await; + assert_eq!(stats.active_ips, 2); + assert_eq!(stats.total_connections, 2); + assert_eq!(stats.total_auth_attempts, 1); + } + + #[tokio::test] + async fn test_rate_limit_window_expiry() { + let config = RateLimitConfig { + max_connections_per_ip: 2, + max_connections_per_ip_window: Duration::from_millis(100), + ..Default::default() + }; + let limiter = ConnectionRateLimiter::new(config); + let ip = test_ip(); + + limiter.check_connection_allowed(ip).await.unwrap(); + limiter.check_connection_allowed(ip).await.unwrap(); + + let result = limiter.check_connection_allowed(ip).await; + assert!(matches!(result, Err(RateLimitError::IpRateExceeded))); + + tokio::time::sleep(Duration::from_millis(150)).await; + + let result = limiter.check_connection_allowed(ip).await; + assert!(result.is_ok()); + } +} \ No newline at end of file