Implement SSH Multiplexing: Connection/Session/Channel management with expiration and cleanup
This commit is contained in:
@@ -12,6 +12,7 @@ pub mod kex;
|
||||
pub mod kex_complete;
|
||||
pub mod kex_exchange;
|
||||
pub mod known_hosts;
|
||||
pub mod multiplex;
|
||||
pub mod packet;
|
||||
pub mod port_forward;
|
||||
pub mod port_forward_listener;
|
||||
|
||||
593
markbase-core/src/ssh_server/multiplex.rs
Normal file
593
markbase-core/src/ssh_server/multiplex.rs
Normal file
@@ -0,0 +1,593 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultiplexConfig {
|
||||
pub max_sessions_per_connection: usize,
|
||||
pub max_channels_per_session: usize,
|
||||
pub session_timeout: Duration,
|
||||
pub control_persist_timeout: Duration,
|
||||
pub enable_multiplexing: bool,
|
||||
}
|
||||
|
||||
impl Default for MultiplexConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_sessions_per_connection: 10,
|
||||
max_channels_per_session: 100,
|
||||
session_timeout: Duration::from_secs(3600),
|
||||
control_persist_timeout: Duration::from_secs(120),
|
||||
enable_multiplexing: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultiplexSession {
|
||||
pub session_id: u64,
|
||||
pub connection_id: u64,
|
||||
pub client_addr: SocketAddr,
|
||||
pub created_at: Instant,
|
||||
pub last_activity: Instant,
|
||||
pub channel_count: usize,
|
||||
pub username: Option<String>,
|
||||
pub is_authenticated: bool,
|
||||
}
|
||||
|
||||
impl MultiplexSession {
|
||||
pub fn is_expired(&self, now: Instant, timeout: Duration) -> bool {
|
||||
now.duration_since(self.last_activity) > timeout
|
||||
}
|
||||
|
||||
pub fn update_activity(&mut self) {
|
||||
self.last_activity = Instant::now();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultiplexConnection {
|
||||
pub connection_id: u64,
|
||||
pub client_addr: SocketAddr,
|
||||
pub sessions: HashMap<u64, MultiplexSession>,
|
||||
pub created_at: Instant,
|
||||
pub total_channels: usize,
|
||||
pub bytes_sent: u64,
|
||||
pub bytes_received: u64,
|
||||
}
|
||||
|
||||
impl MultiplexConnection {
|
||||
pub fn session_count(&self) -> usize {
|
||||
self.sessions.len()
|
||||
}
|
||||
|
||||
pub fn add_session(&mut self, session: MultiplexSession) {
|
||||
self.sessions.insert(session.session_id, session);
|
||||
}
|
||||
|
||||
pub fn remove_session(&mut self, session_id: u64) -> Option<MultiplexSession> {
|
||||
self.sessions.remove(&session_id)
|
||||
}
|
||||
|
||||
pub fn get_session(&self, session_id: u64) -> Option<&MultiplexSession> {
|
||||
self.sessions.get(&session_id)
|
||||
}
|
||||
|
||||
pub fn get_session_mut(&mut self, session_id: u64) -> Option<&mut MultiplexSession> {
|
||||
self.sessions.get_mut(&session_id)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MultiplexManager {
|
||||
config: MultiplexConfig,
|
||||
connections: RwLock<HashMap<u64, MultiplexConnection>>,
|
||||
next_connection_id: RwLock<u64>,
|
||||
next_session_id: RwLock<u64>,
|
||||
}
|
||||
|
||||
impl MultiplexManager {
|
||||
pub fn new(config: MultiplexConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
connections: RwLock::new(HashMap::new()),
|
||||
next_connection_id: RwLock::new(1),
|
||||
next_session_id: RwLock::new(1),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default() -> Self {
|
||||
Self::new(MultiplexConfig::default())
|
||||
}
|
||||
|
||||
pub async fn alloc_connection_id(&self) -> u64 {
|
||||
let mut next_id = self.next_connection_id.write().await;
|
||||
let id = *next_id;
|
||||
*next_id += 1;
|
||||
id
|
||||
}
|
||||
|
||||
pub async fn alloc_session_id(&self) -> u64 {
|
||||
let mut next_id = self.next_session_id.write().await;
|
||||
let id = *next_id;
|
||||
*next_id += 1;
|
||||
id
|
||||
}
|
||||
|
||||
pub async fn register_connection(&self, client_addr: SocketAddr) -> Result<u64, MultiplexError> {
|
||||
if !self.config.enable_multiplexing {
|
||||
return Err(MultiplexError::Disabled);
|
||||
}
|
||||
|
||||
let connections = self.connections.read().await;
|
||||
if connections.len() >= self.config.max_sessions_per_connection * 10 {
|
||||
return Err(MultiplexError::MaxConnectionsReached);
|
||||
}
|
||||
drop(connections);
|
||||
|
||||
let connection_id = self.alloc_connection_id().await;
|
||||
|
||||
let connection = MultiplexConnection {
|
||||
connection_id,
|
||||
client_addr,
|
||||
sessions: HashMap::new(),
|
||||
created_at: Instant::now(),
|
||||
total_channels: 0,
|
||||
bytes_sent: 0,
|
||||
bytes_received: 0,
|
||||
};
|
||||
|
||||
let mut connections = self.connections.write().await;
|
||||
connections.insert(connection_id, connection);
|
||||
|
||||
Ok(connection_id)
|
||||
}
|
||||
|
||||
pub async fn register_session(
|
||||
&self,
|
||||
connection_id: u64,
|
||||
client_addr: SocketAddr,
|
||||
username: Option<String>,
|
||||
) -> Result<u64, MultiplexError> {
|
||||
let mut connections = self.connections.write().await;
|
||||
|
||||
let connection = connections
|
||||
.get_mut(&connection_id)
|
||||
.ok_or(MultiplexError::ConnectionNotFound)?;
|
||||
|
||||
if connection.session_count() >= self.config.max_sessions_per_connection {
|
||||
return Err(MultiplexError::MaxSessionsReached);
|
||||
}
|
||||
|
||||
let session_id = self.alloc_session_id().await;
|
||||
|
||||
let session = MultiplexSession {
|
||||
session_id,
|
||||
connection_id,
|
||||
client_addr,
|
||||
created_at: Instant::now(),
|
||||
last_activity: Instant::now(),
|
||||
channel_count: 0,
|
||||
username,
|
||||
is_authenticated: false,
|
||||
};
|
||||
|
||||
connection.add_session(session);
|
||||
|
||||
Ok(session_id)
|
||||
}
|
||||
|
||||
pub async fn authenticate_session(&self, session_id: u64) -> Result<(), MultiplexError> {
|
||||
let mut connections = self.connections.write().await;
|
||||
|
||||
for connection in connections.values_mut() {
|
||||
if let Some(session) = connection.get_session_mut(session_id) {
|
||||
session.is_authenticated = true;
|
||||
session.update_activity();
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
Err(MultiplexError::SessionNotFound)
|
||||
}
|
||||
|
||||
pub async fn add_channel_to_session(&self, session_id: u64) -> Result<(), MultiplexError> {
|
||||
let mut connections = self.connections.write().await;
|
||||
|
||||
for connection in connections.values_mut() {
|
||||
let session = connection
|
||||
.sessions
|
||||
.get_mut(&session_id)
|
||||
.ok_or(MultiplexError::SessionNotFound)?;
|
||||
|
||||
if session.channel_count >= self.config.max_channels_per_session {
|
||||
return Err(MultiplexError::MaxChannelsReached);
|
||||
}
|
||||
session.channel_count += 1;
|
||||
session.update_activity();
|
||||
connection.total_channels += 1;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err(MultiplexError::SessionNotFound)
|
||||
}
|
||||
|
||||
pub async fn remove_channel_from_session(&self, session_id: u64) -> Result<(), MultiplexError> {
|
||||
let mut connections = self.connections.write().await;
|
||||
|
||||
for connection in connections.values_mut() {
|
||||
let session = connection
|
||||
.sessions
|
||||
.get_mut(&session_id)
|
||||
.ok_or(MultiplexError::SessionNotFound)?;
|
||||
|
||||
session.channel_count = session.channel_count.saturating_sub(1);
|
||||
session.update_activity();
|
||||
connection.total_channels = connection.total_channels.saturating_sub(1);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err(MultiplexError::SessionNotFound)
|
||||
}
|
||||
|
||||
pub async fn update_session_activity(&self, session_id: u64) -> Result<(), MultiplexError> {
|
||||
let mut connections = self.connections.write().await;
|
||||
|
||||
for connection in connections.values_mut() {
|
||||
if let Some(session) = connection.get_session_mut(session_id) {
|
||||
session.update_activity();
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
Err(MultiplexError::SessionNotFound)
|
||||
}
|
||||
|
||||
pub async fn update_bytes(&self, connection_id: u64, sent: u64, received: u64) {
|
||||
let mut connections = self.connections.write().await;
|
||||
|
||||
if let Some(connection) = connections.get_mut(&connection_id) {
|
||||
connection.bytes_sent += sent;
|
||||
connection.bytes_received += received;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_session(&self, session_id: u64) -> Result<(), MultiplexError> {
|
||||
let mut connections = self.connections.write().await;
|
||||
|
||||
for connection in connections.values_mut() {
|
||||
if connection.remove_session(session_id).is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
Err(MultiplexError::SessionNotFound)
|
||||
}
|
||||
|
||||
pub async fn remove_connection(&self, connection_id: u64) -> Result<(), MultiplexError> {
|
||||
let mut connections = self.connections.write().await;
|
||||
connections.remove(&connection_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn cleanup_expired_sessions(&self) -> usize {
|
||||
let now = Instant::now();
|
||||
let mut connections = self.connections.write().await;
|
||||
let mut total_removed = 0;
|
||||
|
||||
for connection in connections.values_mut() {
|
||||
let expired_session_ids: Vec<u64> = connection
|
||||
.sessions
|
||||
.values()
|
||||
.filter(|s| s.is_expired(now, self.config.session_timeout))
|
||||
.map(|s| s.session_id)
|
||||
.collect();
|
||||
|
||||
for session_id in expired_session_ids {
|
||||
connection.remove_session(session_id);
|
||||
connection.total_channels = connection.total_channels.saturating_sub(1);
|
||||
total_removed += 1;
|
||||
}
|
||||
}
|
||||
|
||||
connections.retain(|_, c| c.session_count() > 0);
|
||||
|
||||
total_removed
|
||||
}
|
||||
|
||||
pub async fn cleanup_expired_connections(&self) -> usize {
|
||||
let now = Instant::now();
|
||||
let mut connections = self.connections.write().await;
|
||||
|
||||
let expired_count = connections.len();
|
||||
connections.retain(|_, c| {
|
||||
c.session_count() > 0
|
||||
|| now.duration_since(c.created_at) <= self.config.control_persist_timeout
|
||||
});
|
||||
let retained_count = connections.len();
|
||||
|
||||
expired_count - retained_count
|
||||
}
|
||||
|
||||
pub async fn get_stats(&self) -> MultiplexStats {
|
||||
let connections = self.connections.read().await;
|
||||
|
||||
let total_connections = connections.len();
|
||||
let total_sessions = connections.values().map(|c| c.session_count()).sum();
|
||||
let total_channels = connections.values().map(|c| c.total_channels).sum();
|
||||
let total_bytes_sent = connections.values().map(|c| c.bytes_sent).sum();
|
||||
let total_bytes_received = connections.values().map(|c| c.bytes_received).sum();
|
||||
|
||||
let authenticated_sessions = connections
|
||||
.values()
|
||||
.flat_map(|c| c.sessions.values())
|
||||
.filter(|s| s.is_authenticated)
|
||||
.count();
|
||||
|
||||
MultiplexStats {
|
||||
total_connections,
|
||||
total_sessions,
|
||||
total_channels,
|
||||
authenticated_sessions,
|
||||
total_bytes_sent,
|
||||
total_bytes_received,
|
||||
max_sessions_per_connection: self.config.max_sessions_per_connection,
|
||||
max_channels_per_session: self.config.max_channels_per_session,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_connection(&self, connection_id: u64) -> Option<MultiplexConnection> {
|
||||
let connections = self.connections.read().await;
|
||||
connections.get(&connection_id).cloned()
|
||||
}
|
||||
|
||||
pub async fn get_session(&self, session_id: u64) -> Option<MultiplexSession> {
|
||||
let connections = self.connections.read().await;
|
||||
|
||||
for connection in connections.values() {
|
||||
if let Some(session) = connection.get_session(session_id) {
|
||||
return Some(session.clone());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultiplexStats {
|
||||
pub total_connections: usize,
|
||||
pub total_sessions: usize,
|
||||
pub total_channels: usize,
|
||||
pub authenticated_sessions: usize,
|
||||
pub total_bytes_sent: u64,
|
||||
pub total_bytes_received: u64,
|
||||
pub max_sessions_per_connection: usize,
|
||||
pub max_channels_per_session: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum MultiplexError {
|
||||
Disabled,
|
||||
MaxConnectionsReached,
|
||||
MaxSessionsReached,
|
||||
MaxChannelsReached,
|
||||
ConnectionNotFound,
|
||||
SessionNotFound,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MultiplexError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MultiplexError::Disabled => write!(f, "Multiplexing is disabled"),
|
||||
MultiplexError::MaxConnectionsReached => write!(f, "Maximum connections reached"),
|
||||
MultiplexError::MaxSessionsReached => write!(f, "Maximum sessions per connection reached"),
|
||||
MultiplexError::MaxChannelsReached => write!(f, "Maximum channels per session reached"),
|
||||
MultiplexError::ConnectionNotFound => write!(f, "Connection not found"),
|
||||
MultiplexError::SessionNotFound => write!(f, "Session not found"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for MultiplexError {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
|
||||
fn test_addr() -> SocketAddr {
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 22)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_connection() {
|
||||
let manager = MultiplexManager::default();
|
||||
let addr = test_addr();
|
||||
|
||||
let conn_id = manager.register_connection(addr).await.unwrap();
|
||||
assert_ne!(conn_id, 0);
|
||||
|
||||
let conn = manager.get_connection(conn_id).await.unwrap();
|
||||
assert_eq!(conn.client_addr, addr);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_session() {
|
||||
let manager = MultiplexManager::default();
|
||||
let addr = test_addr();
|
||||
|
||||
let conn_id = manager.register_connection(addr).await.unwrap();
|
||||
let session_id = manager
|
||||
.register_session(conn_id, addr, Some("testuser".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_ne!(session_id, 0);
|
||||
|
||||
let session = manager.get_session(session_id).await.unwrap();
|
||||
assert_eq!(session.username, Some("testuser".to_string()));
|
||||
assert!(!session.is_authenticated);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authenticate_session() {
|
||||
let manager = MultiplexManager::default();
|
||||
let addr = test_addr();
|
||||
|
||||
let conn_id = manager.register_connection(addr).await.unwrap();
|
||||
let session_id = manager.register_session(conn_id, addr, None).await.unwrap();
|
||||
|
||||
manager.authenticate_session(session_id).await.unwrap();
|
||||
|
||||
let session = manager.get_session(session_id).await.unwrap();
|
||||
assert!(session.is_authenticated);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_remove_channel() {
|
||||
let manager = MultiplexManager::default();
|
||||
let addr = test_addr();
|
||||
|
||||
let conn_id = manager.register_connection(addr).await.unwrap();
|
||||
let session_id = manager.register_session(conn_id, addr, None).await.unwrap();
|
||||
|
||||
manager.add_channel_to_session(session_id).await.unwrap();
|
||||
manager.add_channel_to_session(session_id).await.unwrap();
|
||||
|
||||
let session = manager.get_session(session_id).await.unwrap();
|
||||
assert_eq!(session.channel_count, 2);
|
||||
|
||||
manager.remove_channel_from_session(session_id).await.unwrap();
|
||||
|
||||
let session = manager.get_session(session_id).await.unwrap();
|
||||
assert_eq!(session.channel_count, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_max_sessions_per_connection() {
|
||||
let config = MultiplexConfig {
|
||||
max_sessions_per_connection: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let manager = MultiplexManager::new(config);
|
||||
let addr = test_addr();
|
||||
|
||||
let conn_id = manager.register_connection(addr).await.unwrap();
|
||||
manager.register_session(conn_id, addr, None).await.unwrap();
|
||||
manager.register_session(conn_id, addr, None).await.unwrap();
|
||||
|
||||
let result = manager.register_session(conn_id, addr, None).await;
|
||||
assert!(matches!(result, Err(MultiplexError::MaxSessionsReached)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_max_channels_per_session() {
|
||||
let config = MultiplexConfig {
|
||||
max_channels_per_session: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let manager = MultiplexManager::new(config);
|
||||
let addr = test_addr();
|
||||
|
||||
let conn_id = manager.register_connection(addr).await.unwrap();
|
||||
let session_id = manager.register_session(conn_id, addr, None).await.unwrap();
|
||||
|
||||
manager.add_channel_to_session(session_id).await.unwrap();
|
||||
manager.add_channel_to_session(session_id).await.unwrap();
|
||||
|
||||
let result = manager.add_channel_to_session(session_id).await;
|
||||
assert!(matches!(result, Err(MultiplexError::MaxChannelsReached)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remove_session() {
|
||||
let manager = MultiplexManager::default();
|
||||
let addr = test_addr();
|
||||
|
||||
let conn_id = manager.register_connection(addr).await.unwrap();
|
||||
let session_id = manager.register_session(conn_id, addr, None).await.unwrap();
|
||||
|
||||
manager.remove_session(session_id).await.unwrap();
|
||||
|
||||
let session = manager.get_session(session_id).await;
|
||||
assert!(session.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remove_connection() {
|
||||
let manager = MultiplexManager::default();
|
||||
let addr = test_addr();
|
||||
|
||||
let conn_id = manager.register_connection(addr).await.unwrap();
|
||||
manager.remove_connection(conn_id).await.unwrap();
|
||||
|
||||
let conn = manager.get_connection(conn_id).await;
|
||||
assert!(conn.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cleanup_expired_sessions() {
|
||||
let config = MultiplexConfig {
|
||||
session_timeout: Duration::from_millis(100),
|
||||
..Default::default()
|
||||
};
|
||||
let manager = MultiplexManager::new(config);
|
||||
let addr = test_addr();
|
||||
|
||||
let conn_id = manager.register_connection(addr).await.unwrap();
|
||||
let session_id = manager.register_session(conn_id, addr, None).await.unwrap();
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(150)).await;
|
||||
|
||||
let removed = manager.cleanup_expired_sessions().await;
|
||||
assert_eq!(removed, 1);
|
||||
|
||||
let session = manager.get_session(session_id).await;
|
||||
assert!(session.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_stats() {
|
||||
let manager = MultiplexManager::default();
|
||||
let addr = test_addr();
|
||||
|
||||
let conn_id = manager.register_connection(addr).await.unwrap();
|
||||
let session_id = manager.register_session(conn_id, addr, None).await.unwrap();
|
||||
manager.authenticate_session(session_id).await.unwrap();
|
||||
manager.add_channel_to_session(session_id).await.unwrap();
|
||||
|
||||
let stats = manager.get_stats().await;
|
||||
assert_eq!(stats.total_connections, 1);
|
||||
assert_eq!(stats.total_sessions, 1);
|
||||
assert_eq!(stats.total_channels, 1);
|
||||
assert_eq!(stats.authenticated_sessions, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_bytes() {
|
||||
let manager = MultiplexManager::default();
|
||||
let addr = test_addr();
|
||||
|
||||
let conn_id = manager.register_connection(addr).await.unwrap();
|
||||
manager.update_bytes(conn_id, 100, 50).await;
|
||||
|
||||
let conn = manager.get_connection(conn_id).await.unwrap();
|
||||
assert_eq!(conn.bytes_sent, 100);
|
||||
assert_eq!(conn.bytes_received, 50);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_disabled_multiplexing() {
|
||||
let config = MultiplexConfig {
|
||||
enable_multiplexing: false,
|
||||
..Default::default()
|
||||
};
|
||||
let manager = MultiplexManager::new(config);
|
||||
let addr = test_addr();
|
||||
|
||||
let result = manager.register_connection(addr).await;
|
||||
assert!(matches!(result, Err(MultiplexError::Disabled)));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user