Implement SSH Connection Rate Limiting: IP rate limit + global rate limit + auth brute force prevention

This commit is contained in:
Warren
2026-06-21 05:01:04 +08:00
parent 56e73ad8a4
commit b014390d12
2 changed files with 482 additions and 0 deletions

View File

@@ -14,6 +14,7 @@ pub mod kex_exchange;
pub mod packet; pub mod packet;
pub mod port_forward; pub mod port_forward;
pub mod port_forward_listener; pub mod port_forward_listener;
pub mod rate_limiter;
pub mod rsync_handler; pub mod rsync_handler;
pub mod scp_handler; pub mod scp_handler;
pub mod server; pub mod server;

View File

@@ -0,0 +1,481 @@
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_connections_per_ip: usize,
pub max_connections_per_ip_window: Duration,
pub max_global_connections: usize,
pub max_global_connections_window: Duration,
pub max_auth_attempts_per_ip: usize,
pub max_auth_attempts_window: Duration,
pub ban_duration: Duration,
pub whitelist: Vec<IpAddr>,
pub blacklist: Vec<IpAddr>,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_connections_per_ip: 10,
max_connections_per_ip_window: Duration::from_secs(60),
max_global_connections: 100,
max_global_connections_window: Duration::from_secs(60),
max_auth_attempts_per_ip: 5,
max_auth_attempts_window: Duration::from_secs(60),
ban_duration: Duration::from_secs(300),
whitelist: Vec::new(),
blacklist: Vec::new(),
}
}
}
#[derive(Debug)]
struct ConnectionRecord {
attempts: Vec<Instant>,
auth_attempts: Vec<Instant>,
banned_until: Option<Instant>,
}
impl ConnectionRecord {
fn new() -> Self {
Self {
attempts: Vec::new(),
auth_attempts: Vec::new(),
banned_until: None,
}
}
fn add_connection(&mut self, now: Instant) {
self.attempts.push(now);
}
fn add_auth_attempt(&mut self, now: Instant) {
self.auth_attempts.push(now);
}
fn cleanup_old_attempts(&mut self, now: Instant, window: Duration) {
self.attempts.retain(|t| now.duration_since(*t) < window);
self.auth_attempts.retain(|t| now.duration_since(*t) < window);
}
fn connection_count(&self) -> usize {
self.attempts.len()
}
fn auth_attempt_count(&self) -> usize {
self.auth_attempts.len()
}
fn ban(&mut self, until: Instant) {
self.banned_until = Some(until);
}
fn is_banned(&self, now: Instant) -> bool {
if let Some(banned_until) = self.banned_until {
now < banned_until
} else {
false
}
}
fn unban(&mut self) {
self.banned_until = None;
}
}
#[derive(Debug)]
struct GlobalConnectionRecord {
attempts: Vec<Instant>,
}
impl GlobalConnectionRecord {
fn new() -> Self {
Self {
attempts: Vec::new(),
}
}
fn add_connection(&mut self, now: Instant) {
self.attempts.push(now);
}
fn cleanup_old_attempts(&mut self, now: Instant, window: Duration) {
self.attempts.retain(|t| now.duration_since(*t) < window);
}
fn connection_count(&self) -> usize {
self.attempts.len()
}
}
pub struct ConnectionRateLimiter {
config: RwLock<RateLimitConfig>,
ip_records: RwLock<HashMap<IpAddr, ConnectionRecord>>,
global_record: RwLock<GlobalConnectionRecord>,
}
impl ConnectionRateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config: RwLock::new(config),
ip_records: RwLock::new(HashMap::new()),
global_record: RwLock::new(GlobalConnectionRecord::new()),
}
}
pub fn default() -> Self {
Self::new(RateLimitConfig::default())
}
pub fn with_config(config: RateLimitConfig) -> Self {
Self::new(config)
}
pub async fn check_connection_allowed(&self, ip: IpAddr) -> Result<(), RateLimitError> {
let now = Instant::now();
let config = self.config.read().await;
if config.whitelist.contains(&ip) {
return Ok(());
}
if config.blacklist.contains(&ip) {
return Err(RateLimitError::Blacklisted);
}
let max_conn_per_ip = config.max_connections_per_ip;
let max_conn_window = config.max_connections_per_ip_window;
let max_global_conn = config.max_global_connections;
let max_global_window = config.max_global_connections_window;
let mut ip_records = self.ip_records.write().await;
let record = ip_records.entry(ip).or_insert_with(ConnectionRecord::new);
record.cleanup_old_attempts(now, max_conn_window);
if record.is_banned(now) {
return Err(RateLimitError::Banned);
}
if record.connection_count() >= max_conn_per_ip {
return Err(RateLimitError::IpRateExceeded);
}
let mut global_record = self.global_record.write().await;
global_record.cleanup_old_attempts(now, max_global_window);
if global_record.connection_count() >= max_global_conn {
return Err(RateLimitError::GlobalRateExceeded);
}
record.add_connection(now);
global_record.add_connection(now);
Ok(())
}
pub async fn check_auth_attempt_allowed(&self, ip: IpAddr) -> Result<(), RateLimitError> {
let now = Instant::now();
let config = self.config.read().await;
if config.whitelist.contains(&ip) {
return Ok(());
}
let max_auth = config.max_auth_attempts_per_ip;
let auth_window = config.max_auth_attempts_window;
let ban_duration = config.ban_duration;
let mut ip_records = self.ip_records.write().await;
let record = ip_records.entry(ip).or_insert_with(ConnectionRecord::new);
record.cleanup_old_attempts(now, auth_window);
if record.is_banned(now) {
return Err(RateLimitError::Banned);
}
if record.auth_attempt_count() >= max_auth {
let ban_until = now + ban_duration;
record.ban(ban_until);
return Err(RateLimitError::AuthRateExceeded);
}
record.add_auth_attempt(now);
Ok(())
}
pub async fn ban_ip(&self, ip: IpAddr, duration: Duration) {
let now = Instant::now();
let ban_until = now + duration;
let mut ip_records = self.ip_records.write().await;
if let Some(record) = ip_records.get_mut(&ip) {
record.ban(ban_until);
} else {
let mut record = ConnectionRecord::new();
record.ban(ban_until);
ip_records.insert(ip, record);
}
}
pub async fn unban_ip(&self, ip: IpAddr) {
let mut ip_records = self.ip_records.write().await;
if let Some(record) = ip_records.get_mut(&ip) {
record.unban();
}
}
pub async fn add_to_blacklist(&self, ip: IpAddr) {
let mut config = self.config.write().await;
if !config.blacklist.contains(&ip) {
config.blacklist.push(ip);
}
}
pub async fn remove_from_blacklist(&self, ip: IpAddr) {
let mut config = self.config.write().await;
config.blacklist.retain(|&x| x != ip);
}
pub async fn add_to_whitelist(&self, ip: IpAddr) {
let mut config = self.config.write().await;
if !config.whitelist.contains(&ip) {
config.whitelist.push(ip);
}
}
pub async fn remove_from_whitelist(&self, ip: IpAddr) {
let mut config = self.config.write().await;
config.whitelist.retain(|&x| x != ip);
}
pub async fn get_stats(&self) -> RateLimitStats {
let ip_records = self.ip_records.read().await;
let config = self.config.read().await;
let banned_ips = ip_records
.values()
.filter(|r| r.banned_until.is_some())
.count();
let active_ips = ip_records.len();
let total_connections = ip_records.values().map(|r| r.connection_count()).sum();
let total_auth_attempts = ip_records.values().map(|r| r.auth_attempt_count()).sum();
RateLimitStats {
active_ips,
banned_ips,
total_connections,
total_auth_attempts,
blacklist_size: config.blacklist.len(),
whitelist_size: config.whitelist.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitStats {
pub active_ips: usize,
pub banned_ips: usize,
pub total_connections: usize,
pub total_auth_attempts: usize,
pub blacklist_size: usize,
pub whitelist_size: usize,
}
#[derive(Debug, Clone)]
pub enum RateLimitError {
IpRateExceeded,
GlobalRateExceeded,
AuthRateExceeded,
Banned,
Blacklisted,
}
impl std::fmt::Display for RateLimitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RateLimitError::IpRateExceeded => write!(f, "IP connection rate exceeded"),
RateLimitError::GlobalRateExceeded => write!(f, "Global connection rate exceeded"),
RateLimitError::AuthRateExceeded => write!(f, "Authentication rate exceeded - IP banned"),
RateLimitError::Banned => write!(f, "IP is banned"),
RateLimitError::Blacklisted => write!(f, "IP is blacklisted"),
}
}
}
impl std::error::Error for RateLimitError {}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn test_ip() -> IpAddr {
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
}
#[tokio::test]
async fn test_connection_allowed_initial() {
let limiter = ConnectionRateLimiter::default();
let ip = test_ip();
let result = limiter.check_connection_allowed(ip).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_connection_rate_exceeded() {
let config = RateLimitConfig {
max_connections_per_ip: 2,
max_connections_per_ip_window: Duration::from_secs(60),
..Default::default()
};
let limiter = ConnectionRateLimiter::new(config);
let ip = test_ip();
limiter.check_connection_allowed(ip).await.unwrap();
limiter.check_connection_allowed(ip).await.unwrap();
let result = limiter.check_connection_allowed(ip).await;
assert!(matches!(result, Err(RateLimitError::IpRateExceeded)));
}
#[tokio::test]
async fn test_auth_rate_exceeded() {
let config = RateLimitConfig {
max_auth_attempts_per_ip: 2,
max_auth_attempts_window: Duration::from_secs(60),
ban_duration: Duration::from_secs(10),
..Default::default()
};
let limiter = ConnectionRateLimiter::new(config);
let ip = test_ip();
limiter.check_auth_attempt_allowed(ip).await.unwrap();
limiter.check_auth_attempt_allowed(ip).await.unwrap();
let result = limiter.check_auth_attempt_allowed(ip).await;
assert!(matches!(result, Err(RateLimitError::AuthRateExceeded)));
let conn_result = limiter.check_connection_allowed(ip).await;
assert!(matches!(conn_result, Err(RateLimitError::Banned)));
}
#[tokio::test]
async fn test_whitelist_bypass() {
let config = RateLimitConfig {
max_connections_per_ip: 1,
max_connections_per_ip_window: Duration::from_secs(60),
whitelist: vec![test_ip()],
..Default::default()
};
let limiter = ConnectionRateLimiter::new(config);
let ip = test_ip();
limiter.check_connection_allowed(ip).await.unwrap();
limiter.check_connection_allowed(ip).await.unwrap();
let result = limiter.check_connection_allowed(ip).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_blacklist_blocked() {
let config = RateLimitConfig {
blacklist: vec![test_ip()],
..Default::default()
};
let limiter = ConnectionRateLimiter::new(config);
let ip = test_ip();
let result = limiter.check_connection_allowed(ip).await;
assert!(matches!(result, Err(RateLimitError::Blacklisted)));
}
#[tokio::test]
async fn test_global_rate_exceeded() {
let config = RateLimitConfig {
max_global_connections: 2,
max_global_connections_window: Duration::from_secs(60),
max_connections_per_ip: 100,
..Default::default()
};
let limiter = ConnectionRateLimiter::new(config);
let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
let ip3 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3));
limiter.check_connection_allowed(ip1).await.unwrap();
limiter.check_connection_allowed(ip2).await.unwrap();
let result = limiter.check_connection_allowed(ip3).await;
assert!(matches!(result, Err(RateLimitError::GlobalRateExceeded)));
}
#[tokio::test]
async fn test_ban_unban() {
let limiter = ConnectionRateLimiter::default();
let ip = test_ip();
limiter.ban_ip(ip, Duration::from_secs(10)).await;
let result = limiter.check_connection_allowed(ip).await;
assert!(matches!(result, Err(RateLimitError::Banned)));
limiter.unban_ip(ip).await;
let result = limiter.check_connection_allowed(ip).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_stats() {
let limiter = ConnectionRateLimiter::default();
let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
limiter.check_connection_allowed(ip1).await.unwrap();
limiter.check_connection_allowed(ip2).await.unwrap();
limiter.check_auth_attempt_allowed(ip1).await.unwrap();
let stats = limiter.get_stats().await;
assert_eq!(stats.active_ips, 2);
assert_eq!(stats.total_connections, 2);
assert_eq!(stats.total_auth_attempts, 1);
}
#[tokio::test]
async fn test_rate_limit_window_expiry() {
let config = RateLimitConfig {
max_connections_per_ip: 2,
max_connections_per_ip_window: Duration::from_millis(100),
..Default::default()
};
let limiter = ConnectionRateLimiter::new(config);
let ip = test_ip();
limiter.check_connection_allowed(ip).await.unwrap();
limiter.check_connection_allowed(ip).await.unwrap();
let result = limiter.check_connection_allowed(ip).await;
assert!(matches!(result, Err(RateLimitError::IpRateExceeded)));
tokio::time::sleep(Duration::from_millis(150)).await;
let result = limiter.check_connection_allowed(ip).await;
assert!(result.is_ok());
}
}