VFS/DataProvider/Config refactoring + SSH public key authentication
Some checks failed
Test / test (push) Has been cancelled
Test / build (push) Has been cancelled

Phase 1-6 of refactoring plan:
- VFS abstraction (VfsBackend trait + LocalFs + OpenFlags builder)
- DataProvider trait (SqliteProvider + PgProvider, SFTPGo-compatible)
- Config refactoring (AppConfig unified sections, env overrides)
- SSH handlers (sftp/scp/rsync) migrated to VFS + DataProvider
- SSH public key authentication (Ed25519 signature verification)
- SSH stderr → CHANNEL_EXTENDED_DATA support
- Web auth uses DataProvider instead of direct SQL
- User home directory from provider (per-user isolation)
- PostgreSQL auth provider for SFTPGo compatibility
This commit is contained in:
Warren
2026-06-18 23:35:18 +08:00
parent 83fb0de78a
commit f90e4f496c
25 changed files with 2039 additions and 612 deletions

View File

@@ -1,70 +1,58 @@
// SSH认证协议实现Phase 5
// 参考OpenSSH auth.c, auth-passwd.c
use crate::ssh_server::packet::{SshPacket, PacketType};
use std::io::{Read, Write}; // 导入Write traitOpenSSH标准
use std::io::Write;
use anyhow::{Result, anyhow};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use log::{info, warn, debug};
use rusqlite::{Connection, params};
use bcrypt::{verify, DEFAULT_COST};
use base64::{Engine as _, engine::general_purpose}; // Phase 9: Base64 for authorized_keys
use base64::{Engine as _, engine::general_purpose};
use ed25519_dalek::{VerifyingKey, Signature};
use crate::provider::{DataProvider, ProviderError};
/// SSH认证处理器参考OpenSSH auth2.c
pub struct AuthHandler {
db_path: String, // SQLite数据库路径
provider: Box<dyn DataProvider>,
}
impl AuthHandler {
/// 创建认证处理器
pub fn new() -> Result<Self> {
let db_path = "data/auth.sqlite".to_string();
// 验证数据库是否存在
let conn = Connection::open(&db_path)?;
drop(conn); // rusqlite会自动关闭
info!("AuthHandler initialized with database: {}", db_path);
Ok(Self { db_path })
pub fn new(provider: Box<dyn DataProvider>) -> Self {
info!("AuthHandler initialized with DataProvider");
Self { provider }
}
/// 获取用户home目录SFTPGo兼容
pub fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
self.provider.get_home_dir(username)
}
/// 处理SSH_MSG_USERAUTH_REQUEST参考OpenSSH auth2.c: userauth_request()
pub fn handle_userauth_request(&mut self, packet: &SshPacket) -> Result<AuthResult> {
pub fn handle_userauth_request(&mut self, packet: &SshPacket, session_id: &[u8]) -> Result<AuthResult> {
info!("Processing SSH_MSG_USERAUTH_REQUEST");
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
// Packet type
let packet_type = cursor.read_u8()?;
if packet_type != PacketType::SSH_MSG_USERAUTH_REQUEST as u8 {
return Err(anyhow!("Invalid packet type for USERAUTH_REQUEST"));
}
// 读取用户名SSH string
let user = read_ssh_string(&mut cursor)?;
// 读取服务名称SSH string
let service = read_ssh_string(&mut cursor)?;
// 读取认证方法名称SSH string
let method = read_ssh_string(&mut cursor)?;
info!("Auth request: user={}, service={}, method={}", user, service, method);
// 检查服务名称OpenSSH要求ssh-connection
if service != "ssh-connection" {
warn!("Unsupported service: {}", service);
return Ok(AuthResult::Failure("Unsupported service".to_string()));
}
// 根据认证方法处理参考OpenSSH auth2.c
if method == "password" {
self.handle_password_auth(&mut cursor, &user)
} else if method == "publickey" {
self.handle_publickey_auth(&mut cursor, &user)
self.handle_publickey_auth(&mut cursor, &user, &service, session_id)
} else if method == "none" {
// OpenSSHnone认证总是失败用于查询支持的认证方法
// 返回支持的认证方法列表password, publickey
warn!("None auth request - returning supported methods");
Ok(AuthResult::Failure("password,publickey".to_string()))
} else {
@@ -72,203 +60,254 @@ impl AuthHandler {
Ok(AuthResult::Failure("Unsupported auth method".to_string()))
}
}
/// 处理password认证参考OpenSSH auth-passwd.c
fn handle_password_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result<AuthResult> {
info!("Handling password auth for user: {}", user);
// 读取是否修改密码标志booleanOpenSSH password认证格式
let change_password = cursor.read_u8()? != 0;
if change_password {
warn!("Password change not supported");
return Ok(AuthResult::Failure("Password change not supported".to_string()));
}
// 读取密码SSH string
let password = read_ssh_string(cursor)?;
debug!("Password auth attempt: user={}, password length={}", user, password.len());
// 查询数据库获取password_hash
let conn = Connection::open(&self.db_path)?;
let password_hash_result = conn.query_row(
"SELECT password_hash FROM sftpgo_users WHERE username = ?1 AND status = 1",
params![user],
|row| row.get::<_, String>(0)
);
// 关闭连接rusqlite会自动关闭
drop(conn);
// 验证用户是否存在
let password_hash = match password_hash_result {
Ok(hash) => Some(hash),
Err(rusqlite::Error::QueryReturnedNoRows) => None,
Err(e) => return Err(anyhow!("Database query error: {}", e)),
};
if password_hash.is_none() {
warn!("User not found or disabled: {}", user);
// SSH_MSG_USERAUTH_FAILURE必须返回可继续使用的认证方法列表RFC 4253
return Ok(AuthResult::Failure("password,publickey".to_string()));
}
// 使用bcrypt验证密码
let stored_hash = password_hash.unwrap();
info!("Attempting bcrypt verify: password='{}', hash='{}'", password, stored_hash);
let valid = verify(&password, &stored_hash)?;
info!("bcrypt verify result: {}", valid);
if valid {
info!("Password auth successful for user: {}", user);
Ok(AuthResult::Success)
} else {
warn!("Password auth failed for user: {}", user);
// SSH_MSG_USERAUTH_FAILURE必须返回可继续使用的认证方法列表RFC 4253
Ok(AuthResult::Failure("password,publickey".to_string()))
match self.provider.check_password(user, &password) {
Ok(true) => {
info!("Password auth successful for user: {}", user);
Ok(AuthResult::Success)
}
Ok(false) => {
warn!("Password auth failed for user: {}", user);
Ok(AuthResult::Failure("password,publickey".to_string()))
}
Err(ProviderError::NotFound(msg)) => {
warn!("User not found: {}", msg);
Ok(AuthResult::Failure("password,publickey".to_string()))
}
Err(e) => {
Err(anyhow!("Password auth error: {}", e))
}
}
}
/// 构建SSH_MSG_USERAUTH_SUCCESS packet参考OpenSSH auth2.c
/// 构建SSH_MSG_USERAUTH_SUCCESS packet
pub fn build_userauth_success() -> Result<SshPacket> {
let payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8];
Ok(SshPacket::new(payload))
}
/// 构建SSH_MSG_USERAUTH_FAILURE packet参考OpenSSH auth2.c
/// 构建SSH_MSG_USERAUTH_FAILURE packet
pub fn build_userauth_failure(methods: &[String], partial_success: bool) -> Result<SshPacket> {
let mut payload = Vec::new();
// Packet type
payload.write_u8(PacketType::SSH_MSG_USERAUTH_FAILURE as u8)?;
// 认证方法列表SSH string逗号分隔
let methods_str = methods.join(",");
payload.write_u32::<BigEndian>(methods_str.len() as u32)?;
payload.write_all(methods_str.as_bytes())?;
// partial_success标志boolean
payload.write_u8(if partial_success { 1 } else { 0 })?;
Ok(SshPacket::new(payload))
}
/// 构建SSH_MSG_USERAUTH_BANNER packet可选参考OpenSSH auth2.c
/// 构建SSH_MSG_USERAUTH_BANNER packet
pub fn build_userauth_banner(message: &str, language: &str) -> Result<SshPacket> {
let mut payload = Vec::new();
// Packet type
payload.write_u8(PacketType::SSH_MSG_USERAUTH_BANNER as u8)?;
// Banner messageSSH string
payload.write_u32::<BigEndian>(message.len() as u32)?;
payload.write_all(message.as_bytes())?;
// Language tagSSH string
payload.write_u32::<BigEndian>(language.len() as u32)?;
payload.write_all(language.as_bytes())?;
Ok(SshPacket::new(payload))
}
/// 处理publickey认证Phase 9参考OpenSSH auth2-pubkey.c
fn handle_publickey_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result<AuthResult> {
/// 处理publickey认证RFC 4252 §7
/// 支持Ed25519签名验证 + 数据库/filesystem密钥查找
fn handle_publickey_auth(
&mut self,
cursor: &mut std::io::Cursor<&[u8]>,
user: &str,
service: &str,
session_id: &[u8],
) -> Result<AuthResult> {
info!("Handling publickey auth for user: {}", user);
// 读取是否签名的标志boolean
let is_signed = cursor.read_u8()? != 0;
// 读取public key algorithmSSH string
let algorithm = read_ssh_string(cursor)?;
// 读取public key blobSSH string
let public_key_blob = read_ssh_string_bytes(cursor)?;
info!("Publickey auth: algorithm={}, blob_len={}, is_signed={}", algorithm, public_key_blob.len(), is_signed);
// Phase 9简化实现 - 从authorized_keys文件验证
let authorized_keys_path = format!("data/{}/authorized_keys", user);
let authorized_keys = match std::fs::read_to_string(&authorized_keys_path) {
Ok(content) => content,
Err(_) => {
// 尝试默认路径
let default_path = "data/authorized_keys";
match std::fs::read_to_string(default_path) {
Ok(content) => content,
Err(_) => {
warn!("No authorized_keys file found for user: {}", user);
return Ok(AuthResult::Failure("password,publickey".to_string()));
}
}
}
};
// 解析authorized_keys查找匹配的public key
let public_key_matches = authorized_keys.lines().any(|line| {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
return false;
}
// SSH authorized_keys格式algorithm base64-key comment
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() < 2 {
return false;
}
let key_algorithm = parts[0];
let key_base64 = parts[1];
// 匹配algorithm
if key_algorithm != algorithm {
return false;
}
// 匹配public key blobbase64解码对比
match base64_decode(key_base64) {
Ok(decoded_key) => decoded_key == public_key_blob,
Err(_) => false,
}
});
if !public_key_matches {
if !self.is_key_authorized(user, &algorithm, &public_key_blob)? {
warn!("Public key not authorized for user: {}", user);
return Ok(AuthResult::Failure("password,publickey".to_string()));
}
info!("Public key authorized for user: {}", user);
// 如果没有签名返回PK_OKquery阶段
if !is_signed {
// SSH_MSG_USERAUTH_PK_OK表示public key可接受client需要发送签名
return Ok(AuthResult::PublicKeyOk(algorithm, public_key_blob));
}
// 读取signatureSSH string
let signature = read_ssh_string_bytes(cursor)?;
info!("Verifying signature for user: {}", user);
// Phase 9简化签名验证 - 信任authorized_keys
// 完整实现需要提取session_id, 构建signed_data, verify signature
// 这里简化处理只要public key匹配authorized_keys就接受
let signature_blob = read_ssh_string_bytes(cursor)?;
self.verify_signature(&algorithm, &public_key_blob, &signature_blob, user, service, session_id)?;
info!("Publickey auth successful for user: {}", user);
Ok(AuthResult::Success)
}
/// 检查public key是否在授权列表中数据库优先fallback到filesystem
fn is_key_authorized(&self, user: &str, algorithm: &str, public_key_blob: &[u8]) -> Result<bool> {
// 1. 先检查数据库
match self.provider.get_public_keys(user) {
Ok(keys) => {
for key_line in &keys {
if public_key_matches_line(key_line, algorithm, public_key_blob) {
return Ok(true);
}
}
}
Err(e) => warn!("Failed to get public keys from provider: {}", e),
}
// 2. Fallback到filesystem
let authorized_keys_path = format!("data/{}/authorized_keys", user);
let content = match std::fs::read_to_string(&authorized_keys_path) {
Ok(c) => c,
Err(_) => match std::fs::read_to_string("data/authorized_keys") {
Ok(c) => c,
Err(_) => return Ok(false),
}
};
Ok(content.lines().any(|line| public_key_matches_line(line, algorithm, public_key_blob)))
}
/// 验证Ed25519签名RFC 4252 §7
fn verify_signature(
&self,
algorithm: &str,
public_key_blob: &[u8],
signature_blob: &[u8],
user: &str,
service: &str,
session_id: &[u8],
) -> Result<()> {
// 目前只支援Ed25519
if algorithm != "ssh-ed25519" {
return Err(anyhow!("Unsupported public key algorithm: {}", algorithm));
}
let verifying_key = parse_ed25519_verifying_key(public_key_blob)?;
let signature = parse_ed25519_signature(signature_blob)?;
// 建立签名验证数据RFC 4252 §7
let mut signed_data = Vec::new();
// string session identifier
signed_data.write_u32::<BigEndian>(session_id.len() as u32)?;
signed_data.write_all(session_id)?;
// byte SSH_MSG_USERAUTH_REQUEST
signed_data.write_u8(PacketType::SSH_MSG_USERAUTH_REQUEST as u8)?;
// string user name
signed_data.write_u32::<BigEndian>(user.len() as u32)?;
signed_data.write_all(user.as_bytes())?;
// string service name
signed_data.write_u32::<BigEndian>(service.len() as u32)?;
signed_data.write_all(service.as_bytes())?;
// string "publickey"
const PUBKEY_STR: &str = "publickey";
signed_data.write_u32::<BigEndian>(PUBKEY_STR.len() as u32)?;
signed_data.write_all(PUBKEY_STR.as_bytes())?;
// boolean TRUE
signed_data.write_u8(1)?;
// string public key algorithm name
signed_data.write_u32::<BigEndian>(algorithm.len() as u32)?;
signed_data.write_all(algorithm.as_bytes())?;
// string public key blob
signed_data.write_u32::<BigEndian>(public_key_blob.len() as u32)?;
signed_data.write_all(public_key_blob)?;
// 验证签名
verifying_key.verify_strict(&signed_data, &signature)
.map_err(|e| anyhow!("Ed25519 signature verification failed: {}", e))
}
}
/// SSH认证结果参考OpenSSH auth2.c
/// SSH认证结果
pub enum AuthResult {
Success,
Failure(String), // 失败原因
PartialSuccess, // 部分成功(多步骤认证)
PublicKeyOk(String, Vec<u8>), // Public key acceptable (algorithm, blob)
Failure(String),
PartialSuccess,
PublicKeyOk(String, Vec<u8>),
}
/// 解析Ed25519公钥blobSSH格式 -> VerifyingKey
fn parse_ed25519_verifying_key(public_key_blob: &[u8]) -> Result<VerifyingKey> {
let mut cursor = std::io::Cursor::new(public_key_blob);
let algorithm = read_ssh_string(&mut cursor)?;
if algorithm != "ssh-ed25519" {
return Err(anyhow!("Unsupported algorithm: {}", algorithm));
}
let key_bytes = read_ssh_string_bytes(&mut cursor)?;
if key_bytes.len() != 32 {
return Err(anyhow!("Invalid Ed25519 key length: {}", key_bytes.len()));
}
let key_array: [u8; 32] = key_bytes.try_into()
.map_err(|_| anyhow!("Invalid Ed25519 key data"))?;
VerifyingKey::from_bytes(&key_array)
.map_err(|e| anyhow!("Invalid Ed25519 key: {}", e))
}
/// 解析Ed25519签名blobSSH格式 -> Signature
fn parse_ed25519_signature(signature_blob: &[u8]) -> Result<Signature> {
let mut cursor = std::io::Cursor::new(signature_blob);
let algorithm = read_ssh_string(&mut cursor)?;
if algorithm != "ssh-ed25519" {
return Err(anyhow!("Unsupported signature algorithm: {}", algorithm));
}
let sig_bytes = read_ssh_string_bytes(&mut cursor)?;
if sig_bytes.len() != 64 {
return Err(anyhow!("Invalid Ed25519 signature length: {}", sig_bytes.len()));
}
let sig_array: [u8; 64] = sig_bytes.try_into()
.map_err(|_| anyhow!("Invalid Ed25519 signature data"))?;
Ok(Signature::from_bytes(&sig_array))
}
/// 检查一行authorized_keys格式的密钥是否匹配
fn public_key_matches_line(line: &str, algorithm: &str, public_key_blob: &[u8]) -> bool {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
return false;
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() < 2 {
return false;
}
if parts[0] != algorithm {
return false;
}
base64_decode(parts[1]).map(|decoded| decoded == public_key_blob).unwrap_or(false)
}
/// SSH string读取辅助函数
fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
let length = reader.read_u32::<BigEndian>()?;
let mut buffer = vec![0u8; length as usize];
@@ -276,7 +315,6 @@ fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
Ok(String::from_utf8(buffer)?)
}
/// SSH string读取辅助函数bytes版本
fn read_ssh_string_bytes<R: std::io::Read>(reader: &mut R) -> Result<Vec<u8>> {
let length = reader.read_u32::<BigEndian>()?;
let mut buffer = vec![0u8; length as usize];
@@ -284,9 +322,7 @@ fn read_ssh_string_bytes<R: std::io::Read>(reader: &mut R) -> Result<Vec<u8>> {
Ok(buffer)
}
/// Base64解码辅助函数Phase 9
fn base64_decode(input: &str) -> Result<Vec<u8>> {
use base64::{Engine as _, engine::general_purpose};
general_purpose::STANDARD.decode(input)
.map_err(|e| anyhow!("Base64 decode error: {}", e))
}
@@ -294,18 +330,19 @@ fn base64_decode(input: &str) -> Result<Vec<u8>> {
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::sqlite::SqliteProvider;
#[test]
fn test_userauth_success_packet() {
let packet = AuthHandler::build_userauth_success().unwrap();
assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_SUCCESS as u8);
}
#[test]
fn test_userauth_failure_packet() {
let methods = vec!["password".to_string(), "publickey".to_string()];
let packet = AuthHandler::build_userauth_failure(&methods, false).unwrap();
assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_FAILURE as u8);
}
}

View File

@@ -25,6 +25,8 @@ pub struct ChannelManager {
next_channel_id: u32,
/// ⭐⭐⭐⭐⭐ Phase 15.1: 待发送packet队列用于同时发送WINDOW_ADJUST和SFTP响应
pub pending_packets: VecDeque<SshPacket>,
/// 用户home目录SFTP/SCP/rsync根目录SFTPGo兼容
pub home_dir: PathBuf,
}
/// Phase 14: 交互式Exec进程管理参考OpenSSH session.c: do_exec_no_pty
@@ -40,11 +42,12 @@ pub struct ExecProcess {
}
impl ChannelManager {
pub fn new() -> Self {
pub fn new(home_dir: PathBuf) -> Self {
Self {
channels: HashMap::new(),
next_channel_id: 0,
pending_packets: VecDeque::new(),
home_dir,
}
}
@@ -371,9 +374,12 @@ impl ChannelManager {
info!("⭐⭐⭐⭐⭐ [{}_EXEC_START] Starting interactive process: {}", process_type, command);
// 启动子进程相当于OpenSSH fork
// ⭐⭐⭐⭐⭐ Phase 17: 设置工作目录为用户home_dirSFTPGo兼容
let home_dir = self.home_dir.clone();
let mut child = Command::new("sh")
.arg("-c")
.arg(command)
.current_dir(&home_dir)
.stdin(Stdio::piped()) // ← 创建stdin管道相当于pipe(pin)
.stdout(Stdio::piped()) // ← 创建stdout管道相当于pipe(pout)
.stderr(Stdio::piped()) // ← 创建stderr管道相当于pipe(perr)
@@ -446,8 +452,8 @@ impl ChannelManager {
if subsystem == "sftp" {
info!("SFTP subsystem requested");
// Phase 7: 初始化SFTP handler
let root_dir = PathBuf::from("/Users/accusys/markbase"); // 默认root目录
// Phase 7: 初始化SFTP handler使用用户home目录SFTPGo兼容
let root_dir = self.home_dir.clone();
// ⭐⭐⭐⭐⭐ Phase 4: 获取 client maxpack 限制(从 Channel 中获取)
let maxpacket = if let Some(ch) = self.channels.get(&channel) {
@@ -456,7 +462,8 @@ impl ChannelManager {
32768 // OpenSSH 默认值32KB
};
let sftp_handler = SftpHandler::new(root_dir, maxpacket); // ⭐⭐⭐⭐⭐ Phase 4: 传入 maxpack
let vfs = Box::new(crate::vfs::local_fs::LocalFs::new());
let sftp_handler = SftpHandler::new(root_dir, vfs, maxpacket); // ⭐⭐⭐⭐⭐ Phase 4: 传入 maxpack
// 存储到channel
if let Some(ch) = self.channels.get_mut(&channel) {
@@ -952,6 +959,22 @@ impl ChannelManager {
false
}
/// Phase 17: 关闭所有子进程stdin收到CHANNEL_EOF时调用
/// SCP upload需要scp -t 等待EOF on stdin才知道数据传输完毕
pub fn close_child_stdin(&mut self) {
let channel_ids: Vec<u32> = self.channels.keys().copied().collect();
for id in channel_ids {
if let Some(channel) = self.channels.get_mut(&id) {
if let Some(exec) = &mut channel.exec_process {
if let Some(stdin) = exec.stdin.take() {
drop(stdin);
info!("⭐⭐⭐⭐⭐ [CHANNEL_EOF] Closed child stdin (channel {})", id);
}
}
}
}
}
/// 获取channel输出Phase 6新增
pub fn get_channel_output(&mut self, channel_id: u32) -> Option<Vec<u8>> {
if let Some(channel) = self.channels.get_mut(&channel_id) {
@@ -1283,6 +1306,7 @@ impl ChannelManager {
// 4. 检查stdout/stderr fd是否有数据
let mut packets_data: Vec<(u32, Vec<u8>)> = Vec::new();
let mut stderr_packets: Vec<(u32, Vec<u8>)> = Vec::new(); // Phase 17: stderr → CHANNEL_EXTENDED_DATA
for (channel_id, (stdout_idx, stderr_idx)) in channel_fds_map {
if let Some(channel) = self.channels.get_mut(&channel_id) {
@@ -1325,7 +1349,8 @@ impl ChannelManager {
Ok(n) if n > 0 => {
info!("⭐⭐⭐⭐⭐ [AFTER stderr.read] Read {} bytes from stderr (channel {})", n, channel_id);
info!("⭐⭐⭐⭐⭐ stderr content: {:?}", &buffer[..std::cmp::min(50, n)]);
packets_data.push((channel_id, buffer[..n].to_vec()));
// ⭐⭐⭐⭐⭐ Phase 17: stderr → SSH_MSG_CHANNEL_EXTENDED_DATA (data_type=1)
stderr_packets.push((channel_id, buffer[..n].to_vec()));
}
Ok(0) => {
info!("stderr EOF (channel {}), closing stderr pipe", channel_id);
@@ -1351,12 +1376,17 @@ impl ChannelManager {
}
// 构建packets
if !packets_data.is_empty() {
if !packets_data.is_empty() || !stderr_packets.is_empty() {
let mut packets = Vec::new();
for (channel_id, data) in packets_data {
let packet = self.build_channel_data(channel_id, &data)?;
packets.push(packet);
}
// Phase 17: stderr → SSH_MSG_CHANNEL_EXTENDED_DATA (data_type=1)
for (channel_id, data) in stderr_packets {
let packet = self.build_channel_extended_data(channel_id, 1, &data)?;
packets.push(packet);
}
info!("⭐⭐⭐⭐⭐ Returning {} packets (stdout/stderr data)", packets.len());
return Ok((Some(packets), client_has_data, child_exited));
}
@@ -1689,13 +1719,13 @@ mod tests {
#[test]
fn test_channel_manager_creation() {
let manager = ChannelManager::new();
let manager = ChannelManager::new(PathBuf::from("/tmp"));
assert_eq!(manager.next_channel_id, 0);
}
#[test]
fn test_channel_open_confirmation() {
let manager = ChannelManager::new();
let manager = ChannelManager::new(PathBuf::from("/tmp"));
let packet = manager.build_channel_open_confirmation(0, 100, 2097152, 32768).unwrap();
assert_eq!(packet.payload[0], PacketType::SSH_MSG_CHANNEL_OPEN_CONFIRMATION as u8);
@@ -1703,7 +1733,7 @@ mod tests {
#[test]
fn test_channel_success() {
let manager = ChannelManager::new();
let manager = ChannelManager::new(PathBuf::from("/tmp"));
let packet = manager.build_channel_success(0).unwrap();
assert_eq!(packet.payload[0], PacketType::SSH_MSG_CHANNEL_SUCCESS as u8);

View File

@@ -17,6 +17,7 @@ type HmacSha256 = Hmac<Sha256>;
/// SSH加密通道管理器参考OpenSSH struct sshcipher_ctx
pub struct EncryptionContext {
pub session_id: Vec<u8>, // session identifier (exchange hash)
pub encryption_key_ctos: Vec<u8>, // 客户端→服务器加密密钥
pub encryption_key_stoc: Vec<u8>, // 服务器→客户端加密密钥
pub mac_key_ctos: Vec<u8>, // 客户端→服务器MAC密钥
@@ -32,6 +33,7 @@ pub struct EncryptionContext {
impl Default for EncryptionContext {
fn default() -> Self {
Self {
session_id: vec![0u8; 32],
encryption_key_ctos: vec![0u8; 32],
encryption_key_stoc: vec![0u8; 32],
mac_key_ctos: vec![0u8; 32],
@@ -73,6 +75,7 @@ impl EncryptionContext {
info!("Ciphers initialized successfully");
Self {
session_id: keys.session_id.clone(),
encryption_key_ctos: keys.encryption_key_ctos.clone(),
encryption_key_stoc: keys.encryption_key_stoc.clone(),
mac_key_ctos: keys.mac_key_ctos.clone(),

View File

@@ -1,8 +1,8 @@
use std::path::PathBuf;
use std::fs::{self, File};
use std::io::Write;
use anyhow::{Result, anyhow};
use log::{info, debug, warn};
use crate::vfs::{VfsBackend, VfsFile, VfsError};
use crate::vfs::open_flags::OpenFlags;
/// MPLEX_BASE from rsync io.h
const MPLEX_BASE: u32 = 7;
@@ -27,23 +27,21 @@ pub(crate) enum RsyncState {
pub struct RsyncHandler {
state: RsyncState,
/// Raw input from SSH (multiplexed after version exchange)
raw_input: Vec<u8>,
/// Decoded rsync protocol data (after stripping multiplex)
rsync_input: Vec<u8>,
/// Raw rsync data to send (multiplex wrapping applied in drain_output)
output_raw: Vec<u8>,
dest_path: PathBuf,
output_file: Option<File>,
output_file: Option<Box<dyn VfsFile>>,
total_written: u64,
file_entries: Vec<String>,
current_file: usize,
protocol_version: u32,
multiplex: bool,
vfs: Box<dyn VfsBackend>,
}
impl RsyncHandler {
pub fn parse_rsync_command(command: &str) -> Result<Self> {
pub fn parse_rsync_command(command: &str, vfs: Box<dyn VfsBackend>) -> Result<Self> {
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.len() < 3 || parts[0] != "rsync" {
return Err(anyhow!("Invalid rsync command: {}", command));
@@ -83,9 +81,9 @@ impl RsyncHandler {
current_file: 0,
protocol_version: 30,
multiplex: false,
vfs,
};
// Send protocol version (4-byte LE int, no multiplex)
handler.output_raw.extend_from_slice(&30u32.to_le_bytes());
handler.state = RsyncState::WaitVersion;
@@ -129,7 +127,6 @@ impl RsyncHandler {
}
MSG_DONE => {
info!("rsync: MSG_DONE received (file complete)");
// Signal file completion by appending a sentinel to rsync_input
self.rsync_input.extend_from_slice(b"RSYNCDONE");
}
9 => {
@@ -147,7 +144,6 @@ impl RsyncHandler {
if data.is_empty() || !self.multiplex {
return data;
}
// Wrap with multiplex header (MSG_DATA)
let header = (MPLEX_BASE << 24) | (data.len() as u32);
let mut wrapped = Vec::with_capacity(4 + data.len());
wrapped.extend_from_slice(&header.to_le_bytes());
@@ -180,7 +176,6 @@ impl RsyncHandler {
loop {
match self.state.clone() {
RsyncState::SendVersion => {
// Version already sent in constructor
self.transition(RsyncState::WaitVersion);
}
@@ -206,7 +201,6 @@ impl RsyncHandler {
let flags = self.rsync_input[0];
if flags == 0 {
// End of file list
self.rsync_input.drain(..1);
info!("rsync: file list end ({} entries)", self.file_entries.len());
@@ -214,14 +208,12 @@ impl RsyncHandler {
self.file_entries.push("file".to_string());
}
self.current_file = 0;
// Enter sum head reading state
self.transition(RsyncState::ReadSumHead { need: 20 });
break;
}
let mut pos = 1;
// Extended flags
let _more_flags = if flags & 0x80 != 0 {
if self.rsync_input.len() <= pos { break; }
let ef = self.rsync_input[pos];
@@ -249,7 +241,6 @@ impl RsyncHandler {
self.file_entries.push(name);
}
// Skip metadata varints
let skip_count = if flags & 0x10 == 0 { 1 } else { 0 }
+ if flags & 0x20 == 0 { 1 } else { 0 }
+ if flags & 0x40 == 0 { 1 } else { 0 }
@@ -277,9 +268,6 @@ impl RsyncHandler {
RsyncState::ReadSumHead { need } => {
if self.rsync_input.len() >= need {
// Read sum head: count, blength, s2length, remainder (4 × LE int)
// + checksum seed (1 × LE int)
// = 5 × 4 = 20 bytes
let sum_count = i32::from_le_bytes([
self.rsync_input[0], self.rsync_input[1],
self.rsync_input[2], self.rsync_input[3],
@@ -312,7 +300,6 @@ impl RsyncHandler {
RsyncState::SendSumCount => {
self.open_current_file()?;
// Send sum_count = 0 (4-byte LE int = we have no existing data)
self.output_raw.extend_from_slice(&0u32.to_le_bytes());
info!("rsync: sent sum_count=0, ready to receive file data");
@@ -320,22 +307,17 @@ impl RsyncHandler {
}
RsyncState::ReadFileData => {
// Data comes as raw bytes inside MSG_DATA multiplex packets.
// MSG_DONE appends b"RSYNCDONE" to rsync_input.
let done_marker = b"RSYNCDONE";
if let Some(pos) = self.rsync_input.windows(done_marker.len())
.position(|w| w == done_marker)
{
// Data before the marker
if pos > 0 {
let data = self.rsync_input[..pos].to_vec();
self.rsync_input.drain(..pos);
self.write_to_file(&data)?;
}
// Remove marker
self.rsync_input.drain(..done_marker.len());
// Close file
if let Some(mut file) = self.output_file.take() {
if let Err(e) = file.flush() {
warn!("rsync flush error: {}", e);
@@ -353,11 +335,9 @@ impl RsyncHandler {
info!("rsync ALL DONE: {} bytes written to {}",
self.total_written, self.dest_path.display());
} else {
// Next file sum head
self.transition(RsyncState::ReadSumHead { need: 20 });
}
} else if !self.rsync_input.is_empty() {
// Partial data, keep it in buffer for more
let data = self.rsync_input.clone();
self.rsync_input.clear();
self.write_to_file(&data)?;
@@ -377,9 +357,11 @@ impl RsyncHandler {
fn open_current_file(&mut self) -> Result<()> {
if let Some(parent) = self.dest_path.parent() {
fs::create_dir_all(parent).ok();
self.vfs.create_dir_all(parent, 0o755).ok();
}
let file = File::create(&self.dest_path)?;
let flags = OpenFlags::new().write().create().truncate();
let file = self.vfs.open_file(&self.dest_path, &flags)
.map_err(|e| anyhow!("open error: {}", e))?;
self.output_file = Some(file);
info!("rsync: opened {} for writing", self.dest_path.display());
Ok(())
@@ -387,7 +369,8 @@ impl RsyncHandler {
fn write_to_file(&mut self, data: &[u8]) -> Result<()> {
if let Some(file) = &mut self.output_file {
file.write_all(data)?;
file.write_all(data)
.map_err(|e| anyhow!("write error: {}", e))?;
self.total_written += data.len() as u64;
}
Ok(())
@@ -426,28 +409,37 @@ fn read_varint(buf: &[u8]) -> Option<(i32, usize)> {
#[cfg(test)]
mod tests {
use super::*;
use crate::vfs::local_fs::LocalFs;
fn make_vfs() -> Box<dyn VfsBackend> {
Box::new(LocalFs::new())
}
#[test]
fn test_parse_command() {
let h = RsyncHandler::parse_rsync_command("rsync --server -g -l -o -p -D -r -t -v --dirs . /tmp/upload.bin").unwrap();
let h = RsyncHandler::parse_rsync_command(
"rsync --server -g -l -o -p -D -r -t -v --dirs . /tmp/upload.bin",
make_vfs()
).unwrap();
assert_eq!(h.dest_path, PathBuf::from("/tmp/upload.bin"));
}
#[test]
fn test_parse_command_sender() {
let h = RsyncHandler::parse_rsync_command("rsync --server --sender -vlogDtprz . /home/user/file.txt").unwrap();
let h = RsyncHandler::parse_rsync_command(
"rsync --server --sender -vlogDtprz . /home/user/file.txt",
make_vfs()
).unwrap();
assert_eq!(h.dest_path, PathBuf::from("/home/user/file.txt"));
}
#[test]
fn test_version_exchange() {
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin").unwrap();
// Initial output: protocol version (30 as LE int)
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap();
let output = h.drain_output();
assert_eq!(output, b"\x1e\x00\x00\x00");
assert_eq!(h.state, RsyncState::WaitVersion);
// Client sends its version (30 = 0x1E)
h.feed(b"\x1e\x00\x00\x00").unwrap();
assert_eq!(h.state, RsyncState::ReadFileList);
assert!(h.multiplex);
@@ -455,9 +447,8 @@ mod tests {
#[test]
fn test_version_negotiate_down() {
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin").unwrap();
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap();
let _ = h.drain_output();
// Client has lower version (29)
h.feed(b"\x1d\x00\x00\x00").unwrap();
assert_eq!(h.protocol_version, 29);
assert_eq!(h.state, RsyncState::ReadFileList);
@@ -471,24 +462,14 @@ mod tests {
buf
}
fn build_multiplex_done() -> Vec<u8> {
let header = (MPLEX_BASE << 24) | 0u32; // MSG_DONE (tag=1 → raw_tag=8)
let mut buf = Vec::new();
buf.extend_from_slice(&header.to_le_bytes());
buf
}
#[test]
fn test_file_list_multiplex() {
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin").unwrap();
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap();
let _ = h.drain_output();
// Version exchange
h.feed(b"\x1e\x00\x00\x00").unwrap();
assert!(h.multiplex);
// Build file list with multiplex wrapping
let mut flist = Vec::new();
// Entry: flags=0, name="test.txt\0", + 6 varints
flist.push(0);
flist.extend_from_slice(b"test.txt");
flist.push(0);
@@ -514,46 +495,40 @@ mod tests {
}
}
}
write_varint(&mut flist, 33188); // mode
write_varint(&mut flist, 501); // uid
write_varint(&mut flist, 20); // gid
write_varint(&mut flist, 1700000000); // time
write_varint(&mut flist, 100); // size
write_varint(&mut flist, 0); // checksum seed
// End marker
write_varint(&mut flist, 33188);
write_varint(&mut flist, 501);
write_varint(&mut flist, 20);
write_varint(&mut flist, 1700000000);
write_varint(&mut flist, 100);
write_varint(&mut flist, 0);
flist.push(0);
// Sum head (5 ints = 20 bytes) as separate multiplex packet
let mut sum_head = Vec::new();
sum_head.extend_from_slice(&0i32.to_le_bytes()); // count
sum_head.extend_from_slice(&7000i32.to_le_bytes()); // blength
sum_head.extend_from_slice(&2i32.to_le_bytes()); // s2length
sum_head.extend_from_slice(&100i32.to_le_bytes()); // remainder
sum_head.extend_from_slice(&42i32.to_le_bytes()); // checksum_seed
sum_head.extend_from_slice(&0i32.to_le_bytes());
sum_head.extend_from_slice(&7000i32.to_le_bytes());
sum_head.extend_from_slice(&2i32.to_le_bytes());
sum_head.extend_from_slice(&100i32.to_le_bytes());
sum_head.extend_from_slice(&42i32.to_le_bytes());
// Feed file list
h.feed(&build_multiplex(&flist)).unwrap();
assert_eq!(h.state, RsyncState::ReadFileList); // Still reading, 0x00 end marker triggered transition
assert_eq!(h.state, RsyncState::ReadFileList);
assert_eq!(h.file_entries.len(), 1);
// Now feed sum head
h.feed(&build_multiplex(&sum_head)).unwrap();
assert_eq!(h.state, RsyncState::SendSumCount);
// Send sum count response
let sum_resp = h.drain_output();
assert_eq!(sum_resp.len(), 8); // 4-byte header + 4-byte int
assert_eq!(sum_resp.len(), 8);
assert_eq!(&sum_resp[4..8], &0u32.to_le_bytes());
assert_eq!(h.state, RsyncState::ReadFileData);
}
#[test]
fn test_file_data_multiplex() {
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin").unwrap();
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap();
let _ = h.drain_output();
h.feed(b"\x1e\x00\x00\x00").unwrap(); // version
h.feed(b"\x1e\x00\x00\x00").unwrap();
// Simple file list
let mut flist = Vec::new();
flist.push(0);
flist.extend_from_slice(b"test.bin");
@@ -568,7 +543,6 @@ mod tests {
flist.push(0);
h.feed(&build_multiplex(&flist)).unwrap();
// Sum head
let mut sh = Vec::new();
sh.extend_from_slice(&0i32.to_le_bytes());
sh.extend_from_slice(&7000i32.to_le_bytes());
@@ -576,16 +550,13 @@ mod tests {
sh.extend_from_slice(&100i32.to_le_bytes());
sh.extend_from_slice(&42i32.to_le_bytes());
h.feed(&build_multiplex(&sh)).unwrap();
let _ = h.drain_output(); // sum count response
let _ = h.drain_output();
// File data + MSG_DONE
let file_data = b"Hello, rsync protocol!";
h.feed(&build_multiplex(file_data)).unwrap();
assert_eq!(h.state, RsyncState::ReadFileData);
// MSG_DONE
// MSG_DONE has tag=1, so raw_tag = MPLEX_BASE + 1 = 8
let done_header = (MPLEX_BASE + 1) << 24; // raw_tag = 8, len = 0
let done_header = (MPLEX_BASE + 1) << 24;
let done_bytes = done_header.to_le_bytes();
h.feed(&done_bytes).unwrap();

View File

@@ -1,12 +1,13 @@
// SCP协议实现Phase 8
// 参考OpenSSH scp.c源码
use crate::vfs::{VfsBackend, VfsFile, VfsError, VfsStat};
use crate::vfs::open_flags::OpenFlags;
use anyhow::{Result, anyhow};
use log::{info, warn, debug};
use std::path::{Path, PathBuf};
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Write, BufReader, BufWriter, BufRead}; // 导入BufRead traitOpenSSH标准
use chrono::{DateTime, Utc};
use std::io::{Read, Write, BufRead};
use std::time::SystemTime;
/// SCP Handler参考OpenSSH scp.c
pub struct ScpHandler {
@@ -14,6 +15,7 @@ pub struct ScpHandler {
mode: ScpMode,
recursive: bool,
preserve_times: bool,
vfs: Box<dyn VfsBackend>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -23,24 +25,25 @@ pub enum ScpMode {
}
impl ScpHandler {
pub fn new(root_dir: PathBuf) -> Self {
pub fn new(root_dir: PathBuf, vfs: Box<dyn VfsBackend>) -> Self {
Self {
root_dir,
mode: ScpMode::Destination,
recursive: false,
preserve_times: false,
vfs,
}
}
/// 解析SCP命令参考OpenSSH scp.c: parse_command()
pub fn parse_scp_command(command: &str) -> Result<Self> {
pub fn parse_scp_command(command: &str, vfs: Box<dyn VfsBackend>) -> Result<Self> {
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.len() < 2 || parts[0] != "scp" {
return Err(anyhow!("Invalid SCP command: {}", command));
}
let mut handler = ScpHandler::new(PathBuf::from("/tmp"));
let mut handler = ScpHandler::new(PathBuf::from("/tmp"), vfs);
for part in &parts[1..] {
match part {
@@ -68,19 +71,19 @@ impl ScpHandler {
/// SCP Source Modescp -f发送文件
fn handle_source_mode(&self, channel: &mut dyn ReadWrite) -> Result<()> {
info!("SCP source mode: sending files from {}", self.root_dir.display()); // 使用display()Rust标准
info!("SCP source mode: sending files from {}", self.root_dir.display());
let full_path = self.resolve_path(&self.root_dir.to_string_lossy())?;
let stat = self.vfs.stat(&full_path)
.map_err(|e| anyhow!("stat error: {}", e))?;
if full_path.is_file() {
self.send_file(channel, &full_path)?;
} else if full_path.is_dir() {
if stat.is_dir {
if !self.recursive {
return Err(anyhow!("Directory detected but -r flag not specified"));
}
self.send_directory(channel, &full_path)?;
} else {
return Err(anyhow!("Path does not exist: {}", full_path.display()));
self.send_file(channel, &full_path)?;
}
Ok(())
@@ -88,9 +91,8 @@ impl ScpHandler {
/// SCP Destination Modescp -t接收文件
fn handle_destination_mode(&mut self, channel: &mut dyn ReadWrite) -> Result<()> {
info!("SCP destination mode: receiving files to {}", self.root_dir.display()); // 使用display()Rust标准
info!("SCP destination mode: receiving files to {}", self.root_dir.display());
// 发送确认('\0'
channel.write_all(&[0])?;
channel.flush()?;
@@ -99,10 +101,9 @@ impl ScpHandler {
loop {
buffer.clear();
// 每次循环创建新的reader避免borrow冲突- OpenSSH标准
let mut reader = BufReader::new(&mut *channel);
let mut reader = std::io::BufReader::new(&mut *channel);
match reader.read_line(&mut buffer)? {
0 => break, // EOF
0 => break,
_ => {
let command = buffer.trim();
debug!("SCP command: {}", command);
@@ -113,7 +114,6 @@ impl ScpHandler {
Some('E') => self.handle_end_directory(channel)?,
Some('T') => self.handle_time_command(channel, command)?,
Some('\0') => {
// 确认信号,继续
continue;
}
_ => {
@@ -130,28 +130,30 @@ impl ScpHandler {
/// 发送文件参考OpenSSH scp.c: source()
fn send_file(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
let metadata = fs::metadata(path)?;
let size = metadata.len();
let stat = self.vfs.stat(path)
.map_err(|e| anyhow!("stat error: {}", e))?;
let size = stat.size;
let filename = path.file_name().unwrap().to_string_lossy();
// 发送文件命令C0644 size filename
let command = format!("C0644 {} {}\n", size, filename);
channel.write_all(command.as_bytes())?;
channel.flush()?;
// 等待确认('\0'
let mut ack = [0u8; 1];
channel.read_exact(&mut ack)?;
if ack[0] != 0 {
return Err(anyhow!("SCP file command rejected"));
}
// 发送文件内容
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let flags = OpenFlags::new().read();
let mut file = self.vfs.open_file(path, &flags)
.map_err(|e| anyhow!("open error: {}", e))?;
let mut buffer = vec![0u8; 8192];
while let Ok(n) = reader.read(&mut buffer) {
loop {
let n = file.read(&mut buffer)
.map_err(|e| anyhow!("read error: {}", e))?;
if n == 0 {
break;
}
@@ -160,11 +162,9 @@ impl ScpHandler {
channel.flush()?;
// 发送结束确认('\0'
channel.write_all(&[0])?;
channel.flush()?;
// 等待确认('\0'
channel.read_exact(&mut ack)?;
if ack[0] != 0 {
return Err(anyhow!("SCP file transfer rejected"));
@@ -178,35 +178,34 @@ impl ScpHandler {
fn send_directory(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
let dirname = path.file_name().unwrap().to_string_lossy();
// 发送目录命令D0755 0 dirname
let command = format!("D0755 0 {}\n", dirname);
channel.write_all(command.as_bytes())?;
channel.flush()?;
// 等待确认('\0'
let mut ack = [0u8; 1];
channel.read_exact(&mut ack)?;
if ack[0] != 0 {
return Err(anyhow!("SCP directory command rejected"));
}
// 递归发送目录内容
for entry in fs::read_dir(path)? {
let entry = entry?;
let full_path = entry.path();
let entries = self.vfs.read_dir(path)
.map_err(|e| anyhow!("read_dir error: {}", e))?;
if full_path.is_file() {
self.send_file(channel, &full_path)?;
} else if full_path.is_dir() && self.recursive {
self.send_directory(channel, &full_path)?;
for entry in &entries {
let entry_path = path.join(&entry.name);
if entry.stat.is_dir {
if self.recursive {
self.send_directory(channel, &entry_path)?;
}
} else {
self.send_file(channel, &entry_path)?;
}
}
// 发送结束目录命令E
channel.write_all("E\n".as_bytes())?;
channel.flush()?;
// 等待确认('\0'
channel.read_exact(&mut ack)?;
if ack[0] != 0 {
return Err(anyhow!("SCP end directory rejected"));
@@ -224,31 +223,25 @@ impl ScpHandler {
return self.send_error(channel, "Invalid file command format");
}
let mode = parts[0].trim_start_matches('C');
let mode_str = parts[0].trim_start_matches('C');
let size: u64 = parts[1].parse()?;
let filename = parts[2];
debug!("SCP receive file: mode={}, size={}, name={}", mode, size, filename);
debug!("SCP receive file: mode={}, size={}, name={}", mode_str, size, filename);
// 安全性检查文件大小限制防止DoS
if size > 1024 * 1024 * 1024 { // 1GB限制
if size > 1024 * 1024 * 1024 {
return self.send_error(channel, "File too large (max 1GB)");
}
// 创建文件
let full_path = self.resolve_path(filename)?;
let file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&full_path)?;
// 发送确认('\0'
let flags = OpenFlags::new().write().create().truncate();
let mut file = self.vfs.open_file(&full_path, &flags)
.map_err(|e| anyhow!("open error: {}", e))?;
channel.write_all(&[0])?;
channel.flush()?;
// 接收文件内容
let mut writer = BufWriter::new(file);
let mut buffer = vec![0u8; 8192];
let mut remaining = size;
@@ -258,25 +251,25 @@ impl ScpHandler {
if n == 0 {
break;
}
writer.write_all(&buffer[..n])?;
file.write_all(&buffer[..n])
.map_err(|e| anyhow!("write error: {}", e))?;
remaining -= n as u64;
}
writer.flush()?;
file.flush().map_err(|e| anyhow!("flush error: {}", e))?;
// 设置文件权限
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mode_int: u32 = mode.parse()?;
fs::set_permissions(&full_path, fs::Permissions::from_mode(mode_int))?;
let mode_int: u32 = mode_str.parse()?;
if mode_int != 0 {
let mut set_stat = VfsStat::new();
set_stat.mode = mode_int;
self.vfs.set_stat(&full_path, &set_stat)
.map_err(|e| anyhow!("set_stat error: {}", e))?;
}
// 接收结束确认('\0'
let mut ack = [0u8; 1];
channel.read_exact(&mut ack)?;
// 发送确认('\0'
channel.write_all(&[0])?;
channel.flush()?;
@@ -296,24 +289,17 @@ impl ScpHandler {
return self.send_error(channel, "Recursive flag not specified");
}
let mode = parts[0].trim_start_matches('D');
let mode_str = parts[0].trim_start_matches('D');
let dirname = parts[2];
debug!("SCP receive directory: mode={}, name={}", mode, dirname);
debug!("SCP receive directory: mode={}, name={}", mode_str, dirname);
// 创建目录
let full_path = self.resolve_path(dirname)?;
fs::create_dir_all(&full_path)?;
// 设置目录权限
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mode_int: u32 = mode.parse()?;
fs::set_permissions(&full_path, fs::Permissions::from_mode(mode_int))?;
}
let mode_int: u32 = mode_str.parse()?;
self.vfs.create_dir_all(&full_path, mode_int)
.map_err(|e| anyhow!("create_dir_all error: {}", e))?;
// 发送确认('\0'
channel.write_all(&[0])?;
channel.flush()?;
@@ -325,7 +311,6 @@ impl ScpHandler {
fn handle_end_directory(&self, channel: &mut dyn ReadWrite) -> Result<()> {
debug!("SCP end directory");
// 发送确认('\0'
channel.write_all(&[0])?;
channel.flush()?;
@@ -335,7 +320,6 @@ impl ScpHandler {
/// 处理时间命令T mtime atime
fn handle_time_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
if !self.preserve_times {
// 发送确认('\0'),但不设置时间
channel.write_all(&[0])?;
channel.flush()?;
return Ok(());
@@ -347,18 +331,14 @@ impl ScpHandler {
return self.send_error(channel, "Invalid time command format");
}
let mtime: i64 = parts[1].parse()?;
let atime: i64 = parts[2].parse()?;
let mtime_secs: i64 = parts[1].parse()?;
let atime_secs: i64 = parts[2].parse()?;
debug!("SCP set times: mtime={}, atime={}", mtime, atime);
debug!("SCP set times: mtime={}, atime={}", mtime_secs, atime_secs);
// 发送确认('\0'
channel.write_all(&[0])?;
channel.flush()?;
// 时间设置将在文件接收完成后进行
// 这里仅记录实际设置在handle_file_command中
Ok(())
}
@@ -374,10 +354,13 @@ impl ScpHandler {
fn resolve_path(&self, path: &str) -> Result<PathBuf> {
let full_path = self.root_dir.join(path);
let canonical_path = full_path.canonicalize()
let canonical_path = self.vfs.real_path(&full_path)
.map_err(|e| anyhow!("Path resolution error: {}", e))?;
if !canonical_path.starts_with(&self.root_dir.canonicalize()?) {
let root_canonical = self.vfs.real_path(&self.root_dir)
.map_err(|e| anyhow!("Root path resolution error: {}", e))?;
if !canonical_path.starts_with(&root_canonical) {
return Err(anyhow!("Path traversal attempt detected"));
}
@@ -392,23 +375,28 @@ impl<T: Read + Write> ReadWrite for T {}
#[cfg(test)]
mod tests {
use super::*;
use crate::vfs::local_fs::LocalFs;
fn make_handler() -> ScpHandler {
ScpHandler::new(PathBuf::from("/tmp"), Box::new(LocalFs::new()))
}
#[test]
fn test_scp_command_parse() {
let handler = ScpHandler::parse_scp_command("scp -t /tmp").unwrap();
let handler = ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap();
assert_eq!(handler.mode, ScpMode::Destination);
assert_eq!(handler.root_dir, PathBuf::from("/tmp"));
}
#[test]
fn test_scp_recursive_parse() {
let handler = ScpHandler::parse_scp_command("scp -r -t /tmp").unwrap();
let handler = ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap();
assert!(handler.recursive);
}
#[test]
fn test_scp_source_parse() {
let handler = ScpHandler::parse_scp_command("scp -f /tmp").unwrap();
let handler = ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap();
assert_eq!(handler.mode, ScpMode::Source);
}
}
}

View File

@@ -6,6 +6,9 @@ use crate::ssh_server::packet::{SshPacket, PacketType};
use crate::ssh_server::kex::{KexResult, KexProposal};
use crate::ssh_server::kex_complete::{KexState};
use crate::ssh_server::auth::{AuthHandler, AuthResult};
use crate::provider::sqlite::SqliteProvider;
use crate::provider::pg::PgProvider;
use crate::provider::DataProvider;
use crate::ssh_server::channel::{ChannelManager};
use crate::ssh_server::cipher::{EncryptionContext, EncryptedPacket};
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1
@@ -13,6 +16,7 @@ use crate::ssh_server::port_forward::PortForwardManager; // Phase 13
use anyhow::{Result, anyhow};
use log::{info, warn, error, debug};
use std::net::{TcpListener, TcpStream};
use std::path::PathBuf;
use std::thread;
use std::io::{Read, Write};
use std::sync::{Arc, Mutex}; // Phase 13: 端口转发线程同步
@@ -22,6 +26,7 @@ pub struct SshServerConfig {
pub port: u16,
pub bind_address: String,
pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置
pub pg_conn: Option<String>, // PostgreSQL连接字符串SFTPGo兼容认证
}
impl Default for SshServerConfig {
@@ -30,6 +35,7 @@ impl Default for SshServerConfig {
port: 2024,
bind_address: "127.0.0.1".to_string(),
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1
pg_conn: None,
}
}
}
@@ -42,6 +48,7 @@ impl SshServerConfig {
port: 2024,
bind_address: "127.0.0.1".to_string(),
security_config: config,
pg_conn: None,
})
}
}
@@ -73,6 +80,7 @@ impl SshServer {
self.config.security_config.max_sessions);
let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置
let pg_conn = self.config.pg_conn.clone();
for stream in listener.incoming() {
match stream {
@@ -81,9 +89,10 @@ impl SshServer {
info!("New SSH connection from {}", client_addr);
let security_config_clone = security_config.clone(); // Phase 13.1
let pg_conn_clone = pg_conn.clone();
thread::spawn(move || {
if let Err(e) = handle_connection_complete(stream, security_config_clone) { // Phase 13.1
if let Err(e) = handle_connection_complete(stream, security_config_clone, pg_conn_clone) { // Phase 13.1
error!("Connection error: {}", e);
}
});
@@ -99,7 +108,7 @@ impl SshServer {
}
/// 处理完整SSH连接Phase 1-13完整流程
fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshSecurityConfig>>) -> Result<()> {
fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshSecurityConfig>>, pg_conn: Option<String>) -> Result<()> {
info!("Handling client connection (Phase 1-13 complete flow with port forwarding)");
// Phase 13.1: 增加活动会话数
@@ -122,13 +131,22 @@ fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshS
let mut encryption_ctx = perform_complete_kex_exchange(&mut stream, client_version.clone(), kex_result, server_kexinit, client_kexinit)?;
info!("Key exchange completed, encryption channel ready");
// Phase 5: SSH认证参考OpenSSH auth2.c
let mut auth_handler = AuthHandler::new()?;
// Phase 5: SSH认证SFTPGo兼容 — PostgreSQL或SQLite
let provider: Box<dyn DataProvider> = if let Some(ref conn_str) = pg_conn {
info!("Using PostgreSQL auth provider (SFTPGo-compatible): {}", conn_str);
Box::new(PgProvider::new(conn_str)
.map_err(|e| anyhow!("Failed to init PgProvider: {}", e))?)
} else {
info!("Using SQLite auth provider");
Box::new(SqliteProvider::new("data/auth.sqlite")
.map_err(|e| anyhow!("Failed to init SqliteProvider: {}", e))?)
};
let mut auth_handler = AuthHandler::new(provider);
let auth_user = perform_ssh_auth(&mut stream, &mut auth_handler, &mut encryption_ctx)?;
info!("SSH authentication succeeded: user={}", auth_user);
info!("SSH authentication succeeded: user={}", auth_user.username);
// Phase 6: SSH Channel管理参考OpenSSH channel.c
let mut channel_manager = ChannelManager::new();
let mut channel_manager = ChannelManager::new(auth_user.home_dir.clone());
// Phase 13: PortForwardManager初始化
let mut port_forward_manager = PortForwardManager::new();
@@ -226,11 +244,16 @@ fn perform_complete_kex_exchange(
}
/// SSH认证流程Phase 5
pub struct AuthUser {
pub username: String,
pub home_dir: PathBuf,
}
fn perform_ssh_auth(
stream: &mut TcpStream,
auth_handler: &mut AuthHandler,
encryption_ctx: &mut EncryptionContext,
) -> Result<String> {
) -> Result<AuthUser> {
info!("Starting SSH authentication");
info!("Encryption context: key_ctos_len={}, key_stoc_len={}, iv_ctos_len={}, iv_stoc_len={}",
encryption_ctx.encryption_key_ctos.len(),
@@ -279,6 +302,8 @@ fn perform_ssh_auth(
encrypted_accept.write(stream)?;
info!("Sent encrypted SSH_MSG_SERVICE_ACCEPT");
let session_id = encryption_ctx.session_id.clone();
loop {
let auth_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; // Reading from client, use cipher_ctos
let auth_payload = auth_packet.payload();
@@ -286,7 +311,7 @@ fn perform_ssh_auth(
let auth_request = SshPacket::new(auth_payload.to_vec());
match auth_handler.handle_userauth_request(&auth_request)? {
match auth_handler.handle_userauth_request(&auth_request, &session_id)? {
AuthResult::Success => {
let success_payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8];
let encrypted_success = EncryptedPacket::new(
@@ -297,7 +322,16 @@ fn perform_ssh_auth(
encrypted_success.write(stream)?;
info!("Sent encrypted SSH_MSG_USERAUTH_SUCCESS");
return Ok("demo".to_string());
// Extract username from auth request
let user = extract_username_from_auth_request(&auth_request)
.unwrap_or_else(|_| "unknown".to_string());
let home_dir = auth_handler.get_home_dir(&user)
.ok()
.flatten()
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from("/Users/accusys/markbase"));
info!("Auth success: user={}, home_dir={:?}", user, home_dir);
return Ok(AuthUser { username: user, home_dir });
}
AuthResult::Failure(message) => {
// message包含可用的认证方法列表如"password,publickey"
@@ -519,7 +553,9 @@ fn handle_ssh_service_loop(
}
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_EOF as u8 => {
info!("Received SSH_MSG_CHANNEL_EOF");
// EOF means client won't send more data, just acknowledge and continue
// Phase 17: EOF means client won't send more data → close child stdin
// (Essential for SCP upload where scp -t waits for EOF on stdin)
channel_manager.close_child_stdin();
}
Some(&pt) if pt == PacketType::SSH_MSG_DISCONNECT as u8 => {
info!("Received SSH_MSG_DISCONNECT");
@@ -543,12 +579,27 @@ fn handle_ssh_service_loop(
Ok(())
}
/// 从SSH_MSG_USERAUTH_REQUEST payload中提取用户名
fn extract_username_from_auth_request(packet: &crate::ssh_server::packet::SshPacket) -> Result<String> {
let payload = &packet.payload;
if payload.len() < 5 {
return Err(anyhow!("Auth request too short"));
}
let name_len = u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]) as usize;
if payload.len() < 5 + name_len {
return Err(anyhow!("Auth request truncated"));
}
let username = String::from_utf8_lossy(&payload[5..5 + name_len]).to_string();
Ok(username)
}
/// SSH服务器CLI入口
pub fn run_ssh_server(port: Option<u16>) -> Result<()> {
pub fn run_ssh_server(port: Option<u16>, pg_conn: Option<&str>) -> Result<()> {
let config = SshServerConfig {
port: port.unwrap_or(2024),
bind_address: "127.0.0.1".to_string(),
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置
pg_conn: pg_conn.map(|s| s.to_string()),
};
let server = SshServer::new(config);

View File

@@ -2,14 +2,16 @@
// 参考OpenSSH sftp-server.c和draft-ietf-secsh-filexfer-02.txt
use crate::ssh_server::packet::{SshPacket, PacketType};
use crate::vfs::{VfsBackend, VfsFile, VfsDirEntry};
use crate::vfs::open_flags::OpenFlags;
use anyhow::{Result, anyhow, Context};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use log::{info, warn, debug};
use std::path::{Path, PathBuf};
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Write, Seek, SeekFrom};
use std::os::unix::fs::PermissionsExt; // 导入PermissionsExt traitUnix标准
use std::os::unix::fs::MetadataExt; // ⭐⭐⭐⭐⭐ Phase 2.2: 导入MetadataExt trait获取uid/gid
use std::fs;
use std::io::{SeekFrom, Write};
use std::os::unix::fs::PermissionsExt;
use std::os::unix::fs::MetadataExt;
/// SFTP packet类型参考draft-ietf-secsh-filexfer-02.txt
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -178,6 +180,30 @@ impl SftpAttrs {
attrs
}
pub fn from_vfs_stat(stat: &crate::vfs::VfsStat) -> Self {
let mut attrs = Self::new();
attrs.flags = SftpAttrFlags::SSH_FILEXFER_ATTR_SIZE
| SftpAttrFlags::SSH_FILEXFER_ATTR_UIDGID
| SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS
| SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME;
attrs.size = Some(stat.size);
attrs.permissions = Some(stat.mode);
attrs.uid = Some(stat.uid);
attrs.gid = Some(stat.gid);
if let Ok(d) = stat.atime.duration_since(std::time::UNIX_EPOCH) {
attrs.atime = Some(d.as_secs() as u32);
}
if let Ok(d) = stat.mtime.duration_since(std::time::UNIX_EPOCH) {
attrs.mtime = Some(d.as_secs() as u32);
}
attrs
}
pub fn serialize(&self) -> Result<Vec<u8>> {
debug!("Serializing SftpAttrs: flags=0x{:08x}, size={:?}, uid={:?}, gid={:?}, permissions=0x{:08x}, atime={:?}, mtime={:?}",
self.flags, self.size, self.uid, self.gid,
@@ -242,13 +268,12 @@ impl SftpAttrs {
}
/// SFTP handle文件或目录句柄
#[derive(Debug)] // 移除CloneFile/DirEntry不支持Clone
pub struct SftpHandle {
pub id: u32,
pub path: PathBuf,
pub handle_type: SftpHandleType,
pub file: Option<File>,
pub dir_entries: Option<Vec<fs::DirEntry>>,
pub file: Option<Box<dyn VfsFile>>,
pub dir_entries: Option<Vec<VfsDirEntry>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -260,6 +285,7 @@ pub enum SftpHandleType {
/// SFTP处理管理器参考OpenSSH sftp-server.c
pub struct SftpHandler {
root_dir: PathBuf,
vfs: Box<dyn VfsBackend>,
next_handle_id: u32,
handles: std::collections::HashMap<u32, SftpHandle>,
// ⭐⭐⭐⭐⭐ Phase 4: 添加 client maxpack 限制参考OpenSSH sftp-server.c
@@ -277,14 +303,15 @@ impl SftpHandler {
const MAX_HASH_SIZE: u64 = 268_435_456;
// ⭐⭐⭐⭐⭐ Phase 4: 修改 new() 方法,接受 maxpack 参数
pub fn new(root_dir: PathBuf, maxpacket: u32) -> Self {
pub fn new(root_dir: PathBuf, vfs: Box<dyn VfsBackend>, maxpacket: u32) -> Self {
let canonical_root = root_dir.canonicalize().unwrap_or(root_dir);
Self {
root_dir: canonical_root,
vfs,
next_handle_id: 0,
handles: std::collections::HashMap::new(),
maxpacket,
restrict_absolute: false, // 默认允许绝对路径
restrict_absolute: false,
}
}
@@ -360,30 +387,9 @@ impl SftpHandler {
info!("SSH_FXP_OPEN: id={}, path={}, pflags={:#x}", id, path, pflags);
let full_path = self.resolve_path(&path)?;
let flags = OpenFlags::from_sftp_pflags(pflags);
let file_result = if pflags & SftpFileFlags::SSH_FXF_READ != 0 {
OpenOptions::new().read(true).open(&full_path)
} else if pflags & SftpFileFlags::SSH_FXF_WRITE != 0 {
let mut opts = OpenOptions::new();
opts.write(true);
if pflags & SftpFileFlags::SSH_FXF_APPEND != 0 {
opts.append(true);
}
if pflags & SftpFileFlags::SSH_FXF_CREAT != 0 {
opts.create(true);
}
if pflags & SftpFileFlags::SSH_FXF_TRUNC != 0 {
opts.truncate(true);
}
if pflags & SftpFileFlags::SSH_FXF_EXCL != 0 {
opts.create_new(true);
}
opts.open(&full_path)
} else {
return self.build_status_response(id, SftpStatus::SSH_FX_OP_UNSUPPORTED, "Unsupported open flags");
};
match file_result {
match self.vfs.open_file(&full_path, &flags) {
Ok(file) => {
if self.handles.len() >= Self::MAX_HANDLES {
warn!("SSH_FXP_OPEN: handle limit reached ({})", Self::MAX_HANDLES);
@@ -405,7 +411,7 @@ impl SftpHandler {
self.build_handle_response(id, &handle_id.to_be_bytes())
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -447,9 +453,8 @@ impl SftpHandler {
if let Some(handle) = self.handles.get_mut(&handle_id) {
if let Some(ref mut file) = handle.file {
file.seek(SeekFrom::Start(offset))?;
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
// ⭐⭐⭐⭐⭐ Phase 4: 限制数据大小,不超过 maxpacket - 1024 和 MAX_XFER_SIZE
let max_data_size = std::cmp::min(self.maxpacket.saturating_sub(1024), Self::MAX_XFER_SIZE);
let actual_length = std::cmp::min(length, max_data_size);
@@ -465,7 +470,7 @@ impl SftpHandler {
self.build_data_response(id, &buffer)
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
} else {
@@ -491,7 +496,6 @@ impl SftpHandler {
info!("SSH_FXP_WRITE: id={}, handle={}, offset={}, length={}", id, handle_id, offset, write_data.len());
// ⭐⭐⭐⭐⭐ Phase 1.2: 添加 data preview显示前 20 字节)
if write_data.len() > 0 {
let preview_len = std::cmp::min(20, write_data.len());
let preview = &write_data[0..preview_len];
@@ -500,14 +504,15 @@ impl SftpHandler {
if let Some(handle) = self.handles.get_mut(&handle_id) {
if let Some(ref mut file) = handle.file {
file.seek(SeekFrom::Start(offset))?;
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
match file.write_all(&write_data) {
Ok(_) => {
file.flush().ok();
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Write successful")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
} else {
@@ -532,13 +537,13 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match fs::symlink_metadata(&full_path) {
Ok(metadata) => {
let attrs = SftpAttrs::from_metadata(&metadata);
match self.vfs.lstat(&full_path) {
Ok(stat) => {
let attrs = SftpAttrs::from_vfs_stat(&stat);
self.build_attrs_response(id, &attrs)
}
Err(e) => {
self.build_status_response(id, SftpStatus::SSH_FX_NO_SUCH_FILE, &format!("Stat error: {}", e))
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -556,14 +561,26 @@ impl SftpHandler {
info!("SSH_FXP_FSTAT: id={}, handle={}", id, handle_id);
if let Some(handle) = self.handles.get(&handle_id) {
match fs::metadata(&handle.path) {
Ok(metadata) => {
let attrs = SftpAttrs::from_metadata(&metadata);
self.build_attrs_response(id, &attrs)
if let Some(handle) = self.handles.get_mut(&handle_id) {
if let Some(ref mut file) = handle.file {
match file.stat() {
Ok(stat) => {
let attrs = SftpAttrs::from_vfs_stat(&stat);
self.build_attrs_response(id, &attrs)
}
Err(e) => {
self.build_status_from_vfs_error(id, &e)
}
}
Err(e) => {
self.build_status_from_io_error(id, &e)
} else {
match self.vfs.stat(&handle.path) {
Ok(stat) => {
let attrs = SftpAttrs::from_vfs_stat(&stat);
self.build_attrs_response(id, &attrs)
}
Err(e) => {
self.build_status_from_vfs_error(id, &e)
}
}
}
} else {
@@ -585,7 +602,7 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match fs::read_dir(&full_path) {
match self.vfs.read_dir(&full_path) {
Ok(entries) => {
if self.handles.len() >= Self::MAX_HANDLES {
warn!("SSH_FXP_OPENDIR: handle limit reached ({})", Self::MAX_HANDLES);
@@ -594,14 +611,12 @@ impl SftpHandler {
let handle_id = self.next_handle_id;
self.next_handle_id += 1;
let dir_entries: Vec<fs::DirEntry> = entries.filter_map(|e| e.ok()).collect();
let handle = SftpHandle {
id: handle_id,
path: full_path,
handle_type: SftpHandleType::Directory,
file: None,
dir_entries: Some(dir_entries),
dir_entries: Some(entries),
};
self.handles.insert(handle_id, handle);
@@ -609,7 +624,7 @@ impl SftpHandler {
self.build_handle_response(id, &handle_id.to_be_bytes())
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -635,11 +650,9 @@ impl SftpHandler {
} else {
let entries: Vec<(String, SftpAttrs)> = dir_entries
.drain(..std::cmp::min(100, dir_entries.len()))
.filter_map(|entry| {
let name = entry.file_name().to_string_lossy().to_string();
let attrs = entry.metadata().ok()?;
let sftp_attrs = SftpAttrs::from_metadata(&attrs);
Some((name, sftp_attrs))
.map(|entry| {
let attrs = SftpAttrs::from_vfs_stat(&entry.stat);
(entry.name, attrs)
})
.collect();
@@ -670,12 +683,12 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match fs::remove_file(&full_path) {
match self.vfs.remove_file(&full_path) {
Ok(_) => {
self.build_status_response(id, SftpStatus::SSH_FX_OK, "File removed")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -695,12 +708,12 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match fs::create_dir(&full_path) {
match self.vfs.create_dir(&full_path, 0o755) {
Ok(_) => {
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Directory created")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -719,12 +732,12 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match fs::remove_dir(&full_path) {
match self.vfs.remove_dir(&full_path) {
Ok(_) => {
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Directory removed")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -765,13 +778,13 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match fs::metadata(&full_path) {
Ok(metadata) => {
let attrs = SftpAttrs::from_metadata(&metadata);
match self.vfs.stat(&full_path) {
Ok(stat) => {
let attrs = SftpAttrs::from_vfs_stat(&stat);
self.build_attrs_response(id, &attrs)
}
Err(e) => {
self.build_status_response(id, SftpStatus::SSH_FX_NO_SUCH_FILE, &format!("Stat error: {}", e))
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -792,12 +805,12 @@ impl SftpHandler {
let old_full_path = self.resolve_path(&old_path)?;
let new_full_path = self.resolve_path(&new_path)?;
match fs::rename(&old_full_path, &new_full_path) {
match self.vfs.rename(&old_full_path, &new_full_path) {
Ok(_) => {
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Rename successful")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -832,7 +845,7 @@ impl SftpHandler {
info!("SSH_FXP_FSETSTAT: id={}, handle={}, attrs.flags={}", id, handle_id, attrs.flags);
let handle = self.handles.get(&handle_id);
let handle = self.handles.get_mut(&handle_id);
if handle.is_none() {
return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid handle");
}
@@ -847,25 +860,35 @@ impl SftpHandler {
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_SIZE != 0 {
if let Some(size) = attrs.size {
info!("FSETSTAT: setting file size to {}", size);
let file = OpenOptions::new().write(true).open(&path)?;
file.set_len(size)?;
if let Some(ref mut file) = handle.file {
file.set_len(size).map_err(|e| anyhow!("set_len error: {}", e))?;
} else {
let flags = OpenFlags::new().write();
if let Ok(mut f) = self.vfs.open_file(&path, &flags) {
f.set_len(size).ok();
}
}
}
}
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0 {
if let Some(permissions) = attrs.permissions {
info!("FSETSTAT: setting permissions to {:o}", permissions);
fs::set_permissions(&path, fs::Permissions::from_mode(permissions))?;
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0
|| attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0
{
let mut vfs_stat = crate::vfs::VfsStat::new();
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0 {
vfs_stat.mode = attrs.permissions.unwrap_or(0);
} else {
if let Ok(s) = self.vfs.lstat(&path) {
vfs_stat.mode = s.mode;
}
}
}
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0 {
if let (Some(atime), Some(mtime)) = (attrs.atime, attrs.mtime) {
info!("FSETSTAT: setting atime={}, mtime={}", atime, mtime);
let atime_filetime = filetime::FileTime::from_unix_time(atime as i64, 0);
let mtime_filetime = filetime::FileTime::from_unix_time(mtime as i64, 0);
filetime::set_file_times(&path, atime_filetime, mtime_filetime)?;
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0 {
if let (Some(atime), Some(mtime)) = (attrs.atime, attrs.mtime) {
vfs_stat.atime = std::time::UNIX_EPOCH + std::time::Duration::from_secs(atime as u64);
vfs_stat.mtime = std::time::UNIX_EPOCH + std::time::Duration::from_secs(mtime as u64);
}
}
self.vfs.set_stat(&path, &vfs_stat).ok();
}
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Fsetstat successful")
@@ -885,13 +908,13 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match fs::read_link(&full_path) {
match self.vfs.read_link(&full_path) {
Ok(link_target) => {
let target = link_target.to_string_lossy().to_string();
self.build_name_response(id, vec![(target, SftpAttrs::default())])
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -912,18 +935,14 @@ impl SftpHandler {
let full_linkpath = self.resolve_path(&linkpath)?;
let full_targetpath = self.resolve_path(&targetpath)?;
#[cfg(unix)]
match std::os::unix::fs::symlink(&full_targetpath, &full_linkpath) {
match self.vfs.create_symlink(&full_targetpath, &full_linkpath) {
Ok(_) => {
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Symlink created")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
#[cfg(not(unix))]
self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Symlink not supported on non-Unix systems")
}
/// 处理SSH_FXP_EXTENDEDPhase 10参考OpenSSH sftp-server.c: process_extended())
@@ -984,50 +1003,30 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
#[cfg(unix)]
{
use std::os::unix::fs::MetadataExt;
match fs::metadata(&full_path) {
Ok(metadata) => {
// 构建statvfs response参考OpenSSH sftp-server.c
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
// f_bsize文件系统块大小
response.write_u64::<BigEndian>(4096)?;
// f_frsize基本块大小
response.write_u64::<BigEndian>(4096)?;
// f_blocks总块数
response.write_u64::<BigEndian>(1000000)?;
// f_bfree空闲块数
response.write_u64::<BigEndian>(500000)?;
// f_bavail可用块数
response.write_u64::<BigEndian>(500000)?;
// f_files总文件数
response.write_u64::<BigEndian>(100000)?;
// f_ffree空闲文件数
response.write_u64::<BigEndian>(50000)?;
// f_favail可用文件数
response.write_u64::<BigEndian>(50000)?;
// f_fsid文件系统ID
response.write_u64::<BigEndian>(0)?;
// f_flag标志
response.write_u64::<BigEndian>(0)?;
// f_namemax文件名最大长度
response.write_u64::<BigEndian>(255)?;
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_from_io_error(id, &e)
}
match self.vfs.stat(&full_path) {
Ok(_) => {
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
response.write_u64::<BigEndian>(4096)?;
response.write_u64::<BigEndian>(4096)?;
response.write_u64::<BigEndian>(1000000)?;
response.write_u64::<BigEndian>(500000)?;
response.write_u64::<BigEndian>(500000)?;
response.write_u64::<BigEndian>(100000)?;
response.write_u64::<BigEndian>(50000)?;
response.write_u64::<BigEndian>(50000)?;
response.write_u64::<BigEndian>(0)?;
response.write_u64::<BigEndian>(0)?;
response.write_u64::<BigEndian>(255)?;
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_from_vfs_error(id, &e)
}
}
#[cfg(not(unix))]
self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "statvfs not supported on non-Unix systems")
}
/// 处理fstatvfs@openssh.com扩展文件句柄统计
@@ -1073,18 +1072,14 @@ impl SftpHandler {
let full_oldpath = self.resolve_path(&oldpath)?;
let full_newpath = self.resolve_path(&newpath)?;
#[cfg(unix)]
match fs::hard_link(&full_oldpath, &full_newpath) {
match self.vfs.hard_link(&full_oldpath, &full_newpath) {
Ok(_) => {
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Hardlink created")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
#[cfg(not(unix))]
self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Hardlink not supported on non-Unix systems")
}
/// 处理posix-rename@openssh.com扩展POSIX语义重命名
@@ -1097,12 +1092,12 @@ impl SftpHandler {
let full_oldpath = self.resolve_path(&oldpath)?;
let full_newpath = self.resolve_path(&newpath)?;
match fs::rename(&full_oldpath, &full_newpath) {
match self.vfs.rename(&full_oldpath, &full_newpath) {
Ok(_) => {
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Posix rename successful")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -1122,34 +1117,31 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match File::open(&full_path) {
let flags = OpenFlags::new().read();
match self.vfs.open_file(&full_path, &flags) {
Ok(mut file) => {
file.seek(SeekFrom::Start(offset))?;
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
let mut buffer = vec![0u8; actual_length as usize];
file.read_exact(&mut buffer)?;
file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
// 计算MD5哈希
let hash = md5::compute(&buffer);
let hash_hex = format!("{:x}", hash);
// 构建响应
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
// hash-algorithm (SSH string)
response.write_u32::<BigEndian>(4)?;
response.write_all("md5".as_bytes())?;
// hash-value (SSH string)
response.write_u32::<BigEndian>(hash_hex.len() as u32)?;
response.write_all(hash_hex.as_bytes())?;
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -1169,37 +1161,34 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match File::open(&full_path) {
let flags = OpenFlags::new().read();
match self.vfs.open_file(&full_path, &flags) {
Ok(mut file) => {
file.seek(SeekFrom::Start(offset))?;
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
let mut buffer = vec![0u8; actual_length as usize];
file.read_exact(&mut buffer)?;
file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
// 计算SHA256哈希使用sha2 crate
use sha2::{Sha256, Digest};
let mut hasher = Sha256::new();
hasher.update(&buffer);
let hash = hasher.finalize();
let hash_hex = format!("{:x}", hash);
// 构建响应
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
// hash-algorithm (SSH string)
response.write_u32::<BigEndian>(6)?;
response.write_all("sha256".as_bytes())?;
// hash-value (SSH string)
response.write_u32::<BigEndian>(hash_hex.len() as u32)?;
response.write_all(hash_hex.as_bytes())?;
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -1219,21 +1208,20 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match File::open(&full_path) {
let flags = OpenFlags::new().read();
match self.vfs.open_file(&full_path, &flags) {
Ok(mut file) => {
file.seek(SeekFrom::Start(offset))?;
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
let mut buffer = vec![0u8; actual_length as usize];
file.read_exact(&mut buffer)?;
file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
// 计算SHA384哈希
use sha2::{Sha384, Digest};
let mut hasher = Sha384::new();
hasher.update(&buffer);
let hash = hasher.finalize();
let hash_hex = format!("{:x}", hash);
// 构建响应
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
@@ -1247,7 +1235,7 @@ impl SftpHandler {
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -1267,21 +1255,20 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match File::open(&full_path) {
let flags = OpenFlags::new().read();
match self.vfs.open_file(&full_path, &flags) {
Ok(mut file) => {
file.seek(SeekFrom::Start(offset))?;
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
let mut buffer = vec![0u8; actual_length as usize];
file.read_exact(&mut buffer)?;
file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
// 计算SHA512哈希
use sha2::{Sha512, Digest};
let mut hasher = Sha512::new();
hasher.update(&buffer);
let hash = hasher.finalize();
let hash_hex = format!("{:x}", hash);
// 构建响应
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
@@ -1295,7 +1282,7 @@ impl SftpHandler {
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -1303,30 +1290,28 @@ impl SftpHandler {
/// 处理check-file@openssh.com扩展Phase 12文件检查
fn handle_check_file(&self, cursor: &mut std::io::Cursor<&[u8]>, id: u32) -> Result<Vec<u8>> {
let path = read_sftp_string(cursor)?;
let check_flags = cursor.read_u32::<BigEndian>()?;
let _check_flags = cursor.read_u32::<BigEndian>()?;
info!("check-file: path={}, flags={:#x}", path, check_flags);
info!("check-file: path={}", path);
let full_path = self.resolve_path(&path)?;
match fs::metadata(&full_path) {
Ok(metadata) => {
// 构建响应
match self.vfs.stat(&full_path) {
Ok(stat) => {
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
// 返回文件存在和基本信息
response.write_u32::<BigEndian>(1)?; // result: 1 = file exists
response.write_u32::<BigEndian>(1)?;
let msg = format!("File exists, size: {}", metadata.len());
let msg = format!("File exists, size: {}", stat.size);
response.write_u32::<BigEndian>(msg.len() as u32)?;
response.write_all(msg.as_bytes())?;
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_response(id, SftpStatus::SSH_FX_NO_SUCH_FILE, &format!("Check file error: {}", e))
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -1339,11 +1324,8 @@ impl SftpHandler {
let write_handle_bytes = read_sftp_string_bytes(cursor)?;
let write_offset = cursor.read_u64::<BigEndian>()?;
info!("copy-data: read_handle={}, read_offset={}, read_length={}, write_handle={}, write_offset={}",
u32::from_be_bytes([read_handle_bytes[0], read_handle_bytes[1], read_handle_bytes[2], read_handle_bytes[3]]),
read_offset, read_length,
u32::from_be_bytes([write_handle_bytes[0], write_handle_bytes[1], write_handle_bytes[2], write_handle_bytes[3]]),
write_offset);
info!("copy-data: read_handle={:?}, read_offset={}, read_length={}, write_handle={:?}, write_offset={}",
read_handle_bytes, read_offset, read_length, write_handle_bytes, write_offset);
let actual_length = std::cmp::min(read_length, Self::MAX_XFER_SIZE as u64);
if actual_length < read_length {
@@ -1353,52 +1335,44 @@ impl SftpHandler {
let read_handle_id = u32::from_be_bytes([read_handle_bytes[0], read_handle_bytes[1], read_handle_bytes[2], read_handle_bytes[3]]);
let write_handle_id = u32::from_be_bytes([write_handle_bytes[0], write_handle_bytes[1], write_handle_bytes[2], write_handle_bytes[3]]);
// 获取read handle的path不可变引用
let read_path = if let Some(read_handle) = self.handles.get(&read_handle_id) {
read_handle.path.clone()
} else {
return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid read handle");
};
// 获取write handle的path不可变引用
let write_path = if let Some(write_handle) = self.handles.get(&write_handle_id) {
write_handle.path.clone()
} else {
return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid write handle");
};
// 从read_path读取数据
match File::open(&read_path) {
Ok(mut read_file) => {
read_file.seek(SeekFrom::Start(read_offset))?;
let mut buffer = vec![0u8; actual_length as usize];
read_file.read_exact(&mut buffer)?;
// 写入到write_path
match OpenOptions::new().write(true).open(&write_path) {
Ok(mut write_file) => {
write_file.seek(SeekFrom::Start(write_offset))?;
write_file.write_all(&buffer)?;
// 构建响应
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
// 返回复制的字节数
response.write_u64::<BigEndian>(actual_length)?;
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_from_io_error(id, &e)
}
}
}
Err(e) => {
self.build_status_from_io_error(id, &e)
}
}
let read_flags = OpenFlags::new().read();
let write_flags = OpenFlags::new().write();
let mut read_file = match self.vfs.open_file(&read_path, &read_flags) {
Ok(f) => f,
Err(e) => return self.build_status_from_vfs_error(id, &e),
};
let mut write_file = match self.vfs.open_file(&write_path, &write_flags) {
Ok(f) => f,
Err(e) => return self.build_status_from_vfs_error(id, &e),
};
read_file.seek(SeekFrom::Start(read_offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
let mut buffer = vec![0u8; actual_length as usize];
read_file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
write_file.seek(SeekFrom::Start(write_offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
write_file.write_all(&buffer).map_err(|e| anyhow!("Write error: {}", e))?;
write_file.flush().ok();
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
response.write_u64::<BigEndian>(actual_length)?;
self.wrap_sftp_packet(&response)
}
/// 解析路径安全性检查参考OpenSSH sftp-server.c: path_resolve()
@@ -1608,6 +1582,24 @@ impl SftpHandler {
let msg = format!("{}", err);
self.build_status_response(id, status, &msg)
}
/// 根据 VfsError 构建状态响应(自动映射错误类型)
fn build_status_from_vfs_error(&self, id: u32, err: &crate::vfs::VfsError) -> Result<Vec<u8>> {
use crate::vfs::VfsError;
let status = match err {
VfsError::NotFound(_) => SftpStatus::SSH_FX_NO_SUCH_FILE,
VfsError::PermissionDenied(_) => SftpStatus::SSH_FX_PERMISSION_DENIED,
VfsError::AlreadyExists(_) => SftpStatus::SSH_FX_FAILURE,
VfsError::NotEmpty(_) => SftpStatus::SSH_FX_FAILURE,
VfsError::NotADirectory(_) => SftpStatus::SSH_FX_FAILURE,
VfsError::IsADirectory(_) => SftpStatus::SSH_FX_FAILURE,
VfsError::Unsupported(_) => SftpStatus::SSH_FX_OP_UNSUPPORTED,
VfsError::Io(_) => SftpStatus::SSH_FX_FAILURE,
VfsError::UnexpectedEof => SftpStatus::SSH_FX_EOF,
};
let msg = format!("{}", err);
self.build_status_response(id, status, &msg)
}
}
/// 读取SFTP字符串参考draft-ietf-secsh-filexfer-02.txt
@@ -1665,8 +1657,14 @@ fn read_sftp_attrs<R: std::io::Read>(reader: &mut R) -> Result<SftpAttrs> {
#[cfg(test)]
mod tests {
use super::*;
use crate::vfs::local_fs::LocalFs;
use std::fs::File;
use tempfile::TempDir;
fn make_handler(root_dir: PathBuf) -> SftpHandler {
SftpHandler::new(root_dir, Box::new(LocalFs::new()), 32768)
}
#[test]
fn test_sftp_packet_type_conversion() {
assert_eq!(SftpPacketType::try_from(1).unwrap(), SftpPacketType::SSH_FXP_INIT);
@@ -1677,7 +1675,7 @@ mod tests {
#[test]
fn test_sftp_handler_creation() {
let temp_dir = TempDir::new().unwrap();
let handler = SftpHandler::new(temp_dir.path().to_path_buf(), 32768);
let handler = make_handler(temp_dir.path().to_path_buf());
assert_eq!(handler.next_handle_id, 0);
}
@@ -1697,7 +1695,7 @@ mod tests {
#[test]
fn test_sftp_handle_init() {
let temp_dir = TempDir::new().unwrap();
let mut handler = SftpHandler::new(temp_dir.path().to_path_buf(), 32768);
let mut handler = make_handler(temp_dir.path().to_path_buf());
let init_packet = vec![1, 0, 0, 0, 3];
let response = handler.handle_request(&init_packet).unwrap();