Implement SSH Multiplexing: Connection/Session/Channel management with expiration and cleanup
Some checks failed
Test / test (push) Has been cancelled
Test / build (push) Has been cancelled

This commit is contained in:
Warren
2026-06-21 05:31:06 +08:00
parent 30c1e5fff9
commit d368a7a4c0
2 changed files with 594 additions and 0 deletions

View File

@@ -12,6 +12,7 @@ pub mod kex;
pub mod kex_complete; pub mod kex_complete;
pub mod kex_exchange; pub mod kex_exchange;
pub mod known_hosts; pub mod known_hosts;
pub mod multiplex;
pub mod packet; pub mod packet;
pub mod port_forward; pub mod port_forward;
pub mod port_forward_listener; pub mod port_forward_listener;

View File

@@ -0,0 +1,593 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct MultiplexConfig {
pub max_sessions_per_connection: usize,
pub max_channels_per_session: usize,
pub session_timeout: Duration,
pub control_persist_timeout: Duration,
pub enable_multiplexing: bool,
}
impl Default for MultiplexConfig {
fn default() -> Self {
Self {
max_sessions_per_connection: 10,
max_channels_per_session: 100,
session_timeout: Duration::from_secs(3600),
control_persist_timeout: Duration::from_secs(120),
enable_multiplexing: true,
}
}
}
#[derive(Debug, Clone)]
pub struct MultiplexSession {
pub session_id: u64,
pub connection_id: u64,
pub client_addr: SocketAddr,
pub created_at: Instant,
pub last_activity: Instant,
pub channel_count: usize,
pub username: Option<String>,
pub is_authenticated: bool,
}
impl MultiplexSession {
pub fn is_expired(&self, now: Instant, timeout: Duration) -> bool {
now.duration_since(self.last_activity) > timeout
}
pub fn update_activity(&mut self) {
self.last_activity = Instant::now();
}
}
#[derive(Debug, Clone)]
pub struct MultiplexConnection {
pub connection_id: u64,
pub client_addr: SocketAddr,
pub sessions: HashMap<u64, MultiplexSession>,
pub created_at: Instant,
pub total_channels: usize,
pub bytes_sent: u64,
pub bytes_received: u64,
}
impl MultiplexConnection {
pub fn session_count(&self) -> usize {
self.sessions.len()
}
pub fn add_session(&mut self, session: MultiplexSession) {
self.sessions.insert(session.session_id, session);
}
pub fn remove_session(&mut self, session_id: u64) -> Option<MultiplexSession> {
self.sessions.remove(&session_id)
}
pub fn get_session(&self, session_id: u64) -> Option<&MultiplexSession> {
self.sessions.get(&session_id)
}
pub fn get_session_mut(&mut self, session_id: u64) -> Option<&mut MultiplexSession> {
self.sessions.get_mut(&session_id)
}
}
pub struct MultiplexManager {
config: MultiplexConfig,
connections: RwLock<HashMap<u64, MultiplexConnection>>,
next_connection_id: RwLock<u64>,
next_session_id: RwLock<u64>,
}
impl MultiplexManager {
pub fn new(config: MultiplexConfig) -> Self {
Self {
config,
connections: RwLock::new(HashMap::new()),
next_connection_id: RwLock::new(1),
next_session_id: RwLock::new(1),
}
}
pub fn default() -> Self {
Self::new(MultiplexConfig::default())
}
pub async fn alloc_connection_id(&self) -> u64 {
let mut next_id = self.next_connection_id.write().await;
let id = *next_id;
*next_id += 1;
id
}
pub async fn alloc_session_id(&self) -> u64 {
let mut next_id = self.next_session_id.write().await;
let id = *next_id;
*next_id += 1;
id
}
pub async fn register_connection(&self, client_addr: SocketAddr) -> Result<u64, MultiplexError> {
if !self.config.enable_multiplexing {
return Err(MultiplexError::Disabled);
}
let connections = self.connections.read().await;
if connections.len() >= self.config.max_sessions_per_connection * 10 {
return Err(MultiplexError::MaxConnectionsReached);
}
drop(connections);
let connection_id = self.alloc_connection_id().await;
let connection = MultiplexConnection {
connection_id,
client_addr,
sessions: HashMap::new(),
created_at: Instant::now(),
total_channels: 0,
bytes_sent: 0,
bytes_received: 0,
};
let mut connections = self.connections.write().await;
connections.insert(connection_id, connection);
Ok(connection_id)
}
pub async fn register_session(
&self,
connection_id: u64,
client_addr: SocketAddr,
username: Option<String>,
) -> Result<u64, MultiplexError> {
let mut connections = self.connections.write().await;
let connection = connections
.get_mut(&connection_id)
.ok_or(MultiplexError::ConnectionNotFound)?;
if connection.session_count() >= self.config.max_sessions_per_connection {
return Err(MultiplexError::MaxSessionsReached);
}
let session_id = self.alloc_session_id().await;
let session = MultiplexSession {
session_id,
connection_id,
client_addr,
created_at: Instant::now(),
last_activity: Instant::now(),
channel_count: 0,
username,
is_authenticated: false,
};
connection.add_session(session);
Ok(session_id)
}
pub async fn authenticate_session(&self, session_id: u64) -> Result<(), MultiplexError> {
let mut connections = self.connections.write().await;
for connection in connections.values_mut() {
if let Some(session) = connection.get_session_mut(session_id) {
session.is_authenticated = true;
session.update_activity();
return Ok(());
}
}
Err(MultiplexError::SessionNotFound)
}
pub async fn add_channel_to_session(&self, session_id: u64) -> Result<(), MultiplexError> {
let mut connections = self.connections.write().await;
for connection in connections.values_mut() {
let session = connection
.sessions
.get_mut(&session_id)
.ok_or(MultiplexError::SessionNotFound)?;
if session.channel_count >= self.config.max_channels_per_session {
return Err(MultiplexError::MaxChannelsReached);
}
session.channel_count += 1;
session.update_activity();
connection.total_channels += 1;
return Ok(());
}
Err(MultiplexError::SessionNotFound)
}
pub async fn remove_channel_from_session(&self, session_id: u64) -> Result<(), MultiplexError> {
let mut connections = self.connections.write().await;
for connection in connections.values_mut() {
let session = connection
.sessions
.get_mut(&session_id)
.ok_or(MultiplexError::SessionNotFound)?;
session.channel_count = session.channel_count.saturating_sub(1);
session.update_activity();
connection.total_channels = connection.total_channels.saturating_sub(1);
return Ok(());
}
Err(MultiplexError::SessionNotFound)
}
pub async fn update_session_activity(&self, session_id: u64) -> Result<(), MultiplexError> {
let mut connections = self.connections.write().await;
for connection in connections.values_mut() {
if let Some(session) = connection.get_session_mut(session_id) {
session.update_activity();
return Ok(());
}
}
Err(MultiplexError::SessionNotFound)
}
pub async fn update_bytes(&self, connection_id: u64, sent: u64, received: u64) {
let mut connections = self.connections.write().await;
if let Some(connection) = connections.get_mut(&connection_id) {
connection.bytes_sent += sent;
connection.bytes_received += received;
}
}
pub async fn remove_session(&self, session_id: u64) -> Result<(), MultiplexError> {
let mut connections = self.connections.write().await;
for connection in connections.values_mut() {
if connection.remove_session(session_id).is_some() {
return Ok(());
}
}
Err(MultiplexError::SessionNotFound)
}
pub async fn remove_connection(&self, connection_id: u64) -> Result<(), MultiplexError> {
let mut connections = self.connections.write().await;
connections.remove(&connection_id);
Ok(())
}
pub async fn cleanup_expired_sessions(&self) -> usize {
let now = Instant::now();
let mut connections = self.connections.write().await;
let mut total_removed = 0;
for connection in connections.values_mut() {
let expired_session_ids: Vec<u64> = connection
.sessions
.values()
.filter(|s| s.is_expired(now, self.config.session_timeout))
.map(|s| s.session_id)
.collect();
for session_id in expired_session_ids {
connection.remove_session(session_id);
connection.total_channels = connection.total_channels.saturating_sub(1);
total_removed += 1;
}
}
connections.retain(|_, c| c.session_count() > 0);
total_removed
}
pub async fn cleanup_expired_connections(&self) -> usize {
let now = Instant::now();
let mut connections = self.connections.write().await;
let expired_count = connections.len();
connections.retain(|_, c| {
c.session_count() > 0
|| now.duration_since(c.created_at) <= self.config.control_persist_timeout
});
let retained_count = connections.len();
expired_count - retained_count
}
pub async fn get_stats(&self) -> MultiplexStats {
let connections = self.connections.read().await;
let total_connections = connections.len();
let total_sessions = connections.values().map(|c| c.session_count()).sum();
let total_channels = connections.values().map(|c| c.total_channels).sum();
let total_bytes_sent = connections.values().map(|c| c.bytes_sent).sum();
let total_bytes_received = connections.values().map(|c| c.bytes_received).sum();
let authenticated_sessions = connections
.values()
.flat_map(|c| c.sessions.values())
.filter(|s| s.is_authenticated)
.count();
MultiplexStats {
total_connections,
total_sessions,
total_channels,
authenticated_sessions,
total_bytes_sent,
total_bytes_received,
max_sessions_per_connection: self.config.max_sessions_per_connection,
max_channels_per_session: self.config.max_channels_per_session,
}
}
pub async fn get_connection(&self, connection_id: u64) -> Option<MultiplexConnection> {
let connections = self.connections.read().await;
connections.get(&connection_id).cloned()
}
pub async fn get_session(&self, session_id: u64) -> Option<MultiplexSession> {
let connections = self.connections.read().await;
for connection in connections.values() {
if let Some(session) = connection.get_session(session_id) {
return Some(session.clone());
}
}
None
}
}
#[derive(Debug, Clone)]
pub struct MultiplexStats {
pub total_connections: usize,
pub total_sessions: usize,
pub total_channels: usize,
pub authenticated_sessions: usize,
pub total_bytes_sent: u64,
pub total_bytes_received: u64,
pub max_sessions_per_connection: usize,
pub max_channels_per_session: usize,
}
#[derive(Debug, Clone)]
pub enum MultiplexError {
Disabled,
MaxConnectionsReached,
MaxSessionsReached,
MaxChannelsReached,
ConnectionNotFound,
SessionNotFound,
}
impl std::fmt::Display for MultiplexError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MultiplexError::Disabled => write!(f, "Multiplexing is disabled"),
MultiplexError::MaxConnectionsReached => write!(f, "Maximum connections reached"),
MultiplexError::MaxSessionsReached => write!(f, "Maximum sessions per connection reached"),
MultiplexError::MaxChannelsReached => write!(f, "Maximum channels per session reached"),
MultiplexError::ConnectionNotFound => write!(f, "Connection not found"),
MultiplexError::SessionNotFound => write!(f, "Session not found"),
}
}
}
impl std::error::Error for MultiplexError {}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
fn test_addr() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 22)
}
#[tokio::test]
async fn test_register_connection() {
let manager = MultiplexManager::default();
let addr = test_addr();
let conn_id = manager.register_connection(addr).await.unwrap();
assert_ne!(conn_id, 0);
let conn = manager.get_connection(conn_id).await.unwrap();
assert_eq!(conn.client_addr, addr);
}
#[tokio::test]
async fn test_register_session() {
let manager = MultiplexManager::default();
let addr = test_addr();
let conn_id = manager.register_connection(addr).await.unwrap();
let session_id = manager
.register_session(conn_id, addr, Some("testuser".to_string()))
.await
.unwrap();
assert_ne!(session_id, 0);
let session = manager.get_session(session_id).await.unwrap();
assert_eq!(session.username, Some("testuser".to_string()));
assert!(!session.is_authenticated);
}
#[tokio::test]
async fn test_authenticate_session() {
let manager = MultiplexManager::default();
let addr = test_addr();
let conn_id = manager.register_connection(addr).await.unwrap();
let session_id = manager.register_session(conn_id, addr, None).await.unwrap();
manager.authenticate_session(session_id).await.unwrap();
let session = manager.get_session(session_id).await.unwrap();
assert!(session.is_authenticated);
}
#[tokio::test]
async fn test_add_remove_channel() {
let manager = MultiplexManager::default();
let addr = test_addr();
let conn_id = manager.register_connection(addr).await.unwrap();
let session_id = manager.register_session(conn_id, addr, None).await.unwrap();
manager.add_channel_to_session(session_id).await.unwrap();
manager.add_channel_to_session(session_id).await.unwrap();
let session = manager.get_session(session_id).await.unwrap();
assert_eq!(session.channel_count, 2);
manager.remove_channel_from_session(session_id).await.unwrap();
let session = manager.get_session(session_id).await.unwrap();
assert_eq!(session.channel_count, 1);
}
#[tokio::test]
async fn test_max_sessions_per_connection() {
let config = MultiplexConfig {
max_sessions_per_connection: 2,
..Default::default()
};
let manager = MultiplexManager::new(config);
let addr = test_addr();
let conn_id = manager.register_connection(addr).await.unwrap();
manager.register_session(conn_id, addr, None).await.unwrap();
manager.register_session(conn_id, addr, None).await.unwrap();
let result = manager.register_session(conn_id, addr, None).await;
assert!(matches!(result, Err(MultiplexError::MaxSessionsReached)));
}
#[tokio::test]
async fn test_max_channels_per_session() {
let config = MultiplexConfig {
max_channels_per_session: 2,
..Default::default()
};
let manager = MultiplexManager::new(config);
let addr = test_addr();
let conn_id = manager.register_connection(addr).await.unwrap();
let session_id = manager.register_session(conn_id, addr, None).await.unwrap();
manager.add_channel_to_session(session_id).await.unwrap();
manager.add_channel_to_session(session_id).await.unwrap();
let result = manager.add_channel_to_session(session_id).await;
assert!(matches!(result, Err(MultiplexError::MaxChannelsReached)));
}
#[tokio::test]
async fn test_remove_session() {
let manager = MultiplexManager::default();
let addr = test_addr();
let conn_id = manager.register_connection(addr).await.unwrap();
let session_id = manager.register_session(conn_id, addr, None).await.unwrap();
manager.remove_session(session_id).await.unwrap();
let session = manager.get_session(session_id).await;
assert!(session.is_none());
}
#[tokio::test]
async fn test_remove_connection() {
let manager = MultiplexManager::default();
let addr = test_addr();
let conn_id = manager.register_connection(addr).await.unwrap();
manager.remove_connection(conn_id).await.unwrap();
let conn = manager.get_connection(conn_id).await;
assert!(conn.is_none());
}
#[tokio::test]
async fn test_cleanup_expired_sessions() {
let config = MultiplexConfig {
session_timeout: Duration::from_millis(100),
..Default::default()
};
let manager = MultiplexManager::new(config);
let addr = test_addr();
let conn_id = manager.register_connection(addr).await.unwrap();
let session_id = manager.register_session(conn_id, addr, None).await.unwrap();
tokio::time::sleep(Duration::from_millis(150)).await;
let removed = manager.cleanup_expired_sessions().await;
assert_eq!(removed, 1);
let session = manager.get_session(session_id).await;
assert!(session.is_none());
}
#[tokio::test]
async fn test_get_stats() {
let manager = MultiplexManager::default();
let addr = test_addr();
let conn_id = manager.register_connection(addr).await.unwrap();
let session_id = manager.register_session(conn_id, addr, None).await.unwrap();
manager.authenticate_session(session_id).await.unwrap();
manager.add_channel_to_session(session_id).await.unwrap();
let stats = manager.get_stats().await;
assert_eq!(stats.total_connections, 1);
assert_eq!(stats.total_sessions, 1);
assert_eq!(stats.total_channels, 1);
assert_eq!(stats.authenticated_sessions, 1);
}
#[tokio::test]
async fn test_update_bytes() {
let manager = MultiplexManager::default();
let addr = test_addr();
let conn_id = manager.register_connection(addr).await.unwrap();
manager.update_bytes(conn_id, 100, 50).await;
let conn = manager.get_connection(conn_id).await.unwrap();
assert_eq!(conn.bytes_sent, 100);
assert_eq!(conn.bytes_received, 50);
}
#[tokio::test]
async fn test_disabled_multiplexing() {
let config = MultiplexConfig {
enable_multiplexing: false,
..Default::default()
};
let manager = MultiplexManager::new(config);
let addr = test_addr();
let result = manager.register_connection(addr).await;
assert!(matches!(result, Err(MultiplexError::Disabled)));
}
}