2080 lines
87 KiB
Rust
2080 lines
87 KiB
Rust
// SSH Channel协议实现(Phase 6 + Phase 13端口转发)
|
||
// 参考OpenSSH channel.c
|
||
|
||
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||
use crate::ssh_server::port_forward::{
|
||
DirectTcpipChannel, ForwardedTcpipChannel, PortForwardManager,
|
||
};
|
||
use crate::ssh_server::rsync_handler::RsyncHandler;
|
||
use crate::ssh_server::scp_handler::ScpHandler;
|
||
use crate::ssh_server::sftp_handler::SftpHandler;
|
||
use crate::ssh_server::ssh_security_config::SshSecurityConfig;
|
||
use crate::ssh_server::upload_hook::UploadHook;
|
||
use anyhow::{anyhow, Result};
|
||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||
use log::{info, warn};
|
||
use nix::fcntl::{fcntl, FcntlArg, OFlag}; // Phase 14: 非阻塞I/O(OpenSSH风格)
|
||
use nix::poll::{poll, PollFd, PollFlags};
|
||
use std::collections::{HashMap, VecDeque};
|
||
use std::io::{Read, Write}; // 导入Write trait(OpenSSH标准)
|
||
use std::os::unix::io::{AsRawFd, RawFd}; // Phase 14: OpenSSH风格poll机制(需要RawFd)
|
||
use std::path::PathBuf; // Phase 7-8: Path for SFTP/SCP/rsync root directory
|
||
use std::process::{Child, ChildStderr, ChildStdin, ChildStdout}; // Phase 14: 交互式exec
|
||
// Phase 14: poll机制(OpenSSH风格)
|
||
|
||
/// SSH Channel管理器(参考OpenSSH channel.c: struct channel)
|
||
pub struct ChannelManager {
|
||
channels: HashMap<u32, Channel>,
|
||
next_channel_id: u32,
|
||
pub pending_packets: VecDeque<SshPacket>,
|
||
pub home_dir: PathBuf,
|
||
pub upload_hook: Option<std::sync::Arc<UploadHook>>,
|
||
pub user_uuid: String,
|
||
}
|
||
|
||
/// Phase 14: 交互式Exec进程管理(参考OpenSSH session.c: do_exec_no_pty)
|
||
/// ⭐⭐⭐⭐⭐ OpenSSH风格:使用poll()替代thread::spawn(非阻塞I/O)
|
||
pub struct ExecProcess {
|
||
pub child: Child, // 子进程(rsync/scp等)
|
||
pub stdin: Option<ChildStdin>, // stdin管道(SSH client → 子进程)
|
||
pub stdout: Option<ChildStdout>, // ⭐⭐⭐⭐⭐ stdout管道(直接poll,不使用thread)
|
||
pub stderr: Option<ChildStderr>, // ⭐⭐⭐⭐⭐ stderr管道(直接poll,不使用thread)
|
||
pub stdout_fd: RawFd, // ⭐⭐⭐⭐⭐ stdout RawFd(用于poll)
|
||
pub stderr_fd: RawFd, // ⭐⭐⭐⭐⭐ stderr RawFd(用于poll)
|
||
pub command: String, // ⭐⭐⭐⭐⭐ Phase 16.2: 存储exec命令(用于SCP检测)
|
||
pub reuse_buf: Vec<u8>, // Phase 2a: reusable buffer for CHANNEL_DATA content
|
||
pub read_buf: Vec<u8>, // Phase 2b: reusable buffer for stdout/stderr reads (32KB)
|
||
}
|
||
|
||
impl ChannelManager {
|
||
pub fn new(
|
||
home_dir: PathBuf,
|
||
upload_hook: Option<std::sync::Arc<UploadHook>>,
|
||
user_uuid: String,
|
||
) -> Self {
|
||
Self {
|
||
channels: HashMap::new(),
|
||
next_channel_id: 0,
|
||
pending_packets: VecDeque::new(),
|
||
home_dir,
|
||
upload_hook,
|
||
user_uuid,
|
||
}
|
||
}
|
||
|
||
/// 处理SSH_MSG_CHANNEL_OPEN(参考OpenSSH channel.c: channel_open())
|
||
/// Phase 13.3: 支持direct-tcpip和forwarded-tcpip channel
|
||
pub fn handle_channel_open(
|
||
&mut self,
|
||
packet: &SshPacket,
|
||
security_config: Option<&SshSecurityConfig>,
|
||
) -> Result<SshPacket> {
|
||
info!("Processing SSH_MSG_CHANNEL_OPEN");
|
||
|
||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
|
||
|
||
// Packet type
|
||
let packet_type = cursor.read_u8()?;
|
||
if packet_type != PacketType::SSH_MSG_CHANNEL_OPEN as u8 {
|
||
return Err(anyhow!("Invalid packet type for CHANNEL_OPEN"));
|
||
}
|
||
|
||
// 读取channel类型(SSH string)
|
||
let channel_type = read_ssh_string(&mut cursor)?;
|
||
|
||
// 读取sender channel ID(u32)
|
||
let sender_channel = cursor.read_u32::<BigEndian>()?;
|
||
|
||
// 读取初始窗口大小(u32)
|
||
let initial_window_size = cursor.read_u32::<BigEndian>()?;
|
||
|
||
// 读取最大packet大小(u32)
|
||
let maximum_packet_size = cursor.read_u32::<BigEndian>()?;
|
||
|
||
info!(
|
||
"Channel open: type={}, sender_channel={}, window={}, max_packet={}",
|
||
channel_type, sender_channel, initial_window_size, maximum_packet_size
|
||
);
|
||
|
||
// Phase 13.3: 检查channel类型(支持session、direct-tcpip、forwarded-tcpip)
|
||
match channel_type.as_str() {
|
||
"session" => {
|
||
// 传统的session channel(Phase 6)
|
||
self.handle_session_channel_open(
|
||
sender_channel,
|
||
initial_window_size,
|
||
maximum_packet_size,
|
||
)
|
||
}
|
||
|
||
"direct-tcpip" => {
|
||
// Phase 13.3: Remote port forwarding channel
|
||
info!("Received direct-tcpip channel open (Remote port forwarding)");
|
||
self.handle_direct_tcpip_channel_open(
|
||
packet,
|
||
sender_channel,
|
||
initial_window_size,
|
||
maximum_packet_size,
|
||
security_config,
|
||
)
|
||
}
|
||
|
||
"forwarded-tcpip" => {
|
||
// Phase 13.3: Local port forwarding channel
|
||
info!("Received forwarded-tcpip channel open (Local port forwarding)");
|
||
self.handle_forwarded_tcpip_channel_open(
|
||
packet,
|
||
sender_channel,
|
||
initial_window_size,
|
||
maximum_packet_size,
|
||
)
|
||
}
|
||
|
||
_ => {
|
||
warn!("Unsupported channel type: {}", channel_type);
|
||
self.build_channel_open_failure(
|
||
sender_channel,
|
||
3, // SSH_OPEN_UNKNOWN_CHANNEL_TYPE
|
||
"Unsupported channel type",
|
||
"en",
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 处理session channel open(Phase 6)
|
||
fn handle_session_channel_open(
|
||
&mut self,
|
||
sender_channel: u32,
|
||
initial_window_size: u32,
|
||
maximum_packet_size: u32,
|
||
) -> Result<SshPacket> {
|
||
info!("Processing session channel open");
|
||
|
||
let server_channel = self.next_channel_id;
|
||
self.next_channel_id += 1;
|
||
|
||
let channel = Channel {
|
||
server_channel,
|
||
sender_channel,
|
||
channel_type: "session".to_string(),
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 15: Window Control(参考OpenSSH channels.h)
|
||
remote_window: initial_window_size, // 远端窗口(从 CHANNEL_OPEN packet 中读取)
|
||
remote_maxpacket: maximum_packet_size, // 远端最大 packet
|
||
local_window: 2097152, // 本地窗口(OpenSSH 默认 2MB)
|
||
local_window_max: 2097152, // 本地窗口最大值(同上)
|
||
local_consumed: 0, // 本地已消费的数据(初始为 0)⭐⭐⭐⭐⭐
|
||
local_maxpacket: 32768, // 本地最大 packet(OpenSSH 默认 32KB)
|
||
|
||
// 旧字段(保留兼容)
|
||
window_size: initial_window_size,
|
||
maximum_packet_size,
|
||
state: ChannelState::Open,
|
||
output_buffer: None,
|
||
sftp_handler: None,
|
||
scp_handler: None,
|
||
rsync_handler: None,
|
||
exec_process: None, // Phase 14: 交互式exec
|
||
sftp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.2修复:SFTP packet累积
|
||
scp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.4修复:SCP packet累积
|
||
direct_tcpip: None,
|
||
forwarded_tcpip: None,
|
||
};
|
||
|
||
self.channels.insert(server_channel, channel);
|
||
|
||
info!(
|
||
"Session channel created: server_channel={}, sender_channel={}",
|
||
server_channel, sender_channel
|
||
);
|
||
|
||
self.build_channel_open_confirmation(
|
||
server_channel,
|
||
sender_channel,
|
||
initial_window_size,
|
||
maximum_packet_size,
|
||
)
|
||
}
|
||
|
||
/// 处理direct-tcpip channel open(Phase 13.3: Remote port forwarding)
|
||
fn handle_direct_tcpip_channel_open(
|
||
&mut self,
|
||
packet: &SshPacket,
|
||
sender_channel: u32,
|
||
initial_window_size: u32,
|
||
maximum_packet_size: u32,
|
||
security_config: Option<&SshSecurityConfig>,
|
||
) -> Result<SshPacket> {
|
||
info!("Processing direct-tcpip channel open");
|
||
|
||
// 解析direct-tcpip参数
|
||
let mut port_forward_manager = PortForwardManager::new();
|
||
let direct_tcpip = port_forward_manager.handle_direct_tcpip_channel(&packet.payload)?;
|
||
|
||
// Phase 13.3: 安全配置验证
|
||
if let Some(security) = security_config {
|
||
if let Err(e) = security.validate_direct_tcpip_channel(
|
||
&direct_tcpip.host_to_connect,
|
||
direct_tcpip.port_to_connect,
|
||
) {
|
||
warn!("direct-tcpip security validation failed: {}", e);
|
||
return self.build_channel_open_failure(
|
||
sender_channel,
|
||
2, // SSH_OPEN_CONNECT_FAILED
|
||
"Security validation failed",
|
||
"en",
|
||
);
|
||
}
|
||
info!("direct-tcpip security validation passed");
|
||
}
|
||
|
||
let server_channel = self.next_channel_id;
|
||
self.next_channel_id += 1;
|
||
|
||
let channel = Channel {
|
||
server_channel,
|
||
sender_channel,
|
||
channel_type: "direct-tcpip".to_string(),
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 15: Window Control
|
||
remote_window: initial_window_size,
|
||
remote_maxpacket: maximum_packet_size,
|
||
local_window: 2097152,
|
||
local_window_max: 2097152,
|
||
local_consumed: 0,
|
||
local_maxpacket: 32768,
|
||
|
||
window_size: initial_window_size,
|
||
maximum_packet_size,
|
||
state: ChannelState::Open,
|
||
output_buffer: None,
|
||
sftp_handler: None,
|
||
scp_handler: None,
|
||
rsync_handler: None,
|
||
exec_process: None,
|
||
sftp_input_buffer: Vec::new(),
|
||
scp_input_buffer: Vec::new(),
|
||
direct_tcpip: Some(direct_tcpip),
|
||
forwarded_tcpip: None,
|
||
};
|
||
|
||
self.channels.insert(server_channel, channel);
|
||
|
||
info!(
|
||
"direct-tcpip channel created: server_channel={}, host={}, port={}",
|
||
server_channel,
|
||
self.channels
|
||
.get(&server_channel)
|
||
.unwrap()
|
||
.direct_tcpip
|
||
.as_ref()
|
||
.unwrap()
|
||
.host_to_connect,
|
||
self.channels
|
||
.get(&server_channel)
|
||
.unwrap()
|
||
.direct_tcpip
|
||
.as_ref()
|
||
.unwrap()
|
||
.port_to_connect
|
||
);
|
||
|
||
self.build_channel_open_confirmation(
|
||
server_channel,
|
||
sender_channel,
|
||
initial_window_size,
|
||
maximum_packet_size,
|
||
)
|
||
}
|
||
|
||
/// 处理forwarded-tcpip channel open(Phase 13.3: Local port forwarding)
|
||
fn handle_forwarded_tcpip_channel_open(
|
||
&mut self,
|
||
packet: &SshPacket,
|
||
sender_channel: u32,
|
||
initial_window_size: u32,
|
||
maximum_packet_size: u32,
|
||
) -> Result<SshPacket> {
|
||
info!("Processing forwarded-tcpip channel open");
|
||
|
||
// 解析forwarded-tcpip参数
|
||
let mut port_forward_manager = PortForwardManager::new();
|
||
let forwarded_tcpip =
|
||
port_forward_manager.handle_forwarded_tcpip_channel(&packet.payload)?;
|
||
|
||
let server_channel = self.next_channel_id;
|
||
self.next_channel_id += 1;
|
||
|
||
let channel = Channel {
|
||
server_channel,
|
||
sender_channel,
|
||
channel_type: "forwarded-tcpip".to_string(),
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 15: Window Control
|
||
remote_window: initial_window_size,
|
||
remote_maxpacket: maximum_packet_size,
|
||
local_window: 2097152,
|
||
local_window_max: 2097152,
|
||
local_consumed: 0,
|
||
local_maxpacket: 32768,
|
||
|
||
window_size: initial_window_size,
|
||
maximum_packet_size,
|
||
state: ChannelState::Open,
|
||
output_buffer: None,
|
||
sftp_handler: None,
|
||
scp_handler: None,
|
||
rsync_handler: None,
|
||
exec_process: None, // Phase 14: 交互式exec
|
||
sftp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.2修复
|
||
scp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.4修复
|
||
direct_tcpip: None,
|
||
forwarded_tcpip: Some(forwarded_tcpip),
|
||
};
|
||
|
||
self.channels.insert(server_channel, channel);
|
||
|
||
info!(
|
||
"forwarded-tcpip channel created: server_channel={}, bind={}, originator={}",
|
||
server_channel,
|
||
self.channels
|
||
.get(&server_channel)
|
||
.unwrap()
|
||
.forwarded_tcpip
|
||
.as_ref()
|
||
.unwrap()
|
||
.bind_port,
|
||
self.channels
|
||
.get(&server_channel)
|
||
.unwrap()
|
||
.forwarded_tcpip
|
||
.as_ref()
|
||
.unwrap()
|
||
.originator_address
|
||
);
|
||
|
||
self.build_channel_open_confirmation(
|
||
server_channel,
|
||
sender_channel,
|
||
initial_window_size,
|
||
maximum_packet_size,
|
||
)
|
||
}
|
||
/// 处理SSH_MSG_CHANNEL_REQUEST(参考OpenSSH channel.c: channel_request())
|
||
pub fn handle_channel_request(&mut self, packet: &SshPacket) -> Result<Option<SshPacket>> {
|
||
info!("Processing SSH_MSG_CHANNEL_REQUEST");
|
||
|
||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准)
|
||
|
||
// Packet type
|
||
let packet_type = cursor.read_u8()?;
|
||
if packet_type != PacketType::SSH_MSG_CHANNEL_REQUEST as u8 {
|
||
return Err(anyhow!("Invalid packet type for CHANNEL_REQUEST"));
|
||
}
|
||
|
||
// 读取recipient channel(u32)
|
||
let recipient_channel = cursor.read_u32::<BigEndian>()?;
|
||
|
||
// 读取请求类型(SSH string)
|
||
let request_type = read_ssh_string(&mut cursor)?;
|
||
|
||
// 读取want reply标志(boolean)
|
||
let want_reply = cursor.read_u8()? != 0;
|
||
|
||
info!(
|
||
"Channel request: channel={}, type={}, want_reply={}",
|
||
recipient_channel, request_type, want_reply
|
||
);
|
||
|
||
// 处理不同请求类型(参考OpenSSH channel.c)
|
||
if request_type == "exec" {
|
||
self.handle_exec_request(&mut cursor, recipient_channel, want_reply)
|
||
// 移除?操作符(返回Option不是Result)
|
||
} else if request_type == "subsystem" {
|
||
self.handle_subsystem_request(&mut cursor, recipient_channel, want_reply)
|
||
// 移除?操作符
|
||
} else if request_type == "shell" {
|
||
self.handle_shell_request(recipient_channel, want_reply) // 移除?操作符
|
||
} else if request_type == "env" {
|
||
self.handle_env_request(&mut cursor, recipient_channel, want_reply) // 移除?操作符
|
||
} else if request_type == "pty-req" {
|
||
self.handle_pty_request(&mut cursor, recipient_channel, want_reply) // 移除?操作符
|
||
} else {
|
||
warn!("Unsupported channel request: {}", request_type);
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_failure(recipient_channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 处理exec请求(参考OpenSSH channel.c: channel_request_exec() + session.c: do_exec_no_pty)
|
||
fn handle_exec_request(
|
||
&mut self,
|
||
cursor: &mut std::io::Cursor<&[u8]>,
|
||
channel: u32,
|
||
want_reply: bool,
|
||
) -> Result<Option<SshPacket>> {
|
||
info!("Handling exec request for channel {}", channel);
|
||
|
||
// 读取命令(SSH string)
|
||
let command = read_ssh_string(cursor)?;
|
||
|
||
info!("Exec command: {}", command);
|
||
|
||
// Phase 14: 所有exec命令使用交互式进程(与OpenSSH一致)
|
||
// ⭐⭐⭐⭐⭐ 修复:cat/grep/sed等命令需要stdin数据,必须使用交互式进程
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ [EXEC_REQUEST] Starting interactive process for: {}",
|
||
command
|
||
);
|
||
self.handle_interactive_exec(&command, channel, "exec")?;
|
||
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_success(channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
/// ⭐⭐⭐⭐⭐ Phase 16.5: rsync exec(使用真实rsync子进程,替代in-process handler)
|
||
fn handle_rsync_exec(&mut self, command: &str, channel_id: u32) -> Result<()> {
|
||
self.handle_interactive_exec(command, channel_id, "rsync")
|
||
}
|
||
|
||
/// Phase 14.5: 处理SCP交互式exec(scp -t destination 或 scp -f source)
|
||
/// ⭐⭐⭐⭐⭐ OpenSSH风格:使用poll()替代thread::spawn(非阻塞I/O)
|
||
fn handle_scp_exec(&mut self, command: &str, channel_id: u32) -> Result<()> {
|
||
// ⭐⭐⭐⭐⭐ SCP和rsync共用相同的交互式exec逻辑
|
||
self.handle_interactive_exec(command, channel_id, "scp")
|
||
}
|
||
|
||
/// ⭐⭐⭐⭐⭐ Phase 14.6: 交互式exec通用处理(rsync/SCP共用)
|
||
fn handle_interactive_exec(
|
||
&mut self,
|
||
command: &str,
|
||
channel_id: u32,
|
||
process_type: &str,
|
||
) -> Result<()> {
|
||
use std::os::unix::io::AsRawFd;
|
||
use std::process::{Command, Stdio};
|
||
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ [{}_EXEC_START] Starting interactive process: {}",
|
||
process_type, command
|
||
);
|
||
|
||
// 启动子进程(相当于OpenSSH fork)
|
||
// ⭐⭐⭐⭐⭐ Phase 17: 设置工作目录为用户home_dir(SFTPGo兼容)
|
||
let home_dir = self.home_dir.clone();
|
||
let mut child = Command::new("sh")
|
||
.arg("-c")
|
||
.arg(command)
|
||
.current_dir(&home_dir)
|
||
.stdin(Stdio::piped()) // ← 创建stdin管道(相当于pipe(pin))
|
||
.stdout(Stdio::piped()) // ← 创建stdout管道(相当于pipe(pout))
|
||
.stderr(Stdio::piped()) // ← 创建stderr管道(相当于pipe(perr))
|
||
.spawn()?;
|
||
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ [CHILD_SPAWNED] Child process spawned, PID: {}",
|
||
child.id()
|
||
);
|
||
|
||
// 提取管道(相当于OpenSSH dup2)
|
||
let stdin = child.stdin.take().ok_or(anyhow!("stdin take failed"))?;
|
||
let stdout = child.stdout.take().ok_or(anyhow!("stdout take failed"))?;
|
||
let stderr = child.stderr.take().ok_or(anyhow!("stderr take failed"))?;
|
||
|
||
// ⭐⭐⭐⭐⭐ OpenSSH关键:设置非阻塞模式(fcntl O_NONBLOCK)
|
||
// stdin 保持阻塞模式(write_all 需要阻塞写入)
|
||
let stdout_fd = stdout.as_raw_fd();
|
||
let stderr_fd = stderr.as_raw_fd();
|
||
|
||
info!("Setting stdout/stderr to non-blocking mode (OpenSSH style)");
|
||
fcntl(stdout_fd, FcntlArg::F_SETFL(OFlag::O_NONBLOCK))?;
|
||
fcntl(stderr_fd, FcntlArg::F_SETFL(OFlag::O_NONBLOCK))?;
|
||
info!(
|
||
"Non-blocking I/O enabled for stdout (fd {}) and stderr (fd {})",
|
||
stdout_fd, stderr_fd
|
||
);
|
||
|
||
// ⭐⭐⭐⭐⭐ OpenSSH风格:不再使用thread::spawn,直接保留File对象用于poll
|
||
// 存储到channel(相当于OpenSSH session_set_fds)
|
||
if let Some(ch) = self.channels.get_mut(&channel_id) {
|
||
ch.exec_process = Some(ExecProcess {
|
||
child,
|
||
stdin: Some(stdin),
|
||
stdout: Some(stdout), // ⭐⭐⭐⭐⭐ 直接保留File对象
|
||
stderr: Some(stderr), // ⭐⭐⭐⭐⭐ 直接保留File对象
|
||
stdout_fd, // ⭐⭐⭐⭐⭐ RawFd用于poll
|
||
stderr_fd, // ⭐⭐⭐⭐⭐ RawFd用于poll
|
||
command: command.to_string(), // ⭐⭐⭐⭐⭐ Phase 16.2: 存储exec命令(用于SCP检测)
|
||
reuse_buf: Vec::new(), // Phase 2a: reusable buffer
|
||
read_buf: Vec::new(), // Phase 2b: reusable read buffer
|
||
});
|
||
info!(
|
||
"Interactive process stored for channel {} (poll-ready)",
|
||
channel_id
|
||
);
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 执行命令并捕获输出(Phase 6基础实现)
|
||
fn execute_command(&self, command: &str) -> Result<Vec<u8>> {
|
||
use std::process::Command;
|
||
|
||
info!("Executing command: {}", command);
|
||
|
||
// 使用shell执行命令(参考OpenSSH session.c)
|
||
let output = Command::new("sh").arg("-c").arg(command).output()?;
|
||
|
||
// 返回stdout + stderr
|
||
let mut result = output.stdout;
|
||
result.extend_from_slice(&output.stderr);
|
||
|
||
info!("Command output: {} bytes", result.len());
|
||
Ok(result)
|
||
}
|
||
|
||
/// 处理subsystem请求(参考OpenSSH channel.c: channel_request_subsystem())
|
||
fn handle_subsystem_request(
|
||
&mut self,
|
||
cursor: &mut std::io::Cursor<&[u8]>,
|
||
channel: u32,
|
||
want_reply: bool,
|
||
) -> Result<Option<SshPacket>> {
|
||
info!("Handling subsystem request for channel {}", channel);
|
||
|
||
// 读取subsystem名称(SSH string)
|
||
let subsystem = read_ssh_string(cursor)?;
|
||
|
||
info!("Subsystem: {}", subsystem);
|
||
|
||
// 检查subsystem支持(OpenSSH支持:sftp, scp)
|
||
if subsystem == "sftp" {
|
||
info!("SFTP subsystem requested");
|
||
|
||
// Phase 7: 初始化SFTP handler(使用用户home目录,SFTPGo兼容)
|
||
let root_dir = self.home_dir.clone();
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 4: 获取 client maxpacket 限制(从 Channel 中获取)
|
||
let maxpacket = if let Some(ch) = self.channels.get(&channel) {
|
||
ch.remote_maxpacket // 来自 SSH_MSG_CHANNEL_OPEN 的 maximum_packet_size
|
||
} else {
|
||
32768 // OpenSSH 默认值(32KB)
|
||
};
|
||
|
||
let vfs = Box::new(crate::vfs::local_fs::LocalFs::new());
|
||
let sftp_handler = SftpHandler::new(
|
||
root_dir,
|
||
vfs,
|
||
maxpacket,
|
||
self.upload_hook.clone(),
|
||
self.user_uuid.clone(),
|
||
);
|
||
|
||
// 存储到channel
|
||
if let Some(ch) = self.channels.get_mut(&channel) {
|
||
ch.sftp_handler = Some(sftp_handler);
|
||
info!("SFTP handler initialized for channel {}", channel);
|
||
}
|
||
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_success(channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
} else if subsystem == "scp" {
|
||
info!("SCP subsystem requested");
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 8: 初始化SCP handler(使用用户home目录)
|
||
let root_dir = self.home_dir.clone();
|
||
let vfs = Box::new(crate::vfs::local_fs::LocalFs::new());
|
||
let scp_handler = ScpHandler::new(root_dir, vfs);
|
||
|
||
// 存储到channel
|
||
if let Some(ch) = self.channels.get_mut(&channel) {
|
||
ch.scp_handler = Some(scp_handler);
|
||
info!("SCP handler initialized for channel {}", channel);
|
||
}
|
||
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_success(channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
} else {
|
||
warn!("Unsupported subsystem: {}", subsystem);
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_failure(channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 处理shell请求(参考OpenSSH channel.c)
|
||
fn handle_shell_request(
|
||
&mut self,
|
||
channel: u32,
|
||
want_reply: bool,
|
||
) -> Result<Option<SshPacket>> {
|
||
info!("Handling shell request for channel {}", channel);
|
||
|
||
// Phase 9将实现shell
|
||
warn!("Shell not implemented in Phase 6");
|
||
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_failure(channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
/// 处理env请求(参考OpenSSH channel.c)
|
||
fn handle_env_request(
|
||
&mut self,
|
||
cursor: &mut std::io::Cursor<&[u8]>,
|
||
channel: u32,
|
||
want_reply: bool,
|
||
) -> Result<Option<SshPacket>> {
|
||
info!("Handling env request for channel {}", channel);
|
||
|
||
// 读取环境变量名和值
|
||
let name = read_ssh_string(cursor)?;
|
||
let value = read_ssh_string(cursor)?;
|
||
|
||
info!("Env: {}={}", name, value);
|
||
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_success(channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
/// 处理pty请求(参考OpenSSH channel.c)
|
||
fn handle_pty_request(
|
||
&mut self,
|
||
cursor: &mut std::io::Cursor<&[u8]>,
|
||
channel: u32,
|
||
want_reply: bool,
|
||
) -> Result<Option<SshPacket>> {
|
||
info!("Handling pty request for channel {}", channel);
|
||
|
||
// 读取terminal类型(SSH string)
|
||
let term = read_ssh_string(cursor)?;
|
||
|
||
// 读取窗口大小(4个uint32)
|
||
let width = cursor.read_u32::<BigEndian>()?;
|
||
let height = cursor.read_u32::<BigEndian>()?;
|
||
let _pixel_width = cursor.read_u32::<BigEndian>()?;
|
||
let _pixel_height = cursor.read_u32::<BigEndian>()?;
|
||
|
||
// 读取terminal modes(SSH string格式)
|
||
let modes_len = cursor.read_u32::<BigEndian>()?;
|
||
let mut modes = vec![0u8; modes_len as usize];
|
||
cursor.read_exact(&mut modes)?;
|
||
|
||
info!(
|
||
"PTY: term={}, width={}, height={}, modes_len={}",
|
||
term, width, height, modes_len
|
||
);
|
||
|
||
if want_reply {
|
||
Ok(Some(self.build_channel_success(channel)?))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
/// 处理SSH_MSG_CHANNEL_DATA(参考OpenSSH channel.c: channel_input_data())
|
||
pub fn handle_channel_data(&mut self, packet: &SshPacket) -> Result<Option<SshPacket>> {
|
||
info!("Processing SSH_MSG_CHANNEL_DATA");
|
||
|
||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
|
||
|
||
// Packet type
|
||
let packet_type = cursor.read_u8()?;
|
||
if packet_type != PacketType::SSH_MSG_CHANNEL_DATA as u8 {
|
||
return Err(anyhow!("Invalid packet type for CHANNEL_DATA"));
|
||
}
|
||
|
||
// 读取recipient channel
|
||
let recipient_channel = cursor.read_u32::<BigEndian>()?;
|
||
|
||
// 读取数据长度(SSH string — 先读长度,数据稍后读取)
|
||
let data_length = cursor.read_u32::<BigEndian>()?;
|
||
|
||
// Phase 14: 检查是否是交互式exec进程(用reuse buffer避免分配)
|
||
if let Some(channel) = self.channels.get_mut(&recipient_channel) {
|
||
if let Some(exec_process) = &mut channel.exec_process {
|
||
// Phase 2a: read into reusable buffer
|
||
exec_process.reuse_buf.resize(data_length as usize, 0);
|
||
cursor.read_exact(&mut exec_process.reuse_buf)?;
|
||
|
||
info!(
|
||
"Channel data: channel={}, length={}",
|
||
recipient_channel,
|
||
exec_process.reuse_buf.len()
|
||
);
|
||
|
||
// 转发数据到子进程stdin
|
||
if let Some(stdin) = &mut exec_process.stdin {
|
||
use std::io::Write;
|
||
info!("⭐⭐⭐⭐⭐ [STDIN_WRITE] Writing {} bytes to child stdin", exec_process.reuse_buf.len());
|
||
stdin.write_all(&exec_process.reuse_buf)?;
|
||
stdin.flush()?;
|
||
info!("⭐⭐⭐⭐⭐ [STDIN_FLUSH] Flushed stdin (channel {})", recipient_channel);
|
||
} else {
|
||
warn!("⚠️⚠️⚠️⚠️⚠️ [STDIN_MISSING] No stdin pipe available for channel {}", recipient_channel);
|
||
}
|
||
|
||
// Window Control — all in one borrow scope (NLL releases after last use)
|
||
let data_len = exec_process.reuse_buf.len() as u32;
|
||
channel.local_window -= data_len;
|
||
channel.local_consumed += data_len;
|
||
|
||
// No more uses of channel or exec_process after this point
|
||
|
||
// 检查窗口并发送 Window adjust
|
||
if let Some(window_adjust_packet) =
|
||
channel_check_window(recipient_channel, &mut self.channels)
|
||
{
|
||
return Ok(Some(window_adjust_packet));
|
||
}
|
||
|
||
return Ok(None);
|
||
}
|
||
|
||
// 非exec_process路径:分配data(供rsync/SFTP handlers使用)
|
||
let mut data = vec![0u8; data_length as usize];
|
||
cursor.read_exact(&mut data)?;
|
||
|
||
info!(
|
||
"Channel data: channel={}, length={}",
|
||
recipient_channel,
|
||
data.len()
|
||
);
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 8: SCP handler (subsystem)
|
||
if let Some(scp_handler) = &mut channel.scp_handler {
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ [SCP_DATA] Feeding {} bytes to ScpHandler",
|
||
data.len()
|
||
);
|
||
|
||
// Window Control - decrease local_window
|
||
channel.local_window -= data.len() as u32;
|
||
channel.local_consumed += data.len() as u32;
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 14.4: SCP packet accumulation
|
||
channel.scp_input_buffer.extend_from_slice(&data);
|
||
info!(
|
||
"SCP buffer accumulated: {} bytes total",
|
||
channel.scp_input_buffer.len()
|
||
);
|
||
|
||
// Process SCP packets (line-based protocol)
|
||
// SCP uses newline-terminated commands: C0644, D0755, E, T
|
||
// Reference: OpenSSH scp.c
|
||
|
||
// Find complete lines in buffer
|
||
let mut responses: Vec<Vec<u8>> = Vec::new();
|
||
while let Some(newline_pos) = channel.scp_input_buffer.iter().position(|&b| b == b'\n') {
|
||
let line = channel.scp_input_buffer[..newline_pos].to_vec();
|
||
channel.scp_input_buffer = channel.scp_input_buffer[newline_pos + 1..].to_vec();
|
||
|
||
info!("SCP command: {}", String::from_utf8_lossy(&line));
|
||
|
||
// Process SCP command
|
||
// TODO: Full implementation requires ScpHandler.handle_scp() with ReadWrite trait
|
||
// Current implementation: basic ACK (0 byte)
|
||
responses.push(vec![0]); // SCP ACK
|
||
}
|
||
|
||
// Check for window adjust
|
||
if let Some(window_adjust_packet) =
|
||
channel_check_window(recipient_channel, &mut self.channels)
|
||
{
|
||
return Ok(Some(window_adjust_packet));
|
||
}
|
||
|
||
// Send SCP responses
|
||
if !responses.is_empty() {
|
||
// All responses except last go to pending_packets
|
||
for i in 0..responses.len().saturating_sub(1) {
|
||
let pending = self.build_channel_data(recipient_channel, &responses[i])?;
|
||
self.pending_packets.push_back(pending);
|
||
}
|
||
|
||
// Last response is returned
|
||
if let Some(last_response) = responses.into_iter().last() {
|
||
return Ok(Some(self.build_channel_data(recipient_channel, &last_response)?));
|
||
}
|
||
}
|
||
|
||
return Ok(None);
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 16.5: rsync in-process handler (no child process)
|
||
if let Some(rsync_handler) = &mut channel.rsync_handler {
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ [RSYNC_DATA] Feeding {} bytes to RsyncHandler",
|
||
data.len()
|
||
);
|
||
|
||
rsync_handler.feed(&data)?;
|
||
|
||
let output = rsync_handler.drain_output();
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ [RSYNC_DATA] RsyncHandler produced {} bytes output, done={}",
|
||
output.len(),
|
||
rsync_handler.is_done()
|
||
);
|
||
|
||
// Window Control - decrease local_window
|
||
channel.local_window -= data.len() as u32;
|
||
channel.local_consumed += data.len() as u32;
|
||
|
||
// Check for window adjust
|
||
if let Some(window_adjust_packet) =
|
||
channel_check_window(recipient_channel, &mut self.channels)
|
||
{
|
||
return Ok(Some(window_adjust_packet));
|
||
}
|
||
|
||
if !output.is_empty() {
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ [RSYNC_DATA] Returning {} bytes as CHANNEL_DATA",
|
||
output.len()
|
||
);
|
||
return Ok(Some(self.build_channel_data(recipient_channel, &output)?));
|
||
}
|
||
|
||
return Ok(None);
|
||
}
|
||
|
||
// Phase 7: 检查是否是SFTP channel(⭐⭐⭐⭐⭐ Phase 14.3: packet accumulation)
|
||
// Extract SFTP result from channel borrow, then send outside
|
||
let sftp_result = if let Some(sftp_handler) = &mut channel.sftp_handler {
|
||
info!("Processing SFTP request ({} bytes)", data.len());
|
||
|
||
// ⭐⭐⭐⭐⭐ Window Control: decrease local_window
|
||
channel.local_window -= data.len() as u32;
|
||
channel.local_consumed += data.len() as u32;
|
||
|
||
// ⭐⭐⭐⭐⭐ Critical修复:累积SFTP packet数据
|
||
channel.sftp_input_buffer.extend_from_slice(&data);
|
||
info!(
|
||
"SFTP buffer accumulated: {} bytes total",
|
||
channel.sftp_input_buffer.len()
|
||
);
|
||
|
||
// ⭐⭐⭐⭐⭐ Process ALL complete SFTP packets from buffer (not just one)
|
||
let mut all_responses: Vec<Vec<u8>> = Vec::new();
|
||
loop {
|
||
if channel.sftp_input_buffer.len() < 4 {
|
||
info!("SFTP buffer too short for length field, waiting for more data");
|
||
break;
|
||
}
|
||
|
||
let sftp_length = u32::from_be_bytes([
|
||
channel.sftp_input_buffer[0],
|
||
channel.sftp_input_buffer[1],
|
||
channel.sftp_input_buffer[2],
|
||
channel.sftp_input_buffer[3],
|
||
]) as usize;
|
||
info!("SFTP packet length field: {}", sftp_length);
|
||
|
||
let expected_total = 4 + sftp_length;
|
||
if channel.sftp_input_buffer.len() < expected_total {
|
||
info!("SFTP packet incomplete: expected {} bytes, have {} bytes in buffer, waiting for more",
|
||
expected_total, channel.sftp_input_buffer.len());
|
||
break;
|
||
}
|
||
|
||
let sftp_packet = channel.sftp_input_buffer[4..expected_total].to_vec();
|
||
info!(
|
||
"SFTP packet complete: {} bytes, processing",
|
||
sftp_packet.len()
|
||
);
|
||
|
||
let response = sftp_handler.handle_request(&sftp_packet)?;
|
||
info!("SFTP response: {} bytes", response.len());
|
||
|
||
if channel.sftp_input_buffer.len() > expected_total {
|
||
let remaining = channel.sftp_input_buffer[expected_total..].to_vec();
|
||
channel.sftp_input_buffer = remaining;
|
||
info!(
|
||
"SFTP buffer has remaining {} bytes after processing",
|
||
channel.sftp_input_buffer.len()
|
||
);
|
||
} else {
|
||
channel.sftp_input_buffer.clear();
|
||
info!("SFTP buffer cleared after processing");
|
||
}
|
||
|
||
all_responses.push(response);
|
||
}
|
||
|
||
Some(all_responses)
|
||
} else {
|
||
None
|
||
};
|
||
|
||
if let Some(responses) = sftp_result {
|
||
// ⭐⭐⭐⭐⭐ Channel borrow is dropped; now we can use self freely
|
||
|
||
// All responses except the last go to pending_packets
|
||
for i in 0..responses.len().saturating_sub(1) {
|
||
let pending = self.build_channel_data(recipient_channel, &responses[i])?;
|
||
self.pending_packets.push_back(pending);
|
||
}
|
||
|
||
// Last response is returned (possibly with WINDOW_ADJUST)
|
||
if let Some(last_response) = responses.into_iter().last() {
|
||
// ⭐⭐⭐⭐⭐ Check window adjust (re-borrow channel briefly)
|
||
let (needs_window, consumed) =
|
||
if let Some(ch) = self.channels.get_mut(&recipient_channel) {
|
||
let window_used = ch.local_window_max - ch.local_window;
|
||
let need = (window_used > ch.local_maxpacket * 3)
|
||
|| (ch.local_window < ch.local_window_max / 2);
|
||
(need, ch.local_consumed)
|
||
} else {
|
||
(false, 0)
|
||
};
|
||
|
||
if needs_window && consumed > 0 {
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ [SFTP_WINDOW] Sending WINDOW_ADJUST before SFTP response"
|
||
);
|
||
let window_adjust = build_window_adjust(recipient_channel, consumed);
|
||
// Update window state
|
||
if let Some(ch) = self.channels.get_mut(&recipient_channel) {
|
||
ch.local_window += consumed;
|
||
ch.local_consumed = 0;
|
||
}
|
||
// Use standalone builder to avoid self borrow conflict
|
||
use std::io::Write;
|
||
let mut payload = Vec::new();
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_DATA as u8)?;
|
||
payload.write_u32::<BigEndian>(recipient_channel)?;
|
||
payload.write_u32::<BigEndian>(last_response.len() as u32)?;
|
||
payload.write_all(&last_response)?;
|
||
let sftp_packet = SshPacket::new(payload);
|
||
self.pending_packets.push_back(window_adjust);
|
||
self.pending_packets.push_back(sftp_packet);
|
||
return Ok(None);
|
||
}
|
||
return Ok(Some(
|
||
self.build_channel_data(recipient_channel, &last_response)?,
|
||
));
|
||
} else {
|
||
// No SFTP packets were complete, but maybe we need window adjust
|
||
if let Some(ch) = self.channels.get_mut(&recipient_channel) {
|
||
let window_used = ch.local_window_max - ch.local_window;
|
||
if (window_used > ch.local_maxpacket * 3
|
||
|| ch.local_window < ch.local_window_max / 2)
|
||
&& ch.local_consumed > 0
|
||
{
|
||
let window_adjust =
|
||
build_window_adjust(recipient_channel, ch.local_consumed);
|
||
ch.local_window += ch.local_consumed;
|
||
ch.local_consumed = 0;
|
||
self.pending_packets.push_back(window_adjust);
|
||
}
|
||
}
|
||
return Ok(None);
|
||
}
|
||
}
|
||
}
|
||
|
||
// 如果不是SFTP或exec_process,返回None
|
||
Ok(None)
|
||
}
|
||
|
||
/// ⭐⭐⭐⭐⭐ Phase 13.5: 处理 client 发送的 SSH_MSG_CHANNEL_WINDOW_ADJUST
|
||
pub fn adjust_remote_window(&mut self, recipient_channel: u32, bytes_to_add: u32) {
|
||
if let Some(channel) = self.channels.get_mut(&recipient_channel) {
|
||
channel.remote_window = channel.remote_window.saturating_add(bytes_to_add);
|
||
info!("⭐⭐⭐⭐⭐ [ADJUST_REMOTE_WINDOW] channel {} remote_window increased by {} (new: {})",
|
||
recipient_channel, bytes_to_add, channel.remote_window);
|
||
}
|
||
}
|
||
|
||
/// Phase 14: 构建SSH_MSG_CHANNEL_EXTENDED_DATA(参考OpenSSH channel.c)
|
||
fn build_channel_extended_data(
|
||
&self,
|
||
channel: u32,
|
||
data_type: u32,
|
||
data: &[u8],
|
||
) -> Result<SshPacket> {
|
||
let mut buffer = Vec::new();
|
||
|
||
buffer.write_u8(PacketType::SSH_MSG_CHANNEL_EXTENDED_DATA as u8)?;
|
||
buffer.write_u32::<BigEndian>(channel)?;
|
||
buffer.write_u32::<BigEndian>(data_type)?; // 1 = stderr, 2 = exit status
|
||
buffer.write_u32::<BigEndian>(data.len() as u32)?;
|
||
buffer.write_all(data)?;
|
||
|
||
Ok(SshPacket {
|
||
packet_length: 0,
|
||
padding_length: 0,
|
||
payload: buffer,
|
||
padding: Vec::new(),
|
||
})
|
||
}
|
||
|
||
/// 处理SSH_MSG_CHANNEL_CLOSE(参考OpenSSH channel.c: channel_input_close())
|
||
pub fn handle_channel_close(&mut self, packet: &SshPacket) -> Result<Option<SshPacket>> {
|
||
info!("Processing SSH_MSG_CHANNEL_CLOSE");
|
||
|
||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准)
|
||
|
||
// Packet type
|
||
let packet_type = cursor.read_u8()?;
|
||
if packet_type != PacketType::SSH_MSG_CHANNEL_CLOSE as u8 {
|
||
return Err(anyhow!("Invalid packet type for CHANNEL_CLOSE"));
|
||
}
|
||
|
||
// 读取recipient channel
|
||
let recipient_channel = cursor.read_u32::<BigEndian>()?;
|
||
|
||
info!("Channel close: channel={}", recipient_channel);
|
||
|
||
// 移除channel(参考OpenSSH channel.c)
|
||
if let Some(channel) = self.channels.remove(&recipient_channel) {
|
||
info!("Channel {} removed", recipient_channel);
|
||
|
||
// 发送SSH_MSG_CHANNEL_CLOSE回应
|
||
Ok(Some(self.build_channel_close(channel.sender_channel)?))
|
||
} else {
|
||
warn!("Channel {} not found", recipient_channel);
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
/// 构建SSH_MSG_CHANNEL_OPEN_CONFIRMATION(参考OpenSSH channel.c)
|
||
fn build_channel_open_confirmation(
|
||
&self,
|
||
server_channel: u32,
|
||
sender_channel: u32,
|
||
window_size: u32,
|
||
packet_size: u32,
|
||
) -> Result<SshPacket> {
|
||
let mut payload = Vec::new();
|
||
|
||
// Packet type
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_OPEN_CONFIRMATION as u8)?;
|
||
|
||
// Server channel number
|
||
payload.write_u32::<BigEndian>(server_channel)?;
|
||
|
||
// Sender channel number
|
||
payload.write_u32::<BigEndian>(sender_channel)?;
|
||
|
||
// Initial window size
|
||
payload.write_u32::<BigEndian>(window_size)?;
|
||
|
||
// Maximum packet size
|
||
payload.write_u32::<BigEndian>(packet_size)?;
|
||
|
||
Ok(SshPacket::new(payload))
|
||
}
|
||
|
||
/// 构建SSH_MSG_CHANNEL_OPEN_FAILURE(参考OpenSSH channel.c)
|
||
fn build_channel_open_failure(
|
||
&self,
|
||
sender_channel: u32,
|
||
reason_code: u32,
|
||
description: &str,
|
||
language: &str,
|
||
) -> Result<SshPacket> {
|
||
let mut payload = Vec::new();
|
||
|
||
// Packet type
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_OPEN_FAILURE as u8)?;
|
||
|
||
// Sender channel number
|
||
payload.write_u32::<BigEndian>(sender_channel)?;
|
||
|
||
// Reason code
|
||
payload.write_u32::<BigEndian>(reason_code)?;
|
||
|
||
// Description(SSH string)
|
||
payload.write_u32::<BigEndian>(description.len() as u32)?;
|
||
payload.write_all(description.as_bytes())?;
|
||
|
||
// Language(SSH string)
|
||
payload.write_u32::<BigEndian>(language.len() as u32)?;
|
||
payload.write_all(language.as_bytes())?;
|
||
|
||
Ok(SshPacket::new(payload))
|
||
}
|
||
|
||
/// 构建SSH_MSG_CHANNEL_SUCCESS(参考OpenSSH channel.c)
|
||
fn build_channel_success(&self, channel: u32) -> Result<SshPacket> {
|
||
let mut payload = Vec::new();
|
||
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_SUCCESS as u8)?;
|
||
payload.write_u32::<BigEndian>(channel)?;
|
||
|
||
Ok(SshPacket::new(payload))
|
||
}
|
||
|
||
/// 构建SSH_MSG_CHANNEL_FAILURE(参考OpenSSH channel.c)
|
||
fn build_channel_failure(&self, channel: u32) -> Result<SshPacket> {
|
||
let mut payload = Vec::new();
|
||
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_FAILURE as u8)?;
|
||
payload.write_u32::<BigEndian>(channel)?;
|
||
|
||
Ok(SshPacket::new(payload))
|
||
}
|
||
|
||
/// 构建SSH_MSG_CHANNEL_CLOSE(参考OpenSSH channel.c)
|
||
pub fn build_channel_close(&self, channel: u32) -> Result<SshPacket> {
|
||
let mut payload = Vec::new();
|
||
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_CLOSE as u8)?;
|
||
payload.write_u32::<BigEndian>(channel)?;
|
||
|
||
Ok(SshPacket::new(payload))
|
||
}
|
||
|
||
/// 构建SSH_MSG_CHANNEL_DATA(Phase 6新增)
|
||
pub fn build_channel_data(&self, channel: u32, data: &[u8]) -> Result<SshPacket> {
|
||
info!("⭐⭐⭐⭐⭐ [build_channel_data] Building SSH_MSG_CHANNEL_DATA: channel={}, data_len={}", channel, data.len());
|
||
let mut payload = Vec::new();
|
||
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_DATA as u8)?;
|
||
payload.write_u32::<BigEndian>(channel)?;
|
||
payload.write_u32::<BigEndian>(data.len() as u32)?;
|
||
payload.write_all(data)?;
|
||
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ [build_channel_data] Packet built successfully, payload_len={}",
|
||
payload.len()
|
||
);
|
||
Ok(SshPacket::new(payload))
|
||
}
|
||
|
||
/// 构建SSH_MSG_CHANNEL_EOF(Phase 6新增)
|
||
pub fn build_channel_eof(&self, channel: u32) -> Result<SshPacket> {
|
||
let mut payload = Vec::new();
|
||
|
||
payload.write_u8(PacketType::SSH_MSG_CHANNEL_EOF as u8)?;
|
||
payload.write_u32::<BigEndian>(channel)?;
|
||
|
||
Ok(SshPacket::new(payload))
|
||
}
|
||
|
||
/// 获取有输出待发送的channel ID(Phase 6新增)
|
||
pub fn get_channel_with_output(&self) -> Option<u32> {
|
||
for (&id, channel) in &self.channels {
|
||
if channel.output_buffer.is_some() {
|
||
return Some(id);
|
||
}
|
||
}
|
||
None
|
||
}
|
||
|
||
/// ⭐⭐⭐⭐⭐ Phase 14.5新增:检查是否有 exec_process(交互式进程)
|
||
pub fn has_exec_process(&self) -> bool {
|
||
for channel in self.channels.values() {
|
||
if channel.exec_process.is_some() || channel.rsync_handler.is_some() {
|
||
return true;
|
||
}
|
||
}
|
||
false
|
||
}
|
||
|
||
/// Phase 17: 关闭所有子进程stdin(收到CHANNEL_EOF时调用)
|
||
/// SCP upload需要:scp -t 等待EOF on stdin才知道数据传输完毕
|
||
pub fn close_child_stdin(&mut self) {
|
||
let channel_ids: Vec<u32> = self.channels.keys().copied().collect();
|
||
for id in channel_ids {
|
||
if let Some(channel) = self.channels.get_mut(&id) {
|
||
if let Some(exec) = &mut channel.exec_process {
|
||
if let Some(stdin) = exec.stdin.take() {
|
||
drop(stdin);
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ [CHANNEL_EOF] Closed child stdin (channel {})",
|
||
id
|
||
);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 获取channel输出(Phase 6新增)
|
||
pub fn get_channel_output(&mut self, channel_id: u32) -> Option<Vec<u8>> {
|
||
if let Some(channel) = self.channels.get_mut(&channel_id) {
|
||
channel.output_buffer.take()
|
||
} else {
|
||
None
|
||
}
|
||
}
|
||
|
||
/// 移除channel(Phase 6新增)
|
||
pub fn remove_channel(&mut self, channel_id: u32) {
|
||
self.channels.remove(&channel_id);
|
||
}
|
||
|
||
/// Phase 14: OpenSSH风格poll机制(使用nix::poll监听stdout/stderr fd)
|
||
/// ⭐⭐⭐⭐⭐ 关键:非阻塞读取数据,不等待子进程完成
|
||
/// ⭐⭐⭐⭐⭐ Phase 14.2: 处理child exited(发送EOF + CLOSE)
|
||
/// 参考:OpenSSH session.c: do_exec_no_pty()
|
||
pub fn handle_child_exited(&mut self) -> Result<Vec<SshPacket>> {
|
||
// 1. 收集需要处理的channel IDs (exec_process OR rsync_handler)
|
||
let channel_ids: Vec<u32> = self
|
||
.channels
|
||
.iter()
|
||
.filter_map(|(id, channel)| {
|
||
if channel.exec_process.is_some() || channel.rsync_handler.is_some() {
|
||
Some(*id)
|
||
} else {
|
||
None
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
// 2. 构建packets(避免borrow冲突)
|
||
let mut packets = Vec::new();
|
||
for channel_id in &channel_ids {
|
||
let eof_packet = self.build_channel_eof(*channel_id)?;
|
||
packets.push(eof_packet);
|
||
|
||
let close_packet = self.build_channel_close(*channel_id)?;
|
||
packets.push(close_packet);
|
||
}
|
||
|
||
// 3. 清除exec_process + rsync_handler(mutable borrow)
|
||
for channel_id in &channel_ids {
|
||
if let Some(channel) = self.channels.get_mut(channel_id) {
|
||
channel.exec_process = None;
|
||
channel.rsync_handler = None;
|
||
}
|
||
}
|
||
|
||
if !channel_ids.is_empty() {
|
||
info!(
|
||
"Child/rsync exited, sent EOF + CLOSE for {} channels",
|
||
channel_ids.len()
|
||
);
|
||
}
|
||
|
||
Ok(packets)
|
||
}
|
||
|
||
/// ⭐⭐⭐⭐⭐ Phase 14.2: OpenSSH统一poll + child进程状态检测
|
||
/// 参考:OpenSSH session.c: do_exec_no_pty() + channel.c: channel_handle_fd()
|
||
///
|
||
/// 关键改进(Phase 14.2):
|
||
/// - 单次poll()同时监听client socket和子进程输出
|
||
/// - timeout 10ms(非阻塞)
|
||
/// - **添加child进程状态检测**(防止无限spinning)⭐⭐⭐⭐⭐
|
||
/// - **添加max_poll_iterations限制**(最多100次,1秒)
|
||
/// - 返回(stdout_packets, client_has_data, child_exited)
|
||
pub fn poll_exec_stdout_and_client(
|
||
&mut self,
|
||
stream: &std::net::TcpStream,
|
||
) -> Result<(Option<Vec<SshPacket>>, bool, bool)> {
|
||
use nix::poll::{poll, PollFd, PollFlags};
|
||
use std::io::Read;
|
||
use std::os::unix::io::{AsRawFd, BorrowedFd};
|
||
|
||
// 收集所有需要poll的fd
|
||
// Phase 3: 预分配 poll_fds_vec(避免频繁扩容)
|
||
let mut poll_fds_vec = Vec::with_capacity(self.channels.len() * 3 + 1); // 最多 (channels * 3) + client fd
|
||
let mut client_has_data = false;
|
||
let mut child_exited = false;
|
||
|
||
// 1. 添加client socket fd(监听stdin数据)
|
||
let client_fd = stream.as_raw_fd();
|
||
let client_poll_fd = unsafe { BorrowedFd::borrow_raw(client_fd) };
|
||
poll_fds_vec.push(PollFd::new(client_poll_fd, PollFlags::POLLIN));
|
||
let client_fd_idx = 0; // client fd总是第一个
|
||
|
||
// 2. 添加所有channel的stdout/stderr fd
|
||
let mut channel_fds_map: HashMap<u32, (usize, usize)> = HashMap::new(); // channel_id -> (stdout_idx, stderr_idx)
|
||
let mut channel_ids_vec = Vec::new(); // 用于后续child状态检查
|
||
|
||
for (channel_id, channel) in &self.channels {
|
||
if let Some(exec_process) = &channel.exec_process {
|
||
channel_ids_vec.push(*channel_id);
|
||
|
||
// stdout fd
|
||
if let Some(_stdout) = &exec_process.stdout {
|
||
let stdout_poll_fd = unsafe { BorrowedFd::borrow_raw(exec_process.stdout_fd) };
|
||
poll_fds_vec.push(PollFd::new(stdout_poll_fd, PollFlags::POLLIN));
|
||
}
|
||
|
||
// stderr fd
|
||
if let Some(_stderr) = &exec_process.stderr {
|
||
let stderr_poll_fd = unsafe { BorrowedFd::borrow_raw(exec_process.stderr_fd) };
|
||
poll_fds_vec.push(PollFd::new(stderr_poll_fd, PollFlags::POLLIN));
|
||
}
|
||
|
||
// 记录索引(相对于client_fd_idx)
|
||
let stdout_idx = poll_fds_vec.len() - 2;
|
||
let stderr_idx = poll_fds_vec.len() - 1;
|
||
channel_fds_map.insert(*channel_id, (stdout_idx, stderr_idx));
|
||
}
|
||
}
|
||
|
||
if poll_fds_vec.len() == 1 {
|
||
// 只有client fd,没有exec_process
|
||
// ⭐⭐⭐⭐⭐ Phase 16.5: 检查rsync handler的pending output
|
||
|
||
// 检查rsync handler是否done(先收集,避免borrow冲突)
|
||
let mut rsync_is_done = false;
|
||
|
||
// Drain rsync handler output (mutable borrow)
|
||
let mut rsync_items: Vec<(u32, Vec<u8>)> = Vec::new();
|
||
for channel in self.channels.values_mut() {
|
||
if let Some(rsync) = &mut channel.rsync_handler {
|
||
let out = rsync.drain_output();
|
||
if !out.is_empty() {
|
||
let sid = channel.server_channel;
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ [RSYNC_POLL] {} bytes pending from rsync handler",
|
||
out.len()
|
||
);
|
||
rsync_items.push((sid, out));
|
||
}
|
||
}
|
||
}
|
||
|
||
// Check rsync done (immutable borrow)
|
||
rsync_is_done = self
|
||
.channels
|
||
.values()
|
||
.any(|ch| ch.rsync_handler.as_ref().is_some_and(|r| r.is_done()));
|
||
|
||
// Directly poll client
|
||
match poll(&mut poll_fds_vec, 10u16) {
|
||
Ok(n) if n > 0 => {
|
||
if let Some(revents) = poll_fds_vec[client_fd_idx].revents() {
|
||
if revents.contains(PollFlags::POLLIN) {
|
||
client_has_data = true;
|
||
}
|
||
}
|
||
}
|
||
_ => {}
|
||
}
|
||
|
||
if rsync_is_done {
|
||
info!("⭐⭐⭐⭐⭐ [RSYNC_DONE] RsyncHandler is done, signaling child_exited");
|
||
}
|
||
|
||
// Return rsync output if any
|
||
if !rsync_items.is_empty() {
|
||
let mut packets = Vec::new();
|
||
for (channel_id, data) in rsync_items {
|
||
packets.push(self.build_channel_data(channel_id, &data)?);
|
||
}
|
||
return Ok((Some(packets), client_has_data, rsync_is_done));
|
||
}
|
||
|
||
return Ok((None, client_has_data, rsync_is_done));
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 16.4修复:增加poll轮询限制(支持大文件传输)
|
||
// 最多轮询2000次(200秒),poll timeout从10ms改到100ms
|
||
// 修复:从500改到2000,支持50MB+文件传输(预计可传输500MB+)
|
||
let max_poll_iterations = 2000;
|
||
let mut poll_iteration = 0;
|
||
let mut found_data = false;
|
||
let mut stdin_closed = false; // ⭐⭐⭐⭐⭐ 新增:跟踪stdin是否已关闭
|
||
|
||
for iteration in 0..max_poll_iterations {
|
||
poll_iteration = iteration;
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 16.2.1优化:增加poll timeout(减少iteration overhead)
|
||
// 每50次轮询记录一次日志(从10改到50,减少噪音)
|
||
if iteration % 50 == 0 {
|
||
info!(
|
||
"Polling {} fds (iteration {} of {}, stdin_closed={})",
|
||
poll_fds_vec.len(),
|
||
iteration,
|
||
max_poll_iterations,
|
||
stdin_closed
|
||
);
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 16.2.1优化:增加poll timeout(减少iteration overhead)
|
||
match poll(&mut poll_fds_vec, 100u16) {
|
||
Ok(n) if n > 0 => {
|
||
info!("{} fds have data available (iteration {})", n, iteration);
|
||
found_data = true;
|
||
break; // 有数据,立即处理
|
||
}
|
||
Ok(0) => {
|
||
// timeout,无数据
|
||
// ⭐⭐⭐⭐⭐ Phase 16.2.1优化:减少child状态检查频率(每50次)
|
||
if iteration % 50 == 49 {
|
||
// 检查child是否exited
|
||
for channel_id in &channel_ids_vec {
|
||
if let Some(channel) = self.channels.get_mut(channel_id) {
|
||
if let Some(exec_process) = &mut channel.exec_process {
|
||
match exec_process.child.try_wait() {
|
||
Ok(Some(status)) => {
|
||
info!(
|
||
"Child process exited (channel {}, status: {:?})",
|
||
channel_id, status
|
||
);
|
||
child_exited = true;
|
||
|
||
let command_str = exec_process.command.clone();
|
||
let should_trigger_hook = status.success()
|
||
&& (command_str.contains("scp") || command_str.contains("rsync"));
|
||
|
||
if let Some(stdout) = &mut exec_process.stdout {
|
||
exec_process.read_buf.resize(32768, 0);
|
||
match stdout.read(&mut exec_process.read_buf) {
|
||
Ok(n) if n > 0 => {
|
||
info!("Read {} final bytes from stdout (child exited)", n);
|
||
let data = exec_process.read_buf[..n].to_vec();
|
||
let packet = self.build_channel_data(
|
||
*channel_id,
|
||
&data,
|
||
)?;
|
||
return Ok((
|
||
Some(vec![packet]),
|
||
false,
|
||
true,
|
||
));
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
if should_trigger_hook {
|
||
let dest_path = Self::extract_dest_path_from_command(&command_str, &self.home_dir);
|
||
if let Some(path) = dest_path {
|
||
if let Some(hook) = &self.upload_hook {
|
||
if let Err(e) = hook.trigger(&path, &self.user_uuid) {
|
||
warn!("Upload hook failed for {:?}: {}", path, e);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 没有剩余数据,返回child_exited标志
|
||
return Ok((None, false, true));
|
||
}
|
||
Ok(None) => {
|
||
// Child still running(正常)
|
||
info!("Child still running (channel {}, iteration {}, stdin_closed={})", channel_id, iteration, stdin_closed);
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 16.4修复:增加stdin超时机制(支持大文件传输)
|
||
// 如果stdin未关闭,且超过1500次poll(150s)无数据
|
||
// 强制关闭stdin,发送EOF给SCP/rsync
|
||
// ⭐⭐⭐⭐⭐ Phase 16.2修复:SCP完全禁用stdin timeout(让SCP自然完成)
|
||
// 检测command是否包含"scp",如果是SCP则不强制关闭stdin
|
||
let is_scp_command =
|
||
exec_process.command.contains("scp");
|
||
|
||
if !stdin_closed
|
||
&& !is_scp_command
|
||
&& iteration >= 1500
|
||
&& exec_process.stdin.is_some()
|
||
{
|
||
info!("⭐⭐⭐⭐⭐ Forcing stdin close after {} iterations ({} ms) - sending EOF to rsync (SCP excluded)", iteration, iteration * 100);
|
||
exec_process.stdin = None; // Drop stdin,发送EOF
|
||
stdin_closed = true;
|
||
|
||
// ⭐⭐⭐⭐⭐ stdin关闭后,继续等待child处理完成
|
||
// 不要立即返回,给rsync时间处理数据并产生stdout
|
||
info!("stdin closed, continuing to poll for stdout output...");
|
||
}
|
||
}
|
||
Err(e) => {
|
||
warn!("Child try_wait error: {}", e);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 继续轮询(如果iteration < max_poll_iterations)
|
||
}
|
||
Err(e) => {
|
||
warn!("poll error: {}", e);
|
||
return Ok((None, false, false));
|
||
}
|
||
Ok(_) => {
|
||
// 其他情况(不应该发生)
|
||
}
|
||
}
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ 达到max_poll_iterations,检查最终child状态
|
||
if !found_data {
|
||
info!(
|
||
"No data after {} iterations ({} ms), checking child status",
|
||
max_poll_iterations,
|
||
max_poll_iterations * 10
|
||
);
|
||
|
||
for channel_id in &channel_ids_vec {
|
||
if let Some(channel) = self.channels.get_mut(channel_id) {
|
||
if let Some(exec_process) = &mut channel.exec_process {
|
||
match exec_process.child.try_wait() {
|
||
Ok(Some(status)) => {
|
||
info!("Child exited after max iterations (status: {:?})", status);
|
||
child_exited = true;
|
||
|
||
// 读取剩余stdout
|
||
if let Some(stdout) = &mut exec_process.stdout {
|
||
exec_process.read_buf.resize(32768, 0);
|
||
match stdout.read(&mut exec_process.read_buf) {
|
||
Ok(n) if n > 0 => {
|
||
let data = exec_process.read_buf[..n].to_vec();
|
||
let packet = self.build_channel_data(*channel_id, &data)?;
|
||
return Ok((Some(vec![packet]), false, true));
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
return Ok((None, false, true));
|
||
}
|
||
Ok(None) => {
|
||
info!("Child still running after max iterations, returning None");
|
||
// Child还在运行,但无stdout数据
|
||
// 主循环会继续调用此函数
|
||
return Ok((None, false, false));
|
||
}
|
||
Err(e) => {
|
||
warn!("Final child check error: {}", e);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ 处理找到的数据(如果found_data)
|
||
// 3. 检查client fd状态(包括EOF/HUP)
|
||
if let Some(revents) = poll_fds_vec[client_fd_idx].revents() {
|
||
if revents.contains(PollFlags::POLLIN) {
|
||
info!("Client fd has data (stdin from client)");
|
||
client_has_data = true;
|
||
} else if revents.contains(PollFlags::POLLHUP) {
|
||
info!("Client fd hangup (EOF received from client)");
|
||
// ⭐⭐⭐⭐⭐ Phase 14.2关键修复:关闭stdin pipe,发送EOF给child
|
||
// 参考:OpenSSH session.c: do_exec_no_pty() stdin handling
|
||
for channel in self.channels.values_mut() {
|
||
if let Some(exec_process) = &mut channel.exec_process {
|
||
if exec_process.stdin.is_some() {
|
||
info!("Closing stdin pipe (sending EOF to child process)");
|
||
exec_process.stdin = None; // Drop stdin,发送EOF给child
|
||
}
|
||
}
|
||
}
|
||
client_has_data = false;
|
||
} else if revents.contains(PollFlags::POLLERR) {
|
||
warn!("Client fd error");
|
||
return Err(anyhow::anyhow!("Client socket error"));
|
||
}
|
||
}
|
||
|
||
// 4. 检查stdout/stderr fd是否有数据
|
||
let mut packets_data: Vec<(u32, Vec<u8>)> = Vec::new();
|
||
let mut stderr_packets: Vec<(u32, Vec<u8>)> = Vec::new(); // Phase 17: stderr → CHANNEL_EXTENDED_DATA
|
||
|
||
for (channel_id, (stdout_idx, stderr_idx)) in channel_fds_map {
|
||
if let Some(channel) = self.channels.get_mut(&channel_id) {
|
||
if let Some(exec_process) = &mut channel.exec_process {
|
||
// 检查stdout
|
||
if let Some(revents) = poll_fds_vec[stdout_idx].revents() {
|
||
if revents.contains(PollFlags::POLLIN) {
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ [stdout POLLIN] stdout fd has data (channel {})",
|
||
channel_id
|
||
);
|
||
if let Some(stdout) = &mut exec_process.stdout {
|
||
exec_process.read_buf.resize(32768, 0);
|
||
info!("⭐⭐⭐⭐⭐ [BEFORE stdout.read] Attempting to read from stdout (buffer size 32KB)");
|
||
match stdout.read(&mut exec_process.read_buf) {
|
||
Ok(n) if n > 0 => {
|
||
info!("⭐⭐⭐⭐⭐ [AFTER stdout.read] Read {} bytes from stdout (channel {})", n, channel_id);
|
||
packets_data.push((channel_id, exec_process.read_buf[..n].to_vec()));
|
||
}
|
||
Ok(0) => {
|
||
info!(
|
||
"stdout EOF (channel {}), closing stdout pipe",
|
||
channel_id
|
||
);
|
||
// ⭐⭐⭐⭐⭐ Critical修复:EOF时关闭pipe,避免无限循环
|
||
exec_process.stdout = None;
|
||
}
|
||
Err(e) if e.kind() != std::io::ErrorKind::WouldBlock => {
|
||
warn!("stdout read error: {}", e);
|
||
exec_process.stdout = None; // 错误时也关闭
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 检查stderr
|
||
if let Some(revents) = poll_fds_vec[stderr_idx].revents() {
|
||
if revents.contains(PollFlags::POLLIN) {
|
||
info!("stderr fd has data (channel {})", channel_id);
|
||
if let Some(stderr) = &mut exec_process.stderr {
|
||
exec_process.read_buf.resize(32768, 0);
|
||
info!("⭐⭐⭐⭐⭐ [BEFORE stderr.read] Attempting to read from stderr (buffer size 32KB)");
|
||
match stderr.read(&mut exec_process.read_buf) {
|
||
Ok(n) if n > 0 => {
|
||
info!("⭐⭐⭐⭐⭐ [AFTER stderr.read] Read {} bytes from stderr (channel {})", n, channel_id);
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ stderr content: {:?}",
|
||
&exec_process.read_buf[..std::cmp::min(50, n)]
|
||
);
|
||
// ⭐⭐⭐⭐⭐ Phase 17: stderr → SSH_MSG_CHANNEL_EXTENDED_DATA (data_type=1)
|
||
stderr_packets.push((channel_id, exec_process.read_buf[..n].to_vec()));
|
||
}
|
||
Ok(0) => {
|
||
info!(
|
||
"stderr EOF (channel {}), closing stderr pipe",
|
||
channel_id
|
||
);
|
||
// ⭐⭐⭐⭐⭐ Critical修复:EOF时关闭pipe,避免无限循环
|
||
exec_process.stderr = None;
|
||
}
|
||
Err(e) if e.kind() != std::io::ErrorKind::WouldBlock => {
|
||
warn!("stderr read error: {}", e);
|
||
exec_process.stderr = None; // 错误时也关闭
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
}
|
||
// ⭐⭐⭐⭐⭐ 检查 POLLHUP(pipe 关闭)
|
||
if revents.contains(PollFlags::POLLHUP) {
|
||
info!("stderr POLLHUP (channel {}), pipe closed", channel_id);
|
||
exec_process.stderr = None;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 构建packets
|
||
if !packets_data.is_empty() || !stderr_packets.is_empty() {
|
||
let mut packets = Vec::new();
|
||
for (channel_id, data) in packets_data {
|
||
let packet = self.build_channel_data(channel_id, &data)?;
|
||
packets.push(packet);
|
||
}
|
||
// Phase 17: stderr → SSH_MSG_CHANNEL_EXTENDED_DATA (data_type=1)
|
||
for (channel_id, data) in stderr_packets {
|
||
let packet = self.build_channel_extended_data(channel_id, 1, &data)?;
|
||
packets.push(packet);
|
||
}
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ Returning {} packets (stdout/stderr data)",
|
||
packets.len()
|
||
);
|
||
return Ok((Some(packets), client_has_data, child_exited));
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 14.2最终修复:stdout/stderr EOF后检查child exited
|
||
// 当stdout和stderr都关闭后,强制检查child状态
|
||
for channel_id in &channel_ids_vec {
|
||
if let Some(channel) = self.channels.get_mut(channel_id) {
|
||
if let Some(exec_process) = &mut channel.exec_process {
|
||
// 检查stdout和stderr是否都已关闭
|
||
if exec_process.stdout.is_none() && exec_process.stderr.is_none() {
|
||
info!(
|
||
"stdout/stderr both closed (channel {}), checking child status",
|
||
channel_id
|
||
);
|
||
|
||
// ⭐⭐⭐⭐⭐ 立即检查child是否exited
|
||
match exec_process.child.try_wait() {
|
||
Ok(Some(status)) => {
|
||
info!("⭐⭐⭐⭐⭐ Child exited after stdout/stderr EOF (status: {:?})", status);
|
||
child_exited = true;
|
||
|
||
// ⭐⭐⭐⭐⭐ 关键:立即返回child_exited标志
|
||
// server.rs会发送SSH_MSG_CHANNEL_EOF + CLOSE
|
||
info!("⭐⭐⭐⭐⭐ No packets to send, returning child_exited flag");
|
||
return Ok((None, false, true));
|
||
}
|
||
Ok(None) => {
|
||
// Child still running but stdout/stderr closed
|
||
// 等待child exited
|
||
info!("Child still running after pipes closed, waiting...");
|
||
}
|
||
Err(e) => {
|
||
warn!("Child try_wait error after pipes closed: {}", e);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 有数据但只有client数据
|
||
Ok((None, client_has_data, child_exited))
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 14.0: 旧版poll(仅监听stdout/stderr,已废弃)
|
||
/// 已废弃:使用poll_exec_stdout_and_client()替代
|
||
#[allow(dead_code)]
|
||
pub fn poll_exec_stdout_with_fds(&mut self) -> Result<Option<Vec<SshPacket>>> {
|
||
use std::io::Read;
|
||
use std::os::unix::io::BorrowedFd;
|
||
|
||
// 遍历所有channel,收集poll_fds
|
||
// Phase 3: 预分配 poll_fds_vec(避免频繁扩容)
|
||
let mut poll_fds_vec = Vec::with_capacity(self.channels.len() * 2); // 最多 channels * 2 (stdout + stderr)
|
||
let mut channel_fds_map: HashMap<u32, (usize, usize)> = HashMap::new(); // channel_id -> (stdout_idx, stderr_idx) in poll_fds_vec
|
||
|
||
for (channel_id, channel) in &self.channels {
|
||
if let Some(exec_process) = &channel.exec_process {
|
||
// ⭐⭐⭐⭐⭐ OpenSSH风格:创建PollFd监听stdout/stderr
|
||
// nix 0.29 API: PollFd::new()需要借用fd,不是RawFd
|
||
if let Some(_stdout) = &exec_process.stdout {
|
||
let stdout_poll_fd = unsafe {
|
||
// ⭐⭐⭐⭐⭐ 使用BorrowedFd::borrow_raw()(正确API)
|
||
BorrowedFd::borrow_raw(exec_process.stdout_fd)
|
||
};
|
||
poll_fds_vec.push(PollFd::new(stdout_poll_fd, PollFlags::POLLIN));
|
||
}
|
||
|
||
if let Some(_stderr) = &exec_process.stderr {
|
||
let stderr_poll_fd = unsafe { BorrowedFd::borrow_raw(exec_process.stderr_fd) };
|
||
poll_fds_vec.push(PollFd::new(stderr_poll_fd, PollFlags::POLLIN));
|
||
}
|
||
|
||
// 记录poll_fds_vec中的索引
|
||
let stdout_idx = poll_fds_vec.len() - 2;
|
||
let stderr_idx = poll_fds_vec.len() - 1;
|
||
channel_fds_map.insert(*channel_id, (stdout_idx, stderr_idx));
|
||
}
|
||
}
|
||
|
||
if poll_fds_vec.is_empty() {
|
||
return Ok(None); // 没有exec_process
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ OpenSSH关键:使用poll监听所有fd
|
||
// ⭐⭐⭐⭐⭐ 持续poll机制:最多轮询1000次(给大文件传输足够时间)
|
||
// 大文件传输需要很长时间,增加轮询次数到1000次(总共10秒)
|
||
let max_poll_attempts = 1000;
|
||
let mut poll_attempt = 0;
|
||
let mut found_data = false;
|
||
|
||
for attempt in 0..max_poll_attempts {
|
||
poll_attempt = attempt;
|
||
// 每100次轮询记录一次日志(减少日志噪音)
|
||
if attempt % 100 == 0 {
|
||
info!(
|
||
"Polling {} fds (OpenSSH style, timeout 10ms, attempt {} of {})",
|
||
poll_fds_vec.len(),
|
||
attempt,
|
||
max_poll_attempts
|
||
);
|
||
}
|
||
match poll(&mut poll_fds_vec, 10u16) {
|
||
// timeout 10ms
|
||
Ok(n) => {
|
||
if n > 0 {
|
||
info!("{} fds have data available (attempt {})", n, attempt);
|
||
found_data = true;
|
||
break; // 有数据,立即处理
|
||
}
|
||
// 没有数据,继续轮询(最多1000次)
|
||
}
|
||
Err(e) => {
|
||
warn!("poll error: {}", e);
|
||
return Ok(None);
|
||
}
|
||
}
|
||
}
|
||
|
||
if !found_data {
|
||
info!(
|
||
"No data available after {} poll attempts ({} ms), returning None",
|
||
max_poll_attempts,
|
||
max_poll_attempts * 10
|
||
);
|
||
return Ok(None); // 轮询1000次后仍无数据,主循环继续处理client packet
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ OpenSSH风格:根据revents判断哪个fd有数据,立即读取
|
||
let mut packets_data: Vec<(u32, Vec<u8>)> = Vec::new(); // (channel_id, data)
|
||
|
||
for (channel_id, (stdout_idx, stderr_idx)) in channel_fds_map {
|
||
if let Some(channel) = self.channels.get_mut(&channel_id) {
|
||
if let Some(exec_process) = &mut channel.exec_process {
|
||
// 检查stdout是否有数据
|
||
if let Some(revents) = poll_fds_vec[stdout_idx].revents() {
|
||
if revents.contains(PollFlags::POLLIN) {
|
||
info!("stdout fd has data (channel {})", channel_id);
|
||
if let Some(stdout) = &mut exec_process.stdout {
|
||
exec_process.read_buf.resize(32768, 0);
|
||
match stdout.read(&mut exec_process.read_buf) {
|
||
Ok(n) => {
|
||
if n > 0 {
|
||
info!("Read {} bytes from stdout (channel {})", n, channel_id);
|
||
packets_data.push((channel_id, exec_process.read_buf[..n].to_vec()));
|
||
} else {
|
||
info!("stdout EOF (channel {})", channel_id);
|
||
}
|
||
}
|
||
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
|
||
Err(e) => {
|
||
warn!("stdout read error: {}", e);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if let Some(revents) = poll_fds_vec[stderr_idx].revents() {
|
||
if revents.contains(PollFlags::POLLIN) {
|
||
info!("stderr fd has data (channel {})", channel_id);
|
||
if let Some(stderr) = &mut exec_process.stderr {
|
||
exec_process.read_buf.resize(32768, 0);
|
||
match stderr.read(&mut exec_process.read_buf) {
|
||
Ok(n) => {
|
||
if n > 0 {
|
||
info!("Read {} bytes from stderr (channel {})", n, channel_id);
|
||
packets_data.push((channel_id, exec_process.read_buf[..n].to_vec()));
|
||
} else {
|
||
info!("stderr EOF (channel {})", channel_id);
|
||
}
|
||
}
|
||
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
|
||
Err(e) => {
|
||
warn!("stderr read error: {}", e);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// ⭐⭐⭐⭐⭐ 释放mutable borrow后,构建packets(避免borrow冲突)
|
||
let mut packets = Vec::new();
|
||
for (channel_id, data) in packets_data {
|
||
let packet = self.build_channel_data(channel_id, &data)?;
|
||
packets.push(packet);
|
||
}
|
||
|
||
if packets.is_empty() {
|
||
Ok(None)
|
||
} else {
|
||
Ok(Some(packets))
|
||
}
|
||
}
|
||
|
||
fn extract_dest_path_from_command(command: &str, home_dir: &PathBuf) -> Option<PathBuf> {
|
||
if command.contains("scp") {
|
||
if command.contains("scp -t") {
|
||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||
for part in parts.iter().rev() {
|
||
if !part.starts_with("-") && *part != "scp" && *part != "-t" {
|
||
return Some(home_dir.join(part));
|
||
}
|
||
}
|
||
}
|
||
} else if command.contains("rsync") {
|
||
if command.contains("--server") {
|
||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||
for part in parts.iter().rev() {
|
||
if !part.starts_with("-") && !part.contains("--") && *part != "rsync" && *part != "--server" && *part != "--sender" {
|
||
return Some(home_dir.join(part));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
None
|
||
}
|
||
}
|
||
|
||
/// SSH Channel结构(参考OpenSSH channel.c: struct channel)
|
||
struct Channel {
|
||
server_channel: u32,
|
||
sender_channel: u32,
|
||
channel_type: String,
|
||
|
||
// ⭐⭐⭐⭐⭐ Phase 15: Window Control(参考OpenSSH channels.h:176-182)
|
||
remote_window: u32, // 远端窗口大小(OpenSSH: c->remote_window)
|
||
remote_maxpacket: u32, // 远端最大 packet(OpenSSH: c->remote_maxpacket)
|
||
local_window: u32, // 本地窗口大小(OpenSSH: c->local_window)
|
||
local_window_max: u32, // 本地窗口最大值(OpenSSH: c->local_window_max)
|
||
local_consumed: u32, // 本地已消费的数据(OpenSSH: c->local_consumed)⭐⭐⭐⭐⭐ 关键!
|
||
local_maxpacket: u32, // 本地最大 packet(OpenSSH: c->local_maxpacket)
|
||
|
||
// 旧字段(保留兼容)
|
||
window_size: u32, // 当前窗口大小(兼容旧代码)
|
||
maximum_packet_size: u32, // 最大 packet 大小(兼容旧代码)
|
||
|
||
state: ChannelState,
|
||
output_buffer: Option<Vec<u8>>, // Phase 6: 命令输出缓冲
|
||
sftp_handler: Option<SftpHandler>, // Phase 7: SFTP处理器
|
||
scp_handler: Option<ScpHandler>, // Phase 8: SCP处理器
|
||
rsync_handler: Option<RsyncHandler>, // Phase 8: rsync处理器
|
||
exec_process: Option<ExecProcess>, // Phase 14: 交互式exec进程
|
||
// ⭐⭐⭐⭐⭐ Critical修复:SFTP packet累积buffer
|
||
sftp_input_buffer: Vec<u8>, // Phase 14.2修复:累积不完整的SFTP packets
|
||
// ⭐⭐⭐⭐⭐ Phase 14.4:SCP packet累积buffer
|
||
scp_input_buffer: Vec<u8>, // Phase 14.4修复:累积不完整的SCP packets
|
||
// Phase 13.3: 端口转发相关字段
|
||
direct_tcpip: Option<DirectTcpipChannel>, // direct-tcpip channel(Remote forwarding)
|
||
forwarded_tcpip: Option<ForwardedTcpipChannel>, // forwarded-tcpip channel(Local forwarding)
|
||
}
|
||
|
||
/// SSH Channel状态(参考OpenSSH channel.c)
|
||
enum ChannelState {
|
||
Open,
|
||
Closing,
|
||
Closed,
|
||
}
|
||
|
||
/// SSH string读取辅助函数
|
||
fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
|
||
let length = reader.read_u32::<BigEndian>()?;
|
||
let mut buffer = vec![0u8; length as usize];
|
||
reader.read_exact(&mut buffer)?;
|
||
Ok(String::from_utf8(buffer)?)
|
||
}
|
||
|
||
/// ⭐⭐⭐⭐⭐ Phase 15: 检查并发送 Window Adjust(参考OpenSSH channels.c:2425-2450)
|
||
///
|
||
/// OpenSSH 实现:
|
||
/// ```c
|
||
/// static int channel_check_window(struct ssh *ssh, Channel *c) {
|
||
/// if (c->type == SSH_CHANNEL_OPEN &&
|
||
/// !(c->flags & (CHAN_CLOSE_SENT|CHAN_CLOSE_RCVD)) &&
|
||
/// ((c->local_window_max - c->local_window > c->local_maxpacket*3) ||
|
||
/// c->local_window < c->local_window_max/2) &&
|
||
/// c->local_consumed > 0) {
|
||
///
|
||
/// // 发送 SSH2_MSG_CHANNEL_WINDOW_ADJUST
|
||
/// sshpkt_start(ssh, SSH2_MSG_CHANNEL_WINDOW_ADJUST);
|
||
/// sshpkt_put_u32(ssh, c->remote_id);
|
||
/// sshpkt_put_u32(ssh, c->local_consumed);
|
||
/// sshpkt_send(ssh);
|
||
///
|
||
/// c->local_window += c->local_consumed;
|
||
/// c->local_consumed = 0;
|
||
/// }
|
||
/// }
|
||
/// ```
|
||
pub fn channel_check_window(
|
||
channel_id: u32,
|
||
channels: &mut HashMap<u32, Channel>,
|
||
) -> Option<SshPacket> {
|
||
if let Some(channel) = channels.get_mut(&channel_id) {
|
||
// 检查窗口调整条件
|
||
let window_used = channel.local_window_max - channel.local_window;
|
||
let need_adjust = (window_used > channel.local_maxpacket * 3)
|
||
|| (channel.local_window < channel.local_window_max / 2);
|
||
|
||
if need_adjust && channel.local_consumed > 0 {
|
||
info!("⭐⭐⭐⭐⭐ [WINDOW_ADJUST] channel {} needs adjust: window_used={}, local_consumed={}",
|
||
channel_id, window_used, channel.local_consumed);
|
||
|
||
// 发送 SSH_MSG_CHANNEL_WINDOW_ADJUST
|
||
let adjust_packet = build_window_adjust(channel.server_channel, channel.local_consumed);
|
||
|
||
// 更新窗口大小
|
||
channel.local_window += channel.local_consumed;
|
||
channel.local_consumed = 0;
|
||
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ [WINDOW_UPDATED] channel {} new window: {}",
|
||
channel_id, channel.local_window
|
||
);
|
||
|
||
return Some(adjust_packet);
|
||
}
|
||
}
|
||
|
||
None
|
||
}
|
||
|
||
/// ⭐⭐⭐⭐⭐ Phase 15: 构建 SSH_MSG_CHANNEL_WINDOW_ADJUST packet
|
||
///
|
||
/// OpenSSH packet format:
|
||
/// ```c
|
||
/// SSH2_MSG_CHANNEL_WINDOW_ADJUST (93)
|
||
/// recipient_channel (u32)
|
||
/// bytes_to_add (u32)
|
||
/// ```
|
||
fn build_window_adjust(recipient_channel: u32, bytes_to_add: u32) -> SshPacket {
|
||
let mut payload = Vec::new();
|
||
|
||
// Packet type
|
||
payload.push(PacketType::SSH_MSG_CHANNEL_WINDOW_ADJUST as u8);
|
||
|
||
// recipient_channel (u32)
|
||
payload.write_u32::<BigEndian>(recipient_channel).unwrap();
|
||
|
||
// bytes_to_add (u32)
|
||
payload.write_u32::<BigEndian>(bytes_to_add).unwrap();
|
||
|
||
info!(
|
||
"⭐⭐⭐⭐⭐ [BUILD_WINDOW_ADJUST] recipient_channel={}, bytes_to_add={}",
|
||
recipient_channel, bytes_to_add
|
||
);
|
||
|
||
SshPacket {
|
||
packet_length: 0,
|
||
padding_length: 0,
|
||
payload,
|
||
padding: Vec::new(),
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_channel_manager_creation() {
|
||
let manager = ChannelManager::new(PathBuf::from("/tmp"), None, "test_user".to_string());
|
||
assert_eq!(manager.next_channel_id, 0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_channel_open_confirmation() {
|
||
let manager = ChannelManager::new(PathBuf::from("/tmp"), None, "test_user".to_string());
|
||
let packet = manager
|
||
.build_channel_open_confirmation(0, 100, 2097152, 32768)
|
||
.unwrap();
|
||
|
||
assert_eq!(
|
||
packet.payload[0],
|
||
PacketType::SSH_MSG_CHANNEL_OPEN_CONFIRMATION as u8
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_channel_success() {
|
||
let manager = ChannelManager::new(PathBuf::from("/tmp"), None, "test_user".to_string());
|
||
let packet = manager.build_channel_success(0).unwrap();
|
||
|
||
assert_eq!(packet.payload[0], PacketType::SSH_MSG_CHANNEL_SUCCESS as u8);
|
||
}
|
||
}
|