diff --git a/markbase-core/src/ssh_server/channel.rs b/markbase-core/src/ssh_server/channel.rs index 51cd221..29551eb 100644 --- a/markbase-core/src/ssh_server/channel.rs +++ b/markbase-core/src/ssh_server/channel.rs @@ -8,7 +8,7 @@ use std::io::{Read, Write}; // 导入Write trait(OpenSSH标准) use anyhow::{Result, anyhow}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use log::{info, warn, debug, error}; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::sync::{Arc, Mutex}; use crate::ssh_server::sftp_handler::SftpHandler; // Phase 7: SFTP handler use crate::ssh_server::scp_handler::ScpHandler; // Phase 8: SCP handler @@ -23,6 +23,8 @@ use nix::poll::{poll, PollFd, PollFlags}; // Phase 14: poll机制(OpenSSH风 pub struct ChannelManager { channels: HashMap, next_channel_id: u32, + /// ⭐⭐⭐⭐⭐ Phase 15.1: 待发送packet队列(用于同时发送WINDOW_ADJUST和SFTP响应) + pub pending_packets: VecDeque, } /// Phase 14: 交互式Exec进程管理(参考OpenSSH session.c: do_exec_no_pty) @@ -42,6 +44,7 @@ impl ChannelManager { Self { channels: HashMap::new(), next_channel_id: 0, + pending_packets: VecDeque::new(), } } @@ -641,57 +644,119 @@ impl ChannelManager { } // Phase 7: 检查是否是SFTP channel(⭐⭐⭐⭐⭐ Phase 14.3: packet accumulation) - if let Some(sftp_handler) = &mut channel.sftp_handler { + // Extract SFTP result from channel borrow, then send outside + let sftp_result = if let Some(sftp_handler) = &mut channel.sftp_handler { info!("Processing SFTP request ({} bytes)", data.len()); + // ⭐⭐⭐⭐⭐ Window Control: decrease local_window + channel.local_window -= data.len() as u32; + channel.local_consumed += data.len() as u32; + // ⭐⭐⭐⭐⭐ Critical修复:累积SFTP packet数据 channel.sftp_input_buffer.extend_from_slice(&data); info!("SFTP buffer accumulated: {} bytes total", channel.sftp_input_buffer.len()); - // 检查buffer是否有足够数据解析packet length - if channel.sftp_input_buffer.len() < 4 { - info!("SFTP buffer too short for length field, waiting for more data"); - return Ok(None); // 继续累积 + // ⭐⭐⭐⭐⭐ Process ALL complete SFTP packets from buffer (not just one) + let mut all_responses: Vec> = Vec::new(); + loop { + if channel.sftp_input_buffer.len() < 4 { + info!("SFTP buffer too short for length field, waiting for more data"); + break; + } + + let sftp_length = u32::from_be_bytes([ + channel.sftp_input_buffer[0], + channel.sftp_input_buffer[1], + channel.sftp_input_buffer[2], + channel.sftp_input_buffer[3] + ]) as usize; + info!("SFTP packet length field: {}", sftp_length); + + let expected_total = 4 + sftp_length; + if channel.sftp_input_buffer.len() < expected_total { + info!("SFTP packet incomplete: expected {} bytes, have {} bytes in buffer, waiting for more", + expected_total, channel.sftp_input_buffer.len()); + break; + } + + let sftp_packet = channel.sftp_input_buffer[4..expected_total].to_vec(); + info!("SFTP packet complete: {} bytes, processing", sftp_packet.len()); + + let response = sftp_handler.handle_request(&sftp_packet)?; + info!("SFTP response: {} bytes", response.len()); + + if channel.sftp_input_buffer.len() > expected_total { + let remaining = channel.sftp_input_buffer[expected_total..].to_vec(); + channel.sftp_input_buffer = remaining; + info!("SFTP buffer has remaining {} bytes after processing", channel.sftp_input_buffer.len()); + } else { + channel.sftp_input_buffer.clear(); + info!("SFTP buffer cleared after processing"); + } + + all_responses.push(response); } - // 解析SFTP packet length(前4 bytes) - let sftp_length = u32::from_be_bytes([ - channel.sftp_input_buffer[0], - channel.sftp_input_buffer[1], - channel.sftp_input_buffer[2], - channel.sftp_input_buffer[3] - ]) as usize; - info!("SFTP packet length field: {}", sftp_length); + Some(all_responses) + } else { + None + }; + + if let Some(responses) = sftp_result { + // ⭐⭐⭐⭐⭐ Channel borrow is dropped; now we can use self freely - let expected_total = 4 + sftp_length; - if channel.sftp_input_buffer.len() < expected_total { - info!("SFTP packet incomplete: expected {} bytes, have {} bytes in buffer, waiting for more", - expected_total, channel.sftp_input_buffer.len()); - return Ok(None); // 继续累积 + // All responses except the last go to pending_packets + for i in 0..responses.len().saturating_sub(1) { + let pending = self.build_channel_data(recipient_channel, &responses[i])?; + self.pending_packets.push_back(pending); } - // ⭐⭐⭐⭐⭐ Buffer足够,解析完整SFTP packet - let sftp_packet = &channel.sftp_input_buffer[4..expected_total]; - info!("SFTP packet complete: {} bytes, processing", sftp_packet.len()); - info!("SFTP packet content (first 20 bytes): {:?}", &sftp_packet[..std::cmp::min(20, sftp_packet.len())]); - - let response = sftp_handler.handle_request(sftp_packet)?; - info!("SFTP response: {} bytes", response.len()); - - // ⭐⭐⭐⭐⭐ 处理完后,清空buffer或保留剩余数据 - if channel.sftp_input_buffer.len() > expected_total { - // 有剩余数据(多个packets的情况) - let remaining = channel.sftp_input_buffer[expected_total..].to_vec(); - channel.sftp_input_buffer = remaining; - info!("SFTP buffer has remaining {} bytes after processing", channel.sftp_input_buffer.len()); + // Last response is returned (possibly with WINDOW_ADJUST) + if let Some(last_response) = responses.into_iter().last() { + // ⭐⭐⭐⭐⭐ Check window adjust (re-borrow channel briefly) + let (needs_window, consumed) = if let Some(ch) = self.channels.get_mut(&recipient_channel) { + let window_used = ch.local_window_max - ch.local_window; + let need = (window_used > ch.local_maxpacket * 3) || + (ch.local_window < ch.local_window_max / 2); + (need, ch.local_consumed) + } else { + (false, 0) + }; + + if needs_window && consumed > 0 { + info!("⭐⭐⭐⭐⭐ [SFTP_WINDOW] Sending WINDOW_ADJUST before SFTP response"); + let window_adjust = build_window_adjust(recipient_channel, consumed); + // Update window state + if let Some(ch) = self.channels.get_mut(&recipient_channel) { + ch.local_window += consumed; + ch.local_consumed = 0; + } + // Use standalone builder to avoid self borrow conflict + use std::io::Write; + let mut payload = Vec::new(); + payload.write_u8(PacketType::SSH_MSG_CHANNEL_DATA as u8)?; + payload.write_u32::(recipient_channel)?; + payload.write_u32::(last_response.len() as u32)?; + payload.write_all(&last_response)?; + let sftp_packet = SshPacket::new(payload); + self.pending_packets.push_back(window_adjust); + self.pending_packets.push_back(sftp_packet); + return Ok(None); + } + return Ok(Some(self.build_channel_data(recipient_channel, &last_response)?)); } else { - // 清空buffer - channel.sftp_input_buffer.clear(); - info!("SFTP buffer cleared after processing"); + // No SFTP packets were complete, but maybe we need window adjust + if let Some(ch) = self.channels.get_mut(&recipient_channel) { + let window_used = ch.local_window_max - ch.local_window; + if (window_used > ch.local_maxpacket * 3 || ch.local_window < ch.local_window_max / 2) && ch.local_consumed > 0 { + let window_adjust = build_window_adjust(recipient_channel, ch.local_consumed); + ch.local_window += ch.local_consumed; + ch.local_consumed = 0; + self.pending_packets.push_back(window_adjust); + } + } + return Ok(None); } - - // 构建SSH_MSG_CHANNEL_DATA返回SFTP响应(需要SSH string格式) - return Ok(Some(self.build_channel_data(recipient_channel, &response)?)); } } @@ -699,6 +764,15 @@ impl ChannelManager { Ok(None) } + /// ⭐⭐⭐⭐⭐ Phase 13.5: 处理 client 发送的 SSH_MSG_CHANNEL_WINDOW_ADJUST + pub fn adjust_remote_window(&mut self, recipient_channel: u32, bytes_to_add: u32) { + if let Some(channel) = self.channels.get_mut(&recipient_channel) { + channel.remote_window = channel.remote_window.saturating_add(bytes_to_add); + info!("⭐⭐⭐⭐⭐ [ADJUST_REMOTE_WINDOW] channel {} remote_window increased by {} (new: {})", + recipient_channel, bytes_to_add, channel.remote_window); + } + } + /// Phase 14: 构建SSH_MSG_CHANNEL_EXTENDED_DATA(参考OpenSSH channel.c) fn build_channel_extended_data(&self, channel: u32, data_type: u32, data: &[u8]) -> Result { let mut buffer = Vec::new(); diff --git a/markbase-core/src/ssh_server/server.rs b/markbase-core/src/ssh_server/server.rs index f80605d..24b1563 100644 --- a/markbase-core/src/ssh_server/server.rs +++ b/markbase-core/src/ssh_server/server.rs @@ -501,6 +501,13 @@ fn handle_ssh_service_loop( encrypted_response.write(stream)?; info!("Sent SSH_MSG_CHANNEL_DATA (SFTP response)"); } + + // ⭐⭐⭐⭐⭐ Phase 15.1: Drain pending packets (e.g. WINDOW_ADJUST + delayed SFTP response) + while let Some(pending) = channel_manager.pending_packets.pop_front() { + let encrypted_pending = EncryptedPacket::new(&pending.payload, encryption_ctx, true)?; + encrypted_pending.write(stream)?; + info!("Sent pending packet (type {})", pending.payload.first().unwrap_or(&0)); + } } Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_CLOSE as u8 => { info!("Received SSH_MSG_CHANNEL_CLOSE"); @@ -518,6 +525,15 @@ fn handle_ssh_service_loop( info!("Received SSH_MSG_DISCONNECT"); break; } + Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_WINDOW_ADJUST as u8 => { + let payload = &packet.payload; + if payload.len() >= 9 { + // Format: uint32 recipient_channel || uint32 bytes_to_add + let recipient_channel = u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]); + let bytes_to_add = u32::from_be_bytes([payload[5], payload[6], payload[7], payload[8]]); + channel_manager.adjust_remote_window(recipient_channel, bytes_to_add); + } + } _ => { warn!("Unknown packet type: {:?}", packet.payload.first()); } diff --git a/markbase-core/src/ssh_server/sftp_handler.rs b/markbase-core/src/ssh_server/sftp_handler.rs index 113921b..f948207 100644 --- a/markbase-core/src/ssh_server/sftp_handler.rs +++ b/markbase-core/src/ssh_server/sftp_handler.rs @@ -341,21 +341,6 @@ impl SftpHandler { let version = cursor.read_u32::()?; info!("Client SFTP version: {}", version); - // Read any extension data client sent (SSH_FXP_INIT may contain extensions) - let pos = cursor.position() as usize; - let inner = cursor.get_ref(); - if inner.len() > pos && (inner.len() - pos) >= 4 { - let ext_count = match cursor.read_u32::() { - Ok(n) => n, - Err(_) => 0, - }; - for i in 0..ext_count { - let ext_name = read_sftp_string(&mut cursor).unwrap_or_default(); - let ext_data = read_sftp_string(&mut cursor).unwrap_or_default(); - debug!("Client extension[{}]: {} = {}", i, ext_name, ext_data); - } - } - let response = self.build_version_response(3)?; Ok(response) } @@ -376,8 +361,8 @@ impl SftpHandler { let full_path = self.resolve_path(&path)?; - let file = if pflags & SftpFileFlags::SSH_FXF_READ != 0 { - OpenOptions::new().read(true).open(&full_path).ok() + let file_result = if pflags & SftpFileFlags::SSH_FXF_READ != 0 { + OpenOptions::new().read(true).open(&full_path) } else if pflags & SftpFileFlags::SSH_FXF_WRITE != 0 { let mut opts = OpenOptions::new(); opts.write(true); @@ -393,13 +378,13 @@ impl SftpHandler { if pflags & SftpFileFlags::SSH_FXF_EXCL != 0 { opts.create_new(true); } - opts.open(&full_path).ok() + opts.open(&full_path) } else { - None + return self.build_status_response(id, SftpStatus::SSH_FX_OP_UNSUPPORTED, "Unsupported open flags"); }; - match file { - Some(file) => { + match file_result { + Ok(file) => { if self.handles.len() >= Self::MAX_HANDLES { warn!("SSH_FXP_OPEN: handle limit reached ({})", Self::MAX_HANDLES); return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Handle limit reached"); @@ -419,8 +404,8 @@ impl SftpHandler { self.build_handle_response(id, &handle_id.to_be_bytes()) } - None => { - self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Failed to open file") + Err(e) => { + self.build_status_from_io_error(id, &e) } } } @@ -1478,14 +1463,25 @@ impl SftpHandler { } } -/// 构建SSH_FXP_VERSION响应,包含扩展声明(参考OpenSSH sftp-server.c) +/// 构建SSH_FXP_VERSION响应,包含扩展声明(参考OpenSSH sftp-server.c: process_init()) + /// + /// SFTP协议格式(draft-ietf-secsh-filexfer-02): + /// uint32 length + /// uint8 type (SSH_FXP_VERSION = 2) + /// uint32 version + /// // extensions: NO count field, simply paired strings until buffer empty + /// string extension_name (= uint32(len_with_nul) + data + \0) + /// string extension_data (= uint32(len_with_nul) + data + \0) + /// + /// OpenSSH uses sshbuf_put_cstring() which includes NUL terminator. + /// Client reads with sshbuf_get_cstring() which expects \0 at end. fn build_version_response(&self, version: u32) -> Result> { let mut buffer = Vec::new(); buffer.write_u8(SftpPacketType::SSH_FXP_VERSION as u8)?; buffer.write_u32::(version)?; - // 扩展声明(OpenSSH sftp-server.c: process_init() 中声明支持的扩展) + // 扩展声明 — OpenSSH sftp-server.c: process_init() style, NO count field let extensions: &[(&str, &str)] = &[ ("posix-rename@openssh.com", "1"), ("hardlink@openssh.com", "1"), @@ -1498,13 +1494,14 @@ impl SftpHandler { ("sha384-hash@openssh.com", "1"), ("sha512-hash@openssh.com", "1"), ]; - - buffer.write_u32::(extensions.len() as u32)?; for (name, data) in extensions { - buffer.write_u32::(name.len() as u32)?; + // sshbuf_put_cstring(buf, s) → sshbuf_put_string(buf, s, strlen(s)+1) + buffer.write_u32::((name.len() + 1) as u32)?; buffer.write_all(name.as_bytes())?; - buffer.write_u32::(data.len() as u32)?; + buffer.write_u8(0)?; + buffer.write_u32::((data.len() + 1) as u32)?; buffer.write_all(data.as_bytes())?; + buffer.write_u8(0)?; } self.wrap_sftp_packet(&buffer)