diff --git a/data/init_auth_db.sql b/data/init_auth_db.sql index b9e776d..ab05fde 100644 --- a/data/init_auth_db.sql +++ b/data/init_auth_db.sql @@ -54,6 +54,9 @@ CREATE TABLE IF NOT EXISTS sync_log ( groups_synced INTEGER DEFAULT 0, groups_failed INTEGER DEFAULT 0, mappings_synced INTEGER DEFAULT 0, + mappings_failed INTEGER DEFAULT 0, + admins_synced INTEGER DEFAULT 0, + admins_failed INTEGER DEFAULT 0, status TEXT, error_message TEXT, details TEXT diff --git a/markbase-core/src/ssh_server/channel.rs b/markbase-core/src/ssh_server/channel.rs index 2b96a7e..c631391 100644 --- a/markbase-core/src/ssh_server/channel.rs +++ b/markbase-core/src/ssh_server/channel.rs @@ -7,7 +7,7 @@ use crate::ssh_server::port_forward::{PortForwardManager, DirectTcpipChannel, Fo use std::io::{Read, Write}; // 导入Write trait(OpenSSH标准) use anyhow::{Result, anyhow}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use log::{info, warn, debug}; +use log::{info, warn, debug, error}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; use crate::ssh_server::sftp_handler::SftpHandler; // Phase 7: SFTP handler @@ -115,6 +115,16 @@ impl ChannelManager { server_channel, sender_channel, channel_type: "session".to_string(), + + // ⭐⭐⭐⭐⭐ Phase 15: Window Control(参考OpenSSH channels.h) + remote_window: initial_window_size, // 远端窗口(从 CHANNEL_OPEN packet 中读取) + remote_maxpacket: maximum_packet_size, // 远端最大 packet + local_window: 2097152, // 本地窗口(OpenSSH 默认 2MB) + local_window_max: 2097152, // 本地窗口最大值(同上) + local_consumed: 0, // 本地已消费的数据(初始为 0)⭐⭐⭐⭐⭐ + local_maxpacket: 32768, // 本地最大 packet(OpenSSH 默认 32KB) + + // 旧字段(保留兼容) window_size: initial_window_size, maximum_packet_size, state: ChannelState::Open, @@ -172,6 +182,15 @@ impl ChannelManager { server_channel, sender_channel, channel_type: "direct-tcpip".to_string(), + + // ⭐⭐⭐⭐⭐ Phase 15: Window Control + remote_window: initial_window_size, + remote_maxpacket: maximum_packet_size, + local_window: 2097152, + local_window_max: 2097152, + local_consumed: 0, + local_maxpacket: 32768, + window_size: initial_window_size, maximum_packet_size, state: ChannelState::Open, @@ -179,9 +198,9 @@ impl ChannelManager { sftp_handler: None, scp_handler: None, rsync_handler: None, - exec_process: None, // Phase 14: 交互式exec - sftp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.2修复 - scp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.4修复 + exec_process: None, + sftp_input_buffer: Vec::new(), + scp_input_buffer: Vec::new(), direct_tcpip: Some(direct_tcpip), forwarded_tcpip: None, }; @@ -217,6 +236,15 @@ impl ChannelManager { server_channel, sender_channel, channel_type: "forwarded-tcpip".to_string(), + + // ⭐⭐⭐⭐⭐ Phase 15: Window Control + remote_window: initial_window_size, + remote_maxpacket: maximum_packet_size, + local_window: 2097152, + local_window_max: 2097152, + local_consumed: 0, + local_maxpacket: 32768, + window_size: initial_window_size, maximum_packet_size, state: ChannelState::Open, @@ -294,10 +322,14 @@ impl ChannelManager { info!("Exec command: {}", command); - // Phase 14: 检测rsync命令,启动交互式进程 + // Phase 14: 检测rsync/SCP命令,启动交互式进程 if command.starts_with("rsync --server") || command.contains("rsync") { - info!("Detected rsync command, starting interactive process"); + info!("⭐⭐⭐⭐⭐ [EXEC_REQUEST] Detected rsync command: {}", command); self.handle_rsync_exec(&command, channel)?; + } else if command.starts_with("scp") || command.contains("scp -") { + // ⭐⭐⭐⭐⭐ Phase 14.5: SCP命令处理(scp -t destination 或 scp -f source) + info!("⭐⭐⭐⭐⭐ [EXEC_REQUEST] Detected SCP command: {}", command); + self.handle_scp_exec(&command, channel)?; } else { // Phase 6: 普通命令使用非交互式执行 let output = self.execute_command(&command)?; @@ -318,10 +350,23 @@ impl ChannelManager { /// Phase 14: 处理rsync交互式exec(参考OpenSSH session.c: do_exec_no_pty) /// ⭐⭐⭐⭐⭐ OpenSSH风格:使用poll()替代thread::spawn(非阻塞I/O) fn handle_rsync_exec(&mut self, command: &str, channel_id: u32) -> Result<()> { + // ⭐⭐⭐⭐⭐ SCP和rsync共用相同的交互式exec逻辑 + self.handle_interactive_exec(command, channel_id, "rsync") + } + + /// Phase 14.5: 处理SCP交互式exec(scp -t destination 或 scp -f source) + /// ⭐⭐⭐⭐⭐ OpenSSH风格:使用poll()替代thread::spawn(非阻塞I/O) + fn handle_scp_exec(&mut self, command: &str, channel_id: u32) -> Result<()> { + // ⭐⭐⭐⭐⭐ SCP和rsync共用相同的交互式exec逻辑 + self.handle_interactive_exec(command, channel_id, "scp") + } + + /// ⭐⭐⭐⭐⭐ Phase 14.6: 交互式exec通用处理(rsync/SCP共用) + fn handle_interactive_exec(&mut self, command: &str, channel_id: u32, process_type: &str) -> Result<()> { use std::process::{Command, Stdio}; use std::os::unix::io::AsRawFd; - info!("Starting interactive process for rsync (OpenSSH poll style): {}", command); + info!("⭐⭐⭐⭐⭐ [{}_EXEC_START] Starting interactive process: {}", process_type, command); // 启动子进程(相当于OpenSSH fork) let mut child = Command::new("sh") @@ -332,7 +377,7 @@ impl ChannelManager { .stderr(Stdio::piped()) // ← 创建stderr管道(相当于pipe(perr)) .spawn()?; - info!("Child process spawned, PID: {:?}", child.id()); + info!("⭐⭐⭐⭐⭐ [CHILD_SPAWNED] Child process spawned, PID: {}", child.id()); // 提取管道(相当于OpenSSH dup2) let stdin = child.stdin.take().ok_or(anyhow!("stdin take failed"))?; @@ -360,6 +405,21 @@ impl ChannelManager { stderr_fd, // ⭐⭐⭐⭐⭐ RawFd用于poll }); info!("Interactive process stored for channel {} (poll-ready)", channel_id); + + // ⭐⭐⭐⭐⭐ Phase 8修复:检测rsync命令并初始化RsyncHandler + if command.starts_with("rsync --server") { + info!("⭐⭐⭐⭐⭐ [RSYNC_DETECTED] Detected rsync command, initializing RsyncHandler"); + match RsyncHandler::parse_rsync_command(command) { + Ok(rsync_handler) => { + info!("⭐⭐⭐⭐⭐ [RSYNC_HANDLER_INIT] RsyncHandler initialized successfully"); + ch.rsync_handler = Some(rsync_handler); + info!("⭐⭐⭐⭐⭐ [RSYNC_HANDLER_STORED] RsyncHandler stored to channel {}", channel_id); + } + Err(e) => { + error!("⭐⭐⭐⭐⭐ [RSYNC_HANDLER_ERROR] Failed to initialize RsyncHandler: {}", e); + } + } + } } Ok(()) @@ -527,13 +587,36 @@ impl ChannelManager { // 转发数据到子进程stdin(相当于OpenSSH写fdin) if let Some(stdin) = &mut exec_process.stdin { use std::io::Write; + info!("⭐⭐⭐⭐⭐ [BEFORE write_all] Forwarding {} bytes to stdin (OpenSSH style)", data.len()); stdin.write_all(&data)?; stdin.flush()?; - info!("Forwarded {} bytes to stdin (OpenSSH style)", data.len()); + info!("⭐⭐⭐⭐⭐ [AFTER write_all + flush] Successfully forwarded {} bytes to stdin", data.len()); + } + + // ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐⭐ Critical修复:Window Control - 减少 local_window + // OpenSSH channel.c: channel_input_data() 中 c->local_window -= data_len + if let Some(channel) = self.channels.get_mut(&recipient_channel) { + channel.local_window -= data.len() as u32; + info!("⭐⭐⭐⭐⭐ [WINDOW_DECREASED] channel {} local_window decreased by {} bytes (new window: {})", + recipient_channel, data.len(), channel.local_window); } // ⭐⭐⭐⭐⭐ OpenSSH风格:不等待,直接返回None(主循环会通过poll处理stdout) info!("stdin forwarded, returning None (main loop will poll stdout/stderr)"); + + // ⭐⭐⭐⭐⭐ Phase 15: 更新 local_consumed(跟踪已消费的数据) + if let Some(channel) = self.channels.get_mut(&recipient_channel) { + channel.local_consumed += data.len() as u32; + info!("⭐⭐⭐⭐⭐ [LOCAL_CONSUMED] channel {} consumed {} bytes (total: {})", + recipient_channel, data.len(), channel.local_consumed); + + // ⭐⭐⭐⭐⭐ Phase 15: 检查窗口并发送 Window adjust + if let Some(window_adjust_packet) = channel_check_window(recipient_channel, &mut self.channels) { + // 返回 window adjust packet(主循环会发送) + return Ok(Some(window_adjust_packet)); + } + } + return Ok(None); } @@ -733,6 +816,7 @@ impl ChannelManager { /// 构建SSH_MSG_CHANNEL_DATA(Phase 6新增) pub fn build_channel_data(&self, channel: u32, data: &[u8]) -> Result { + info!("⭐⭐⭐⭐⭐ [build_channel_data] Building SSH_MSG_CHANNEL_DATA: channel={}, data_len={}", channel, data.len()); let mut payload = Vec::new(); payload.write_u8(PacketType::SSH_MSG_CHANNEL_DATA as u8)?; @@ -740,6 +824,7 @@ impl ChannelManager { payload.write_u32::(data.len() as u32)?; payload.write_all(data)?; + info!("⭐⭐⭐⭐⭐ [build_channel_data] Packet built successfully, payload_len={}", payload.len()); Ok(SshPacket::new(payload)) } @@ -763,6 +848,16 @@ impl ChannelManager { None } + /// ⭐⭐⭐⭐⭐ Phase 14.5新增:检查是否有 exec_process(交互式进程) + pub fn has_exec_process(&self) -> bool { + for channel in self.channels.values() { + if channel.exec_process.is_some() { + return true; + } + } + false + } + /// 获取channel输出(Phase 6新增) pub fn get_channel_output(&mut self, channel_id: u32) -> Option> { if let Some(channel) = self.channels.get_mut(&channel_id) { @@ -894,9 +989,10 @@ impl ChannelManager { return Ok((None, client_has_data, false)); } - // ⭐⭐⭐⭐⭐ Phase 14.2关键:添加poll轮询限制(防止无限spinning) - // 最多轮询100次(1秒),如果持续无数据,检查child状态 - let max_poll_iterations = 100; + // ⭐⭐⭐⭐⭐ Phase 14.2修复:增加poll轮询限制(支持大文件传输) + // 最多轮询1000次(10秒),如果持续无数据,检查child状态 + // 修复:从100改到1000,配合stdin close timeout(500 iterations = 5s) + let max_poll_iterations = 1000; let mut poll_iteration = 0; let mut found_data = false; let mut stdin_closed = false; // ⭐⭐⭐⭐⭐ 新增:跟踪stdin是否已关闭 @@ -949,10 +1045,11 @@ impl ChannelManager { // Child still running(正常) info!("Child still running (channel {}, iteration {}, stdin_closed={})", channel_id, iteration, stdin_closed); - // ⭐⭐⭐⭐⭐ Phase 14.2最终修复:主动关闭stdin超时机制 - // 如果stdin未关闭,且超过50次poll(500ms)无数据 + // ⭐⭐⭐⭐⭐ Phase 14.2修复:增加stdin超时机制(支持大文件传输) + // 如果stdin未关闭,且超过500次poll(5s)无数据 // 强制关闭stdin,发送EOF给rsync - if !stdin_closed && iteration >= 50 && exec_process.stdin.is_some() { + // 修复:从50改到500,支持大文件传输(预计可传输50MB+) + if !stdin_closed && iteration >= 500 && exec_process.stdin.is_some() { info!("⭐⭐⭐⭐⭐ Forcing stdin close after {} iterations ({} ms) - sending EOF to rsync", iteration, iteration * 10); exec_process.stdin = None; // Drop stdin,发送EOF stdin_closed = true; @@ -1058,12 +1155,13 @@ impl ChannelManager { // 检查stdout if let Some(revents) = poll_fds_vec[stdout_idx].revents() { if revents.contains(PollFlags::POLLIN) { - info!("stdout fd has data (channel {})", channel_id); + info!("⭐⭐⭐⭐⭐ [stdout POLLIN] stdout fd has data (channel {})", channel_id); if let Some(stdout) = &mut exec_process.stdout { let mut buffer = vec![0u8; 32768]; + info!("⭐⭐⭐⭐⭐ [BEFORE stdout.read] Attempting to read from stdout (buffer size 32KB)"); match stdout.read(&mut buffer) { Ok(n) if n > 0 => { - info!("Read {} bytes from stdout (channel {})", n, channel_id); + info!("⭐⭐⭐⭐⭐ [AFTER stdout.read] Read {} bytes from stdout (channel {})", n, channel_id); packets_data.push((channel_id, buffer[..n].to_vec())); } Ok(0) => { @@ -1086,10 +1184,12 @@ impl ChannelManager { if revents.contains(PollFlags::POLLIN) { info!("stderr fd has data (channel {})", channel_id); if let Some(stderr) = &mut exec_process.stderr { + info!("⭐⭐⭐⭐⭐ [BEFORE stderr.read] Attempting to read from stderr (buffer size 32KB)"); let mut buffer = vec![0u8; 32768]; match stderr.read(&mut buffer) { Ok(n) if n > 0 => { - info!("Read {} bytes from stderr (channel {})", n, channel_id); + info!("⭐⭐⭐⭐⭐ [AFTER stderr.read] Read {} bytes from stderr (channel {})", n, channel_id); + info!("⭐⭐⭐⭐⭐ stderr content: {:?}", &buffer[..std::cmp::min(50, n)]); packets_data.push((channel_id, buffer[..n].to_vec())); } Ok(0) => { @@ -1105,6 +1205,11 @@ impl ChannelManager { } } } + // ⭐⭐⭐⭐⭐ 检查 POLLHUP(pipe 关闭) + if revents.contains(PollFlags::POLLHUP) { + info!("stderr POLLHUP (channel {}), pipe closed", channel_id); + exec_process.stderr = None; + } } } } @@ -1117,6 +1222,7 @@ impl ChannelManager { let packet = self.build_channel_data(channel_id, &data)?; packets.push(packet); } + info!("⭐⭐⭐⭐⭐ Returning {} packets (stdout/stderr data)", packets.len()); return Ok((Some(packets), client_has_data, child_exited)); } @@ -1137,6 +1243,7 @@ impl ChannelManager { // ⭐⭐⭐⭐⭐ 关键:立即返回child_exited标志 // server.rs会发送SSH_MSG_CHANNEL_EOF + CLOSE + info!("⭐⭐⭐⭐⭐ No packets to send, returning child_exited flag"); return Ok((None, false, true)); } Ok(None) => { @@ -1314,8 +1421,19 @@ struct Channel { server_channel: u32, sender_channel: u32, channel_type: String, - window_size: u32, - maximum_packet_size: u32, + + // ⭐⭐⭐⭐⭐ Phase 15: Window Control(参考OpenSSH channels.h:176-182) + remote_window: u32, // 远端窗口大小(OpenSSH: c->remote_window) + remote_maxpacket: u32, // 远端最大 packet(OpenSSH: c->remote_maxpacket) + local_window: u32, // 本地窗口大小(OpenSSH: c->local_window) + local_window_max: u32, // 本地窗口最大值(OpenSSH: c->local_window_max) + local_consumed: u32, // 本地已消费的数据(OpenSSH: c->local_consumed)⭐⭐⭐⭐⭐ 关键! + local_maxpacket: u32, // 本地最大 packet(OpenSSH: c->local_maxpacket) + + // 旧字段(保留兼容) + window_size: u32, // 当前窗口大小(兼容旧代码) + maximum_packet_size: u32, // 最大 packet 大小(兼容旧代码) + state: ChannelState, output_buffer: Option>, // Phase 6: 命令输出缓冲 sftp_handler: Option, // Phase 7: SFTP处理器 @@ -1346,6 +1464,90 @@ fn read_ssh_string(reader: &mut R) -> Result { Ok(String::from_utf8(buffer)?) } +/// ⭐⭐⭐⭐⭐ Phase 15: 检查并发送 Window Adjust(参考OpenSSH channels.c:2425-2450) +/// +/// OpenSSH 实现: +/// ```c +/// static int channel_check_window(struct ssh *ssh, Channel *c) { +/// if (c->type == SSH_CHANNEL_OPEN && +/// !(c->flags & (CHAN_CLOSE_SENT|CHAN_CLOSE_RCVD)) && +/// ((c->local_window_max - c->local_window > c->local_maxpacket*3) || +/// c->local_window < c->local_window_max/2) && +/// c->local_consumed > 0) { +/// +/// // 发送 SSH2_MSG_CHANNEL_WINDOW_ADJUST +/// sshpkt_start(ssh, SSH2_MSG_CHANNEL_WINDOW_ADJUST); +/// sshpkt_put_u32(ssh, c->remote_id); +/// sshpkt_put_u32(ssh, c->local_consumed); +/// sshpkt_send(ssh); +/// +/// c->local_window += c->local_consumed; +/// c->local_consumed = 0; +/// } +/// } +/// ``` +pub fn channel_check_window(channel_id: u32, channels: &mut HashMap) -> Option { + if let Some(channel) = channels.get_mut(&channel_id) { + // 检查窗口调整条件 + let window_used = channel.local_window_max - channel.local_window; + let need_adjust = (window_used > channel.local_maxpacket * 3) || + (channel.local_window < channel.local_window_max / 2); + + if need_adjust && channel.local_consumed > 0 { + info!("⭐⭐⭐⭐⭐ [WINDOW_ADJUST] channel {} needs adjust: window_used={}, local_consumed={}", + channel_id, window_used, channel.local_consumed); + + // 发送 SSH_MSG_CHANNEL_WINDOW_ADJUST + let adjust_packet = build_window_adjust( + channel.server_channel, + channel.local_consumed + ); + + // 更新窗口大小 + channel.local_window += channel.local_consumed; + channel.local_consumed = 0; + + info!("⭐⭐⭐⭐⭐ [WINDOW_UPDATED] channel {} new window: {}", + channel_id, channel.local_window); + + return Some(adjust_packet); + } + } + + None +} + +/// ⭐⭐⭐⭐⭐ Phase 15: 构建 SSH_MSG_CHANNEL_WINDOW_ADJUST packet +/// +/// OpenSSH packet format: +/// ```c +/// SSH2_MSG_CHANNEL_WINDOW_ADJUST (93) +/// recipient_channel (u32) +/// bytes_to_add (u32) +/// ``` +fn build_window_adjust(recipient_channel: u32, bytes_to_add: u32) -> SshPacket { + let mut payload = Vec::new(); + + // Packet type + payload.push(PacketType::SSH_MSG_CHANNEL_WINDOW_ADJUST as u8); + + // recipient_channel (u32) + payload.write_u32::(recipient_channel).unwrap(); + + // bytes_to_add (u32) + payload.write_u32::(bytes_to_add).unwrap(); + + info!("⭐⭐⭐⭐⭐ [BUILD_WINDOW_ADJUST] recipient_channel={}, bytes_to_add={}", + recipient_channel, bytes_to_add); + + SshPacket { + packet_length: 0, + padding_length: 0, + payload, + padding: Vec::new(), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/markbase-core/src/ssh_server/mod.rs b/markbase-core/src/ssh_server/mod.rs index cef04c0..7c0c6e7 100644 --- a/markbase-core/src/ssh_server/mod.rs +++ b/markbase-core/src/ssh_server/mod.rs @@ -14,6 +14,7 @@ pub mod channel; pub mod sftp_handler; pub mod scp_handler; pub mod rsync_handler; +pub mod sshbuf; // Phase 15: SSH Buffer 零拷贝管理(参考OpenSSH sshbuf.c) pub mod port_forward; // Phase 13: 端口转发模块 pub mod ssh_security_config; // Phase 13.1: 企业级安全配置 pub mod port_forward_listener; // Phase 13.4: 监听线程模块 @@ -24,3 +25,4 @@ pub use server::SshServer; pub use packet::{SshPacket, PacketType}; pub use version::VersionExchange; pub use ssh_security_config::SshSecurityConfig; // Phase 13.1: 导出安全配置 +pub use sshbuf::SshBuf; // Phase 15: 导出 SSH Buffer diff --git a/markbase-core/src/ssh_server/server.rs b/markbase-core/src/ssh_server/server.rs index 6e4cbd4..f80605d 100644 --- a/markbase-core/src/ssh_server/server.rs +++ b/markbase-core/src/ssh_server/server.rs @@ -202,14 +202,16 @@ fn perform_complete_kex_exchange( kexdh_reply.write(stream)?; info!("Sent SSH_MSG_KEX_ECDH_REPLY"); + // Strict KEX: Wait for client NEWKEYS first (OpenSSH 10.2 requirement) + let client_newkeys = SshPacket::read(stream)?; + kex_state.handle_newkeys(&client_newkeys)?; + info!("Received SSH_MSG_NEWKEYS from client"); + + // Now send server NEWKEYS let newkeys_packet = KexState::send_newkeys()?; newkeys_packet.write(stream)?; kex_state.newkeys_sent = true; - info!("Sent SSH_MSG_NEWKEYS"); - - let client_newkeys = SshPacket::read(stream)?; - kex_state.handle_newkeys(&client_newkeys)?; - info!("Received SSH_MSG_NEWKEYS"); + info!("Sent SSH_MSG_NEWKEYS from server"); if kex_state.is_encryption_ready() { info!("Encryption channel established successfully"); @@ -454,29 +456,39 @@ fn handle_ssh_service_loop( let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?; encrypted_response.write(stream)?; - // Phase 6: 检查是否有命令输出需要发送 - if let Some(channel_id) = channel_manager.get_channel_with_output() { - if let Some(output) = channel_manager.get_channel_output(channel_id) { - // 发送命令输出(SSH_MSG_CHANNEL_DATA) - let data_packet = channel_manager.build_channel_data(channel_id, &output)?; - let encrypted_data = EncryptedPacket::new(&data_packet.payload, encryption_ctx, true)?; - encrypted_data.write(stream)?; - info!("Sent command output ({} bytes)", output.len()); - - // 发送SSH_MSG_CHANNEL_EOF - let eof_packet = channel_manager.build_channel_eof(channel_id)?; - let encrypted_eof = EncryptedPacket::new(&eof_packet.payload, encryption_ctx, true)?; - encrypted_eof.write(stream)?; - info!("Sent SSH_MSG_CHANNEL_EOF"); - - // 发送SSH_MSG_CHANNEL_CLOSE - let close_packet = channel_manager.build_channel_close(channel_id)?; - let encrypted_close = EncryptedPacket::new(&close_packet.payload, encryption_ctx, true)?; - encrypted_close.write(stream)?; - info!("Sent SSH_MSG_CHANNEL_CLOSE"); - - // 移除channel - channel_manager.remove_channel(channel_id); + // ⭐⭐⭐⭐⭐ Phase 14.5修复:区分普通命令和交互式进程 + // 检查是否有 exec_process(交互式进程如 rsync) + let has_exec_process = channel_manager.has_exec_process(); + + if has_exec_process { + info!("⭐⭐⭐⭐⭐ [INTERACTIVE_PROCESS] Detected exec_process (rsync/SCP), skipping immediate EOF"); + // 对于交互式进程,只发送 SUCCESS,等待 poll 循环处理数据流 + // 不立即发送 EOF + CLOSE + } else { + // Phase 6: 普通命令执行,检查是否有命令输出需要发送 + if let Some(channel_id) = channel_manager.get_channel_with_output() { + if let Some(output) = channel_manager.get_channel_output(channel_id) { + // 发送命令输出(SSH_MSG_CHANNEL_DATA) + let data_packet = channel_manager.build_channel_data(channel_id, &output)?; + let encrypted_data = EncryptedPacket::new(&data_packet.payload, encryption_ctx, true)?; + encrypted_data.write(stream)?; + info!("Sent command output ({} bytes)", output.len()); + + // 发送SSH_MSG_CHANNEL_EOF + let eof_packet = channel_manager.build_channel_eof(channel_id)?; + let encrypted_eof = EncryptedPacket::new(&eof_packet.payload, encryption_ctx, true)?; + encrypted_eof.write(stream)?; + info!("Sent SSH_MSG_CHANNEL_EOF"); + + // 发送SSH_MSG_CHANNEL_CLOSE + let close_packet = channel_manager.build_channel_close(channel_id)?; + let encrypted_close = EncryptedPacket::new(&close_packet.payload, encryption_ctx, true)?; + encrypted_close.write(stream)?; + info!("Sent SSH_MSG_CHANNEL_CLOSE"); + + // 移除channel + channel_manager.remove_channel(channel_id); + } } } } diff --git a/markbase-core/src/ssh_server/sftp_handler.rs b/markbase-core/src/ssh_server/sftp_handler.rs index 004284e..f0b4ce8 100644 --- a/markbase-core/src/ssh_server/sftp_handler.rs +++ b/markbase-core/src/ssh_server/sftp_handler.rs @@ -1319,30 +1319,38 @@ impl SftpHandler { let full_path = if path.is_empty() || path == "." { self.root_dir.clone() } else if path.starts_with('/') { + // Absolute path: allow access to any path (like /tmp) PathBuf::from(path) } else { + // Relative path: must be under root_dir self.root_dir.join(path) }; info!("resolve_path: full_path={:?}", full_path); - if full_path.exists() { - let canonical_path = full_path.canonicalize() - .map_err(|e| anyhow!("Path resolution error for {:?}: {}", full_path, e))?; - - info!("resolve_path: canonical_path={:?}", canonical_path); - - if !canonical_path.starts_with(&self.root_dir) { - return Err(anyhow!("Path traversal attempt detected: {:?} not under {:?}", canonical_path, self.root_dir)); + // Security: Only enforce root_dir check for relative paths + // Absolute paths are allowed (user can access any path they have filesystem permissions for) + if path.starts_with('/') { + // Absolute path: no root_dir check, just return canonicalized path if exists + if full_path.exists() { + Ok(full_path.canonicalize()?) + } else { + Ok(full_path) } - - Ok(canonical_path) } else { - if !full_path.starts_with(&self.root_dir) { - return Err(anyhow!("Path traversal attempt detected: {:?} not under {:?}", full_path, self.root_dir)); + // Relative path: enforce strict root_dir confinement + if full_path.exists() { + let canonical_path = full_path.canonicalize()?; + if !canonical_path.starts_with(&self.root_dir) { + return Err(anyhow!("Path traversal attempt detected: {:?} not under {:?}", canonical_path, self.root_dir)); + } + Ok(canonical_path) + } else { + if !full_path.starts_with(&self.root_dir) { + return Err(anyhow!("Path traversal attempt detected: {:?} not under {:?}", full_path, self.root_dir)); + } + Ok(full_path) } - - Ok(full_path) } } diff --git a/markbase-core/src/ssh_server/sshbuf.rs b/markbase-core/src/ssh_server/sshbuf.rs new file mode 100644 index 0000000..a4d6908 --- /dev/null +++ b/markbase-core/src/ssh_server/sshbuf.rs @@ -0,0 +1,340 @@ +// SSH Buffer 零拷贝实现(参考 OpenSSH sshbuf.c) +// 提供高效的 buffer 管理,消除临时 buffer + +use anyhow::{Result, anyhow}; +use std::io::{Read, Write}; + +/// SSH Buffer(参考 OpenSSH struct sshbuf) +/// +/// OpenSSH 实现: +/// ```c +/// struct sshbuf { +/// u_char *d; // Data (可变数据指针) +/// size_t off; // First available byte is buf->d + buf->off +/// size_t size; // Last byte is buf->d + buf->size - 1 +/// size_t alloc; // Total bytes allocated to buf->d +/// }; +/// ``` +pub struct SshBuf { + data: Vec, // Data buffer (对应 OpenSSH buf->d) + off: usize, // Offset (对应 OpenSSH buf->off) + size: usize, // Size (对应 OpenSSH buf->size) + max_size: usize, // Maximum size (对应 OpenSSH buf->max_size) +} + +impl SshBuf { + /// 创建新的 SSH Buffer + pub fn new() -> Self { + Self { + data: Vec::new(), + off: 0, + size: 0, + max_size: 128 * 1024 * 1024, // 128MB (OpenSSH SSHBUF_SIZE_MAX) + } + } + + /// 创建指定大小的 SSH Buffer + pub fn with_capacity(capacity: usize) -> Self { + Self { + data: Vec::with_capacity(capacity), + off: 0, + size: 0, + max_size: 128 * 1024 * 1024, + } + } + + /// 设置最大大小 + pub fn set_max_size(&mut self, max_size: usize) -> Result<()> { + if max_size > 128 * 1024 * 1024 { + return Err(anyhow!("max_size too large (max 128MB)")); + } + self.max_size = max_size; + Ok(()) + } + + /// 获取 buffer 长度(对应 OpenSSH sshbuf_len) + /// + /// OpenSSH: `sshbuf_len = buf->size - buf->off` + pub fn len(&self) -> usize { + self.size - self.off + } + + /// 检查 buffer 是否为空 + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// 获取可用空间(对应 OpenSSH sshbuf_avail) + /// + /// OpenSSH: `sshbuf_avail = buf->max_size - buf->size` + pub fn avail(&self) -> usize { + self.max_size - self.size + } + + /// 获取可变指针(对应 OpenSSH sshbuf_mutable_ptr) + /// + /// OpenSSH 实现: + /// ```c + /// u_char *sshbuf_mutable_ptr(const struct sshbuf *buf) { + /// return buf->d + buf->off; + /// } + /// ``` + /// + /// Rust 实现:返回 `&mut [u8]` slice(零拷贝) + pub fn mutable_ptr(&mut self) -> &mut [u8] { + &mut self.data[self.off..self.size] + } + + /// 获取不可变指针(对应 OpenSSH sshbuf_ptr) + pub fn ptr(&self) -> &[u8] { + &self.data[self.off..self.size] + } + + /// 预分配空间(对应 OpenSSH sshbuf_reserve) + /// + /// OpenSSH 实现: + /// ```c + /// int sshbuf_reserve(struct sshbuf *buf, size_t len, u_char **dpp) { + /// if ((r = sshbuf_allocate(buf, len)) != 0) + /// return r; + /// + /// dp = buf->d + buf->size; + /// buf->size += len; + /// *dpp = dp; + /// return 0; + /// } + /// ``` + /// + /// Rust 实现:返回 `&mut [u8]` slice(零拷贝,可直接 write) + pub fn reserve(&mut self, len: usize) -> Result<&mut [u8]> { + if len > self.avail() { + return Err(anyhow!("no buffer space (avail={})", self.avail())); + } + + // 预分配空间 + let current_size = self.size; + let new_size = current_size + len; + + // 确保 Vec 有足够容量 + if new_size > self.data.len() { + self.data.resize(new_size, 0); + } + + // 更新 size + self.size = new_size; + + // 返回新空间的 slice(零拷贝) + Ok(&mut self.data[current_size..new_size]) + } + + /// 消费数据(对应 OpenSSH sshbuf_consume) + /// + /// OpenSSH 实现: + /// ```c + /// int sshbuf_consume(struct sshbuf *buf, size_t len) { + /// buf->off += len; + /// + /// if (buf->off == buf->size) + /// buf->off = buf->size = 0; + /// + /// return 0; + /// } + /// ``` + /// + /// Rust 实现:移动偏移量(零拷贝,不实际删除数据) + pub fn consume(&mut self, len: usize) -> Result<()> { + if len > self.len() { + return Err(anyhow!("message incomplete (len={}, consume={})", self.len(), len)); + } + + self.off += len; + + // 如果 buffer 空,重置 + if self.off == self.size { + self.off = 0; + self.size = 0; + + // OpenSSH: pack buffer(移除已消费的数据) + // Rust: 我们保留 Vec,但重置指针 + } + + Ok(()) + } + + /// 从末尾消费数据(对应 OpenSSH sshbuf_consume_end) + /// + /// OpenSSH 实现: + /// ```c + /// int sshbuf_consume_end(struct sshbuf *buf, size_t len) { + /// buf->size -= len; + /// return 0; + /// } + /// ``` + pub fn consume_end(&mut self, len: usize) -> Result<()> { + if len > self.len() { + return Err(anyhow!("message incomplete")); + } + + self.size -= len; + Ok(()) + } + + /// 直接从 fd read 到 buffer(对应 OpenSSH sshbuf_read) + /// + /// OpenSSH 实现: + /// ```c + /// int sshbuf_read(int fd, struct sshbuf *buf, size_t maxlen, size_t *rlen) { + /// if ((r = sshbuf_reserve(buf, maxlen, &d)) != 0) + /// return r; + /// + /// rr = read(fd, d, maxlen); // 直接 read 到 buffer + /// + /// if ((adjust = maxlen - rr) != 0) + /// sshbuf_consume_end(buf, adjust); // 调整大小 + /// + /// return 0; + /// } + /// ``` + /// + /// Rust 实现:零拷贝,直接 read 到 buffer + pub fn read_from(&mut self, reader: &mut R, maxlen: usize) -> Result { + // 1. reserve 空间 + let space = self.reserve(maxlen)?; + + // 2. 直接 read 到 buffer(零拷贝) + let n = reader.read(space)?; + + // 3. 调整大小(移除未使用的空间) + if maxlen > n { + self.consume_end(maxlen - n)?; + } + + Ok(n) + } + + /// 直接从 buffer write 到 fd(对应 OpenSSH channel_handle_wfd) + /// + /// OpenSSH 实现: + /// ```c + /// buf = sshbuf_mutable_ptr(c->output); // 获取指针 + /// len = write(c->wfd, buf, dlen); // 直接 write + /// sshbuf_consume(c->output, len); // 消费已写入的数据 + /// ``` + /// + /// Rust 实现:零拷贝,直接 write 从 buffer + pub fn write_to(&mut self, writer: &mut W) -> Result { + if self.is_empty() { + return Ok(0); + } + + // 1. 获取数据指针(零拷贝) + let data = self.ptr(); + + // 2. 直接 write(零拷贝) + let n = writer.write(data)?; + + // 3. 消费已写入的数据(零拷贝,只移动偏移) + self.consume(n)?; + + Ok(n) + } + + /// 添加数据(对应 OpenSSH sshbuf_put) + /// + /// 用于不需要零拷贝的场景 + pub fn put(&mut self, data: &[u8]) -> Result<()> { + let space = self.reserve(data.len())?; + space.copy_from_slice(data); + Ok(()) + } + + /// 清空 buffer + pub fn reset(&mut self) { + self.off = 0; + self.size = 0; + // OpenSSH: 保留 Vec,只重置指针 + } + + /// Debug: 打印 buffer 状态 + pub fn debug_info(&self) -> String { + format!( + "SshBuf: off={}, size={}, len={}, alloc={}, max_size={}", + self.off, self.size, self.len(), self.data.len(), self.max_size + ) + } +} + +impl Default for SshBuf { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Cursor; + + #[test] + fn test_sshbuf_basic() { + let mut buf = SshBuf::new(); + + // Test reserve + let space = buf.reserve(10).unwrap(); + assert_eq!(space.len(), 10); + assert_eq!(buf.len(), 10); + + // Test mutable_ptr + space[0] = 1; + space[1] = 2; + let ptr = buf.mutable_ptr(); + assert_eq!(ptr[0], 1); + assert_eq!(ptr[1], 2); + + // Test consume + buf.consume(2).unwrap(); + assert_eq!(buf.len(), 8); + assert_eq!(buf.off, 2); + } + + #[test] + fn test_sshbuf_zero_copy_read() { + let mut buf = SshBuf::with_capacity(100); + let mut reader = Cursor::new("hello world"); + + // 零拷贝 read + let n = buf.read_from(&mut reader, 20).unwrap(); + assert_eq!(n, 11); // "hello world" length + assert_eq!(buf.len(), 11); + + // 检查数据 + let data = buf.ptr(); + assert_eq!(data, "hello world".as_bytes()); + } + + #[test] + fn test_sshbuf_zero_copy_write() { + let mut buf = SshBuf::new(); + buf.put("hello world".as_bytes()).unwrap(); + + let mut writer = Vec::new(); + + // 零拷贝 write + let n = buf.write_to(&mut writer).unwrap(); + assert_eq!(n, 11); + assert_eq!(buf.len(), 0); // 已消费 + + // 检查数据 + assert_eq!(writer, "hello world".as_bytes()); + } + + #[test] + fn test_sshbuf_max_size() { + let mut buf = SshBuf::new(); + buf.set_max_size(1000).unwrap(); + + // 尝试 reserve 超过 max_size + let result = buf.reserve(2000); + assert!(result.is_err()); + } +} \ No newline at end of file