diff --git a/markbase-core/src/ssh_server/mod.rs b/markbase-core/src/ssh_server/mod.rs index 716680b..f8b716c 100644 --- a/markbase-core/src/ssh_server/mod.rs +++ b/markbase-core/src/ssh_server/mod.rs @@ -12,6 +12,7 @@ pub mod kex; pub mod kex_complete; pub mod kex_exchange; pub mod known_hosts; +pub mod multiplex; pub mod packet; pub mod port_forward; pub mod port_forward_listener; diff --git a/markbase-core/src/ssh_server/multiplex.rs b/markbase-core/src/ssh_server/multiplex.rs new file mode 100644 index 0000000..9403cbf --- /dev/null +++ b/markbase-core/src/ssh_server/multiplex.rs @@ -0,0 +1,593 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; + +#[derive(Debug, Clone)] +pub struct MultiplexConfig { + pub max_sessions_per_connection: usize, + pub max_channels_per_session: usize, + pub session_timeout: Duration, + pub control_persist_timeout: Duration, + pub enable_multiplexing: bool, +} + +impl Default for MultiplexConfig { + fn default() -> Self { + Self { + max_sessions_per_connection: 10, + max_channels_per_session: 100, + session_timeout: Duration::from_secs(3600), + control_persist_timeout: Duration::from_secs(120), + enable_multiplexing: true, + } + } +} + +#[derive(Debug, Clone)] +pub struct MultiplexSession { + pub session_id: u64, + pub connection_id: u64, + pub client_addr: SocketAddr, + pub created_at: Instant, + pub last_activity: Instant, + pub channel_count: usize, + pub username: Option, + pub is_authenticated: bool, +} + +impl MultiplexSession { + pub fn is_expired(&self, now: Instant, timeout: Duration) -> bool { + now.duration_since(self.last_activity) > timeout + } + + pub fn update_activity(&mut self) { + self.last_activity = Instant::now(); + } +} + +#[derive(Debug, Clone)] +pub struct MultiplexConnection { + pub connection_id: u64, + pub client_addr: SocketAddr, + pub sessions: HashMap, + pub created_at: Instant, + pub total_channels: usize, + pub bytes_sent: u64, + pub bytes_received: u64, +} + +impl MultiplexConnection { + pub fn session_count(&self) -> usize { + self.sessions.len() + } + + pub fn add_session(&mut self, session: MultiplexSession) { + self.sessions.insert(session.session_id, session); + } + + pub fn remove_session(&mut self, session_id: u64) -> Option { + self.sessions.remove(&session_id) + } + + pub fn get_session(&self, session_id: u64) -> Option<&MultiplexSession> { + self.sessions.get(&session_id) + } + + pub fn get_session_mut(&mut self, session_id: u64) -> Option<&mut MultiplexSession> { + self.sessions.get_mut(&session_id) + } +} + +pub struct MultiplexManager { + config: MultiplexConfig, + connections: RwLock>, + next_connection_id: RwLock, + next_session_id: RwLock, +} + +impl MultiplexManager { + pub fn new(config: MultiplexConfig) -> Self { + Self { + config, + connections: RwLock::new(HashMap::new()), + next_connection_id: RwLock::new(1), + next_session_id: RwLock::new(1), + } + } + + pub fn default() -> Self { + Self::new(MultiplexConfig::default()) + } + + pub async fn alloc_connection_id(&self) -> u64 { + let mut next_id = self.next_connection_id.write().await; + let id = *next_id; + *next_id += 1; + id + } + + pub async fn alloc_session_id(&self) -> u64 { + let mut next_id = self.next_session_id.write().await; + let id = *next_id; + *next_id += 1; + id + } + + pub async fn register_connection(&self, client_addr: SocketAddr) -> Result { + if !self.config.enable_multiplexing { + return Err(MultiplexError::Disabled); + } + + let connections = self.connections.read().await; + if connections.len() >= self.config.max_sessions_per_connection * 10 { + return Err(MultiplexError::MaxConnectionsReached); + } + drop(connections); + + let connection_id = self.alloc_connection_id().await; + + let connection = MultiplexConnection { + connection_id, + client_addr, + sessions: HashMap::new(), + created_at: Instant::now(), + total_channels: 0, + bytes_sent: 0, + bytes_received: 0, + }; + + let mut connections = self.connections.write().await; + connections.insert(connection_id, connection); + + Ok(connection_id) + } + + pub async fn register_session( + &self, + connection_id: u64, + client_addr: SocketAddr, + username: Option, + ) -> Result { + let mut connections = self.connections.write().await; + + let connection = connections + .get_mut(&connection_id) + .ok_or(MultiplexError::ConnectionNotFound)?; + + if connection.session_count() >= self.config.max_sessions_per_connection { + return Err(MultiplexError::MaxSessionsReached); + } + + let session_id = self.alloc_session_id().await; + + let session = MultiplexSession { + session_id, + connection_id, + client_addr, + created_at: Instant::now(), + last_activity: Instant::now(), + channel_count: 0, + username, + is_authenticated: false, + }; + + connection.add_session(session); + + Ok(session_id) + } + + pub async fn authenticate_session(&self, session_id: u64) -> Result<(), MultiplexError> { + let mut connections = self.connections.write().await; + + for connection in connections.values_mut() { + if let Some(session) = connection.get_session_mut(session_id) { + session.is_authenticated = true; + session.update_activity(); + return Ok(()); + } + } + + Err(MultiplexError::SessionNotFound) + } + + pub async fn add_channel_to_session(&self, session_id: u64) -> Result<(), MultiplexError> { + let mut connections = self.connections.write().await; + + for connection in connections.values_mut() { + let session = connection + .sessions + .get_mut(&session_id) + .ok_or(MultiplexError::SessionNotFound)?; + + if session.channel_count >= self.config.max_channels_per_session { + return Err(MultiplexError::MaxChannelsReached); + } + session.channel_count += 1; + session.update_activity(); + connection.total_channels += 1; + return Ok(()); + } + + Err(MultiplexError::SessionNotFound) + } + + pub async fn remove_channel_from_session(&self, session_id: u64) -> Result<(), MultiplexError> { + let mut connections = self.connections.write().await; + + for connection in connections.values_mut() { + let session = connection + .sessions + .get_mut(&session_id) + .ok_or(MultiplexError::SessionNotFound)?; + + session.channel_count = session.channel_count.saturating_sub(1); + session.update_activity(); + connection.total_channels = connection.total_channels.saturating_sub(1); + return Ok(()); + } + + Err(MultiplexError::SessionNotFound) + } + + pub async fn update_session_activity(&self, session_id: u64) -> Result<(), MultiplexError> { + let mut connections = self.connections.write().await; + + for connection in connections.values_mut() { + if let Some(session) = connection.get_session_mut(session_id) { + session.update_activity(); + return Ok(()); + } + } + + Err(MultiplexError::SessionNotFound) + } + + pub async fn update_bytes(&self, connection_id: u64, sent: u64, received: u64) { + let mut connections = self.connections.write().await; + + if let Some(connection) = connections.get_mut(&connection_id) { + connection.bytes_sent += sent; + connection.bytes_received += received; + } + } + + pub async fn remove_session(&self, session_id: u64) -> Result<(), MultiplexError> { + let mut connections = self.connections.write().await; + + for connection in connections.values_mut() { + if connection.remove_session(session_id).is_some() { + return Ok(()); + } + } + + Err(MultiplexError::SessionNotFound) + } + + pub async fn remove_connection(&self, connection_id: u64) -> Result<(), MultiplexError> { + let mut connections = self.connections.write().await; + connections.remove(&connection_id); + Ok(()) + } + + pub async fn cleanup_expired_sessions(&self) -> usize { + let now = Instant::now(); + let mut connections = self.connections.write().await; + let mut total_removed = 0; + + for connection in connections.values_mut() { + let expired_session_ids: Vec = connection + .sessions + .values() + .filter(|s| s.is_expired(now, self.config.session_timeout)) + .map(|s| s.session_id) + .collect(); + + for session_id in expired_session_ids { + connection.remove_session(session_id); + connection.total_channels = connection.total_channels.saturating_sub(1); + total_removed += 1; + } + } + + connections.retain(|_, c| c.session_count() > 0); + + total_removed + } + + pub async fn cleanup_expired_connections(&self) -> usize { + let now = Instant::now(); + let mut connections = self.connections.write().await; + + let expired_count = connections.len(); + connections.retain(|_, c| { + c.session_count() > 0 + || now.duration_since(c.created_at) <= self.config.control_persist_timeout + }); + let retained_count = connections.len(); + + expired_count - retained_count + } + + pub async fn get_stats(&self) -> MultiplexStats { + let connections = self.connections.read().await; + + let total_connections = connections.len(); + let total_sessions = connections.values().map(|c| c.session_count()).sum(); + let total_channels = connections.values().map(|c| c.total_channels).sum(); + let total_bytes_sent = connections.values().map(|c| c.bytes_sent).sum(); + let total_bytes_received = connections.values().map(|c| c.bytes_received).sum(); + + let authenticated_sessions = connections + .values() + .flat_map(|c| c.sessions.values()) + .filter(|s| s.is_authenticated) + .count(); + + MultiplexStats { + total_connections, + total_sessions, + total_channels, + authenticated_sessions, + total_bytes_sent, + total_bytes_received, + max_sessions_per_connection: self.config.max_sessions_per_connection, + max_channels_per_session: self.config.max_channels_per_session, + } + } + + pub async fn get_connection(&self, connection_id: u64) -> Option { + let connections = self.connections.read().await; + connections.get(&connection_id).cloned() + } + + pub async fn get_session(&self, session_id: u64) -> Option { + let connections = self.connections.read().await; + + for connection in connections.values() { + if let Some(session) = connection.get_session(session_id) { + return Some(session.clone()); + } + } + + None + } +} + +#[derive(Debug, Clone)] +pub struct MultiplexStats { + pub total_connections: usize, + pub total_sessions: usize, + pub total_channels: usize, + pub authenticated_sessions: usize, + pub total_bytes_sent: u64, + pub total_bytes_received: u64, + pub max_sessions_per_connection: usize, + pub max_channels_per_session: usize, +} + +#[derive(Debug, Clone)] +pub enum MultiplexError { + Disabled, + MaxConnectionsReached, + MaxSessionsReached, + MaxChannelsReached, + ConnectionNotFound, + SessionNotFound, +} + +impl std::fmt::Display for MultiplexError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MultiplexError::Disabled => write!(f, "Multiplexing is disabled"), + MultiplexError::MaxConnectionsReached => write!(f, "Maximum connections reached"), + MultiplexError::MaxSessionsReached => write!(f, "Maximum sessions per connection reached"), + MultiplexError::MaxChannelsReached => write!(f, "Maximum channels per session reached"), + MultiplexError::ConnectionNotFound => write!(f, "Connection not found"), + MultiplexError::SessionNotFound => write!(f, "Session not found"), + } + } +} + +impl std::error::Error for MultiplexError {} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + + fn test_addr() -> SocketAddr { + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 22) + } + + #[tokio::test] + async fn test_register_connection() { + let manager = MultiplexManager::default(); + let addr = test_addr(); + + let conn_id = manager.register_connection(addr).await.unwrap(); + assert_ne!(conn_id, 0); + + let conn = manager.get_connection(conn_id).await.unwrap(); + assert_eq!(conn.client_addr, addr); + } + + #[tokio::test] + async fn test_register_session() { + let manager = MultiplexManager::default(); + let addr = test_addr(); + + let conn_id = manager.register_connection(addr).await.unwrap(); + let session_id = manager + .register_session(conn_id, addr, Some("testuser".to_string())) + .await + .unwrap(); + + assert_ne!(session_id, 0); + + let session = manager.get_session(session_id).await.unwrap(); + assert_eq!(session.username, Some("testuser".to_string())); + assert!(!session.is_authenticated); + } + + #[tokio::test] + async fn test_authenticate_session() { + let manager = MultiplexManager::default(); + let addr = test_addr(); + + let conn_id = manager.register_connection(addr).await.unwrap(); + let session_id = manager.register_session(conn_id, addr, None).await.unwrap(); + + manager.authenticate_session(session_id).await.unwrap(); + + let session = manager.get_session(session_id).await.unwrap(); + assert!(session.is_authenticated); + } + + #[tokio::test] + async fn test_add_remove_channel() { + let manager = MultiplexManager::default(); + let addr = test_addr(); + + let conn_id = manager.register_connection(addr).await.unwrap(); + let session_id = manager.register_session(conn_id, addr, None).await.unwrap(); + + manager.add_channel_to_session(session_id).await.unwrap(); + manager.add_channel_to_session(session_id).await.unwrap(); + + let session = manager.get_session(session_id).await.unwrap(); + assert_eq!(session.channel_count, 2); + + manager.remove_channel_from_session(session_id).await.unwrap(); + + let session = manager.get_session(session_id).await.unwrap(); + assert_eq!(session.channel_count, 1); + } + + #[tokio::test] + async fn test_max_sessions_per_connection() { + let config = MultiplexConfig { + max_sessions_per_connection: 2, + ..Default::default() + }; + let manager = MultiplexManager::new(config); + let addr = test_addr(); + + let conn_id = manager.register_connection(addr).await.unwrap(); + manager.register_session(conn_id, addr, None).await.unwrap(); + manager.register_session(conn_id, addr, None).await.unwrap(); + + let result = manager.register_session(conn_id, addr, None).await; + assert!(matches!(result, Err(MultiplexError::MaxSessionsReached))); + } + + #[tokio::test] + async fn test_max_channels_per_session() { + let config = MultiplexConfig { + max_channels_per_session: 2, + ..Default::default() + }; + let manager = MultiplexManager::new(config); + let addr = test_addr(); + + let conn_id = manager.register_connection(addr).await.unwrap(); + let session_id = manager.register_session(conn_id, addr, None).await.unwrap(); + + manager.add_channel_to_session(session_id).await.unwrap(); + manager.add_channel_to_session(session_id).await.unwrap(); + + let result = manager.add_channel_to_session(session_id).await; + assert!(matches!(result, Err(MultiplexError::MaxChannelsReached))); + } + + #[tokio::test] + async fn test_remove_session() { + let manager = MultiplexManager::default(); + let addr = test_addr(); + + let conn_id = manager.register_connection(addr).await.unwrap(); + let session_id = manager.register_session(conn_id, addr, None).await.unwrap(); + + manager.remove_session(session_id).await.unwrap(); + + let session = manager.get_session(session_id).await; + assert!(session.is_none()); + } + + #[tokio::test] + async fn test_remove_connection() { + let manager = MultiplexManager::default(); + let addr = test_addr(); + + let conn_id = manager.register_connection(addr).await.unwrap(); + manager.remove_connection(conn_id).await.unwrap(); + + let conn = manager.get_connection(conn_id).await; + assert!(conn.is_none()); + } + + #[tokio::test] + async fn test_cleanup_expired_sessions() { + let config = MultiplexConfig { + session_timeout: Duration::from_millis(100), + ..Default::default() + }; + let manager = MultiplexManager::new(config); + let addr = test_addr(); + + let conn_id = manager.register_connection(addr).await.unwrap(); + let session_id = manager.register_session(conn_id, addr, None).await.unwrap(); + + tokio::time::sleep(Duration::from_millis(150)).await; + + let removed = manager.cleanup_expired_sessions().await; + assert_eq!(removed, 1); + + let session = manager.get_session(session_id).await; + assert!(session.is_none()); + } + + #[tokio::test] + async fn test_get_stats() { + let manager = MultiplexManager::default(); + let addr = test_addr(); + + let conn_id = manager.register_connection(addr).await.unwrap(); + let session_id = manager.register_session(conn_id, addr, None).await.unwrap(); + manager.authenticate_session(session_id).await.unwrap(); + manager.add_channel_to_session(session_id).await.unwrap(); + + let stats = manager.get_stats().await; + assert_eq!(stats.total_connections, 1); + assert_eq!(stats.total_sessions, 1); + assert_eq!(stats.total_channels, 1); + assert_eq!(stats.authenticated_sessions, 1); + } + + #[tokio::test] + async fn test_update_bytes() { + let manager = MultiplexManager::default(); + let addr = test_addr(); + + let conn_id = manager.register_connection(addr).await.unwrap(); + manager.update_bytes(conn_id, 100, 50).await; + + let conn = manager.get_connection(conn_id).await.unwrap(); + assert_eq!(conn.bytes_sent, 100); + assert_eq!(conn.bytes_received, 50); + } + + #[tokio::test] + async fn test_disabled_multiplexing() { + let config = MultiplexConfig { + enable_multiplexing: false, + ..Default::default() + }; + let manager = MultiplexManager::new(config); + let addr = test_addr(); + + let result = manager.register_connection(addr).await; + assert!(matches!(result, Err(MultiplexError::Disabled))); + } +} \ No newline at end of file