**Problem Fixed**: - SFTP packet incomplete errors solved - Large file transfers now work (>=8KB) - SSH splits large packets into multiple CHANNEL_DATA **Implementation**: - sftp_input_buffer: Vec<u8> accumulation field - Accumulate CHANNEL_DATA until complete SFTP packet - Parse length field (4 bytes) to determine packet size - Process when buffer >= expected_total - Clear buffer or keep remaining data **Testing Results** ⭐⭐⭐⭐⭐: - SFTP 1MB upload: SUCCESS ✅ (MD5: 38fd6536467443dfdc91f89c0fd573d8) - SCP 1MB transfer: SUCCESS ✅ (MD5: 38fd6536467443dfdc91f89c0fd573d8) - rsync 1MB transfer: SUCCESS ✅ (53.84MB/s) - rsync 2MB transfer: FAILED ❌ (rsync protocol issue, separate from accumulation) **Code Changes**: - handle_channel_data(): 40 lines modified - Accumulation logic with buffer management - Multiple packet handling (remaining data preserved) **Key Achievement**: - SFTP/SCP large file support complete - Only rsync protocol needs Phase 8 implementation **Progress**: SSH 96% complete, SFTP/SCP subsystems fixed
1370 lines
61 KiB
Rust
1370 lines
61 KiB
Rust
// SSH Channel协议实现(Phase 6 + Phase 13端口转发)
|
||
// 参考OpenSSH channel.c
|
||
|
||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
||
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.3: 安全配置
|
||
use crate::ssh_server::port_forward::{PortForwardManager, DirectTcpipChannel, ForwardedTcpipChannel}; // Phase 13.3
|
||
use std::io::{Read, Write}; // 导入Write trait(OpenSSH标准)
|
||
use anyhow::{Result, anyhow};
|
||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||
use log::{info, warn, debug};
|
||
use std::collections::HashMap;
|
||
use std::sync::{Arc, Mutex};
|
||
use crate::ssh_server::sftp_handler::SftpHandler; // Phase 7: SFTP handler
|
||
use crate::ssh_server::scp_handler::ScpHandler; // Phase 8: SCP handler
|
||
use crate::ssh_server::rsync_handler::RsyncHandler; // Phase 8: rsync handler
|
||
use std::path::PathBuf; // Phase 7-8: Path for SFTP/SCP/rsync root directory
|
||
use std::process::{Child, ChildStdin, ChildStdout, ChildStderr}; // Phase 14: 交互式exec
|
||
use std::os::unix::io::{AsRawFd, RawFd}; // Phase 14: OpenSSH风格poll机制(需要RawFd)
|
||
use nix::fcntl::{fcntl, FcntlArg, OFlag}; // Phase 14: 非阻塞I/O(OpenSSH风格)
|
||
use nix::poll::{poll, PollFd, PollFlags}; // Phase 14: poll机制(OpenSSH风格)
|
||
|
||
/// SSH Channel管理器(参考OpenSSH channel.c: struct channel)
|
||
pub struct ChannelManager {
|
||
channels: HashMap<u32, Channel>,
|
||
next_channel_id: u32,
|
||
}
|
||
|
||
/// Phase 14: 交互式Exec进程管理(参考OpenSSH session.c: do_exec_no_pty)
|
||
/// ⭐⭐⭐⭐⭐ OpenSSH风格:使用poll()替代thread::spawn(非阻塞I/O)
|
||
pub struct ExecProcess {
|
||
pub child: Child, // 子进程(rsync/scp等)
|
||
pub stdin: Option<ChildStdin>, // stdin管道(SSH client → 子进程)
|
||
pub stdout: Option<ChildStdout>, // ⭐⭐⭐⭐⭐ stdout管道(直接poll,不使用thread)
|
||
pub stderr: Option<ChildStderr>, // ⭐⭐⭐⭐⭐ stderr管道(直接poll,不使用thread)
|
||
pub stdout_fd: RawFd, // ⭐⭐⭐⭐⭐ stdout RawFd(用于poll)
|
||
pub stderr_fd: RawFd, // ⭐⭐⭐⭐⭐ stderr RawFd(用于poll)
|
||
}
|
||
|
||
impl ChannelManager {
|
||
pub fn new() -> Self {
|
||
Self {
|
||
channels: HashMap::new(),
|
||
next_channel_id: 0,
|
||
}
|
||
}
|
||
|
||
/// 处理SSH_MSG_CHANNEL_OPEN(参考OpenSSH channel.c: channel_open())
|
||
/// Phase 13.3: 支持direct-tcpip和forwarded-tcpip channel
|
||
pub fn handle_channel_open(&mut self, packet: &SshPacket, security_config: Option<&SshSecurityConfig>) -> Result<SshPacket> {
|
||
info!("Processing SSH_MSG_CHANNEL_OPEN");
|
||
|
||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
|
||
|
||
// Packet type
|
||
let packet_type = cursor.read_u8()?;
|
||
if packet_type != PacketType::SSH_MSG_CHANNEL_OPEN as u8 {
|
||
return Err(anyhow!("Invalid packet type for CHANNEL_OPEN"));
|
||
}
|
||
|
||
// 读取channel类型(SSH string)
|
||
let channel_type = read_ssh_string(&mut cursor)?;
|
||
|
||
// 读取sender channel ID(u32)
|
||
let sender_channel = cursor.read_u32::<BigEndian>()?;
|
||
|
||
// 读取初始窗口大小(u32)
|
||
let initial_window_size = cursor.read_u32::<BigEndian>()?;
|
||
|
||
// 读取最大packet大小(u32)
|
||
let maximum_packet_size = cursor.read_u32::<BigEndian>()?;
|
||
|
||
info!("Channel open: type={}, sender_channel={}, window={}, max_packet={}",
|
||
channel_type, sender_channel, initial_window_size, maximum_packet_size);
|
||
|
||
// Phase 13.3: 检查channel类型(支持session、direct-tcpip、forwarded-tcpip)
|
||
match channel_type.as_str() {
|
||
"session" => {
|
||
// 传统的session channel(Phase 6)
|
||
self.handle_session_channel_open(sender_channel, initial_window_size, maximum_packet_size)
|
||
}
|
||
|
||
"direct-tcpip" => {
|
||
// Phase 13.3: Remote port forwarding channel
|
||
info!("Received direct-tcpip channel open (Remote port forwarding)");
|
||
self.handle_direct_tcpip_channel_open(packet, sender_channel, initial_window_size, maximum_packet_size, security_config)
|
||
}
|
||
|
||
"forwarded-tcpip" => {
|
||
// Phase 13.3: Local port forwarding channel
|
||
info!("Received forwarded-tcpip channel open (Local port forwarding)");
|
||
self.handle_forwarded_tcpip_channel_open(packet, sender_channel, initial_window_size, maximum_packet_size)
|
||
}
|
||
|
||
_ => {
|
||
warn!("Unsupported channel type: {}", channel_type);
|
||
self.build_channel_open_failure(
|
||
sender_channel,
|
||
3, // SSH_OPEN_UNKNOWN_CHANNEL_TYPE
|
||
"Unsupported channel type",
|
||
"en"
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
|
||
/// 处理session channel open(Phase 6)
|
||
fn handle_session_channel_open(&mut self, sender_channel: u32, initial_window_size: u32, maximum_packet_size: u32) -> Result<SshPacket> {
|
||
info!("Processing session channel open");
|
||
|
||
let server_channel = self.next_channel_id;
|
||
self.next_channel_id += 1;
|
||
|
||
let channel = Channel {
|
||
server_channel,
|
||
sender_channel,
|
||
channel_type: "session".to_string(),
|
||
window_size: initial_window_size,
|
||
maximum_packet_size,
|
||
state: ChannelState::Open,
|
||
output_buffer: None,
|
||
sftp_handler: None,
|
||
scp_handler: None,
|
||
rsync_handler: None,
|
||
exec_process: None, // Phase 14: 交互式exec
|
||
sftp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.2修复:SFTP packet累积
|
||
direct_tcpip: None,
|
||
forwarded_tcpip: None,
|
||
};
|
||
|
||
self.channels.insert(server_channel, channel);
|
||
|
||
info!("Session channel created: server_channel={}, sender_channel={}", server_channel, sender_channel);
|
||
|
||
self.build_channel_open_confirmation(server_channel, sender_channel, initial_window_size, maximum_packet_size)
|
||
}
|
||
|
||
/// 处理direct-tcpip channel open(Phase 13.3: Remote port forwarding)
|
||
fn handle_direct_tcpip_channel_open(
|
||
&mut self,
|
||
packet: &SshPacket,
|
||
sender_channel: u32,
|
||
initial_window_size: u32,
|
||
maximum_packet_size: u32,
|
||
security_config: Option<&SshSecurityConfig>,
|
||
) -> Result<SshPacket> {
|
||
info!("Processing direct-tcpip channel open");
|
||
|
||
// 解析direct-tcpip参数
|
||
let mut port_forward_manager = PortForwardManager::new();
|
||
let direct_tcpip = port_forward_manager.handle_direct_tcpip_channel(&packet.payload)?;
|
||
|
||
// Phase 13.3: 安全配置验证
|
||
if let Some(security) = security_config {
|
||
if let Err(e) = security.validate_direct_tcpip_channel(&direct_tcpip.host_to_connect, direct_tcpip.port_to_connect) {
|
||
warn!("direct-tcpip security validation failed: {}", e);
|
||
return self.build_channel_open_failure(
|
||
sender_channel,
|
||
2, // SSH_OPEN_CONNECT_FAILED
|
||
"Security validation failed",
|
||
"en"
|
||
);
|
||
}
|
||
info!("direct-tcpip security validation passed");
|
||
}
|
||
|
||
let server_channel = self.next_channel_id;
|
||
self.next_channel_id += 1;
|
||
|
||
let channel = Channel {
|
||
server_channel,
|
||
sender_channel,
|
||
channel_type: "direct-tcpip".to_string(),
|
||
window_size: initial_window_size,
|
||
maximum_packet_size,
|
||
state: ChannelState::Open,
|
||
output_buffer: None,
|
||
sftp_handler: None,
|
||
scp_handler: None,
|
||
rsync_handler: None,
|
||
exec_process: None, // Phase 14: 交互式exec
|
||
sftp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.2修复
|
||
direct_tcpip: Some(direct_tcpip),
|
||
forwarded_tcpip: None,
|
||
};
|
||
|
||
self.channels.insert(server_channel, channel);
|
||
|
||
info!("direct-tcpip channel created: server_channel={}, host={}, port={}",
|
||
server_channel,
|
||
self.channels.get(&server_channel).unwrap().direct_tcpip.as_ref().unwrap().host_to_connect,
|
||
self.channels.get(&server_channel).unwrap().direct_tcpip.as_ref().unwrap().port_to_connect);
|
||
|
||
self.build_channel_open_confirmation(server_channel, sender_channel, initial_window_size, maximum_packet_size)
|
||
}
|
||
|
||
/// 处理forwarded-tcpip channel open(Phase 13.3: Local port forwarding)
|
||
fn handle_forwarded_tcpip_channel_open(
|
||
&mut self,
|
||
packet: &SshPacket,
|
||
sender_channel: u32,
|
||
initial_window_size: u32,
|
||
maximum_packet_size: u32,
|
||
) -> Result<SshPacket> {
|
||
info!("Processing forwarded-tcpip channel open");
|
||
|
||
// 解析forwarded-tcpip参数
|
||
let mut port_forward_manager = PortForwardManager::new();
|
||
let forwarded_tcpip = port_forward_manager.handle_forwarded_tcpip_channel(&packet.payload)?;
|
||
|
||
let server_channel = self.next_channel_id;
|
||
self.next_channel_id += 1;
|
||
|
||
let channel = Channel {
|
||
server_channel,
|
||
sender_channel,
|
||
channel_type: "forwarded-tcpip".to_string(),
|
||
window_size: initial_window_size,
|
||
maximum_packet_size,
|
||
state: ChannelState::Open,
|
||
output_buffer: None,
|
||
sftp_handler: None,
|
||
scp_handler: None,
|
||
rsync_handler: None,
|
||
exec_process: None, // Phase 14: 交互式exec
|
||
sftp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.2修复
|
||
direct_tcpip: None,
|
||
forwarded_tcpip: Some(forwarded_tcpip),
|
||
};
|
||
|
||
self.channels.insert(server_channel, channel);
|
||
|
||
info!("forwarded-tcpip channel created: server_channel={}, bind={}, originator={}",
|
||
server_channel,
|
||
self.channels.get(&server_channel).unwrap().forwarded_tcpip.as_ref().unwrap().bind_port,
|
||
self.channels.get(&server_channel).unwrap().forwarded_tcpip.as_ref().unwrap().originator_address);
|
||
|
||
self.build_channel_open_confirmation(server_channel, sender_channel, initial_window_size, maximum_packet_size)
|
||
}
|
||
/// 处理SSH_MSG_CHANNEL_REQUEST(参考OpenSSH channel.c: channel_request())
|
||
pub fn handle_channel_request(&mut self, packet: &SshPacket) -> Result<Option<SshPacket>> {
|
||
info!("Processing SSH_MSG_CHANNEL_REQUEST");
|
||
|
||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准)
|
||
|
||
// Packet type
|
||
let packet_type = cursor.read_u8()?;
|
||
if packet_type != PacketType::SSH_MSG_CHANNEL_REQUEST as u8 {
|
||
return Err(anyhow!("Invalid packet type for CHANNEL_REQUEST"));
|
||
}
|
||
|
||
// 读取recipient channel(u32)
|
||
let recipient_channel = cursor.read_u32::<BigEndian>()?;
|
||
|
||
// 读取请求类型(SSH string)
|
||
let request_type = read_ssh_string(&mut cursor)?;
|
||
|
||
// 读取want reply标志(boolean)
|
||
let want_reply = cursor.read_u8()? != 0;
|
||
|
||
info!("Channel request: channel={}, type={}, want_reply={}",
|
||
recipient_channel, request_type, want_reply);
|
||
|
||
// 处理不同请求类型(参考OpenSSH channel.c)
|
||
if request_type == "exec" {
|
||
self.handle_exec_request(&mut cursor, recipient_channel, want_reply) // 移除?操作符(返回Option不是Result)
|
||
} else if request_type == "subsystem" {
|
||
self.handle_subsystem_request(&mut cursor, recipient_channel, want_reply) // 移除?操作符
|
||
} else if request_type == "shell" {
|
||
self.handle_shell_request(recipient_channel, want_reply) // 移除?操作符
|
||
} else if request_type == "env" {
|
||
self.handle_env_request(&mut cursor, recipient_channel, want_reply) // 移除?操作符
|
||
} else if request_type == "pty-req" {
|
||
self.handle_pty_request(&mut cursor, recipient_channel, want_reply) // 移除?操作符
|
||
} else {
|
||
warn!("Unsupported channel request: {}", request_type);
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_failure(recipient_channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 处理exec请求(参考OpenSSH channel.c: channel_request_exec() + session.c: do_exec_no_pty)
|
||
fn handle_exec_request(&mut self, cursor: &mut std::io::Cursor<&[u8]>, channel: u32, want_reply: bool) -> Result<Option<SshPacket>> {
|
||
info!("Handling exec request for channel {}", channel);
|
||
|
||
// 读取命令(SSH string)
|
||
let command = read_ssh_string(cursor)?;
|
||
|
||
info!("Exec command: {}", command);
|
||
|
||
// Phase 14: 检测rsync命令,启动交互式进程
|
||
if command.starts_with("rsync --server") || command.contains("rsync") {
|
||
info!("Detected rsync command, starting interactive process");
|
||
self.handle_rsync_exec(&command, channel)?;
|
||
} else {
|
||
// Phase 6: 普通命令使用非交互式执行
|
||
let output = self.execute_command(&command)?;
|
||
|
||
// 存储输出,等待后续发送CHANNEL_DATA
|
||
if let Some(ch) = self.channels.get_mut(&channel) {
|
||
ch.output_buffer = Some(output);
|
||
}
|
||
}
|
||
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_success(channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
/// Phase 14: 处理rsync交互式exec(参考OpenSSH session.c: do_exec_no_pty)
|
||
/// ⭐⭐⭐⭐⭐ OpenSSH风格:使用poll()替代thread::spawn(非阻塞I/O)
|
||
fn handle_rsync_exec(&mut self, command: &str, channel_id: u32) -> Result<()> {
|
||
use std::process::{Command, Stdio};
|
||
use std::os::unix::io::AsRawFd;
|
||
|
||
info!("Starting interactive process for rsync (OpenSSH poll style): {}", command);
|
||
|
||
// 启动子进程(相当于OpenSSH fork)
|
||
let mut child = Command::new("sh")
|
||
.arg("-c")
|
||
.arg(command)
|
||
.stdin(Stdio::piped()) // ← 创建stdin管道(相当于pipe(pin))
|
||
.stdout(Stdio::piped()) // ← 创建stdout管道(相当于pipe(pout))
|
||
.stderr(Stdio::piped()) // ← 创建stderr管道(相当于pipe(perr))
|
||
.spawn()?;
|
||
|
||
info!("Child process spawned, PID: {:?}", child.id());
|
||
|
||
// 提取管道(相当于OpenSSH dup2)
|
||
let stdin = child.stdin.take().ok_or(anyhow!("stdin take failed"))?;
|
||
let stdout = child.stdout.take().ok_or(anyhow!("stdout take failed"))?;
|
||
let stderr = child.stderr.take().ok_or(anyhow!("stderr take failed"))?;
|
||
|
||
// ⭐⭐⭐⭐⭐ OpenSSH关键:设置非阻塞模式(fcntl O_NONBLOCK)
|
||
let stdout_fd = stdout.as_raw_fd();
|
||
let stderr_fd = stderr.as_raw_fd();
|
||
|
||
info!("Setting stdout/stderr to non-blocking mode (OpenSSH style)");
|
||
fcntl(stdout_fd, FcntlArg::F_SETFL(OFlag::O_NONBLOCK))?;
|
||
fcntl(stderr_fd, FcntlArg::F_SETFL(OFlag::O_NONBLOCK))?;
|
||
info!("Non-blocking I/O enabled for stdout (fd {}) and stderr (fd {})", stdout_fd, stderr_fd);
|
||
|
||
// ⭐⭐⭐⭐⭐ OpenSSH风格:不再使用thread::spawn,直接保留File对象用于poll
|
||
// 存储到channel(相当于OpenSSH session_set_fds)
|
||
if let Some(ch) = self.channels.get_mut(&channel_id) {
|
||
ch.exec_process = Some(ExecProcess {
|
||
child,
|
||
stdin: Some(stdin),
|
||
stdout: Some(stdout), // ⭐⭐⭐⭐⭐ 直接保留File对象
|
||
stderr: Some(stderr), // ⭐⭐⭐⭐⭐ 直接保留File对象
|
||
stdout_fd, // ⭐⭐⭐⭐⭐ RawFd用于poll
|
||
stderr_fd, // ⭐⭐⭐⭐⭐ RawFd用于poll
|
||
});
|
||
info!("Interactive process stored for channel {} (poll-ready)", channel_id);
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 执行命令并捕获输出(Phase 6基础实现)
|
||
fn execute_command(&self, command: &str) -> Result<Vec<u8>> {
|
||
use std::process::{Command, Stdio};
|
||
|
||
info!("Executing command: {}", command);
|
||
|
||
// 使用shell执行命令(参考OpenSSH session.c)
|
||
let output = Command::new("sh")
|
||
.arg("-c")
|
||
.arg(command)
|
||
.output()?;
|
||
|
||
// 返回stdout + stderr
|
||
let mut result = output.stdout;
|
||
result.extend_from_slice(&output.stderr);
|
||
|
||
info!("Command output: {} bytes", result.len());
|
||
Ok(result)
|
||
}
|
||
|
||
/// 处理subsystem请求(参考OpenSSH channel.c: channel_request_subsystem())
|
||
fn handle_subsystem_request(&mut self, cursor: &mut std::io::Cursor<&[u8]>, channel: u32, want_reply: bool) -> Result<Option<SshPacket>> {
|
||
info!("Handling subsystem request for channel {}", channel);
|
||
|
||
// 读取subsystem名称(SSH string)
|
||
let subsystem = read_ssh_string(cursor)?;
|
||
|
||
info!("Subsystem: {}", subsystem);
|
||
|
||
// 检查subsystem支持(OpenSSH支持:sftp)
|
||
if subsystem == "sftp" {
|
||
info!("SFTP subsystem requested");
|
||
|
||
// Phase 7: 初始化SFTP handler
|
||
let root_dir = PathBuf::from("/Users/accusys/markbase"); // 默认root目录
|
||
let sftp_handler = SftpHandler::new(root_dir);
|
||
|
||
// 存储到channel
|
||
if let Some(ch) = self.channels.get_mut(&channel) {
|
||
ch.sftp_handler = Some(sftp_handler);
|
||
info!("SFTP handler initialized for channel {}", channel);
|
||
}
|
||
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_success(channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
} else {
|
||
warn!("Unsupported subsystem: {}", subsystem);
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_failure(channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 处理shell请求(参考OpenSSH channel.c)
|
||
fn handle_shell_request(&mut self, channel: u32, want_reply: bool) -> Result<Option<SshPacket>> {
|
||
info!("Handling shell request for channel {}", channel);
|
||
|
||
// Phase 9将实现shell
|
||
warn!("Shell not implemented in Phase 6");
|
||
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_failure(channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
/// 处理env请求(参考OpenSSH channel.c)
|
||
fn handle_env_request(&mut self, cursor: &mut std::io::Cursor<&[u8]>, channel: u32, want_reply: bool) -> Result<Option<SshPacket>> {
|
||
info!("Handling env request for channel {}", channel);
|
||
|
||
// 读取环境变量名和值
|
||
let name = read_ssh_string(cursor)?;
|
||
let value = read_ssh_string(cursor)?;
|
||
|
||
info!("Env: {}={}", name, value);
|
||
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_success(channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
/// 处理pty请求(参考OpenSSH channel.c)
|
||
fn handle_pty_request(&mut self, cursor: &mut std::io::Cursor<&[u8]>, channel: u32, want_reply: bool) -> Result<Option<SshPacket>> {
|
||
info!("Handling pty request for channel {}", channel);
|
||
|
||
// 读取terminal类型(SSH string)
|
||
let term = read_ssh_string(cursor)?;
|
||
|
||
// 读取窗口大小(4个uint32)
|
||
let width = cursor.read_u32::<BigEndian>()?;
|
||
let height = cursor.read_u32::<BigEndian>()?;
|
||
let _pixel_width = cursor.read_u32::<BigEndian>()?;
|
||
let _pixel_height = cursor.read_u32::<BigEndian>()?;
|
||
|
||
// 读取terminal modes(SSH string格式)
|
||
let modes_len = cursor.read_u32::<BigEndian>()?;
|
||
let mut modes = vec![0u8; modes_len as usize];
|
||
cursor.read_exact(&mut modes)?;
|
||
|
||
info!("PTY: term={}, width={}, height={}, modes_len={}", term, width, height, modes_len);
|
||
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_success(channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
/// 处理SSH_MSG_CHANNEL_DATA(参考OpenSSH channel.c: channel_input_data())
|
||
pub fn handle_channel_data(&mut self, packet: &SshPacket) -> Result<Option<SshPacket>> {
|
||
info!("Processing SSH_MSG_CHANNEL_DATA");
|
||
|
||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
|
||
|
||
// Packet type
|
||
let packet_type = cursor.read_u8()?;
|
||
if packet_type != PacketType::SSH_MSG_CHANNEL_DATA as u8 {
|
||
return Err(anyhow!("Invalid packet type for CHANNEL_DATA"));
|
||
}
|
||
|
||
// 读取recipient channel
|
||
let recipient_channel = cursor.read_u32::<BigEndian>()?;
|
||
|
||
// 读取数据(SSH string)
|
||
let data_length = cursor.read_u32::<BigEndian>()?;
|
||
let mut data = vec![0u8; data_length as usize];
|
||
cursor.read_exact(&mut data)?;
|
||
|
||
info!("Channel data: channel={}, length={}", recipient_channel, data.len());
|
||
info!("Channel data content (first 20 bytes): {:?}", &data[..std::cmp::min(20, data.len())]);
|
||
|
||
// Phase 14: 检查是否是交互式exec进程
|
||
if let Some(channel) = self.channels.get_mut(&recipient_channel) {
|
||
if let Some(exec_process) = &mut channel.exec_process {
|
||
info!("Interactive exec process detected, forwarding data to stdin");
|
||
info!("Channel data content: {:?}", &data);
|
||
info!("Child PID: {:?}", exec_process.child.id());
|
||
|
||
// 检查子进程状态
|
||
match exec_process.child.try_wait() {
|
||
Ok(Some(status)) => {
|
||
warn!("Child process already exited with status: {:?}", status);
|
||
}
|
||
Ok(None) => {
|
||
info!("Child process still running");
|
||
}
|
||
Err(e) => {
|
||
warn!("Failed to check child status: {}", e);
|
||
}
|
||
}
|
||
|
||
// 转发数据到子进程stdin(相当于OpenSSH写fdin)
|
||
if let Some(stdin) = &mut exec_process.stdin {
|
||
use std::io::Write;
|
||
stdin.write_all(&data)?;
|
||
stdin.flush()?;
|
||
info!("Forwarded {} bytes to stdin (OpenSSH style)", data.len());
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ OpenSSH风格:不等待,直接返回None(主循环会通过poll处理stdout)
|
||
info!("stdin forwarded, returning None (main loop will poll stdout/stderr)");
|
||
return Ok(None);
|
||
}
|
||
|
||
// Phase 7: 检查是否是SFTP channel(⭐⭐⭐⭐⭐ Phase 14.3: packet accumulation)
|
||
if let Some(sftp_handler) = &mut channel.sftp_handler {
|
||
info!("Processing SFTP request ({} bytes)", data.len());
|
||
|
||
// ⭐⭐⭐⭐⭐ Critical修复:累积SFTP packet数据
|
||
channel.sftp_input_buffer.extend_from_slice(&data);
|
||
info!("SFTP buffer accumulated: {} bytes total", channel.sftp_input_buffer.len());
|
||
|
||
// 检查buffer是否有足够数据解析packet length
|
||
if channel.sftp_input_buffer.len() < 4 {
|
||
info!("SFTP buffer too short for length field, waiting for more data");
|
||
return Ok(None); // 继续累积
|
||
}
|
||
|
||
// 解析SFTP packet length(前4 bytes)
|
||
let sftp_length = u32::from_be_bytes([
|
||
channel.sftp_input_buffer[0],
|
||
channel.sftp_input_buffer[1],
|
||
channel.sftp_input_buffer[2],
|
||
channel.sftp_input_buffer[3]
|
||
]) as usize;
|
||
info!("SFTP packet length field: {}", sftp_length);
|
||
|
||
let expected_total = 4 + sftp_length;
|
||
if channel.sftp_input_buffer.len() < expected_total {
|
||
info!("SFTP packet incomplete: expected {} bytes, have {} bytes in buffer, waiting for more",
|
||
expected_total, channel.sftp_input_buffer.len());
|
||
return Ok(None); // 继续累积
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ Buffer足够,解析完整SFTP packet
|
||
let sftp_packet = &channel.sftp_input_buffer[4..expected_total];
|
||
info!("SFTP packet complete: {} bytes, processing", sftp_packet.len());
|
||
info!("SFTP packet content (first 20 bytes): {:?}", &sftp_packet[..std::cmp::min(20, sftp_packet.len())]);
|
||
|
||
let response = sftp_handler.handle_request(sftp_packet)?;
|
||
info!("SFTP response: {} bytes", response.len());
|
||
|
||
// ⭐⭐⭐⭐⭐ 处理完后,清空buffer或保留剩余数据
|
||
if channel.sftp_input_buffer.len() > expected_total {
|
||
// 有剩余数据(多个packets的情况)
|
||
let remaining = channel.sftp_input_buffer[expected_total..].to_vec();
|
||
channel.sftp_input_buffer = remaining;
|
||
info!("SFTP buffer has remaining {} bytes after processing", channel.sftp_input_buffer.len());
|
||
} else {
|
||
// 清空buffer
|
||
channel.sftp_input_buffer.clear();
|
||
info!("SFTP buffer cleared after processing");
|
||
}
|
||
|
||
// 构建SSH_MSG_CHANNEL_DATA返回SFTP响应(需要SSH string格式)
|
||
return Ok(Some(self.build_channel_data(recipient_channel, &response)?));
|
||
}
|
||
}
|
||
|
||
// 如果不是SFTP或exec_process,返回None
|
||
Ok(None)
|
||
}
|
||
|
||
/// Phase 14: 构建SSH_MSG_CHANNEL_EXTENDED_DATA(参考OpenSSH channel.c)
|
||
fn build_channel_extended_data(&self, channel: u32, data_type: u32, data: &[u8]) -> Result<SshPacket> {
|
||
let mut buffer = Vec::new();
|
||
|
||
buffer.write_u8(PacketType::SSH_MSG_CHANNEL_EXTENDED_DATA as u8)?;
|
||
buffer.write_u32::<BigEndian>(channel)?;
|
||
buffer.write_u32::<BigEndian>(data_type)?; // 1 = stderr, 2 = exit status
|
||
buffer.write_u32::<BigEndian>(data.len() as u32)?;
|
||
buffer.write_all(data)?;
|
||
|
||
Ok(SshPacket {
|
||
packet_length: 0,
|
||
padding_length: 0,
|
||
payload: buffer,
|
||
padding: Vec::new(),
|
||
})
|
||
}
|
||
|
||
/// 处理SSH_MSG_CHANNEL_CLOSE(参考OpenSSH channel.c: channel_input_close())
|
||
pub fn handle_channel_close(&mut self, packet: &SshPacket) -> Result<Option<SshPacket>> {
|
||
info!("Processing SSH_MSG_CHANNEL_CLOSE");
|
||
|
||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准)
|
||
|
||
// Packet type
|
||
let packet_type = cursor.read_u8()?;
|
||
if packet_type != PacketType::SSH_MSG_CHANNEL_CLOSE as u8 {
|
||
return Err(anyhow!("Invalid packet type for CHANNEL_CLOSE"));
|
||
}
|
||
|
||
// 读取recipient channel
|
||
let recipient_channel = cursor.read_u32::<BigEndian>()?;
|
||
|
||
info!("Channel close: channel={}", recipient_channel);
|
||
|
||
// 移除channel(参考OpenSSH channel.c)
|
||
if let Some(channel) = self.channels.remove(&recipient_channel) {
|
||
info!("Channel {} removed", recipient_channel);
|
||
|
||
// 发送SSH_MSG_CHANNEL_CLOSE回应
|
||
Ok(Some(self.build_channel_close(channel.sender_channel)?))
|
||
} else {
|
||
warn!("Channel {} not found", recipient_channel);
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
/// 构建SSH_MSG_CHANNEL_OPEN_CONFIRMATION(参考OpenSSH channel.c)
|
||
fn build_channel_open_confirmation(
|
||
&self,
|
||
server_channel: u32,
|
||
sender_channel: u32,
|
||
window_size: u32,
|
||
packet_size: u32,
|
||
) -> Result<SshPacket> {
|
||
let mut payload = Vec::new();
|
||
|
||
// Packet type
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_OPEN_CONFIRMATION as u8)?;
|
||
|
||
// Server channel number
|
||
payload.write_u32::<BigEndian>(server_channel)?;
|
||
|
||
// Sender channel number
|
||
payload.write_u32::<BigEndian>(sender_channel)?;
|
||
|
||
// Initial window size
|
||
payload.write_u32::<BigEndian>(window_size)?;
|
||
|
||
// Maximum packet size
|
||
payload.write_u32::<BigEndian>(packet_size)?;
|
||
|
||
Ok(SshPacket::new(payload))
|
||
}
|
||
|
||
/// 构建SSH_MSG_CHANNEL_OPEN_FAILURE(参考OpenSSH channel.c)
|
||
fn build_channel_open_failure(
|
||
&self,
|
||
sender_channel: u32,
|
||
reason_code: u32,
|
||
description: &str,
|
||
language: &str,
|
||
) -> Result<SshPacket> {
|
||
let mut payload = Vec::new();
|
||
|
||
// Packet type
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_OPEN_FAILURE as u8)?;
|
||
|
||
// Sender channel number
|
||
payload.write_u32::<BigEndian>(sender_channel)?;
|
||
|
||
// Reason code
|
||
payload.write_u32::<BigEndian>(reason_code)?;
|
||
|
||
// Description(SSH string)
|
||
payload.write_u32::<BigEndian>(description.len() as u32)?;
|
||
payload.write_all(description.as_bytes())?;
|
||
|
||
// Language(SSH string)
|
||
payload.write_u32::<BigEndian>(language.len() as u32)?;
|
||
payload.write_all(language.as_bytes())?;
|
||
|
||
Ok(SshPacket::new(payload))
|
||
}
|
||
|
||
/// 构建SSH_MSG_CHANNEL_SUCCESS(参考OpenSSH channel.c)
|
||
fn build_channel_success(&self, channel: u32) -> Result<SshPacket> {
|
||
let mut payload = Vec::new();
|
||
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_SUCCESS as u8)?;
|
||
payload.write_u32::<BigEndian>(channel)?;
|
||
|
||
Ok(SshPacket::new(payload))
|
||
}
|
||
|
||
/// 构建SSH_MSG_CHANNEL_FAILURE(参考OpenSSH channel.c)
|
||
fn build_channel_failure(&self, channel: u32) -> Result<SshPacket> {
|
||
let mut payload = Vec::new();
|
||
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_FAILURE as u8)?;
|
||
payload.write_u32::<BigEndian>(channel)?;
|
||
|
||
Ok(SshPacket::new(payload))
|
||
}
|
||
|
||
/// 构建SSH_MSG_CHANNEL_CLOSE(参考OpenSSH channel.c)
|
||
pub fn build_channel_close(&self, channel: u32) -> Result<SshPacket> {
|
||
let mut payload = Vec::new();
|
||
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_CLOSE as u8)?;
|
||
payload.write_u32::<BigEndian>(channel)?;
|
||
|
||
Ok(SshPacket::new(payload))
|
||
}
|
||
|
||
/// 构建SSH_MSG_CHANNEL_DATA(Phase 6新增)
|
||
pub fn build_channel_data(&self, channel: u32, data: &[u8]) -> Result<SshPacket> {
|
||
let mut payload = Vec::new();
|
||
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_DATA as u8)?;
|
||
payload.write_u32::<BigEndian>(channel)?;
|
||
payload.write_u32::<BigEndian>(data.len() as u32)?;
|
||
payload.write_all(data)?;
|
||
|
||
Ok(SshPacket::new(payload))
|
||
}
|
||
|
||
/// 构建SSH_MSG_CHANNEL_EOF(Phase 6新增)
|
||
pub fn build_channel_eof(&self, channel: u32) -> Result<SshPacket> {
|
||
let mut payload = Vec::new();
|
||
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_EOF as u8)?;
|
||
payload.write_u32::<BigEndian>(channel)?;
|
||
|
||
Ok(SshPacket::new(payload))
|
||
}
|
||
|
||
/// 获取有输出待发送的channel ID(Phase 6新增)
|
||
pub fn get_channel_with_output(&self) -> Option<u32> {
|
||
for (&id, channel) in &self.channels {
|
||
if channel.output_buffer.is_some() {
|
||
return Some(id);
|
||
}
|
||
}
|
||
None
|
||
}
|
||
|
||
/// 获取channel输出(Phase 6新增)
|
||
pub fn get_channel_output(&mut self, channel_id: u32) -> Option<Vec<u8>> {
|
||
if let Some(channel) = self.channels.get_mut(&channel_id) {
|
||
channel.output_buffer.take()
|
||
} else {
|
||
None
|
||
}
|
||
}
|
||
|
||
/// 移除channel(Phase 6新增)
|
||
pub fn remove_channel(&mut self, channel_id: u32) {
|
||
self.channels.remove(&channel_id);
|
||
}
|
||
|
||
/// Phase 14: OpenSSH风格poll机制(使用nix::poll监听stdout/stderr fd)
|
||
/// ⭐⭐⭐⭐⭐ 关键:非阻塞读取数据,不等待子进程完成
|
||
/// ⭐⭐⭐⭐⭐ Phase 14.2: 处理child exited(发送EOF + CLOSE)
|
||
/// 参考:OpenSSH session.c: do_exec_no_pty()
|
||
pub fn handle_child_exited(&mut self) -> Result<Vec<SshPacket>> {
|
||
// 1. 收集需要处理的channel IDs
|
||
let channel_ids: Vec<u32> = self.channels
|
||
.iter()
|
||
.filter_map(|(id, channel)| {
|
||
if channel.exec_process.is_some() {
|
||
Some(*id)
|
||
} else {
|
||
None
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
// 2. 构建packets(避免borrow冲突)
|
||
let mut packets = Vec::new();
|
||
for channel_id in &channel_ids {
|
||
// 发送SSH_MSG_CHANNEL_EOF
|
||
let eof_packet = self.build_channel_eof(*channel_id)?;
|
||
packets.push(eof_packet);
|
||
|
||
// 发送SSH_MSG_CHANNEL_CLOSE
|
||
let close_packet = self.build_channel_close(*channel_id)?;
|
||
packets.push(close_packet);
|
||
}
|
||
|
||
// 3. 清除exec_process(mutable borrow)
|
||
for channel_id in &channel_ids {
|
||
if let Some(channel) = self.channels.get_mut(channel_id) {
|
||
channel.exec_process = None;
|
||
}
|
||
}
|
||
|
||
if !channel_ids.is_empty() {
|
||
info!("Child exited, sent EOF + CLOSE for {} channels", channel_ids.len());
|
||
}
|
||
|
||
Ok(packets)
|
||
}
|
||
|
||
/// ⭐⭐⭐⭐⭐ Phase 14.2: OpenSSH统一poll + child进程状态检测
|
||
/// 参考:OpenSSH session.c: do_exec_no_pty() + channel.c: channel_handle_fd()
|
||
///
|
||
/// 关键改进(Phase 14.2):
|
||
/// - 单次poll()同时监听client socket和子进程输出
|
||
/// - timeout 10ms(非阻塞)
|
||
/// - **添加child进程状态检测**(防止无限spinning)⭐⭐⭐⭐⭐
|
||
/// - **添加max_poll_iterations限制**(最多100次,1秒)
|
||
/// - 返回(stdout_packets, client_has_data, child_exited)
|
||
pub fn poll_exec_stdout_and_client(&mut self, stream: &std::net::TcpStream) -> Result<(Option<Vec<SshPacket>>, bool, bool)> {
|
||
use std::io::Read;
|
||
use std::os::unix::io::{BorrowedFd, AsRawFd};
|
||
use nix::poll::{poll, PollFd, PollFlags};
|
||
|
||
// 收集所有需要poll的fd
|
||
let mut poll_fds_vec = Vec::new();
|
||
let mut client_has_data = false;
|
||
let mut child_exited = false;
|
||
|
||
// 1. 添加client socket fd(监听stdin数据)
|
||
let client_fd = stream.as_raw_fd();
|
||
let client_poll_fd = unsafe {
|
||
BorrowedFd::borrow_raw(client_fd)
|
||
};
|
||
poll_fds_vec.push(PollFd::new(client_poll_fd, PollFlags::POLLIN));
|
||
let client_fd_idx = 0; // client fd总是第一个
|
||
|
||
// 2. 添加所有channel的stdout/stderr fd
|
||
let mut channel_fds_map: HashMap<u32, (usize, usize)> = HashMap::new(); // channel_id -> (stdout_idx, stderr_idx)
|
||
let mut channel_ids_vec = Vec::new(); // 用于后续child状态检查
|
||
|
||
for (channel_id, channel) in &self.channels {
|
||
if let Some(exec_process) = &channel.exec_process {
|
||
channel_ids_vec.push(*channel_id);
|
||
|
||
// stdout fd
|
||
if let Some(_stdout) = &exec_process.stdout {
|
||
let stdout_poll_fd = unsafe {
|
||
BorrowedFd::borrow_raw(exec_process.stdout_fd)
|
||
};
|
||
poll_fds_vec.push(PollFd::new(stdout_poll_fd, PollFlags::POLLIN));
|
||
}
|
||
|
||
// stderr fd
|
||
if let Some(_stderr) = &exec_process.stderr {
|
||
let stderr_poll_fd = unsafe {
|
||
BorrowedFd::borrow_raw(exec_process.stderr_fd)
|
||
};
|
||
poll_fds_vec.push(PollFd::new(stderr_poll_fd, PollFlags::POLLIN));
|
||
}
|
||
|
||
// 记录索引(相对于client_fd_idx)
|
||
let stdout_idx = poll_fds_vec.len() - 2;
|
||
let stderr_idx = poll_fds_vec.len() - 1;
|
||
channel_fds_map.insert(*channel_id, (stdout_idx, stderr_idx));
|
||
}
|
||
}
|
||
|
||
if poll_fds_vec.len() == 1 {
|
||
// 只有client fd,没有exec_process
|
||
// 直接poll client(short timeout)
|
||
match poll(&mut poll_fds_vec, 10u16) {
|
||
Ok(n) if n > 0 => {
|
||
if let Some(revents) = poll_fds_vec[client_fd_idx].revents() {
|
||
if revents.contains(PollFlags::POLLIN) {
|
||
client_has_data = true;
|
||
}
|
||
}
|
||
}
|
||
_ => {}
|
||
}
|
||
return Ok((None, client_has_data, false));
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 14.2关键:添加poll轮询限制(防止无限spinning)
|
||
// 最多轮询100次(1秒),如果持续无数据,检查child状态
|
||
let max_poll_iterations = 100;
|
||
let mut poll_iteration = 0;
|
||
let mut found_data = false;
|
||
let mut stdin_closed = false; // ⭐⭐⭐⭐⭐ 新增:跟踪stdin是否已关闭
|
||
|
||
for iteration in 0..max_poll_iterations {
|
||
poll_iteration = iteration;
|
||
|
||
// ⭐⭐⭐⭐⭐ 每10次轮询记录一次日志(减少噪音)
|
||
if iteration % 10 == 0 {
|
||
info!("Polling {} fds (iteration {} of {}, stdin_closed={})", poll_fds_vec.len(), iteration, max_poll_iterations, stdin_closed);
|
||
}
|
||
|
||
match poll(&mut poll_fds_vec, 10u16) {
|
||
Ok(n) if n > 0 => {
|
||
info!("{} fds have data available (iteration {})", n, iteration);
|
||
found_data = true;
|
||
break; // 有数据,立即处理
|
||
}
|
||
Ok(0) => {
|
||
// timeout,无数据
|
||
// ⭐⭐⭐⭐⭐ 关键:每10次检查child进程状态(防止spinning)
|
||
if iteration % 10 == 9 {
|
||
// 检查child是否exited
|
||
for channel_id in &channel_ids_vec {
|
||
if let Some(channel) = self.channels.get_mut(channel_id) {
|
||
if let Some(exec_process) = &mut channel.exec_process {
|
||
match exec_process.child.try_wait() {
|
||
Ok(Some(status)) => {
|
||
info!("Child process exited (channel {}, status: {:?})", channel_id, status);
|
||
child_exited = true;
|
||
|
||
// ⭐⭐⭐⭐⭐ Child exited,读取剩余stdout(如果有)
|
||
if let Some(stdout) = &mut exec_process.stdout {
|
||
let mut buffer = vec![0u8; 32768];
|
||
match stdout.read(&mut buffer) {
|
||
Ok(n) if n > 0 => {
|
||
info!("Read {} final bytes from stdout (child exited)", n);
|
||
// 构建packet并返回
|
||
let packet = self.build_channel_data(*channel_id, &buffer[..n])?;
|
||
return Ok((Some(vec![packet]), false, true));
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
// 没有剩余数据,返回child_exited标志
|
||
return Ok((None, false, true));
|
||
}
|
||
Ok(None) => {
|
||
// Child still running(正常)
|
||
info!("Child still running (channel {}, iteration {}, stdin_closed={})", channel_id, iteration, stdin_closed);
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 14.2最终修复:主动关闭stdin超时机制
|
||
// 如果stdin未关闭,且超过50次poll(500ms)无数据
|
||
// 强制关闭stdin,发送EOF给rsync
|
||
if !stdin_closed && iteration >= 50 && exec_process.stdin.is_some() {
|
||
info!("⭐⭐⭐⭐⭐ Forcing stdin close after {} iterations ({} ms) - sending EOF to rsync", iteration, iteration * 10);
|
||
exec_process.stdin = None; // Drop stdin,发送EOF
|
||
stdin_closed = true;
|
||
|
||
// ⭐⭐⭐⭐⭐ stdin关闭后,继续等待child处理完成
|
||
// 不要立即返回,给rsync时间处理数据并产生stdout
|
||
info!("stdin closed, continuing to poll for stdout output...");
|
||
}
|
||
}
|
||
Err(e) => {
|
||
warn!("Child try_wait error: {}", e);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 继续轮询(如果iteration < max_poll_iterations)
|
||
}
|
||
Err(e) => {
|
||
warn!("poll error: {}", e);
|
||
return Ok((None, false, false));
|
||
}
|
||
Ok(_) => {
|
||
// 其他情况(不应该发生)
|
||
}
|
||
}
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ 达到max_poll_iterations,检查最终child状态
|
||
if !found_data {
|
||
info!("No data after {} iterations ({} ms), checking child status", max_poll_iterations, max_poll_iterations * 10);
|
||
|
||
for channel_id in &channel_ids_vec {
|
||
if let Some(channel) = self.channels.get_mut(channel_id) {
|
||
if let Some(exec_process) = &mut channel.exec_process {
|
||
match exec_process.child.try_wait() {
|
||
Ok(Some(status)) => {
|
||
info!("Child exited after max iterations (status: {:?})", status);
|
||
child_exited = true;
|
||
|
||
// 读取剩余stdout
|
||
if let Some(stdout) = &mut exec_process.stdout {
|
||
let mut buffer = vec![0u8; 32768];
|
||
match stdout.read(&mut buffer) {
|
||
Ok(n) if n > 0 => {
|
||
let packet = self.build_channel_data(*channel_id, &buffer[..n])?;
|
||
return Ok((Some(vec![packet]), false, true));
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
return Ok((None, false, true));
|
||
}
|
||
Ok(None) => {
|
||
info!("Child still running after max iterations, returning None");
|
||
// Child还在运行,但无stdout数据
|
||
// 主循环会继续调用此函数
|
||
return Ok((None, false, false));
|
||
}
|
||
Err(e) => {
|
||
warn!("Final child check error: {}", e);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ 处理找到的数据(如果found_data)
|
||
// 3. 检查client fd状态(包括EOF/HUP)
|
||
if let Some(revents) = poll_fds_vec[client_fd_idx].revents() {
|
||
if revents.contains(PollFlags::POLLIN) {
|
||
info!("Client fd has data (stdin from client)");
|
||
client_has_data = true;
|
||
} else if revents.contains(PollFlags::POLLHUP) {
|
||
info!("Client fd hangup (EOF received from client)");
|
||
// ⭐⭐⭐⭐⭐ Phase 14.2关键修复:关闭stdin pipe,发送EOF给child
|
||
// 参考:OpenSSH session.c: do_exec_no_pty() stdin handling
|
||
for (_, channel) in &mut self.channels {
|
||
if let Some(exec_process) = &mut channel.exec_process {
|
||
if exec_process.stdin.is_some() {
|
||
info!("Closing stdin pipe (sending EOF to child process)");
|
||
exec_process.stdin = None; // Drop stdin,发送EOF给child
|
||
}
|
||
}
|
||
}
|
||
client_has_data = false;
|
||
} else if revents.contains(PollFlags::POLLERR) {
|
||
warn!("Client fd error");
|
||
return Err(anyhow::anyhow!("Client socket error"));
|
||
}
|
||
}
|
||
|
||
// 4. 检查stdout/stderr fd是否有数据
|
||
let mut packets_data: Vec<(u32, Vec<u8>)> = Vec::new();
|
||
|
||
for (channel_id, (stdout_idx, stderr_idx)) in channel_fds_map {
|
||
if let Some(channel) = self.channels.get_mut(&channel_id) {
|
||
if let Some(exec_process) = &mut channel.exec_process {
|
||
// 检查stdout
|
||
if let Some(revents) = poll_fds_vec[stdout_idx].revents() {
|
||
if revents.contains(PollFlags::POLLIN) {
|
||
info!("stdout fd has data (channel {})", channel_id);
|
||
if let Some(stdout) = &mut exec_process.stdout {
|
||
let mut buffer = vec![0u8; 32768];
|
||
match stdout.read(&mut buffer) {
|
||
Ok(n) if n > 0 => {
|
||
info!("Read {} bytes from stdout (channel {})", n, channel_id);
|
||
packets_data.push((channel_id, buffer[..n].to_vec()));
|
||
}
|
||
Ok(0) => {
|
||
info!("stdout EOF (channel {}), closing stdout pipe", channel_id);
|
||
// ⭐⭐⭐⭐⭐ Critical修复:EOF时关闭pipe,避免无限循环
|
||
exec_process.stdout = None;
|
||
}
|
||
Err(e) if e.kind() != std::io::ErrorKind::WouldBlock => {
|
||
warn!("stdout read error: {}", e);
|
||
exec_process.stdout = None; // 错误时也关闭
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 检查stderr
|
||
if let Some(revents) = poll_fds_vec[stderr_idx].revents() {
|
||
if revents.contains(PollFlags::POLLIN) {
|
||
info!("stderr fd has data (channel {})", channel_id);
|
||
if let Some(stderr) = &mut exec_process.stderr {
|
||
let mut buffer = vec![0u8; 32768];
|
||
match stderr.read(&mut buffer) {
|
||
Ok(n) if n > 0 => {
|
||
info!("Read {} bytes from stderr (channel {})", n, channel_id);
|
||
packets_data.push((channel_id, buffer[..n].to_vec()));
|
||
}
|
||
Ok(0) => {
|
||
info!("stderr EOF (channel {}), closing stderr pipe", channel_id);
|
||
// ⭐⭐⭐⭐⭐ Critical修复:EOF时关闭pipe,避免无限循环
|
||
exec_process.stderr = None;
|
||
}
|
||
Err(e) if e.kind() != std::io::ErrorKind::WouldBlock => {
|
||
warn!("stderr read error: {}", e);
|
||
exec_process.stderr = None; // 错误时也关闭
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 构建packets
|
||
if !packets_data.is_empty() {
|
||
let mut packets = Vec::new();
|
||
for (channel_id, data) in packets_data {
|
||
let packet = self.build_channel_data(channel_id, &data)?;
|
||
packets.push(packet);
|
||
}
|
||
return Ok((Some(packets), client_has_data, child_exited));
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 14.2最终修复:stdout/stderr EOF后检查child exited
|
||
// 当stdout和stderr都关闭后,强制检查child状态
|
||
for channel_id in &channel_ids_vec {
|
||
if let Some(channel) = self.channels.get_mut(channel_id) {
|
||
if let Some(exec_process) = &mut channel.exec_process {
|
||
// 检查stdout和stderr是否都已关闭
|
||
if exec_process.stdout.is_none() && exec_process.stderr.is_none() {
|
||
info!("stdout/stderr both closed (channel {}), checking child status", channel_id);
|
||
|
||
// ⭐⭐⭐⭐⭐ 立即检查child是否exited
|
||
match exec_process.child.try_wait() {
|
||
Ok(Some(status)) => {
|
||
info!("⭐⭐⭐⭐⭐ Child exited after stdout/stderr EOF (status: {:?})", status);
|
||
child_exited = true;
|
||
|
||
// ⭐⭐⭐⭐⭐ 关键:立即返回child_exited标志
|
||
// server.rs会发送SSH_MSG_CHANNEL_EOF + CLOSE
|
||
return Ok((None, false, true));
|
||
}
|
||
Ok(None) => {
|
||
// Child still running but stdout/stderr closed
|
||
// 等待child exited
|
||
info!("Child still running after pipes closed, waiting...");
|
||
}
|
||
Err(e) => {
|
||
warn!("Child try_wait error after pipes closed: {}", e);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 有数据但只有client数据
|
||
Ok((None, client_has_data, child_exited))
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 14.0: 旧版poll(仅监听stdout/stderr,已废弃)
|
||
/// 已废弃:使用poll_exec_stdout_and_client()替代
|
||
#[allow(dead_code)]
|
||
pub fn poll_exec_stdout_with_fds(&mut self) -> Result<Option<Vec<SshPacket>>> {
|
||
use std::io::Read;
|
||
use std::os::unix::io::BorrowedFd;
|
||
|
||
// 遍历所有channel,收集poll_fds
|
||
let mut poll_fds_vec = Vec::new();
|
||
let mut channel_fds_map: HashMap<u32, (usize, usize)> = HashMap::new(); // channel_id -> (stdout_idx, stderr_idx) in poll_fds_vec
|
||
|
||
for (channel_id, channel) in &self.channels {
|
||
if let Some(exec_process) = &channel.exec_process {
|
||
// ⭐⭐⭐⭐⭐ OpenSSH风格:创建PollFd监听stdout/stderr
|
||
// nix 0.29 API: PollFd::new()需要借用fd,不是RawFd
|
||
if let Some(stdout) = &exec_process.stdout {
|
||
let stdout_poll_fd = unsafe {
|
||
// ⭐⭐⭐⭐⭐ 使用BorrowedFd::borrow_raw()(正确API)
|
||
BorrowedFd::borrow_raw(exec_process.stdout_fd)
|
||
};
|
||
poll_fds_vec.push(PollFd::new(stdout_poll_fd, PollFlags::POLLIN));
|
||
}
|
||
|
||
if let Some(stderr) = &exec_process.stderr {
|
||
let stderr_poll_fd = unsafe {
|
||
BorrowedFd::borrow_raw(exec_process.stderr_fd)
|
||
};
|
||
poll_fds_vec.push(PollFd::new(stderr_poll_fd, PollFlags::POLLIN));
|
||
}
|
||
|
||
// 记录poll_fds_vec中的索引
|
||
let stdout_idx = poll_fds_vec.len() - 2;
|
||
let stderr_idx = poll_fds_vec.len() - 1;
|
||
channel_fds_map.insert(*channel_id, (stdout_idx, stderr_idx));
|
||
}
|
||
}
|
||
|
||
if poll_fds_vec.is_empty() {
|
||
return Ok(None); // 没有exec_process
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ OpenSSH关键:使用poll监听所有fd
|
||
// ⭐⭐⭐⭐⭐ 持续poll机制:最多轮询1000次(给大文件传输足够时间)
|
||
// 大文件传输需要很长时间,增加轮询次数到1000次(总共10秒)
|
||
let max_poll_attempts = 1000;
|
||
let mut poll_attempt = 0;
|
||
let mut found_data = false;
|
||
|
||
for attempt in 0..max_poll_attempts {
|
||
poll_attempt = attempt;
|
||
// 每100次轮询记录一次日志(减少日志噪音)
|
||
if attempt % 100 == 0 {
|
||
info!("Polling {} fds (OpenSSH style, timeout 10ms, attempt {} of {})", poll_fds_vec.len(), attempt, max_poll_attempts);
|
||
}
|
||
match poll(&mut poll_fds_vec, 10u16) { // timeout 10ms
|
||
Ok(n) => {
|
||
if n > 0 {
|
||
info!("{} fds have data available (attempt {})", n, attempt);
|
||
found_data = true;
|
||
break; // 有数据,立即处理
|
||
}
|
||
// 没有数据,继续轮询(最多1000次)
|
||
}
|
||
Err(e) => {
|
||
warn!("poll error: {}", e);
|
||
return Ok(None);
|
||
}
|
||
}
|
||
}
|
||
|
||
if !found_data {
|
||
info!("No data available after {} poll attempts ({} ms), returning None", max_poll_attempts, max_poll_attempts * 10);
|
||
return Ok(None); // 轮询1000次后仍无数据,主循环继续处理client packet
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ OpenSSH风格:根据revents判断哪个fd有数据,立即读取
|
||
let mut packets_data: Vec<(u32, Vec<u8>)> = Vec::new(); // (channel_id, data)
|
||
|
||
for (channel_id, (stdout_idx, stderr_idx)) in channel_fds_map {
|
||
if let Some(channel) = self.channels.get_mut(&channel_id) {
|
||
if let Some(exec_process) = &mut channel.exec_process {
|
||
// 检查stdout是否有数据
|
||
if let Some(revents) = poll_fds_vec[stdout_idx].revents() {
|
||
if revents.contains(PollFlags::POLLIN) {
|
||
info!("stdout fd has data (channel {})", channel_id);
|
||
// ⭐⭐⭐⭐⭐ 非阻塞读取(因为设置了O_NONBLOCK)
|
||
if let Some(stdout) = &mut exec_process.stdout {
|
||
let mut buffer = vec![0u8; 32768];
|
||
match stdout.read(&mut buffer) {
|
||
Ok(n) => {
|
||
if n > 0 {
|
||
info!("Read {} bytes from stdout (channel {})", n, channel_id);
|
||
packets_data.push((channel_id, buffer[..n].to_vec()));
|
||
} else {
|
||
info!("stdout EOF (channel {})", channel_id);
|
||
}
|
||
}
|
||
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
||
// 非阻塞模式,没有数据(正常)
|
||
}
|
||
Err(e) => {
|
||
warn!("stdout read error: {}", e);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 检查stderr是否有数据(类似处理)
|
||
if let Some(revents) = poll_fds_vec[stderr_idx].revents() {
|
||
if revents.contains(PollFlags::POLLIN) {
|
||
info!("stderr fd has data (channel {})", channel_id);
|
||
if let Some(stderr) = &mut exec_process.stderr {
|
||
let mut buffer = vec![0u8; 32768];
|
||
match stderr.read(&mut buffer) {
|
||
Ok(n) => {
|
||
if n > 0 {
|
||
info!("Read {} bytes from stderr (channel {})", n, channel_id);
|
||
packets_data.push((channel_id, buffer[..n].to_vec()));
|
||
} else {
|
||
info!("stderr EOF (channel {})", channel_id);
|
||
}
|
||
}
|
||
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
||
// 非阻塞模式,没有数据(正常)
|
||
}
|
||
Err(e) => {
|
||
warn!("stderr read error: {}", e);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ 释放mutable borrow后,构建packets(避免borrow冲突)
|
||
let mut packets = Vec::new();
|
||
for (channel_id, data) in packets_data {
|
||
let packet = self.build_channel_data(channel_id, &data)?;
|
||
packets.push(packet);
|
||
}
|
||
|
||
if packets.is_empty() {
|
||
Ok(None)
|
||
} else {
|
||
Ok(Some(packets))
|
||
}
|
||
}
|
||
}
|
||
|
||
/// SSH Channel结构(参考OpenSSH channel.c: struct channel)
|
||
struct Channel {
|
||
server_channel: u32,
|
||
sender_channel: u32,
|
||
channel_type: String,
|
||
window_size: u32,
|
||
maximum_packet_size: u32,
|
||
state: ChannelState,
|
||
output_buffer: Option<Vec<u8>>, // Phase 6: 命令输出缓冲
|
||
sftp_handler: Option<SftpHandler>, // Phase 7: SFTP处理器
|
||
scp_handler: Option<ScpHandler>, // Phase 8: SCP处理器
|
||
rsync_handler: Option<RsyncHandler>, // Phase 8: rsync处理器
|
||
exec_process: Option<ExecProcess>, // Phase 14: 交互式exec进程
|
||
// ⭐⭐⭐⭐⭐ Critical修复:SFTP packet累积buffer
|
||
sftp_input_buffer: Vec<u8>, // Phase 14.2修复:累积不完整的SFTP packets
|
||
// Phase 13.3: 端口转发相关字段
|
||
direct_tcpip: Option<DirectTcpipChannel>, // direct-tcpip channel(Remote forwarding)
|
||
forwarded_tcpip: Option<ForwardedTcpipChannel>, // forwarded-tcpip channel(Local forwarding)
|
||
}
|
||
|
||
/// SSH Channel状态(参考OpenSSH channel.c)
|
||
enum ChannelState {
|
||
Open,
|
||
Closing,
|
||
Closed,
|
||
}
|
||
|
||
/// SSH string读取辅助函数
|
||
fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
|
||
let length = reader.read_u32::<BigEndian>()?;
|
||
let mut buffer = vec![0u8; length as usize];
|
||
reader.read_exact(&mut buffer)?;
|
||
Ok(String::from_utf8(buffer)?)
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_channel_manager_creation() {
|
||
let manager = ChannelManager::new();
|
||
assert_eq!(manager.next_channel_id, 0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_channel_open_confirmation() {
|
||
let manager = ChannelManager::new();
|
||
let packet = manager.build_channel_open_confirmation(0, 100, 2097152, 32768).unwrap();
|
||
|
||
assert_eq!(packet.payload[0], PacketType::SSH_MSG_CHANNEL_OPEN_CONFIRMATION as u8);
|
||
}
|
||
|
||
#[test]
|
||
fn test_channel_success() {
|
||
let manager = ChannelManager::new();
|
||
let packet = manager.build_channel_success(0).unwrap();
|
||
|
||
assert_eq!(packet.payload[0], PacketType::SSH_MSG_CHANNEL_SUCCESS as u8);
|
||
}
|
||
}
|