Implement SSH Connection Rate Limiting: IP rate limit + global rate limit + auth brute force prevention
This commit is contained in:
@@ -14,6 +14,7 @@ pub mod kex_exchange;
|
|||||||
pub mod packet;
|
pub mod packet;
|
||||||
pub mod port_forward;
|
pub mod port_forward;
|
||||||
pub mod port_forward_listener;
|
pub mod port_forward_listener;
|
||||||
|
pub mod rate_limiter;
|
||||||
pub mod rsync_handler;
|
pub mod rsync_handler;
|
||||||
pub mod scp_handler;
|
pub mod scp_handler;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
|
|||||||
481
markbase-core/src/ssh_server/rate_limiter.rs
Normal file
481
markbase-core/src/ssh_server/rate_limiter.rs
Normal file
@@ -0,0 +1,481 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RateLimitConfig {
|
||||||
|
pub max_connections_per_ip: usize,
|
||||||
|
pub max_connections_per_ip_window: Duration,
|
||||||
|
pub max_global_connections: usize,
|
||||||
|
pub max_global_connections_window: Duration,
|
||||||
|
pub max_auth_attempts_per_ip: usize,
|
||||||
|
pub max_auth_attempts_window: Duration,
|
||||||
|
pub ban_duration: Duration,
|
||||||
|
pub whitelist: Vec<IpAddr>,
|
||||||
|
pub blacklist: Vec<IpAddr>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RateLimitConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_connections_per_ip: 10,
|
||||||
|
max_connections_per_ip_window: Duration::from_secs(60),
|
||||||
|
max_global_connections: 100,
|
||||||
|
max_global_connections_window: Duration::from_secs(60),
|
||||||
|
max_auth_attempts_per_ip: 5,
|
||||||
|
max_auth_attempts_window: Duration::from_secs(60),
|
||||||
|
ban_duration: Duration::from_secs(300),
|
||||||
|
whitelist: Vec::new(),
|
||||||
|
blacklist: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ConnectionRecord {
|
||||||
|
attempts: Vec<Instant>,
|
||||||
|
auth_attempts: Vec<Instant>,
|
||||||
|
banned_until: Option<Instant>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectionRecord {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
attempts: Vec::new(),
|
||||||
|
auth_attempts: Vec::new(),
|
||||||
|
banned_until: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_connection(&mut self, now: Instant) {
|
||||||
|
self.attempts.push(now);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_auth_attempt(&mut self, now: Instant) {
|
||||||
|
self.auth_attempts.push(now);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cleanup_old_attempts(&mut self, now: Instant, window: Duration) {
|
||||||
|
self.attempts.retain(|t| now.duration_since(*t) < window);
|
||||||
|
self.auth_attempts.retain(|t| now.duration_since(*t) < window);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn connection_count(&self) -> usize {
|
||||||
|
self.attempts.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_attempt_count(&self) -> usize {
|
||||||
|
self.auth_attempts.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ban(&mut self, until: Instant) {
|
||||||
|
self.banned_until = Some(until);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_banned(&self, now: Instant) -> bool {
|
||||||
|
if let Some(banned_until) = self.banned_until {
|
||||||
|
now < banned_until
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unban(&mut self) {
|
||||||
|
self.banned_until = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct GlobalConnectionRecord {
|
||||||
|
attempts: Vec<Instant>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GlobalConnectionRecord {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
attempts: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_connection(&mut self, now: Instant) {
|
||||||
|
self.attempts.push(now);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cleanup_old_attempts(&mut self, now: Instant, window: Duration) {
|
||||||
|
self.attempts.retain(|t| now.duration_since(*t) < window);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn connection_count(&self) -> usize {
|
||||||
|
self.attempts.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ConnectionRateLimiter {
|
||||||
|
config: RwLock<RateLimitConfig>,
|
||||||
|
ip_records: RwLock<HashMap<IpAddr, ConnectionRecord>>,
|
||||||
|
global_record: RwLock<GlobalConnectionRecord>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectionRateLimiter {
|
||||||
|
pub fn new(config: RateLimitConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
config: RwLock::new(config),
|
||||||
|
ip_records: RwLock::new(HashMap::new()),
|
||||||
|
global_record: RwLock::new(GlobalConnectionRecord::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default() -> Self {
|
||||||
|
Self::new(RateLimitConfig::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_config(config: RateLimitConfig) -> Self {
|
||||||
|
Self::new(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn check_connection_allowed(&self, ip: IpAddr) -> Result<(), RateLimitError> {
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
let config = self.config.read().await;
|
||||||
|
|
||||||
|
if config.whitelist.contains(&ip) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.blacklist.contains(&ip) {
|
||||||
|
return Err(RateLimitError::Blacklisted);
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_conn_per_ip = config.max_connections_per_ip;
|
||||||
|
let max_conn_window = config.max_connections_per_ip_window;
|
||||||
|
let max_global_conn = config.max_global_connections;
|
||||||
|
let max_global_window = config.max_global_connections_window;
|
||||||
|
|
||||||
|
let mut ip_records = self.ip_records.write().await;
|
||||||
|
let record = ip_records.entry(ip).or_insert_with(ConnectionRecord::new);
|
||||||
|
|
||||||
|
record.cleanup_old_attempts(now, max_conn_window);
|
||||||
|
|
||||||
|
if record.is_banned(now) {
|
||||||
|
return Err(RateLimitError::Banned);
|
||||||
|
}
|
||||||
|
|
||||||
|
if record.connection_count() >= max_conn_per_ip {
|
||||||
|
return Err(RateLimitError::IpRateExceeded);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut global_record = self.global_record.write().await;
|
||||||
|
global_record.cleanup_old_attempts(now, max_global_window);
|
||||||
|
|
||||||
|
if global_record.connection_count() >= max_global_conn {
|
||||||
|
return Err(RateLimitError::GlobalRateExceeded);
|
||||||
|
}
|
||||||
|
|
||||||
|
record.add_connection(now);
|
||||||
|
global_record.add_connection(now);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn check_auth_attempt_allowed(&self, ip: IpAddr) -> Result<(), RateLimitError> {
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
let config = self.config.read().await;
|
||||||
|
|
||||||
|
if config.whitelist.contains(&ip) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_auth = config.max_auth_attempts_per_ip;
|
||||||
|
let auth_window = config.max_auth_attempts_window;
|
||||||
|
let ban_duration = config.ban_duration;
|
||||||
|
|
||||||
|
let mut ip_records = self.ip_records.write().await;
|
||||||
|
let record = ip_records.entry(ip).or_insert_with(ConnectionRecord::new);
|
||||||
|
|
||||||
|
record.cleanup_old_attempts(now, auth_window);
|
||||||
|
|
||||||
|
if record.is_banned(now) {
|
||||||
|
return Err(RateLimitError::Banned);
|
||||||
|
}
|
||||||
|
|
||||||
|
if record.auth_attempt_count() >= max_auth {
|
||||||
|
let ban_until = now + ban_duration;
|
||||||
|
record.ban(ban_until);
|
||||||
|
return Err(RateLimitError::AuthRateExceeded);
|
||||||
|
}
|
||||||
|
|
||||||
|
record.add_auth_attempt(now);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn ban_ip(&self, ip: IpAddr, duration: Duration) {
|
||||||
|
let now = Instant::now();
|
||||||
|
let ban_until = now + duration;
|
||||||
|
|
||||||
|
let mut ip_records = self.ip_records.write().await;
|
||||||
|
if let Some(record) = ip_records.get_mut(&ip) {
|
||||||
|
record.ban(ban_until);
|
||||||
|
} else {
|
||||||
|
let mut record = ConnectionRecord::new();
|
||||||
|
record.ban(ban_until);
|
||||||
|
ip_records.insert(ip, record);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn unban_ip(&self, ip: IpAddr) {
|
||||||
|
let mut ip_records = self.ip_records.write().await;
|
||||||
|
if let Some(record) = ip_records.get_mut(&ip) {
|
||||||
|
record.unban();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn add_to_blacklist(&self, ip: IpAddr) {
|
||||||
|
let mut config = self.config.write().await;
|
||||||
|
if !config.blacklist.contains(&ip) {
|
||||||
|
config.blacklist.push(ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn remove_from_blacklist(&self, ip: IpAddr) {
|
||||||
|
let mut config = self.config.write().await;
|
||||||
|
config.blacklist.retain(|&x| x != ip);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn add_to_whitelist(&self, ip: IpAddr) {
|
||||||
|
let mut config = self.config.write().await;
|
||||||
|
if !config.whitelist.contains(&ip) {
|
||||||
|
config.whitelist.push(ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn remove_from_whitelist(&self, ip: IpAddr) {
|
||||||
|
let mut config = self.config.write().await;
|
||||||
|
config.whitelist.retain(|&x| x != ip);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_stats(&self) -> RateLimitStats {
|
||||||
|
let ip_records = self.ip_records.read().await;
|
||||||
|
let config = self.config.read().await;
|
||||||
|
|
||||||
|
let banned_ips = ip_records
|
||||||
|
.values()
|
||||||
|
.filter(|r| r.banned_until.is_some())
|
||||||
|
.count();
|
||||||
|
|
||||||
|
let active_ips = ip_records.len();
|
||||||
|
|
||||||
|
let total_connections = ip_records.values().map(|r| r.connection_count()).sum();
|
||||||
|
|
||||||
|
let total_auth_attempts = ip_records.values().map(|r| r.auth_attempt_count()).sum();
|
||||||
|
|
||||||
|
RateLimitStats {
|
||||||
|
active_ips,
|
||||||
|
banned_ips,
|
||||||
|
total_connections,
|
||||||
|
total_auth_attempts,
|
||||||
|
blacklist_size: config.blacklist.len(),
|
||||||
|
whitelist_size: config.whitelist.len(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RateLimitStats {
|
||||||
|
pub active_ips: usize,
|
||||||
|
pub banned_ips: usize,
|
||||||
|
pub total_connections: usize,
|
||||||
|
pub total_auth_attempts: usize,
|
||||||
|
pub blacklist_size: usize,
|
||||||
|
pub whitelist_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum RateLimitError {
|
||||||
|
IpRateExceeded,
|
||||||
|
GlobalRateExceeded,
|
||||||
|
AuthRateExceeded,
|
||||||
|
Banned,
|
||||||
|
Blacklisted,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for RateLimitError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
RateLimitError::IpRateExceeded => write!(f, "IP connection rate exceeded"),
|
||||||
|
RateLimitError::GlobalRateExceeded => write!(f, "Global connection rate exceeded"),
|
||||||
|
RateLimitError::AuthRateExceeded => write!(f, "Authentication rate exceeded - IP banned"),
|
||||||
|
RateLimitError::Banned => write!(f, "IP is banned"),
|
||||||
|
RateLimitError::Blacklisted => write!(f, "IP is blacklisted"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for RateLimitError {}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr};
|
||||||
|
|
||||||
|
fn test_ip() -> IpAddr {
|
||||||
|
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_connection_allowed_initial() {
|
||||||
|
let limiter = ConnectionRateLimiter::default();
|
||||||
|
let ip = test_ip();
|
||||||
|
|
||||||
|
let result = limiter.check_connection_allowed(ip).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_connection_rate_exceeded() {
|
||||||
|
let config = RateLimitConfig {
|
||||||
|
max_connections_per_ip: 2,
|
||||||
|
max_connections_per_ip_window: Duration::from_secs(60),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let limiter = ConnectionRateLimiter::new(config);
|
||||||
|
let ip = test_ip();
|
||||||
|
|
||||||
|
limiter.check_connection_allowed(ip).await.unwrap();
|
||||||
|
limiter.check_connection_allowed(ip).await.unwrap();
|
||||||
|
|
||||||
|
let result = limiter.check_connection_allowed(ip).await;
|
||||||
|
assert!(matches!(result, Err(RateLimitError::IpRateExceeded)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_auth_rate_exceeded() {
|
||||||
|
let config = RateLimitConfig {
|
||||||
|
max_auth_attempts_per_ip: 2,
|
||||||
|
max_auth_attempts_window: Duration::from_secs(60),
|
||||||
|
ban_duration: Duration::from_secs(10),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let limiter = ConnectionRateLimiter::new(config);
|
||||||
|
let ip = test_ip();
|
||||||
|
|
||||||
|
limiter.check_auth_attempt_allowed(ip).await.unwrap();
|
||||||
|
limiter.check_auth_attempt_allowed(ip).await.unwrap();
|
||||||
|
|
||||||
|
let result = limiter.check_auth_attempt_allowed(ip).await;
|
||||||
|
assert!(matches!(result, Err(RateLimitError::AuthRateExceeded)));
|
||||||
|
|
||||||
|
let conn_result = limiter.check_connection_allowed(ip).await;
|
||||||
|
assert!(matches!(conn_result, Err(RateLimitError::Banned)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_whitelist_bypass() {
|
||||||
|
let config = RateLimitConfig {
|
||||||
|
max_connections_per_ip: 1,
|
||||||
|
max_connections_per_ip_window: Duration::from_secs(60),
|
||||||
|
whitelist: vec![test_ip()],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let limiter = ConnectionRateLimiter::new(config);
|
||||||
|
let ip = test_ip();
|
||||||
|
|
||||||
|
limiter.check_connection_allowed(ip).await.unwrap();
|
||||||
|
limiter.check_connection_allowed(ip).await.unwrap();
|
||||||
|
|
||||||
|
let result = limiter.check_connection_allowed(ip).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_blacklist_blocked() {
|
||||||
|
let config = RateLimitConfig {
|
||||||
|
blacklist: vec![test_ip()],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let limiter = ConnectionRateLimiter::new(config);
|
||||||
|
let ip = test_ip();
|
||||||
|
|
||||||
|
let result = limiter.check_connection_allowed(ip).await;
|
||||||
|
assert!(matches!(result, Err(RateLimitError::Blacklisted)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_global_rate_exceeded() {
|
||||||
|
let config = RateLimitConfig {
|
||||||
|
max_global_connections: 2,
|
||||||
|
max_global_connections_window: Duration::from_secs(60),
|
||||||
|
max_connections_per_ip: 100,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let limiter = ConnectionRateLimiter::new(config);
|
||||||
|
|
||||||
|
let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||||
|
let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
|
||||||
|
let ip3 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3));
|
||||||
|
|
||||||
|
limiter.check_connection_allowed(ip1).await.unwrap();
|
||||||
|
limiter.check_connection_allowed(ip2).await.unwrap();
|
||||||
|
|
||||||
|
let result = limiter.check_connection_allowed(ip3).await;
|
||||||
|
assert!(matches!(result, Err(RateLimitError::GlobalRateExceeded)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_ban_unban() {
|
||||||
|
let limiter = ConnectionRateLimiter::default();
|
||||||
|
let ip = test_ip();
|
||||||
|
|
||||||
|
limiter.ban_ip(ip, Duration::from_secs(10)).await;
|
||||||
|
|
||||||
|
let result = limiter.check_connection_allowed(ip).await;
|
||||||
|
assert!(matches!(result, Err(RateLimitError::Banned)));
|
||||||
|
|
||||||
|
limiter.unban_ip(ip).await;
|
||||||
|
|
||||||
|
let result = limiter.check_connection_allowed(ip).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_stats() {
|
||||||
|
let limiter = ConnectionRateLimiter::default();
|
||||||
|
|
||||||
|
let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||||
|
let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
|
||||||
|
|
||||||
|
limiter.check_connection_allowed(ip1).await.unwrap();
|
||||||
|
limiter.check_connection_allowed(ip2).await.unwrap();
|
||||||
|
limiter.check_auth_attempt_allowed(ip1).await.unwrap();
|
||||||
|
|
||||||
|
let stats = limiter.get_stats().await;
|
||||||
|
assert_eq!(stats.active_ips, 2);
|
||||||
|
assert_eq!(stats.total_connections, 2);
|
||||||
|
assert_eq!(stats.total_auth_attempts, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_rate_limit_window_expiry() {
|
||||||
|
let config = RateLimitConfig {
|
||||||
|
max_connections_per_ip: 2,
|
||||||
|
max_connections_per_ip_window: Duration::from_millis(100),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let limiter = ConnectionRateLimiter::new(config);
|
||||||
|
let ip = test_ip();
|
||||||
|
|
||||||
|
limiter.check_connection_allowed(ip).await.unwrap();
|
||||||
|
limiter.check_connection_allowed(ip).await.unwrap();
|
||||||
|
|
||||||
|
let result = limiter.check_connection_allowed(ip).await;
|
||||||
|
assert!(matches!(result, Err(RateLimitError::IpRateExceeded)));
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(150)).await;
|
||||||
|
|
||||||
|
let result = limiter.check_connection_allowed(ip).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user