diff --git a/markbase-core/src/ssh_server/channel.rs b/markbase-core/src/ssh_server/channel.rs index 05b4fa2..658998d 100644 --- a/markbase-core/src/ssh_server/channel.rs +++ b/markbase-core/src/ssh_server/channel.rs @@ -1359,6 +1359,35 @@ impl ChannelManager { false } + /// Keep-alive: Get first session channel for keepalive request + pub fn get_first_session_channel(&self) -> Option { + for (&id, channel) in &self.channels { + if channel.channel_type == "session" { + return Some(id); + } + } + None + } + + /// Keep-alive: Build keepalive@openssh.com channel request + pub fn build_keepalive_request(&self, channel_id: u32) -> Result { + let mut payload = Vec::new(); + use byteorder::{BigEndian, WriteBytesExt}; + + payload.write_u8(PacketType::SSH_MSG_CHANNEL_REQUEST as u8)?; + payload.write_u32::(channel_id)?; + + // Request type: keepalive@openssh.com (SSH string) + let keepalive_type = "keepalive@openssh.com"; + payload.write_u32::(keepalive_type.len() as u32)?; + payload.write_all(keepalive_type.as_bytes())?; + + // want_reply = true + payload.write_u8(1)?; + + Ok(SshPacket::new(payload)) + } + /// Phase 17: 关闭所有子进程stdin(收到CHANNEL_EOF时调用) /// SCP upload需要:scp -t 等待EOF on stdin才知道数据传输完毕 pub fn close_child_stdin(&mut self) { diff --git a/markbase-core/src/ssh_server/server.rs b/markbase-core/src/ssh_server/server.rs index 15aa284..02c96e4 100644 --- a/markbase-core/src/ssh_server/server.rs +++ b/markbase-core/src/ssh_server/server.rs @@ -470,11 +470,42 @@ fn handle_ssh_service_loop( ) -> Result<()> { info!("Starting SSH service loop (Phase 14.2: unified poll + child status)"); + // Keep-alive tracking + let keep_alive_interval = security_config.lock().unwrap().keep_alive_interval; + let keep_alive_max_count = security_config.lock().unwrap().keep_alive_max_count; + let mut last_activity = std::time::Instant::now(); + let mut keep_alive_failures = 0; + loop { // ⭐⭐⭐⭐⭐ Phase 14.2: 统一poll + child状态检测 let (stdout_packets, client_has_data, child_exited) = channel_manager.poll_exec_stdout_and_client(stream)?; + // Update activity timestamp on any data transfer + if stdout_packets.is_some() || client_has_data { + last_activity = std::time::Instant::now(); + keep_alive_failures = 0; + } + + // Keep-alive check: send if idle for too long + let idle_duration = last_activity.elapsed().as_secs(); + if idle_duration >= keep_alive_interval && keep_alive_failures < keep_alive_max_count { + info!("Sending keepalive (idle {}s)", idle_duration); + if let Some(channel_id) = channel_manager.get_first_session_channel() { + let keepalive_packet = channel_manager.build_keepalive_request(channel_id)?; + let encrypted_keepalive = EncryptedPacket::new(&keepalive_packet.payload, encryption_ctx, true)?; + encrypted_keepalive.write(stream)?; + keep_alive_failures += 1; + last_activity = std::time::Instant::now(); + } + } + + // Disconnect if too many keepalive failures + if keep_alive_failures >= keep_alive_max_count { + warn!("Connection timed out (keepalive failures: {})", keep_alive_failures); + return Err(anyhow!("Connection timed out")); + } + // 1. 发送stdout/stderr数据(如果有) if let Some(packets) = stdout_packets { // Phase 4: Batch encrypt all packets in parallel diff --git a/markbase-core/src/ssh_server/ssh_security_config.rs b/markbase-core/src/ssh_server/ssh_security_config.rs index e4eba94..a697ddc 100644 --- a/markbase-core/src/ssh_server/ssh_security_config.rs +++ b/markbase-core/src/ssh_server/ssh_security_config.rs @@ -33,6 +33,14 @@ pub struct SshSecurityConfig { /// 连接超时设置,防止悬挂连接 pub connect_timeout: u64, + /// KeepAliveInterval(秒) + /// 心跳间隔,防止连接超时断开 + pub keep_alive_interval: u64, + + /// KeepAliveMaxCount + /// 最大心跳失败次数,超过则断开连接 + pub keep_alive_max_count: u32, + /// 活动会话数(运行时状态) pub active_sessions: u32, } @@ -47,6 +55,8 @@ impl SshSecurityConfig { allow_tcp_forwarding: true, // 允许TCP转发 max_sessions: 10, // 最多10个会话 connect_timeout: 30, // 30秒超时 + keep_alive_interval: 15, // 15秒心跳间隔 + keep_alive_max_count: 3, // 3次失败后断开 active_sessions: 0, // 运行时状态 } } @@ -59,6 +69,8 @@ impl SshSecurityConfig { allow_tcp_forwarding: true, max_sessions: 20, // 开发:更多会话 connect_timeout: 60, // 开发:更长超时 + keep_alive_interval: 30, // 开发:更宽松心跳 + keep_alive_max_count: 5, // 开发:更多失败容忍 active_sessions: 0, } } @@ -105,6 +117,15 @@ impl SshSecurityConfig { .get("connect_timeout") .and_then(|v| v.as_u64()) .unwrap_or(30), + keep_alive_interval: security + .get("keep_alive_interval") + .and_then(|v| v.as_u64()) + .unwrap_or(15), + keep_alive_max_count: security + .get("keep_alive_max_count") + .and_then(|v| v.as_u64()) + .map(|v| v as u32) + .unwrap_or(3), active_sessions: 0, }) }