use std::collections::HashMap; use std::net::IpAddr; use std::time::{Duration, Instant}; use tokio::sync::RwLock; #[derive(Debug, Clone)] pub struct RateLimitConfig { pub max_connections_per_ip: usize, pub max_connections_per_ip_window: Duration, pub max_global_connections: usize, pub max_global_connections_window: Duration, pub max_auth_attempts_per_ip: usize, pub max_auth_attempts_window: Duration, pub ban_duration: Duration, pub whitelist: Vec, pub blacklist: Vec, } impl Default for RateLimitConfig { fn default() -> Self { Self { max_connections_per_ip: 10, max_connections_per_ip_window: Duration::from_secs(60), max_global_connections: 100, max_global_connections_window: Duration::from_secs(60), max_auth_attempts_per_ip: 5, max_auth_attempts_window: Duration::from_secs(60), ban_duration: Duration::from_secs(300), whitelist: Vec::new(), blacklist: Vec::new(), } } } #[derive(Debug)] struct ConnectionRecord { attempts: Vec, auth_attempts: Vec, banned_until: Option, } impl ConnectionRecord { fn new() -> Self { Self { attempts: Vec::new(), auth_attempts: Vec::new(), banned_until: None, } } fn add_connection(&mut self, now: Instant) { self.attempts.push(now); } fn add_auth_attempt(&mut self, now: Instant) { self.auth_attempts.push(now); } fn cleanup_old_attempts(&mut self, now: Instant, window: Duration) { self.attempts.retain(|t| now.duration_since(*t) < window); self.auth_attempts.retain(|t| now.duration_since(*t) < window); } fn connection_count(&self) -> usize { self.attempts.len() } fn auth_attempt_count(&self) -> usize { self.auth_attempts.len() } fn ban(&mut self, until: Instant) { self.banned_until = Some(until); } fn is_banned(&self, now: Instant) -> bool { if let Some(banned_until) = self.banned_until { now < banned_until } else { false } } fn unban(&mut self) { self.banned_until = None; } } #[derive(Debug)] struct GlobalConnectionRecord { attempts: Vec, } impl GlobalConnectionRecord { fn new() -> Self { Self { attempts: Vec::new(), } } fn add_connection(&mut self, now: Instant) { self.attempts.push(now); } fn cleanup_old_attempts(&mut self, now: Instant, window: Duration) { self.attempts.retain(|t| now.duration_since(*t) < window); } fn connection_count(&self) -> usize { self.attempts.len() } } pub struct ConnectionRateLimiter { config: RwLock, ip_records: RwLock>, global_record: RwLock, } impl ConnectionRateLimiter { pub fn new(config: RateLimitConfig) -> Self { Self { config: RwLock::new(config), ip_records: RwLock::new(HashMap::new()), global_record: RwLock::new(GlobalConnectionRecord::new()), } } pub fn default() -> Self { Self::new(RateLimitConfig::default()) } pub fn with_config(config: RateLimitConfig) -> Self { Self::new(config) } pub async fn check_connection_allowed(&self, ip: IpAddr) -> Result<(), RateLimitError> { let now = Instant::now(); let config = self.config.read().await; if config.whitelist.contains(&ip) { return Ok(()); } if config.blacklist.contains(&ip) { return Err(RateLimitError::Blacklisted); } let max_conn_per_ip = config.max_connections_per_ip; let max_conn_window = config.max_connections_per_ip_window; let max_global_conn = config.max_global_connections; let max_global_window = config.max_global_connections_window; let mut ip_records = self.ip_records.write().await; let record = ip_records.entry(ip).or_insert_with(ConnectionRecord::new); record.cleanup_old_attempts(now, max_conn_window); if record.is_banned(now) { return Err(RateLimitError::Banned); } if record.connection_count() >= max_conn_per_ip { return Err(RateLimitError::IpRateExceeded); } let mut global_record = self.global_record.write().await; global_record.cleanup_old_attempts(now, max_global_window); if global_record.connection_count() >= max_global_conn { return Err(RateLimitError::GlobalRateExceeded); } record.add_connection(now); global_record.add_connection(now); Ok(()) } pub async fn check_auth_attempt_allowed(&self, ip: IpAddr) -> Result<(), RateLimitError> { let now = Instant::now(); let config = self.config.read().await; if config.whitelist.contains(&ip) { return Ok(()); } let max_auth = config.max_auth_attempts_per_ip; let auth_window = config.max_auth_attempts_window; let ban_duration = config.ban_duration; let mut ip_records = self.ip_records.write().await; let record = ip_records.entry(ip).or_insert_with(ConnectionRecord::new); record.cleanup_old_attempts(now, auth_window); if record.is_banned(now) { return Err(RateLimitError::Banned); } if record.auth_attempt_count() >= max_auth { let ban_until = now + ban_duration; record.ban(ban_until); return Err(RateLimitError::AuthRateExceeded); } record.add_auth_attempt(now); Ok(()) } pub async fn ban_ip(&self, ip: IpAddr, duration: Duration) { let now = Instant::now(); let ban_until = now + duration; let mut ip_records = self.ip_records.write().await; if let Some(record) = ip_records.get_mut(&ip) { record.ban(ban_until); } else { let mut record = ConnectionRecord::new(); record.ban(ban_until); ip_records.insert(ip, record); } } pub async fn unban_ip(&self, ip: IpAddr) { let mut ip_records = self.ip_records.write().await; if let Some(record) = ip_records.get_mut(&ip) { record.unban(); } } pub async fn add_to_blacklist(&self, ip: IpAddr) { let mut config = self.config.write().await; if !config.blacklist.contains(&ip) { config.blacklist.push(ip); } } pub async fn remove_from_blacklist(&self, ip: IpAddr) { let mut config = self.config.write().await; config.blacklist.retain(|&x| x != ip); } pub async fn add_to_whitelist(&self, ip: IpAddr) { let mut config = self.config.write().await; if !config.whitelist.contains(&ip) { config.whitelist.push(ip); } } pub async fn remove_from_whitelist(&self, ip: IpAddr) { let mut config = self.config.write().await; config.whitelist.retain(|&x| x != ip); } pub async fn get_stats(&self) -> RateLimitStats { let ip_records = self.ip_records.read().await; let config = self.config.read().await; let banned_ips = ip_records .values() .filter(|r| r.banned_until.is_some()) .count(); let active_ips = ip_records.len(); let total_connections = ip_records.values().map(|r| r.connection_count()).sum(); let total_auth_attempts = ip_records.values().map(|r| r.auth_attempt_count()).sum(); RateLimitStats { active_ips, banned_ips, total_connections, total_auth_attempts, blacklist_size: config.blacklist.len(), whitelist_size: config.whitelist.len(), } } } #[derive(Debug, Clone)] pub struct RateLimitStats { pub active_ips: usize, pub banned_ips: usize, pub total_connections: usize, pub total_auth_attempts: usize, pub blacklist_size: usize, pub whitelist_size: usize, } #[derive(Debug, Clone)] pub enum RateLimitError { IpRateExceeded, GlobalRateExceeded, AuthRateExceeded, Banned, Blacklisted, } impl std::fmt::Display for RateLimitError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { RateLimitError::IpRateExceeded => write!(f, "IP connection rate exceeded"), RateLimitError::GlobalRateExceeded => write!(f, "Global connection rate exceeded"), RateLimitError::AuthRateExceeded => write!(f, "Authentication rate exceeded - IP banned"), RateLimitError::Banned => write!(f, "IP is banned"), RateLimitError::Blacklisted => write!(f, "IP is blacklisted"), } } } impl std::error::Error for RateLimitError {} #[cfg(test)] mod tests { use super::*; use std::net::{IpAddr, Ipv4Addr}; fn test_ip() -> IpAddr { IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) } #[tokio::test] async fn test_connection_allowed_initial() { let limiter = ConnectionRateLimiter::default(); let ip = test_ip(); let result = limiter.check_connection_allowed(ip).await; assert!(result.is_ok()); } #[tokio::test] async fn test_connection_rate_exceeded() { let config = RateLimitConfig { max_connections_per_ip: 2, max_connections_per_ip_window: Duration::from_secs(60), ..Default::default() }; let limiter = ConnectionRateLimiter::new(config); let ip = test_ip(); limiter.check_connection_allowed(ip).await.unwrap(); limiter.check_connection_allowed(ip).await.unwrap(); let result = limiter.check_connection_allowed(ip).await; assert!(matches!(result, Err(RateLimitError::IpRateExceeded))); } #[tokio::test] async fn test_auth_rate_exceeded() { let config = RateLimitConfig { max_auth_attempts_per_ip: 2, max_auth_attempts_window: Duration::from_secs(60), ban_duration: Duration::from_secs(10), ..Default::default() }; let limiter = ConnectionRateLimiter::new(config); let ip = test_ip(); limiter.check_auth_attempt_allowed(ip).await.unwrap(); limiter.check_auth_attempt_allowed(ip).await.unwrap(); let result = limiter.check_auth_attempt_allowed(ip).await; assert!(matches!(result, Err(RateLimitError::AuthRateExceeded))); let conn_result = limiter.check_connection_allowed(ip).await; assert!(matches!(conn_result, Err(RateLimitError::Banned))); } #[tokio::test] async fn test_whitelist_bypass() { let config = RateLimitConfig { max_connections_per_ip: 1, max_connections_per_ip_window: Duration::from_secs(60), whitelist: vec![test_ip()], ..Default::default() }; let limiter = ConnectionRateLimiter::new(config); let ip = test_ip(); limiter.check_connection_allowed(ip).await.unwrap(); limiter.check_connection_allowed(ip).await.unwrap(); let result = limiter.check_connection_allowed(ip).await; assert!(result.is_ok()); } #[tokio::test] async fn test_blacklist_blocked() { let config = RateLimitConfig { blacklist: vec![test_ip()], ..Default::default() }; let limiter = ConnectionRateLimiter::new(config); let ip = test_ip(); let result = limiter.check_connection_allowed(ip).await; assert!(matches!(result, Err(RateLimitError::Blacklisted))); } #[tokio::test] async fn test_global_rate_exceeded() { let config = RateLimitConfig { max_global_connections: 2, max_global_connections_window: Duration::from_secs(60), max_connections_per_ip: 100, ..Default::default() }; let limiter = ConnectionRateLimiter::new(config); let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)); let ip3 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3)); limiter.check_connection_allowed(ip1).await.unwrap(); limiter.check_connection_allowed(ip2).await.unwrap(); let result = limiter.check_connection_allowed(ip3).await; assert!(matches!(result, Err(RateLimitError::GlobalRateExceeded))); } #[tokio::test] async fn test_ban_unban() { let limiter = ConnectionRateLimiter::default(); let ip = test_ip(); limiter.ban_ip(ip, Duration::from_secs(10)).await; let result = limiter.check_connection_allowed(ip).await; assert!(matches!(result, Err(RateLimitError::Banned))); limiter.unban_ip(ip).await; let result = limiter.check_connection_allowed(ip).await; assert!(result.is_ok()); } #[tokio::test] async fn test_stats() { let limiter = ConnectionRateLimiter::default(); let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)); limiter.check_connection_allowed(ip1).await.unwrap(); limiter.check_connection_allowed(ip2).await.unwrap(); limiter.check_auth_attempt_allowed(ip1).await.unwrap(); let stats = limiter.get_stats().await; assert_eq!(stats.active_ips, 2); assert_eq!(stats.total_connections, 2); assert_eq!(stats.total_auth_attempts, 1); } #[tokio::test] async fn test_rate_limit_window_expiry() { let config = RateLimitConfig { max_connections_per_ip: 2, max_connections_per_ip_window: Duration::from_millis(100), ..Default::default() }; let limiter = ConnectionRateLimiter::new(config); let ip = test_ip(); limiter.check_connection_allowed(ip).await.unwrap(); limiter.check_connection_allowed(ip).await.unwrap(); let result = limiter.check_connection_allowed(ip).await; assert!(matches!(result, Err(RateLimitError::IpRateExceeded))); tokio::time::sleep(Duration::from_millis(150)).await; let result = limiter.check_connection_allowed(ip).await; assert!(result.is_ok()); } }