Implement SSH Keep-alive support
Some checks failed
Test / build (push) Has been cancelled
Test / test (push) Has been cancelled

- Add keep_alive_interval and keep_alive_max_count to SshSecurityConfig
- Enterprise default: 15s interval, 3 max failures
- Development default: 30s interval, 5 max failures
- Track last_activity timestamp in service loop
- Send keepalive@openssh.com channel request when idle
- Disconnect after max keepalive failures
- Add build_keepalive_request() and get_first_session_channel()
- Prevents connection timeout on idle SSH sessions

All 229 tests pass.
This commit is contained in:
Warren
2026-06-20 23:29:14 +08:00
parent 82ff713b24
commit 783356852e
3 changed files with 81 additions and 0 deletions

View File

@@ -1359,6 +1359,35 @@ impl ChannelManager {
false false
} }
/// Keep-alive: Get first session channel for keepalive request
pub fn get_first_session_channel(&self) -> Option<u32> {
for (&id, channel) in &self.channels {
if channel.channel_type == "session" {
return Some(id);
}
}
None
}
/// Keep-alive: Build keepalive@openssh.com channel request
pub fn build_keepalive_request(&self, channel_id: u32) -> Result<SshPacket> {
let mut payload = Vec::new();
use byteorder::{BigEndian, WriteBytesExt};
payload.write_u8(PacketType::SSH_MSG_CHANNEL_REQUEST as u8)?;
payload.write_u32::<BigEndian>(channel_id)?;
// Request type: keepalive@openssh.com (SSH string)
let keepalive_type = "keepalive@openssh.com";
payload.write_u32::<BigEndian>(keepalive_type.len() as u32)?;
payload.write_all(keepalive_type.as_bytes())?;
// want_reply = true
payload.write_u8(1)?;
Ok(SshPacket::new(payload))
}
/// Phase 17: 关闭所有子进程stdin收到CHANNEL_EOF时调用 /// Phase 17: 关闭所有子进程stdin收到CHANNEL_EOF时调用
/// SCP upload需要scp -t 等待EOF on stdin才知道数据传输完毕 /// SCP upload需要scp -t 等待EOF on stdin才知道数据传输完毕
pub fn close_child_stdin(&mut self) { pub fn close_child_stdin(&mut self) {

View File

@@ -470,11 +470,42 @@ fn handle_ssh_service_loop(
) -> Result<()> { ) -> Result<()> {
info!("Starting SSH service loop (Phase 14.2: unified poll + child status)"); info!("Starting SSH service loop (Phase 14.2: unified poll + child status)");
// Keep-alive tracking
let keep_alive_interval = security_config.lock().unwrap().keep_alive_interval;
let keep_alive_max_count = security_config.lock().unwrap().keep_alive_max_count;
let mut last_activity = std::time::Instant::now();
let mut keep_alive_failures = 0;
loop { loop {
// ⭐⭐⭐⭐⭐ Phase 14.2: 统一poll + child状态检测 // ⭐⭐⭐⭐⭐ Phase 14.2: 统一poll + child状态检测
let (stdout_packets, client_has_data, child_exited) = let (stdout_packets, client_has_data, child_exited) =
channel_manager.poll_exec_stdout_and_client(stream)?; channel_manager.poll_exec_stdout_and_client(stream)?;
// Update activity timestamp on any data transfer
if stdout_packets.is_some() || client_has_data {
last_activity = std::time::Instant::now();
keep_alive_failures = 0;
}
// Keep-alive check: send if idle for too long
let idle_duration = last_activity.elapsed().as_secs();
if idle_duration >= keep_alive_interval && keep_alive_failures < keep_alive_max_count {
info!("Sending keepalive (idle {}s)", idle_duration);
if let Some(channel_id) = channel_manager.get_first_session_channel() {
let keepalive_packet = channel_manager.build_keepalive_request(channel_id)?;
let encrypted_keepalive = EncryptedPacket::new(&keepalive_packet.payload, encryption_ctx, true)?;
encrypted_keepalive.write(stream)?;
keep_alive_failures += 1;
last_activity = std::time::Instant::now();
}
}
// Disconnect if too many keepalive failures
if keep_alive_failures >= keep_alive_max_count {
warn!("Connection timed out (keepalive failures: {})", keep_alive_failures);
return Err(anyhow!("Connection timed out"));
}
// 1. 发送stdout/stderr数据如果有 // 1. 发送stdout/stderr数据如果有
if let Some(packets) = stdout_packets { if let Some(packets) = stdout_packets {
// Phase 4: Batch encrypt all packets in parallel // Phase 4: Batch encrypt all packets in parallel

View File

@@ -33,6 +33,14 @@ pub struct SshSecurityConfig {
/// 连接超时设置,防止悬挂连接 /// 连接超时设置,防止悬挂连接
pub connect_timeout: u64, pub connect_timeout: u64,
/// KeepAliveInterval
/// 心跳间隔,防止连接超时断开
pub keep_alive_interval: u64,
/// KeepAliveMaxCount
/// 最大心跳失败次数,超过则断开连接
pub keep_alive_max_count: u32,
/// 活动会话数(运行时状态) /// 活动会话数(运行时状态)
pub active_sessions: u32, pub active_sessions: u32,
} }
@@ -47,6 +55,8 @@ impl SshSecurityConfig {
allow_tcp_forwarding: true, // 允许TCP转发 allow_tcp_forwarding: true, // 允许TCP转发
max_sessions: 10, // 最多10个会话 max_sessions: 10, // 最多10个会话
connect_timeout: 30, // 30秒超时 connect_timeout: 30, // 30秒超时
keep_alive_interval: 15, // 15秒心跳间隔
keep_alive_max_count: 3, // 3次失败后断开
active_sessions: 0, // 运行时状态 active_sessions: 0, // 运行时状态
} }
} }
@@ -59,6 +69,8 @@ impl SshSecurityConfig {
allow_tcp_forwarding: true, allow_tcp_forwarding: true,
max_sessions: 20, // 开发:更多会话 max_sessions: 20, // 开发:更多会话
connect_timeout: 60, // 开发:更长超时 connect_timeout: 60, // 开发:更长超时
keep_alive_interval: 30, // 开发:更宽松心跳
keep_alive_max_count: 5, // 开发:更多失败容忍
active_sessions: 0, active_sessions: 0,
} }
} }
@@ -105,6 +117,15 @@ impl SshSecurityConfig {
.get("connect_timeout") .get("connect_timeout")
.and_then(|v| v.as_u64()) .and_then(|v| v.as_u64())
.unwrap_or(30), .unwrap_or(30),
keep_alive_interval: security
.get("keep_alive_interval")
.and_then(|v| v.as_u64())
.unwrap_or(15),
keep_alive_max_count: security
.get("keep_alive_max_count")
.and_then(|v| v.as_u64())
.map(|v| v as u32)
.unwrap_or(3),
active_sessions: 0, active_sessions: 0,
}) })
} }