VFS/DataProvider/Config refactoring + SSH public key authentication
Some checks failed
Test / test (push) Has been cancelled
Test / build (push) Has been cancelled

Phase 1-6 of refactoring plan:
- VFS abstraction (VfsBackend trait + LocalFs + OpenFlags builder)
- DataProvider trait (SqliteProvider + PgProvider, SFTPGo-compatible)
- Config refactoring (AppConfig unified sections, env overrides)
- SSH handlers (sftp/scp/rsync) migrated to VFS + DataProvider
- SSH public key authentication (Ed25519 signature verification)
- SSH stderr → CHANNEL_EXTENDED_DATA support
- Web auth uses DataProvider instead of direct SQL
- User home directory from provider (per-user isolation)
- PostgreSQL auth provider for SFTPGo compatibility
This commit is contained in:
Warren
2026-06-18 23:35:18 +08:00
parent 83fb0de78a
commit f90e4f496c
25 changed files with 2039 additions and 612 deletions

View File

@@ -5,6 +5,8 @@ use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use uuid::Uuid;
use crate::provider::{DataProvider, ProviderError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub user_id: String,
@@ -66,13 +68,13 @@ pub struct AuthState {
pub users: Arc<Mutex<HashMap<String, User>>>,
pub auth_db: Option<crate::sync::AuthDb>,
pub admin_sessions: Arc<Mutex<HashMap<String, AdminSession>>>,
pub provider: Option<Arc<dyn DataProvider>>,
}
impl AuthState {
pub fn new() -> Self {
let mut users = HashMap::new();
// Create default demo user
let password_hash = hash("demo123", DEFAULT_COST).unwrap();
users.insert(
"demo".to_string(),
@@ -89,6 +91,7 @@ impl AuthState {
users: Arc::new(Mutex::new(users)),
auth_db: None,
admin_sessions: Arc::new(Mutex::new(HashMap::new())),
provider: None,
}
}
@@ -100,6 +103,17 @@ impl AuthState {
users: Arc::new(Mutex::new(HashMap::new())),
auth_db,
admin_sessions: Arc::new(Mutex::new(HashMap::new())),
provider: None,
}
}
pub fn with_provider(provider: Box<dyn DataProvider>) -> Self {
AuthState {
sessions: Arc::new(Mutex::new(HashMap::new())),
users: Arc::new(Mutex::new(HashMap::new())),
auth_db: None,
admin_sessions: Arc::new(Mutex::new(HashMap::new())),
provider: Some(Arc::from(provider)),
}
}
@@ -208,8 +222,12 @@ impl AuthState {
}
pub fn login_with_sync(&self, username: &str, password: &str) -> Option<LoginResponse> {
// Prefer provider over auth_db
if let Some(provider) = &self.provider {
return self.login_with_provider(&**provider, username, password);
}
if let Some(auth_db) = &self.auth_db {
// Get user from auth.sqlite
let user = match auth_db.get_user(username) {
Ok(Some(user)) => user,
Ok(None) => {
@@ -266,11 +284,70 @@ impl AuthState {
}
}
fn login_with_provider(&self, provider: &dyn DataProvider, username: &str, password: &str) -> Option<LoginResponse> {
match provider.get_user(username) {
Ok(Some(user)) => {
if user.status != 1 {
log::warn!("User {} is disabled or not found", username);
return None;
}
match provider.check_password(username, password) {
Ok(true) => {
let groups = provider.get_user_groups(username).unwrap_or_default();
let token = Uuid::new_v4().to_string();
let now = Utc::now();
let expires_at = now + Duration::hours(24);
let session = Session {
token: token.clone(),
user_id: username.to_string(),
username: username.to_string(),
created_at: now.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
expires_at: expires_at.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
groups: groups.clone(),
permissions: user.permissions.clone(),
};
let mut sessions = self.sessions.lock().unwrap();
sessions.insert(token.clone(), session);
log::info!("User {} logged in via DataProvider", username);
Some(LoginResponse {
token,
expires_at: expires_at.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
user_id: username.to_string(),
groups,
permissions: user.permissions,
})
}
Ok(false) => {
log::warn!("Invalid password for user {}", username);
None
}
Err(e) => {
log::error!("Password check error for {}: {}", username, e);
None
}
}
}
Ok(None) => {
log::warn!("User {} not found", username);
None
}
Err(e) => {
log::error!("Provider error for {}: {}", username, e);
None
}
}
}
pub fn verify_token(&self, token: &str) -> Option<Session> {
let sessions = self.sessions.lock().unwrap();
let session = sessions.get(token)?;
// Check expiration
let expires_at = chrono::DateTime::parse_from_rfc3339(&session.expires_at)
.ok()?
.with_timezone(&Utc);

View File

@@ -6,20 +6,29 @@ pub enum SshCommand {
Start {
#[arg(short, long, default_value = "2024")]
port: u16,
/// PostgreSQL connection string for SFTPGo-compatible auth (e.g. "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026")
#[arg(long)]
pg_conn: Option<String>,
},
}
pub async fn handle_ssh_command(cmd: SshCommand) -> anyhow::Result<()> {
match cmd {
SshCommand::Start { port } => {
SshCommand::Start { port, pg_conn } => {
println!("=== MarkBase SSH Server (Hand-written Implementation) ===");
println!("Port: {}", port);
println!("Implementation: SSH-2.0-MarkBaseSSH_1.0");
println!("Features: SSH + SFTP + SCP + rsync");
if pg_conn.is_some() {
println!("Auth Provider: PostgreSQL (SFTPGo-compatible)");
} else {
println!("Auth Provider: SQLite");
}
println!("Security: ⭐⭐⭐⭐⭐ (RustCrypto authoritative libraries)");
println!();
crate::ssh_server::server::run_ssh_server(Some(port))?;
crate::ssh_server::server::run_ssh_server(Some(port), pg_conn.as_deref())?;
}
}
Ok(())

View File

@@ -0,0 +1,233 @@
pub mod web;
use anyhow::Result;
use serde::{Deserialize, Serialize};
// Re-export web config for backward compatibility
pub use web::*;
/// Unified application configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppConfig {
#[serde(default)]
pub web: WebSection,
#[serde(default)]
pub s3: S3Section,
#[serde(default)]
pub sftp: SftpSection,
#[serde(default)]
pub ssh: SshSection,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebSection {
pub host: String,
pub port: u16,
pub log_level: String,
pub auth_db_path: String,
pub users_db_dir: String,
}
impl Default for WebSection {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 11438,
log_level: "info".to_string(),
auth_db_path: "data/auth.sqlite".to_string(),
users_db_dir: "data/users".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct S3Section {
pub enabled: bool,
pub endpoint: String,
pub region: String,
pub require_auth: bool,
pub default_access_key: String,
pub default_secret_key: String,
pub keys_db_path: String,
}
impl Default for S3Section {
fn default() -> Self {
Self {
enabled: true,
endpoint: "http://localhost:11438/s3".to_string(),
region: "us-east-1".to_string(),
require_auth: false,
default_access_key: "markbase_access_key_001".to_string(),
default_secret_key: "markbase_secret_key_xyz123".to_string(),
keys_db_path: "data/s3_keys.json".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SftpSection {
pub enabled: bool,
pub port: u16,
pub base_path: String,
pub auth_db_path: String,
pub max_connections: usize,
pub chunk_size: usize,
pub require_path_validation: bool,
pub audit_logging: bool,
pub path_traversal_protection: bool,
pub symlink_check: bool,
}
impl Default for SftpSection {
fn default() -> Self {
Self {
enabled: true,
port: 2023,
base_path: "/Users/accusys/momentry/var/sftpgo/data".to_string(),
auth_db_path: "data/auth.sqlite".to_string(),
max_connections: 100,
chunk_size: 65536,
require_path_validation: true,
audit_logging: true,
path_traversal_protection: true,
symlink_check: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SshSection {
pub enabled: bool,
pub port: u16,
pub bind_address: String,
pub security_config_path: String,
}
impl Default for SshSection {
fn default() -> Self {
Self {
enabled: true,
port: 2024,
bind_address: "127.0.0.1".to_string(),
security_config_path: "data/ssh_config.json".to_string(),
}
}
}
impl AppConfig {
pub fn load(path: &str) -> Result<Self> {
let config_path = std::path::PathBuf::from(path);
if !config_path.exists() {
log::warn!("Config file not found: {}, using defaults", path);
return Ok(Self::default());
}
let content = std::fs::read_to_string(&config_path)?;
let config: AppConfig = toml::from_str(&content)?;
log::info!("App config loaded from: {}", path);
Ok(config)
}
pub fn load_default() -> Result<Self> {
Self::load("config/app.toml")
}
pub fn save(&self, path: &str) -> Result<()> {
let config_path = std::path::PathBuf::from(path);
if config_path.exists() {
let backup_path = config_path.with_extension("toml.bak");
std::fs::copy(&config_path, &backup_path)?;
log::info!("Backup created: {}", backup_path.display());
}
let content = toml::to_string_pretty(self)?;
std::fs::write(&config_path, content)?;
log::info!("App config saved to: {}", path);
Ok(())
}
pub fn merge_env(&mut self) {
if let Ok(v) = std::env::var("MB_WEB_HOST") {
self.web.host = v;
}
if let Ok(v) = std::env::var("MB_WEB_PORT") {
if let Ok(p) = v.parse() { self.web.port = p; }
}
if let Ok(v) = std::env::var("MB_SSH_PORT") {
if let Ok(p) = v.parse() { self.ssh.port = p; }
}
if let Ok(v) = std::env::var("MB_SFTP_PORT") {
if let Ok(p) = v.parse() { self.sftp.port = p; }
}
if let Ok(v) = std::env::var("MB_S3_ENABLED") {
self.s3.enabled = v == "true" || v == "1";
}
if let Ok(v) = std::env::var("MB_AUTH_DB") {
self.web.auth_db_path = v.clone();
self.sftp.auth_db_path = v;
}
}
}
impl Default for AppConfig {
fn default() -> Self {
Self {
web: WebSection::default(),
s3: S3Section::default(),
sftp: SftpSection::default(),
ssh: SshSection::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_default_config() {
let config = AppConfig::default();
assert_eq!(config.web.port, 11438);
assert_eq!(config.ssh.port, 2024);
assert_eq!(config.sftp.port, 2023);
assert_eq!(config.s3.region, "us-east-1");
}
#[test]
fn test_load_missing() {
let config = AppConfig::load("/tmp/nonexistent/config.toml").unwrap();
assert_eq!(config.web.port, 11438);
}
#[test]
fn test_merge_env() {
std::env::set_var("MB_WEB_PORT", "9090");
std::env::set_var("MB_SSH_PORT", "2222");
let mut config = AppConfig::default();
config.merge_env();
assert_eq!(config.web.port, 9090);
assert_eq!(config.ssh.port, 2222);
std::env::remove_var("MB_WEB_PORT");
std::env::remove_var("MB_SSH_PORT");
}
#[test]
fn test_save_and_load() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.toml");
let path_str = path.to_string_lossy().to_string();
let config = AppConfig::default();
config.save(&path_str).unwrap();
let loaded = AppConfig::load(&path_str).unwrap();
assert_eq!(loaded.web.port, 11438);
}
}

View File

@@ -67,13 +67,12 @@ impl MarkBaseConfig {
}
pub fn save(&self, path: &Path) -> Result<()> {
// Create backup before saving
if path.exists() {
let backup_path = path.with_extension("toml.bak");
std::fs::copy(path, &backup_path)?;
log::info!("Backup created: {}", backup_path.display());
}
let content = toml::to_string_pretty(self)?;
std::fs::write(path, content)?;
log::info!("Configuration saved to: {}", path.display());

View File

@@ -23,6 +23,8 @@ pub mod import_markdown; // Category View Module - 双视图管理Phase 1
// pub mod ssh2_mod; // ssh2辅助模块已禁用
pub mod ssh_server; // SSH服务器Phase 1-9完成正在修复编译错误⭐⭐⭐⭐⭐
pub mod sync;
pub mod provider; // DataProvider抽象层Phase 5
pub mod vfs; // VFS抽象层Phase 1-6重构计划
// Re-export from external filetree crate
pub use filetree::node::FileNode;

View File

@@ -0,0 +1,65 @@
pub mod sqlite;
pub mod pg;
use std::path::PathBuf;
/// 用户信息
#[derive(Debug, Clone)]
pub struct User {
pub username: String,
pub password_hash: String,
pub home_dir: PathBuf,
pub uid: u32,
pub gid: u32,
pub permissions: String,
pub status: i32,
}
/// Provider 错误类型
#[derive(Debug)]
pub enum ProviderError {
NotFound(String),
AuthFailed(String),
Internal(String),
}
impl std::fmt::Display for ProviderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProviderError::NotFound(msg) => write!(f, "Not found: {}", msg),
ProviderError::AuthFailed(msg) => write!(f, "Authentication failed: {}", msg),
ProviderError::Internal(msg) => write!(f, "Internal error: {}", msg),
}
}
}
impl std::error::Error for ProviderError {}
/// 数据提供者 trait用户认证和配置
pub trait DataProvider: Send + Sync {
/// 获取用户信息
fn get_user(&self, username: &str) -> Result<Option<User>, ProviderError>;
/// 验证用户密码
fn check_password(&self, username: &str, password: &str) -> Result<bool, ProviderError>;
/// 获取用户主目录
fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError>;
/// 获取用户组列表
fn get_user_groups(&self, username: &str) -> Result<Vec<String>, ProviderError> {
let _ = username;
Ok(Vec::new())
}
/// 检查用户是否存在且启用
fn user_exists(&self, username: &str) -> Result<bool, ProviderError> {
Ok(self.get_user(username)?.map(|u| u.status == 1).unwrap_or(false))
}
/// 获取用户的公开密钥列表OpenSSH authorized_keys格式
fn get_public_keys(&self, username: &str) -> Result<Vec<String>, ProviderError> {
let _ = username;
Ok(Vec::new())
}
}

View File

@@ -0,0 +1,184 @@
use std::path::PathBuf;
use postgres::{Client, NoTls};
use bcrypt::verify;
use super::{DataProvider, ProviderError, User};
/// PostgreSQL 数据提供者(兼容 SFTPGo 的 users 表)
pub struct PgProvider {
conn_str: String,
}
impl PgProvider {
/// 从连接字符串创建 PgProvider
///
/// 连接字符串格式host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026
pub fn new(conn_str: &str) -> Result<Self, ProviderError> {
Ok(Self { conn_str: conn_str.to_string() })
}
pub fn from_params(
host: &str,
port: u16,
dbname: &str,
user: &str,
password: &str,
) -> Result<Self, ProviderError> {
let conn_str = format!(
"host={} port={} dbname={} user={} password={}",
host, port, dbname, user, password
);
Ok(Self { conn_str })
}
fn open_conn(&self) -> Result<Client, ProviderError> {
Client::connect(&self.conn_str, NoTls)
.map_err(|e| ProviderError::Internal(format!("PostgreSQL connect failed: {}", e)))
}
}
impl DataProvider for PgProvider {
fn get_user(&self, username: &str) -> Result<Option<User>, ProviderError> {
let mut conn = self.open_conn()?;
let result = conn.query_opt(
"SELECT username, password, home_dir, permissions, uid, gid, status
FROM users WHERE username = $1 AND status = 1",
&[&username],
).map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
match result {
Some(row) => Ok(Some(User {
username: row.get(0),
password_hash: row.get::<_, Option<String>>(1).unwrap_or_default(),
home_dir: PathBuf::from(row.get::<_, String>(2)),
permissions: row.get::<_, Option<String>>(3).unwrap_or_else(|| "*".to_string()),
uid: row.get::<_, i64>(4) as u32,
gid: row.get::<_, i64>(5) as u32,
status: row.get(6),
})),
None => Ok(None),
}
}
fn check_password(&self, username: &str, password: &str) -> Result<bool, ProviderError> {
let hash = match self.get_user(username)? {
Some(user) => user.password_hash,
None => return Ok(false),
};
if hash.is_empty() {
return Ok(false);
}
verify(password, &hash)
.map_err(|e| ProviderError::Internal(format!("bcrypt verify error: {}", e)))
}
fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
Ok(self.get_user(username)?.map(|u| u.home_dir.to_string_lossy().to_string()))
}
fn get_public_keys(&self, username: &str) -> Result<Vec<String>, ProviderError> {
let mut conn = self.open_conn()?;
let result = conn.query_opt(
"SELECT public_keys FROM users WHERE username = $1 AND status = 1",
&[&username],
).map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
match result {
Some(row) => {
let json_str: Option<String> = row.get(0);
match json_str {
Some(s) if !s.is_empty() => {
let keys: Vec<serde_json::Value> = serde_json::from_str(&s)
.map_err(|e| ProviderError::Internal(format!("JSON parse error: {}", e)))?;
Ok(keys.iter()
.filter_map(|v| v.get("public_key")?.as_str().map(|s| s.to_string()))
.collect())
}
_ => Ok(Vec::new()),
}
}
None => Ok(Vec::new()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pg_provider_connection() {
// 仅当 SFTPGo PostgreSQL 可用时运行
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
);
assert!(provider.is_ok(), "Should connect to SFTPGo PostgreSQL");
}
#[test]
fn test_pg_get_user_demo() {
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
).unwrap();
let user = provider.get_user("demo").unwrap();
assert!(user.is_some(), "Demo user should exist");
assert_eq!(user.unwrap().username, "demo");
}
#[test]
fn test_pg_get_user_momentry() {
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
).unwrap();
let user = provider.get_user("momentry").unwrap();
assert!(user.is_some(), "Momentry user should exist");
}
#[test]
fn test_pg_get_user_warren() {
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
).unwrap();
let user = provider.get_user("warren").unwrap();
assert!(user.is_some(), "Warren user should exist");
}
#[test]
fn test_pg_check_password_demo() {
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
).unwrap();
let valid = provider.check_password("demo", "demo123").unwrap();
assert!(valid, "Password should be valid");
}
#[test]
fn test_pg_check_password_invalid() {
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
).unwrap();
let valid = provider.check_password("demo", "wrong").unwrap();
assert!(!valid, "Wrong password should fail");
}
#[test]
fn test_pg_get_home_dir() {
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
).unwrap();
let dir = provider.get_home_dir("demo").unwrap();
assert!(dir.is_some());
assert!(dir.unwrap().contains("momentry"));
}
#[test]
fn test_pg_nonexistent_user() {
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
).unwrap();
let user = provider.get_user("__nonexistent__").unwrap();
assert!(user.is_none());
}
}

View File

@@ -0,0 +1,135 @@
use std::path::PathBuf;
use rusqlite::{Connection, params};
use bcrypt::verify;
use super::{DataProvider, ProviderError, User};
/// SQLite 数据提供者
pub struct SqliteProvider {
db_path: PathBuf,
}
impl SqliteProvider {
pub fn new(db_path: &str) -> Result<Self, ProviderError> {
let path = PathBuf::from(db_path);
if !path.exists() {
return Err(ProviderError::NotFound(format!(
"Database not found: {}", db_path
)));
}
Ok(Self { db_path: path })
}
fn open_conn(&self) -> Result<Connection, ProviderError> {
Connection::open(&self.db_path)
.map_err(|e| ProviderError::Internal(format!("Failed to open database: {}", e)))
}
}
impl DataProvider for SqliteProvider {
fn get_user(&self, username: &str) -> Result<Option<User>, ProviderError> {
let conn = self.open_conn()?;
let result = conn.query_row(
"SELECT username, password_hash, home_dir, permissions, uid, gid, status
FROM sftpgo_users WHERE username = ?1 AND status = 1",
params![username],
|row| {
Ok(User {
username: row.get(0)?,
password_hash: row.get(1)?,
home_dir: PathBuf::from(row.get::<_, String>(2)?),
permissions: row.get(3)?,
uid: row.get::<_, i64>(4)? as u32,
gid: row.get::<_, i64>(5)? as u32,
status: row.get(6)?,
})
},
);
match result {
Ok(user) => Ok(Some(user)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(ProviderError::Internal(format!(
"Database query error: {}", e
))),
}
}
fn check_password(&self, username: &str, password: &str) -> Result<bool, ProviderError> {
let hash = match self.get_user(username)? {
Some(user) => user.password_hash,
None => return Ok(false),
};
verify(password, &hash)
.map_err(|e| ProviderError::Internal(format!("bcrypt verify error: {}", e)))
}
fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
Ok(self.get_user(username)?.map(|u| u.home_dir.to_string_lossy().to_string()))
}
fn get_public_keys(&self, username: &str) -> Result<Vec<String>, ProviderError> {
let _ = username;
Ok(Vec::new())
}
fn get_user_groups(&self, username: &str) -> Result<Vec<String>, ProviderError> {
let conn = self.open_conn()?;
let groups: Vec<String> = conn
.prepare("SELECT group_name FROM users_groups_mapping WHERE username = ?1")
.map_err(|e| ProviderError::Internal(format!("Query prepare error: {}", e)))?
.query_map(params![username], |row| row.get(0))
.map_err(|e| ProviderError::Internal(format!("Query map error: {}", e)))?
.filter_map(|r| r.ok())
.collect();
Ok(groups)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_not_found() {
let provider = SqliteProvider::new("/tmp/nonexistent.db");
assert!(provider.is_err());
}
#[test]
fn test_provider_get_user() {
let provider = SqliteProvider::new("data/auth.sqlite").unwrap();
let user = provider.get_user("demo").unwrap();
assert!(user.is_some());
assert_eq!(user.unwrap().username, "demo");
}
#[test]
fn test_provider_nonexistent_user() {
let provider = SqliteProvider::new("data/auth.sqlite").unwrap();
let user = provider.get_user("__nonexistent__").unwrap();
assert!(user.is_none());
}
#[test]
fn test_check_password_valid() {
let provider = SqliteProvider::new("data/auth.sqlite").unwrap();
let valid = provider.check_password("demo", "demo123").unwrap();
assert!(valid);
}
#[test]
fn test_check_password_invalid() {
let provider = SqliteProvider::new("data/auth.sqlite").unwrap();
let valid = provider.check_password("demo", "wrong").unwrap();
assert!(!valid);
}
#[test]
fn test_get_home_dir() {
let provider = SqliteProvider::new("data/auth.sqlite").unwrap();
let dir = provider.get_home_dir("demo").unwrap();
assert!(dir.is_some());
}
}

View File

@@ -13,6 +13,7 @@ use std::sync::{Arc, Mutex};
use crate::audio;
use crate::auth::{AuthState, LoginRequest};
use crate::provider::sqlite::SqliteProvider;
use crate::render;
use crate::download;
use crate::archive::{self, ArchiveFormat, ArchiveProcessor, FormatDetector, ArchiveConfig, ProcessorRegistry};
@@ -57,7 +58,10 @@ pub async fn run(port: u16, file: Option<String>) -> anyhow::Result<()> {
}))),
labels: Arc::new(Mutex::new(vec![])),
db_dir: "data/users".to_string(),
auth: AuthState::with_sync("data/auth.sqlite"),
auth: AuthState::with_provider(Box::new(
SqliteProvider::new("data/auth.sqlite")
.map_err(|e| anyhow::anyhow!("Failed to init SqliteProvider: {}", e))?
)),
auth_db_path: "data/auth.sqlite".to_string(),
s3_keys: Arc::new(Mutex::new(load_s3_keys())),
};

View File

@@ -1,9 +1,9 @@
// ssh2 Server核心实现
// 替代russh提供完整的SSH/SFTP/SCP/rsync支持
use crate::sftp::auth::SftpAuth;
use crate::provider::sqlite::SqliteProvider;
use crate::sftp::config::SftpConfig;
use anyhow::{Result, anyhow};
use anyhow::{Result, anyhow, Context};
use log::{info, warn, error};
use ssh2::Session;
use std::net::{TcpListener, TcpStream};
@@ -106,9 +106,10 @@ fn authenticate_client(session: &Session, config: &Arc<SftpConfig>) -> Result<St
let user = "warren";
let password = "demo123";
// 使用SftpAuth验证(复用现有认证系统)
let auth = SftpAuth::new(&config.auth_db_path)?;
if auth.verify_password(user, password)? {
// 使用SqliteProvider验证(复用现有认证系统)
let provider = SqliteProvider::new(&config.auth_db_path)
.context("Failed to init SqliteProvider for ssh2_server")?;
if provider.check_password(user, password)? {
info!("Password auth successful for user: {}", user);
session.userauth_password(user, password)?;
return Ok(user.to_string());

View File

@@ -1,70 +1,58 @@
// SSH认证协议实现Phase 5
// 参考OpenSSH auth.c, auth-passwd.c
use crate::ssh_server::packet::{SshPacket, PacketType};
use std::io::{Read, Write}; // 导入Write traitOpenSSH标准
use std::io::Write;
use anyhow::{Result, anyhow};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use log::{info, warn, debug};
use rusqlite::{Connection, params};
use bcrypt::{verify, DEFAULT_COST};
use base64::{Engine as _, engine::general_purpose}; // Phase 9: Base64 for authorized_keys
use base64::{Engine as _, engine::general_purpose};
use ed25519_dalek::{VerifyingKey, Signature};
use crate::provider::{DataProvider, ProviderError};
/// SSH认证处理器参考OpenSSH auth2.c
pub struct AuthHandler {
db_path: String, // SQLite数据库路径
provider: Box<dyn DataProvider>,
}
impl AuthHandler {
/// 创建认证处理器
pub fn new() -> Result<Self> {
let db_path = "data/auth.sqlite".to_string();
// 验证数据库是否存在
let conn = Connection::open(&db_path)?;
drop(conn); // rusqlite会自动关闭
info!("AuthHandler initialized with database: {}", db_path);
Ok(Self { db_path })
pub fn new(provider: Box<dyn DataProvider>) -> Self {
info!("AuthHandler initialized with DataProvider");
Self { provider }
}
/// 获取用户home目录SFTPGo兼容
pub fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
self.provider.get_home_dir(username)
}
/// 处理SSH_MSG_USERAUTH_REQUEST参考OpenSSH auth2.c: userauth_request()
pub fn handle_userauth_request(&mut self, packet: &SshPacket) -> Result<AuthResult> {
pub fn handle_userauth_request(&mut self, packet: &SshPacket, session_id: &[u8]) -> Result<AuthResult> {
info!("Processing SSH_MSG_USERAUTH_REQUEST");
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
// Packet type
let packet_type = cursor.read_u8()?;
if packet_type != PacketType::SSH_MSG_USERAUTH_REQUEST as u8 {
return Err(anyhow!("Invalid packet type for USERAUTH_REQUEST"));
}
// 读取用户名SSH string
let user = read_ssh_string(&mut cursor)?;
// 读取服务名称SSH string
let service = read_ssh_string(&mut cursor)?;
// 读取认证方法名称SSH string
let method = read_ssh_string(&mut cursor)?;
info!("Auth request: user={}, service={}, method={}", user, service, method);
// 检查服务名称OpenSSH要求ssh-connection
if service != "ssh-connection" {
warn!("Unsupported service: {}", service);
return Ok(AuthResult::Failure("Unsupported service".to_string()));
}
// 根据认证方法处理参考OpenSSH auth2.c
if method == "password" {
self.handle_password_auth(&mut cursor, &user)
} else if method == "publickey" {
self.handle_publickey_auth(&mut cursor, &user)
self.handle_publickey_auth(&mut cursor, &user, &service, session_id)
} else if method == "none" {
// OpenSSHnone认证总是失败用于查询支持的认证方法
// 返回支持的认证方法列表password, publickey
warn!("None auth request - returning supported methods");
Ok(AuthResult::Failure("password,publickey".to_string()))
} else {
@@ -72,203 +60,254 @@ impl AuthHandler {
Ok(AuthResult::Failure("Unsupported auth method".to_string()))
}
}
/// 处理password认证参考OpenSSH auth-passwd.c
fn handle_password_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result<AuthResult> {
info!("Handling password auth for user: {}", user);
// 读取是否修改密码标志booleanOpenSSH password认证格式
let change_password = cursor.read_u8()? != 0;
if change_password {
warn!("Password change not supported");
return Ok(AuthResult::Failure("Password change not supported".to_string()));
}
// 读取密码SSH string
let password = read_ssh_string(cursor)?;
debug!("Password auth attempt: user={}, password length={}", user, password.len());
// 查询数据库获取password_hash
let conn = Connection::open(&self.db_path)?;
let password_hash_result = conn.query_row(
"SELECT password_hash FROM sftpgo_users WHERE username = ?1 AND status = 1",
params![user],
|row| row.get::<_, String>(0)
);
// 关闭连接rusqlite会自动关闭
drop(conn);
// 验证用户是否存在
let password_hash = match password_hash_result {
Ok(hash) => Some(hash),
Err(rusqlite::Error::QueryReturnedNoRows) => None,
Err(e) => return Err(anyhow!("Database query error: {}", e)),
};
if password_hash.is_none() {
warn!("User not found or disabled: {}", user);
// SSH_MSG_USERAUTH_FAILURE必须返回可继续使用的认证方法列表RFC 4253
return Ok(AuthResult::Failure("password,publickey".to_string()));
}
// 使用bcrypt验证密码
let stored_hash = password_hash.unwrap();
info!("Attempting bcrypt verify: password='{}', hash='{}'", password, stored_hash);
let valid = verify(&password, &stored_hash)?;
info!("bcrypt verify result: {}", valid);
if valid {
info!("Password auth successful for user: {}", user);
Ok(AuthResult::Success)
} else {
warn!("Password auth failed for user: {}", user);
// SSH_MSG_USERAUTH_FAILURE必须返回可继续使用的认证方法列表RFC 4253
Ok(AuthResult::Failure("password,publickey".to_string()))
match self.provider.check_password(user, &password) {
Ok(true) => {
info!("Password auth successful for user: {}", user);
Ok(AuthResult::Success)
}
Ok(false) => {
warn!("Password auth failed for user: {}", user);
Ok(AuthResult::Failure("password,publickey".to_string()))
}
Err(ProviderError::NotFound(msg)) => {
warn!("User not found: {}", msg);
Ok(AuthResult::Failure("password,publickey".to_string()))
}
Err(e) => {
Err(anyhow!("Password auth error: {}", e))
}
}
}
/// 构建SSH_MSG_USERAUTH_SUCCESS packet参考OpenSSH auth2.c
/// 构建SSH_MSG_USERAUTH_SUCCESS packet
pub fn build_userauth_success() -> Result<SshPacket> {
let payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8];
Ok(SshPacket::new(payload))
}
/// 构建SSH_MSG_USERAUTH_FAILURE packet参考OpenSSH auth2.c
/// 构建SSH_MSG_USERAUTH_FAILURE packet
pub fn build_userauth_failure(methods: &[String], partial_success: bool) -> Result<SshPacket> {
let mut payload = Vec::new();
// Packet type
payload.write_u8(PacketType::SSH_MSG_USERAUTH_FAILURE as u8)?;
// 认证方法列表SSH string逗号分隔
let methods_str = methods.join(",");
payload.write_u32::<BigEndian>(methods_str.len() as u32)?;
payload.write_all(methods_str.as_bytes())?;
// partial_success标志boolean
payload.write_u8(if partial_success { 1 } else { 0 })?;
Ok(SshPacket::new(payload))
}
/// 构建SSH_MSG_USERAUTH_BANNER packet可选参考OpenSSH auth2.c
/// 构建SSH_MSG_USERAUTH_BANNER packet
pub fn build_userauth_banner(message: &str, language: &str) -> Result<SshPacket> {
let mut payload = Vec::new();
// Packet type
payload.write_u8(PacketType::SSH_MSG_USERAUTH_BANNER as u8)?;
// Banner messageSSH string
payload.write_u32::<BigEndian>(message.len() as u32)?;
payload.write_all(message.as_bytes())?;
// Language tagSSH string
payload.write_u32::<BigEndian>(language.len() as u32)?;
payload.write_all(language.as_bytes())?;
Ok(SshPacket::new(payload))
}
/// 处理publickey认证Phase 9参考OpenSSH auth2-pubkey.c
fn handle_publickey_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result<AuthResult> {
/// 处理publickey认证RFC 4252 §7
/// 支持Ed25519签名验证 + 数据库/filesystem密钥查找
fn handle_publickey_auth(
&mut self,
cursor: &mut std::io::Cursor<&[u8]>,
user: &str,
service: &str,
session_id: &[u8],
) -> Result<AuthResult> {
info!("Handling publickey auth for user: {}", user);
// 读取是否签名的标志boolean
let is_signed = cursor.read_u8()? != 0;
// 读取public key algorithmSSH string
let algorithm = read_ssh_string(cursor)?;
// 读取public key blobSSH string
let public_key_blob = read_ssh_string_bytes(cursor)?;
info!("Publickey auth: algorithm={}, blob_len={}, is_signed={}", algorithm, public_key_blob.len(), is_signed);
// Phase 9简化实现 - 从authorized_keys文件验证
let authorized_keys_path = format!("data/{}/authorized_keys", user);
let authorized_keys = match std::fs::read_to_string(&authorized_keys_path) {
Ok(content) => content,
Err(_) => {
// 尝试默认路径
let default_path = "data/authorized_keys";
match std::fs::read_to_string(default_path) {
Ok(content) => content,
Err(_) => {
warn!("No authorized_keys file found for user: {}", user);
return Ok(AuthResult::Failure("password,publickey".to_string()));
}
}
}
};
// 解析authorized_keys查找匹配的public key
let public_key_matches = authorized_keys.lines().any(|line| {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
return false;
}
// SSH authorized_keys格式algorithm base64-key comment
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() < 2 {
return false;
}
let key_algorithm = parts[0];
let key_base64 = parts[1];
// 匹配algorithm
if key_algorithm != algorithm {
return false;
}
// 匹配public key blobbase64解码对比
match base64_decode(key_base64) {
Ok(decoded_key) => decoded_key == public_key_blob,
Err(_) => false,
}
});
if !public_key_matches {
if !self.is_key_authorized(user, &algorithm, &public_key_blob)? {
warn!("Public key not authorized for user: {}", user);
return Ok(AuthResult::Failure("password,publickey".to_string()));
}
info!("Public key authorized for user: {}", user);
// 如果没有签名返回PK_OKquery阶段
if !is_signed {
// SSH_MSG_USERAUTH_PK_OK表示public key可接受client需要发送签名
return Ok(AuthResult::PublicKeyOk(algorithm, public_key_blob));
}
// 读取signatureSSH string
let signature = read_ssh_string_bytes(cursor)?;
info!("Verifying signature for user: {}", user);
// Phase 9简化签名验证 - 信任authorized_keys
// 完整实现需要提取session_id, 构建signed_data, verify signature
// 这里简化处理只要public key匹配authorized_keys就接受
let signature_blob = read_ssh_string_bytes(cursor)?;
self.verify_signature(&algorithm, &public_key_blob, &signature_blob, user, service, session_id)?;
info!("Publickey auth successful for user: {}", user);
Ok(AuthResult::Success)
}
/// 检查public key是否在授权列表中数据库优先fallback到filesystem
fn is_key_authorized(&self, user: &str, algorithm: &str, public_key_blob: &[u8]) -> Result<bool> {
// 1. 先检查数据库
match self.provider.get_public_keys(user) {
Ok(keys) => {
for key_line in &keys {
if public_key_matches_line(key_line, algorithm, public_key_blob) {
return Ok(true);
}
}
}
Err(e) => warn!("Failed to get public keys from provider: {}", e),
}
// 2. Fallback到filesystem
let authorized_keys_path = format!("data/{}/authorized_keys", user);
let content = match std::fs::read_to_string(&authorized_keys_path) {
Ok(c) => c,
Err(_) => match std::fs::read_to_string("data/authorized_keys") {
Ok(c) => c,
Err(_) => return Ok(false),
}
};
Ok(content.lines().any(|line| public_key_matches_line(line, algorithm, public_key_blob)))
}
/// 验证Ed25519签名RFC 4252 §7
fn verify_signature(
&self,
algorithm: &str,
public_key_blob: &[u8],
signature_blob: &[u8],
user: &str,
service: &str,
session_id: &[u8],
) -> Result<()> {
// 目前只支援Ed25519
if algorithm != "ssh-ed25519" {
return Err(anyhow!("Unsupported public key algorithm: {}", algorithm));
}
let verifying_key = parse_ed25519_verifying_key(public_key_blob)?;
let signature = parse_ed25519_signature(signature_blob)?;
// 建立签名验证数据RFC 4252 §7
let mut signed_data = Vec::new();
// string session identifier
signed_data.write_u32::<BigEndian>(session_id.len() as u32)?;
signed_data.write_all(session_id)?;
// byte SSH_MSG_USERAUTH_REQUEST
signed_data.write_u8(PacketType::SSH_MSG_USERAUTH_REQUEST as u8)?;
// string user name
signed_data.write_u32::<BigEndian>(user.len() as u32)?;
signed_data.write_all(user.as_bytes())?;
// string service name
signed_data.write_u32::<BigEndian>(service.len() as u32)?;
signed_data.write_all(service.as_bytes())?;
// string "publickey"
const PUBKEY_STR: &str = "publickey";
signed_data.write_u32::<BigEndian>(PUBKEY_STR.len() as u32)?;
signed_data.write_all(PUBKEY_STR.as_bytes())?;
// boolean TRUE
signed_data.write_u8(1)?;
// string public key algorithm name
signed_data.write_u32::<BigEndian>(algorithm.len() as u32)?;
signed_data.write_all(algorithm.as_bytes())?;
// string public key blob
signed_data.write_u32::<BigEndian>(public_key_blob.len() as u32)?;
signed_data.write_all(public_key_blob)?;
// 验证签名
verifying_key.verify_strict(&signed_data, &signature)
.map_err(|e| anyhow!("Ed25519 signature verification failed: {}", e))
}
}
/// SSH认证结果参考OpenSSH auth2.c
/// SSH认证结果
pub enum AuthResult {
Success,
Failure(String), // 失败原因
PartialSuccess, // 部分成功(多步骤认证)
PublicKeyOk(String, Vec<u8>), // Public key acceptable (algorithm, blob)
Failure(String),
PartialSuccess,
PublicKeyOk(String, Vec<u8>),
}
/// 解析Ed25519公钥blobSSH格式 -> VerifyingKey
fn parse_ed25519_verifying_key(public_key_blob: &[u8]) -> Result<VerifyingKey> {
let mut cursor = std::io::Cursor::new(public_key_blob);
let algorithm = read_ssh_string(&mut cursor)?;
if algorithm != "ssh-ed25519" {
return Err(anyhow!("Unsupported algorithm: {}", algorithm));
}
let key_bytes = read_ssh_string_bytes(&mut cursor)?;
if key_bytes.len() != 32 {
return Err(anyhow!("Invalid Ed25519 key length: {}", key_bytes.len()));
}
let key_array: [u8; 32] = key_bytes.try_into()
.map_err(|_| anyhow!("Invalid Ed25519 key data"))?;
VerifyingKey::from_bytes(&key_array)
.map_err(|e| anyhow!("Invalid Ed25519 key: {}", e))
}
/// 解析Ed25519签名blobSSH格式 -> Signature
fn parse_ed25519_signature(signature_blob: &[u8]) -> Result<Signature> {
let mut cursor = std::io::Cursor::new(signature_blob);
let algorithm = read_ssh_string(&mut cursor)?;
if algorithm != "ssh-ed25519" {
return Err(anyhow!("Unsupported signature algorithm: {}", algorithm));
}
let sig_bytes = read_ssh_string_bytes(&mut cursor)?;
if sig_bytes.len() != 64 {
return Err(anyhow!("Invalid Ed25519 signature length: {}", sig_bytes.len()));
}
let sig_array: [u8; 64] = sig_bytes.try_into()
.map_err(|_| anyhow!("Invalid Ed25519 signature data"))?;
Ok(Signature::from_bytes(&sig_array))
}
/// 检查一行authorized_keys格式的密钥是否匹配
fn public_key_matches_line(line: &str, algorithm: &str, public_key_blob: &[u8]) -> bool {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
return false;
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() < 2 {
return false;
}
if parts[0] != algorithm {
return false;
}
base64_decode(parts[1]).map(|decoded| decoded == public_key_blob).unwrap_or(false)
}
/// SSH string读取辅助函数
fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
let length = reader.read_u32::<BigEndian>()?;
let mut buffer = vec![0u8; length as usize];
@@ -276,7 +315,6 @@ fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
Ok(String::from_utf8(buffer)?)
}
/// SSH string读取辅助函数bytes版本
fn read_ssh_string_bytes<R: std::io::Read>(reader: &mut R) -> Result<Vec<u8>> {
let length = reader.read_u32::<BigEndian>()?;
let mut buffer = vec![0u8; length as usize];
@@ -284,9 +322,7 @@ fn read_ssh_string_bytes<R: std::io::Read>(reader: &mut R) -> Result<Vec<u8>> {
Ok(buffer)
}
/// Base64解码辅助函数Phase 9
fn base64_decode(input: &str) -> Result<Vec<u8>> {
use base64::{Engine as _, engine::general_purpose};
general_purpose::STANDARD.decode(input)
.map_err(|e| anyhow!("Base64 decode error: {}", e))
}
@@ -294,18 +330,19 @@ fn base64_decode(input: &str) -> Result<Vec<u8>> {
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::sqlite::SqliteProvider;
#[test]
fn test_userauth_success_packet() {
let packet = AuthHandler::build_userauth_success().unwrap();
assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_SUCCESS as u8);
}
#[test]
fn test_userauth_failure_packet() {
let methods = vec!["password".to_string(), "publickey".to_string()];
let packet = AuthHandler::build_userauth_failure(&methods, false).unwrap();
assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_FAILURE as u8);
}
}

View File

@@ -25,6 +25,8 @@ pub struct ChannelManager {
next_channel_id: u32,
/// ⭐⭐⭐⭐⭐ Phase 15.1: 待发送packet队列用于同时发送WINDOW_ADJUST和SFTP响应
pub pending_packets: VecDeque<SshPacket>,
/// 用户home目录SFTP/SCP/rsync根目录SFTPGo兼容
pub home_dir: PathBuf,
}
/// Phase 14: 交互式Exec进程管理参考OpenSSH session.c: do_exec_no_pty
@@ -40,11 +42,12 @@ pub struct ExecProcess {
}
impl ChannelManager {
pub fn new() -> Self {
pub fn new(home_dir: PathBuf) -> Self {
Self {
channels: HashMap::new(),
next_channel_id: 0,
pending_packets: VecDeque::new(),
home_dir,
}
}
@@ -371,9 +374,12 @@ impl ChannelManager {
info!("⭐⭐⭐⭐⭐ [{}_EXEC_START] Starting interactive process: {}", process_type, command);
// 启动子进程相当于OpenSSH fork
// ⭐⭐⭐⭐⭐ Phase 17: 设置工作目录为用户home_dirSFTPGo兼容
let home_dir = self.home_dir.clone();
let mut child = Command::new("sh")
.arg("-c")
.arg(command)
.current_dir(&home_dir)
.stdin(Stdio::piped()) // ← 创建stdin管道相当于pipe(pin)
.stdout(Stdio::piped()) // ← 创建stdout管道相当于pipe(pout)
.stderr(Stdio::piped()) // ← 创建stderr管道相当于pipe(perr)
@@ -446,8 +452,8 @@ impl ChannelManager {
if subsystem == "sftp" {
info!("SFTP subsystem requested");
// Phase 7: 初始化SFTP handler
let root_dir = PathBuf::from("/Users/accusys/markbase"); // 默认root目录
// Phase 7: 初始化SFTP handler使用用户home目录SFTPGo兼容
let root_dir = self.home_dir.clone();
// ⭐⭐⭐⭐⭐ Phase 4: 获取 client maxpack 限制(从 Channel 中获取)
let maxpacket = if let Some(ch) = self.channels.get(&channel) {
@@ -456,7 +462,8 @@ impl ChannelManager {
32768 // OpenSSH 默认值32KB
};
let sftp_handler = SftpHandler::new(root_dir, maxpacket); // ⭐⭐⭐⭐⭐ Phase 4: 传入 maxpack
let vfs = Box::new(crate::vfs::local_fs::LocalFs::new());
let sftp_handler = SftpHandler::new(root_dir, vfs, maxpacket); // ⭐⭐⭐⭐⭐ Phase 4: 传入 maxpack
// 存储到channel
if let Some(ch) = self.channels.get_mut(&channel) {
@@ -952,6 +959,22 @@ impl ChannelManager {
false
}
/// Phase 17: 关闭所有子进程stdin收到CHANNEL_EOF时调用
/// SCP upload需要scp -t 等待EOF on stdin才知道数据传输完毕
pub fn close_child_stdin(&mut self) {
let channel_ids: Vec<u32> = self.channels.keys().copied().collect();
for id in channel_ids {
if let Some(channel) = self.channels.get_mut(&id) {
if let Some(exec) = &mut channel.exec_process {
if let Some(stdin) = exec.stdin.take() {
drop(stdin);
info!("⭐⭐⭐⭐⭐ [CHANNEL_EOF] Closed child stdin (channel {})", id);
}
}
}
}
}
/// 获取channel输出Phase 6新增
pub fn get_channel_output(&mut self, channel_id: u32) -> Option<Vec<u8>> {
if let Some(channel) = self.channels.get_mut(&channel_id) {
@@ -1283,6 +1306,7 @@ impl ChannelManager {
// 4. 检查stdout/stderr fd是否有数据
let mut packets_data: Vec<(u32, Vec<u8>)> = Vec::new();
let mut stderr_packets: Vec<(u32, Vec<u8>)> = Vec::new(); // Phase 17: stderr → CHANNEL_EXTENDED_DATA
for (channel_id, (stdout_idx, stderr_idx)) in channel_fds_map {
if let Some(channel) = self.channels.get_mut(&channel_id) {
@@ -1325,7 +1349,8 @@ impl ChannelManager {
Ok(n) if n > 0 => {
info!("⭐⭐⭐⭐⭐ [AFTER stderr.read] Read {} bytes from stderr (channel {})", n, channel_id);
info!("⭐⭐⭐⭐⭐ stderr content: {:?}", &buffer[..std::cmp::min(50, n)]);
packets_data.push((channel_id, buffer[..n].to_vec()));
// ⭐⭐⭐⭐⭐ Phase 17: stderr → SSH_MSG_CHANNEL_EXTENDED_DATA (data_type=1)
stderr_packets.push((channel_id, buffer[..n].to_vec()));
}
Ok(0) => {
info!("stderr EOF (channel {}), closing stderr pipe", channel_id);
@@ -1351,12 +1376,17 @@ impl ChannelManager {
}
// 构建packets
if !packets_data.is_empty() {
if !packets_data.is_empty() || !stderr_packets.is_empty() {
let mut packets = Vec::new();
for (channel_id, data) in packets_data {
let packet = self.build_channel_data(channel_id, &data)?;
packets.push(packet);
}
// Phase 17: stderr → SSH_MSG_CHANNEL_EXTENDED_DATA (data_type=1)
for (channel_id, data) in stderr_packets {
let packet = self.build_channel_extended_data(channel_id, 1, &data)?;
packets.push(packet);
}
info!("⭐⭐⭐⭐⭐ Returning {} packets (stdout/stderr data)", packets.len());
return Ok((Some(packets), client_has_data, child_exited));
}
@@ -1689,13 +1719,13 @@ mod tests {
#[test]
fn test_channel_manager_creation() {
let manager = ChannelManager::new();
let manager = ChannelManager::new(PathBuf::from("/tmp"));
assert_eq!(manager.next_channel_id, 0);
}
#[test]
fn test_channel_open_confirmation() {
let manager = ChannelManager::new();
let manager = ChannelManager::new(PathBuf::from("/tmp"));
let packet = manager.build_channel_open_confirmation(0, 100, 2097152, 32768).unwrap();
assert_eq!(packet.payload[0], PacketType::SSH_MSG_CHANNEL_OPEN_CONFIRMATION as u8);
@@ -1703,7 +1733,7 @@ mod tests {
#[test]
fn test_channel_success() {
let manager = ChannelManager::new();
let manager = ChannelManager::new(PathBuf::from("/tmp"));
let packet = manager.build_channel_success(0).unwrap();
assert_eq!(packet.payload[0], PacketType::SSH_MSG_CHANNEL_SUCCESS as u8);

View File

@@ -17,6 +17,7 @@ type HmacSha256 = Hmac<Sha256>;
/// SSH加密通道管理器参考OpenSSH struct sshcipher_ctx
pub struct EncryptionContext {
pub session_id: Vec<u8>, // session identifier (exchange hash)
pub encryption_key_ctos: Vec<u8>, // 客户端→服务器加密密钥
pub encryption_key_stoc: Vec<u8>, // 服务器→客户端加密密钥
pub mac_key_ctos: Vec<u8>, // 客户端→服务器MAC密钥
@@ -32,6 +33,7 @@ pub struct EncryptionContext {
impl Default for EncryptionContext {
fn default() -> Self {
Self {
session_id: vec![0u8; 32],
encryption_key_ctos: vec![0u8; 32],
encryption_key_stoc: vec![0u8; 32],
mac_key_ctos: vec![0u8; 32],
@@ -73,6 +75,7 @@ impl EncryptionContext {
info!("Ciphers initialized successfully");
Self {
session_id: keys.session_id.clone(),
encryption_key_ctos: keys.encryption_key_ctos.clone(),
encryption_key_stoc: keys.encryption_key_stoc.clone(),
mac_key_ctos: keys.mac_key_ctos.clone(),

View File

@@ -1,8 +1,8 @@
use std::path::PathBuf;
use std::fs::{self, File};
use std::io::Write;
use anyhow::{Result, anyhow};
use log::{info, debug, warn};
use crate::vfs::{VfsBackend, VfsFile, VfsError};
use crate::vfs::open_flags::OpenFlags;
/// MPLEX_BASE from rsync io.h
const MPLEX_BASE: u32 = 7;
@@ -27,23 +27,21 @@ pub(crate) enum RsyncState {
pub struct RsyncHandler {
state: RsyncState,
/// Raw input from SSH (multiplexed after version exchange)
raw_input: Vec<u8>,
/// Decoded rsync protocol data (after stripping multiplex)
rsync_input: Vec<u8>,
/// Raw rsync data to send (multiplex wrapping applied in drain_output)
output_raw: Vec<u8>,
dest_path: PathBuf,
output_file: Option<File>,
output_file: Option<Box<dyn VfsFile>>,
total_written: u64,
file_entries: Vec<String>,
current_file: usize,
protocol_version: u32,
multiplex: bool,
vfs: Box<dyn VfsBackend>,
}
impl RsyncHandler {
pub fn parse_rsync_command(command: &str) -> Result<Self> {
pub fn parse_rsync_command(command: &str, vfs: Box<dyn VfsBackend>) -> Result<Self> {
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.len() < 3 || parts[0] != "rsync" {
return Err(anyhow!("Invalid rsync command: {}", command));
@@ -83,9 +81,9 @@ impl RsyncHandler {
current_file: 0,
protocol_version: 30,
multiplex: false,
vfs,
};
// Send protocol version (4-byte LE int, no multiplex)
handler.output_raw.extend_from_slice(&30u32.to_le_bytes());
handler.state = RsyncState::WaitVersion;
@@ -129,7 +127,6 @@ impl RsyncHandler {
}
MSG_DONE => {
info!("rsync: MSG_DONE received (file complete)");
// Signal file completion by appending a sentinel to rsync_input
self.rsync_input.extend_from_slice(b"RSYNCDONE");
}
9 => {
@@ -147,7 +144,6 @@ impl RsyncHandler {
if data.is_empty() || !self.multiplex {
return data;
}
// Wrap with multiplex header (MSG_DATA)
let header = (MPLEX_BASE << 24) | (data.len() as u32);
let mut wrapped = Vec::with_capacity(4 + data.len());
wrapped.extend_from_slice(&header.to_le_bytes());
@@ -180,7 +176,6 @@ impl RsyncHandler {
loop {
match self.state.clone() {
RsyncState::SendVersion => {
// Version already sent in constructor
self.transition(RsyncState::WaitVersion);
}
@@ -206,7 +201,6 @@ impl RsyncHandler {
let flags = self.rsync_input[0];
if flags == 0 {
// End of file list
self.rsync_input.drain(..1);
info!("rsync: file list end ({} entries)", self.file_entries.len());
@@ -214,14 +208,12 @@ impl RsyncHandler {
self.file_entries.push("file".to_string());
}
self.current_file = 0;
// Enter sum head reading state
self.transition(RsyncState::ReadSumHead { need: 20 });
break;
}
let mut pos = 1;
// Extended flags
let _more_flags = if flags & 0x80 != 0 {
if self.rsync_input.len() <= pos { break; }
let ef = self.rsync_input[pos];
@@ -249,7 +241,6 @@ impl RsyncHandler {
self.file_entries.push(name);
}
// Skip metadata varints
let skip_count = if flags & 0x10 == 0 { 1 } else { 0 }
+ if flags & 0x20 == 0 { 1 } else { 0 }
+ if flags & 0x40 == 0 { 1 } else { 0 }
@@ -277,9 +268,6 @@ impl RsyncHandler {
RsyncState::ReadSumHead { need } => {
if self.rsync_input.len() >= need {
// Read sum head: count, blength, s2length, remainder (4 × LE int)
// + checksum seed (1 × LE int)
// = 5 × 4 = 20 bytes
let sum_count = i32::from_le_bytes([
self.rsync_input[0], self.rsync_input[1],
self.rsync_input[2], self.rsync_input[3],
@@ -312,7 +300,6 @@ impl RsyncHandler {
RsyncState::SendSumCount => {
self.open_current_file()?;
// Send sum_count = 0 (4-byte LE int = we have no existing data)
self.output_raw.extend_from_slice(&0u32.to_le_bytes());
info!("rsync: sent sum_count=0, ready to receive file data");
@@ -320,22 +307,17 @@ impl RsyncHandler {
}
RsyncState::ReadFileData => {
// Data comes as raw bytes inside MSG_DATA multiplex packets.
// MSG_DONE appends b"RSYNCDONE" to rsync_input.
let done_marker = b"RSYNCDONE";
if let Some(pos) = self.rsync_input.windows(done_marker.len())
.position(|w| w == done_marker)
{
// Data before the marker
if pos > 0 {
let data = self.rsync_input[..pos].to_vec();
self.rsync_input.drain(..pos);
self.write_to_file(&data)?;
}
// Remove marker
self.rsync_input.drain(..done_marker.len());
// Close file
if let Some(mut file) = self.output_file.take() {
if let Err(e) = file.flush() {
warn!("rsync flush error: {}", e);
@@ -353,11 +335,9 @@ impl RsyncHandler {
info!("rsync ALL DONE: {} bytes written to {}",
self.total_written, self.dest_path.display());
} else {
// Next file sum head
self.transition(RsyncState::ReadSumHead { need: 20 });
}
} else if !self.rsync_input.is_empty() {
// Partial data, keep it in buffer for more
let data = self.rsync_input.clone();
self.rsync_input.clear();
self.write_to_file(&data)?;
@@ -377,9 +357,11 @@ impl RsyncHandler {
fn open_current_file(&mut self) -> Result<()> {
if let Some(parent) = self.dest_path.parent() {
fs::create_dir_all(parent).ok();
self.vfs.create_dir_all(parent, 0o755).ok();
}
let file = File::create(&self.dest_path)?;
let flags = OpenFlags::new().write().create().truncate();
let file = self.vfs.open_file(&self.dest_path, &flags)
.map_err(|e| anyhow!("open error: {}", e))?;
self.output_file = Some(file);
info!("rsync: opened {} for writing", self.dest_path.display());
Ok(())
@@ -387,7 +369,8 @@ impl RsyncHandler {
fn write_to_file(&mut self, data: &[u8]) -> Result<()> {
if let Some(file) = &mut self.output_file {
file.write_all(data)?;
file.write_all(data)
.map_err(|e| anyhow!("write error: {}", e))?;
self.total_written += data.len() as u64;
}
Ok(())
@@ -426,28 +409,37 @@ fn read_varint(buf: &[u8]) -> Option<(i32, usize)> {
#[cfg(test)]
mod tests {
use super::*;
use crate::vfs::local_fs::LocalFs;
fn make_vfs() -> Box<dyn VfsBackend> {
Box::new(LocalFs::new())
}
#[test]
fn test_parse_command() {
let h = RsyncHandler::parse_rsync_command("rsync --server -g -l -o -p -D -r -t -v --dirs . /tmp/upload.bin").unwrap();
let h = RsyncHandler::parse_rsync_command(
"rsync --server -g -l -o -p -D -r -t -v --dirs . /tmp/upload.bin",
make_vfs()
).unwrap();
assert_eq!(h.dest_path, PathBuf::from("/tmp/upload.bin"));
}
#[test]
fn test_parse_command_sender() {
let h = RsyncHandler::parse_rsync_command("rsync --server --sender -vlogDtprz . /home/user/file.txt").unwrap();
let h = RsyncHandler::parse_rsync_command(
"rsync --server --sender -vlogDtprz . /home/user/file.txt",
make_vfs()
).unwrap();
assert_eq!(h.dest_path, PathBuf::from("/home/user/file.txt"));
}
#[test]
fn test_version_exchange() {
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin").unwrap();
// Initial output: protocol version (30 as LE int)
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap();
let output = h.drain_output();
assert_eq!(output, b"\x1e\x00\x00\x00");
assert_eq!(h.state, RsyncState::WaitVersion);
// Client sends its version (30 = 0x1E)
h.feed(b"\x1e\x00\x00\x00").unwrap();
assert_eq!(h.state, RsyncState::ReadFileList);
assert!(h.multiplex);
@@ -455,9 +447,8 @@ mod tests {
#[test]
fn test_version_negotiate_down() {
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin").unwrap();
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap();
let _ = h.drain_output();
// Client has lower version (29)
h.feed(b"\x1d\x00\x00\x00").unwrap();
assert_eq!(h.protocol_version, 29);
assert_eq!(h.state, RsyncState::ReadFileList);
@@ -471,24 +462,14 @@ mod tests {
buf
}
fn build_multiplex_done() -> Vec<u8> {
let header = (MPLEX_BASE << 24) | 0u32; // MSG_DONE (tag=1 → raw_tag=8)
let mut buf = Vec::new();
buf.extend_from_slice(&header.to_le_bytes());
buf
}
#[test]
fn test_file_list_multiplex() {
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin").unwrap();
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap();
let _ = h.drain_output();
// Version exchange
h.feed(b"\x1e\x00\x00\x00").unwrap();
assert!(h.multiplex);
// Build file list with multiplex wrapping
let mut flist = Vec::new();
// Entry: flags=0, name="test.txt\0", + 6 varints
flist.push(0);
flist.extend_from_slice(b"test.txt");
flist.push(0);
@@ -514,46 +495,40 @@ mod tests {
}
}
}
write_varint(&mut flist, 33188); // mode
write_varint(&mut flist, 501); // uid
write_varint(&mut flist, 20); // gid
write_varint(&mut flist, 1700000000); // time
write_varint(&mut flist, 100); // size
write_varint(&mut flist, 0); // checksum seed
// End marker
write_varint(&mut flist, 33188);
write_varint(&mut flist, 501);
write_varint(&mut flist, 20);
write_varint(&mut flist, 1700000000);
write_varint(&mut flist, 100);
write_varint(&mut flist, 0);
flist.push(0);
// Sum head (5 ints = 20 bytes) as separate multiplex packet
let mut sum_head = Vec::new();
sum_head.extend_from_slice(&0i32.to_le_bytes()); // count
sum_head.extend_from_slice(&7000i32.to_le_bytes()); // blength
sum_head.extend_from_slice(&2i32.to_le_bytes()); // s2length
sum_head.extend_from_slice(&100i32.to_le_bytes()); // remainder
sum_head.extend_from_slice(&42i32.to_le_bytes()); // checksum_seed
sum_head.extend_from_slice(&0i32.to_le_bytes());
sum_head.extend_from_slice(&7000i32.to_le_bytes());
sum_head.extend_from_slice(&2i32.to_le_bytes());
sum_head.extend_from_slice(&100i32.to_le_bytes());
sum_head.extend_from_slice(&42i32.to_le_bytes());
// Feed file list
h.feed(&build_multiplex(&flist)).unwrap();
assert_eq!(h.state, RsyncState::ReadFileList); // Still reading, 0x00 end marker triggered transition
assert_eq!(h.state, RsyncState::ReadFileList);
assert_eq!(h.file_entries.len(), 1);
// Now feed sum head
h.feed(&build_multiplex(&sum_head)).unwrap();
assert_eq!(h.state, RsyncState::SendSumCount);
// Send sum count response
let sum_resp = h.drain_output();
assert_eq!(sum_resp.len(), 8); // 4-byte header + 4-byte int
assert_eq!(sum_resp.len(), 8);
assert_eq!(&sum_resp[4..8], &0u32.to_le_bytes());
assert_eq!(h.state, RsyncState::ReadFileData);
}
#[test]
fn test_file_data_multiplex() {
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin").unwrap();
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap();
let _ = h.drain_output();
h.feed(b"\x1e\x00\x00\x00").unwrap(); // version
h.feed(b"\x1e\x00\x00\x00").unwrap();
// Simple file list
let mut flist = Vec::new();
flist.push(0);
flist.extend_from_slice(b"test.bin");
@@ -568,7 +543,6 @@ mod tests {
flist.push(0);
h.feed(&build_multiplex(&flist)).unwrap();
// Sum head
let mut sh = Vec::new();
sh.extend_from_slice(&0i32.to_le_bytes());
sh.extend_from_slice(&7000i32.to_le_bytes());
@@ -576,16 +550,13 @@ mod tests {
sh.extend_from_slice(&100i32.to_le_bytes());
sh.extend_from_slice(&42i32.to_le_bytes());
h.feed(&build_multiplex(&sh)).unwrap();
let _ = h.drain_output(); // sum count response
let _ = h.drain_output();
// File data + MSG_DONE
let file_data = b"Hello, rsync protocol!";
h.feed(&build_multiplex(file_data)).unwrap();
assert_eq!(h.state, RsyncState::ReadFileData);
// MSG_DONE
// MSG_DONE has tag=1, so raw_tag = MPLEX_BASE + 1 = 8
let done_header = (MPLEX_BASE + 1) << 24; // raw_tag = 8, len = 0
let done_header = (MPLEX_BASE + 1) << 24;
let done_bytes = done_header.to_le_bytes();
h.feed(&done_bytes).unwrap();

View File

@@ -1,12 +1,13 @@
// SCP协议实现Phase 8
// 参考OpenSSH scp.c源码
use crate::vfs::{VfsBackend, VfsFile, VfsError, VfsStat};
use crate::vfs::open_flags::OpenFlags;
use anyhow::{Result, anyhow};
use log::{info, warn, debug};
use std::path::{Path, PathBuf};
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Write, BufReader, BufWriter, BufRead}; // 导入BufRead traitOpenSSH标准
use chrono::{DateTime, Utc};
use std::io::{Read, Write, BufRead};
use std::time::SystemTime;
/// SCP Handler参考OpenSSH scp.c
pub struct ScpHandler {
@@ -14,6 +15,7 @@ pub struct ScpHandler {
mode: ScpMode,
recursive: bool,
preserve_times: bool,
vfs: Box<dyn VfsBackend>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -23,24 +25,25 @@ pub enum ScpMode {
}
impl ScpHandler {
pub fn new(root_dir: PathBuf) -> Self {
pub fn new(root_dir: PathBuf, vfs: Box<dyn VfsBackend>) -> Self {
Self {
root_dir,
mode: ScpMode::Destination,
recursive: false,
preserve_times: false,
vfs,
}
}
/// 解析SCP命令参考OpenSSH scp.c: parse_command()
pub fn parse_scp_command(command: &str) -> Result<Self> {
pub fn parse_scp_command(command: &str, vfs: Box<dyn VfsBackend>) -> Result<Self> {
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.len() < 2 || parts[0] != "scp" {
return Err(anyhow!("Invalid SCP command: {}", command));
}
let mut handler = ScpHandler::new(PathBuf::from("/tmp"));
let mut handler = ScpHandler::new(PathBuf::from("/tmp"), vfs);
for part in &parts[1..] {
match part {
@@ -68,19 +71,19 @@ impl ScpHandler {
/// SCP Source Modescp -f发送文件
fn handle_source_mode(&self, channel: &mut dyn ReadWrite) -> Result<()> {
info!("SCP source mode: sending files from {}", self.root_dir.display()); // 使用display()Rust标准
info!("SCP source mode: sending files from {}", self.root_dir.display());
let full_path = self.resolve_path(&self.root_dir.to_string_lossy())?;
let stat = self.vfs.stat(&full_path)
.map_err(|e| anyhow!("stat error: {}", e))?;
if full_path.is_file() {
self.send_file(channel, &full_path)?;
} else if full_path.is_dir() {
if stat.is_dir {
if !self.recursive {
return Err(anyhow!("Directory detected but -r flag not specified"));
}
self.send_directory(channel, &full_path)?;
} else {
return Err(anyhow!("Path does not exist: {}", full_path.display()));
self.send_file(channel, &full_path)?;
}
Ok(())
@@ -88,9 +91,8 @@ impl ScpHandler {
/// SCP Destination Modescp -t接收文件
fn handle_destination_mode(&mut self, channel: &mut dyn ReadWrite) -> Result<()> {
info!("SCP destination mode: receiving files to {}", self.root_dir.display()); // 使用display()Rust标准
info!("SCP destination mode: receiving files to {}", self.root_dir.display());
// 发送确认('\0'
channel.write_all(&[0])?;
channel.flush()?;
@@ -99,10 +101,9 @@ impl ScpHandler {
loop {
buffer.clear();
// 每次循环创建新的reader避免borrow冲突- OpenSSH标准
let mut reader = BufReader::new(&mut *channel);
let mut reader = std::io::BufReader::new(&mut *channel);
match reader.read_line(&mut buffer)? {
0 => break, // EOF
0 => break,
_ => {
let command = buffer.trim();
debug!("SCP command: {}", command);
@@ -113,7 +114,6 @@ impl ScpHandler {
Some('E') => self.handle_end_directory(channel)?,
Some('T') => self.handle_time_command(channel, command)?,
Some('\0') => {
// 确认信号,继续
continue;
}
_ => {
@@ -130,28 +130,30 @@ impl ScpHandler {
/// 发送文件参考OpenSSH scp.c: source()
fn send_file(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
let metadata = fs::metadata(path)?;
let size = metadata.len();
let stat = self.vfs.stat(path)
.map_err(|e| anyhow!("stat error: {}", e))?;
let size = stat.size;
let filename = path.file_name().unwrap().to_string_lossy();
// 发送文件命令C0644 size filename
let command = format!("C0644 {} {}\n", size, filename);
channel.write_all(command.as_bytes())?;
channel.flush()?;
// 等待确认('\0'
let mut ack = [0u8; 1];
channel.read_exact(&mut ack)?;
if ack[0] != 0 {
return Err(anyhow!("SCP file command rejected"));
}
// 发送文件内容
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let flags = OpenFlags::new().read();
let mut file = self.vfs.open_file(path, &flags)
.map_err(|e| anyhow!("open error: {}", e))?;
let mut buffer = vec![0u8; 8192];
while let Ok(n) = reader.read(&mut buffer) {
loop {
let n = file.read(&mut buffer)
.map_err(|e| anyhow!("read error: {}", e))?;
if n == 0 {
break;
}
@@ -160,11 +162,9 @@ impl ScpHandler {
channel.flush()?;
// 发送结束确认('\0'
channel.write_all(&[0])?;
channel.flush()?;
// 等待确认('\0'
channel.read_exact(&mut ack)?;
if ack[0] != 0 {
return Err(anyhow!("SCP file transfer rejected"));
@@ -178,35 +178,34 @@ impl ScpHandler {
fn send_directory(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
let dirname = path.file_name().unwrap().to_string_lossy();
// 发送目录命令D0755 0 dirname
let command = format!("D0755 0 {}\n", dirname);
channel.write_all(command.as_bytes())?;
channel.flush()?;
// 等待确认('\0'
let mut ack = [0u8; 1];
channel.read_exact(&mut ack)?;
if ack[0] != 0 {
return Err(anyhow!("SCP directory command rejected"));
}
// 递归发送目录内容
for entry in fs::read_dir(path)? {
let entry = entry?;
let full_path = entry.path();
let entries = self.vfs.read_dir(path)
.map_err(|e| anyhow!("read_dir error: {}", e))?;
if full_path.is_file() {
self.send_file(channel, &full_path)?;
} else if full_path.is_dir() && self.recursive {
self.send_directory(channel, &full_path)?;
for entry in &entries {
let entry_path = path.join(&entry.name);
if entry.stat.is_dir {
if self.recursive {
self.send_directory(channel, &entry_path)?;
}
} else {
self.send_file(channel, &entry_path)?;
}
}
// 发送结束目录命令E
channel.write_all("E\n".as_bytes())?;
channel.flush()?;
// 等待确认('\0'
channel.read_exact(&mut ack)?;
if ack[0] != 0 {
return Err(anyhow!("SCP end directory rejected"));
@@ -224,31 +223,25 @@ impl ScpHandler {
return self.send_error(channel, "Invalid file command format");
}
let mode = parts[0].trim_start_matches('C');
let mode_str = parts[0].trim_start_matches('C');
let size: u64 = parts[1].parse()?;
let filename = parts[2];
debug!("SCP receive file: mode={}, size={}, name={}", mode, size, filename);
debug!("SCP receive file: mode={}, size={}, name={}", mode_str, size, filename);
// 安全性检查文件大小限制防止DoS
if size > 1024 * 1024 * 1024 { // 1GB限制
if size > 1024 * 1024 * 1024 {
return self.send_error(channel, "File too large (max 1GB)");
}
// 创建文件
let full_path = self.resolve_path(filename)?;
let file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&full_path)?;
// 发送确认('\0'
let flags = OpenFlags::new().write().create().truncate();
let mut file = self.vfs.open_file(&full_path, &flags)
.map_err(|e| anyhow!("open error: {}", e))?;
channel.write_all(&[0])?;
channel.flush()?;
// 接收文件内容
let mut writer = BufWriter::new(file);
let mut buffer = vec![0u8; 8192];
let mut remaining = size;
@@ -258,25 +251,25 @@ impl ScpHandler {
if n == 0 {
break;
}
writer.write_all(&buffer[..n])?;
file.write_all(&buffer[..n])
.map_err(|e| anyhow!("write error: {}", e))?;
remaining -= n as u64;
}
writer.flush()?;
file.flush().map_err(|e| anyhow!("flush error: {}", e))?;
// 设置文件权限
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mode_int: u32 = mode.parse()?;
fs::set_permissions(&full_path, fs::Permissions::from_mode(mode_int))?;
let mode_int: u32 = mode_str.parse()?;
if mode_int != 0 {
let mut set_stat = VfsStat::new();
set_stat.mode = mode_int;
self.vfs.set_stat(&full_path, &set_stat)
.map_err(|e| anyhow!("set_stat error: {}", e))?;
}
// 接收结束确认('\0'
let mut ack = [0u8; 1];
channel.read_exact(&mut ack)?;
// 发送确认('\0'
channel.write_all(&[0])?;
channel.flush()?;
@@ -296,24 +289,17 @@ impl ScpHandler {
return self.send_error(channel, "Recursive flag not specified");
}
let mode = parts[0].trim_start_matches('D');
let mode_str = parts[0].trim_start_matches('D');
let dirname = parts[2];
debug!("SCP receive directory: mode={}, name={}", mode, dirname);
debug!("SCP receive directory: mode={}, name={}", mode_str, dirname);
// 创建目录
let full_path = self.resolve_path(dirname)?;
fs::create_dir_all(&full_path)?;
// 设置目录权限
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mode_int: u32 = mode.parse()?;
fs::set_permissions(&full_path, fs::Permissions::from_mode(mode_int))?;
}
let mode_int: u32 = mode_str.parse()?;
self.vfs.create_dir_all(&full_path, mode_int)
.map_err(|e| anyhow!("create_dir_all error: {}", e))?;
// 发送确认('\0'
channel.write_all(&[0])?;
channel.flush()?;
@@ -325,7 +311,6 @@ impl ScpHandler {
fn handle_end_directory(&self, channel: &mut dyn ReadWrite) -> Result<()> {
debug!("SCP end directory");
// 发送确认('\0'
channel.write_all(&[0])?;
channel.flush()?;
@@ -335,7 +320,6 @@ impl ScpHandler {
/// 处理时间命令T mtime atime
fn handle_time_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
if !self.preserve_times {
// 发送确认('\0'),但不设置时间
channel.write_all(&[0])?;
channel.flush()?;
return Ok(());
@@ -347,18 +331,14 @@ impl ScpHandler {
return self.send_error(channel, "Invalid time command format");
}
let mtime: i64 = parts[1].parse()?;
let atime: i64 = parts[2].parse()?;
let mtime_secs: i64 = parts[1].parse()?;
let atime_secs: i64 = parts[2].parse()?;
debug!("SCP set times: mtime={}, atime={}", mtime, atime);
debug!("SCP set times: mtime={}, atime={}", mtime_secs, atime_secs);
// 发送确认('\0'
channel.write_all(&[0])?;
channel.flush()?;
// 时间设置将在文件接收完成后进行
// 这里仅记录实际设置在handle_file_command中
Ok(())
}
@@ -374,10 +354,13 @@ impl ScpHandler {
fn resolve_path(&self, path: &str) -> Result<PathBuf> {
let full_path = self.root_dir.join(path);
let canonical_path = full_path.canonicalize()
let canonical_path = self.vfs.real_path(&full_path)
.map_err(|e| anyhow!("Path resolution error: {}", e))?;
if !canonical_path.starts_with(&self.root_dir.canonicalize()?) {
let root_canonical = self.vfs.real_path(&self.root_dir)
.map_err(|e| anyhow!("Root path resolution error: {}", e))?;
if !canonical_path.starts_with(&root_canonical) {
return Err(anyhow!("Path traversal attempt detected"));
}
@@ -392,23 +375,28 @@ impl<T: Read + Write> ReadWrite for T {}
#[cfg(test)]
mod tests {
use super::*;
use crate::vfs::local_fs::LocalFs;
fn make_handler() -> ScpHandler {
ScpHandler::new(PathBuf::from("/tmp"), Box::new(LocalFs::new()))
}
#[test]
fn test_scp_command_parse() {
let handler = ScpHandler::parse_scp_command("scp -t /tmp").unwrap();
let handler = ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap();
assert_eq!(handler.mode, ScpMode::Destination);
assert_eq!(handler.root_dir, PathBuf::from("/tmp"));
}
#[test]
fn test_scp_recursive_parse() {
let handler = ScpHandler::parse_scp_command("scp -r -t /tmp").unwrap();
let handler = ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap();
assert!(handler.recursive);
}
#[test]
fn test_scp_source_parse() {
let handler = ScpHandler::parse_scp_command("scp -f /tmp").unwrap();
let handler = ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap();
assert_eq!(handler.mode, ScpMode::Source);
}
}
}

View File

@@ -6,6 +6,9 @@ use crate::ssh_server::packet::{SshPacket, PacketType};
use crate::ssh_server::kex::{KexResult, KexProposal};
use crate::ssh_server::kex_complete::{KexState};
use crate::ssh_server::auth::{AuthHandler, AuthResult};
use crate::provider::sqlite::SqliteProvider;
use crate::provider::pg::PgProvider;
use crate::provider::DataProvider;
use crate::ssh_server::channel::{ChannelManager};
use crate::ssh_server::cipher::{EncryptionContext, EncryptedPacket};
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1
@@ -13,6 +16,7 @@ use crate::ssh_server::port_forward::PortForwardManager; // Phase 13
use anyhow::{Result, anyhow};
use log::{info, warn, error, debug};
use std::net::{TcpListener, TcpStream};
use std::path::PathBuf;
use std::thread;
use std::io::{Read, Write};
use std::sync::{Arc, Mutex}; // Phase 13: 端口转发线程同步
@@ -22,6 +26,7 @@ pub struct SshServerConfig {
pub port: u16,
pub bind_address: String,
pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置
pub pg_conn: Option<String>, // PostgreSQL连接字符串SFTPGo兼容认证
}
impl Default for SshServerConfig {
@@ -30,6 +35,7 @@ impl Default for SshServerConfig {
port: 2024,
bind_address: "127.0.0.1".to_string(),
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1
pg_conn: None,
}
}
}
@@ -42,6 +48,7 @@ impl SshServerConfig {
port: 2024,
bind_address: "127.0.0.1".to_string(),
security_config: config,
pg_conn: None,
})
}
}
@@ -73,6 +80,7 @@ impl SshServer {
self.config.security_config.max_sessions);
let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置
let pg_conn = self.config.pg_conn.clone();
for stream in listener.incoming() {
match stream {
@@ -81,9 +89,10 @@ impl SshServer {
info!("New SSH connection from {}", client_addr);
let security_config_clone = security_config.clone(); // Phase 13.1
let pg_conn_clone = pg_conn.clone();
thread::spawn(move || {
if let Err(e) = handle_connection_complete(stream, security_config_clone) { // Phase 13.1
if let Err(e) = handle_connection_complete(stream, security_config_clone, pg_conn_clone) { // Phase 13.1
error!("Connection error: {}", e);
}
});
@@ -99,7 +108,7 @@ impl SshServer {
}
/// 处理完整SSH连接Phase 1-13完整流程
fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshSecurityConfig>>) -> Result<()> {
fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshSecurityConfig>>, pg_conn: Option<String>) -> Result<()> {
info!("Handling client connection (Phase 1-13 complete flow with port forwarding)");
// Phase 13.1: 增加活动会话数
@@ -122,13 +131,22 @@ fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshS
let mut encryption_ctx = perform_complete_kex_exchange(&mut stream, client_version.clone(), kex_result, server_kexinit, client_kexinit)?;
info!("Key exchange completed, encryption channel ready");
// Phase 5: SSH认证参考OpenSSH auth2.c
let mut auth_handler = AuthHandler::new()?;
// Phase 5: SSH认证SFTPGo兼容 — PostgreSQL或SQLite
let provider: Box<dyn DataProvider> = if let Some(ref conn_str) = pg_conn {
info!("Using PostgreSQL auth provider (SFTPGo-compatible): {}", conn_str);
Box::new(PgProvider::new(conn_str)
.map_err(|e| anyhow!("Failed to init PgProvider: {}", e))?)
} else {
info!("Using SQLite auth provider");
Box::new(SqliteProvider::new("data/auth.sqlite")
.map_err(|e| anyhow!("Failed to init SqliteProvider: {}", e))?)
};
let mut auth_handler = AuthHandler::new(provider);
let auth_user = perform_ssh_auth(&mut stream, &mut auth_handler, &mut encryption_ctx)?;
info!("SSH authentication succeeded: user={}", auth_user);
info!("SSH authentication succeeded: user={}", auth_user.username);
// Phase 6: SSH Channel管理参考OpenSSH channel.c
let mut channel_manager = ChannelManager::new();
let mut channel_manager = ChannelManager::new(auth_user.home_dir.clone());
// Phase 13: PortForwardManager初始化
let mut port_forward_manager = PortForwardManager::new();
@@ -226,11 +244,16 @@ fn perform_complete_kex_exchange(
}
/// SSH认证流程Phase 5
pub struct AuthUser {
pub username: String,
pub home_dir: PathBuf,
}
fn perform_ssh_auth(
stream: &mut TcpStream,
auth_handler: &mut AuthHandler,
encryption_ctx: &mut EncryptionContext,
) -> Result<String> {
) -> Result<AuthUser> {
info!("Starting SSH authentication");
info!("Encryption context: key_ctos_len={}, key_stoc_len={}, iv_ctos_len={}, iv_stoc_len={}",
encryption_ctx.encryption_key_ctos.len(),
@@ -279,6 +302,8 @@ fn perform_ssh_auth(
encrypted_accept.write(stream)?;
info!("Sent encrypted SSH_MSG_SERVICE_ACCEPT");
let session_id = encryption_ctx.session_id.clone();
loop {
let auth_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; // Reading from client, use cipher_ctos
let auth_payload = auth_packet.payload();
@@ -286,7 +311,7 @@ fn perform_ssh_auth(
let auth_request = SshPacket::new(auth_payload.to_vec());
match auth_handler.handle_userauth_request(&auth_request)? {
match auth_handler.handle_userauth_request(&auth_request, &session_id)? {
AuthResult::Success => {
let success_payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8];
let encrypted_success = EncryptedPacket::new(
@@ -297,7 +322,16 @@ fn perform_ssh_auth(
encrypted_success.write(stream)?;
info!("Sent encrypted SSH_MSG_USERAUTH_SUCCESS");
return Ok("demo".to_string());
// Extract username from auth request
let user = extract_username_from_auth_request(&auth_request)
.unwrap_or_else(|_| "unknown".to_string());
let home_dir = auth_handler.get_home_dir(&user)
.ok()
.flatten()
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from("/Users/accusys/markbase"));
info!("Auth success: user={}, home_dir={:?}", user, home_dir);
return Ok(AuthUser { username: user, home_dir });
}
AuthResult::Failure(message) => {
// message包含可用的认证方法列表如"password,publickey"
@@ -519,7 +553,9 @@ fn handle_ssh_service_loop(
}
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_EOF as u8 => {
info!("Received SSH_MSG_CHANNEL_EOF");
// EOF means client won't send more data, just acknowledge and continue
// Phase 17: EOF means client won't send more data → close child stdin
// (Essential for SCP upload where scp -t waits for EOF on stdin)
channel_manager.close_child_stdin();
}
Some(&pt) if pt == PacketType::SSH_MSG_DISCONNECT as u8 => {
info!("Received SSH_MSG_DISCONNECT");
@@ -543,12 +579,27 @@ fn handle_ssh_service_loop(
Ok(())
}
/// 从SSH_MSG_USERAUTH_REQUEST payload中提取用户名
fn extract_username_from_auth_request(packet: &crate::ssh_server::packet::SshPacket) -> Result<String> {
let payload = &packet.payload;
if payload.len() < 5 {
return Err(anyhow!("Auth request too short"));
}
let name_len = u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]) as usize;
if payload.len() < 5 + name_len {
return Err(anyhow!("Auth request truncated"));
}
let username = String::from_utf8_lossy(&payload[5..5 + name_len]).to_string();
Ok(username)
}
/// SSH服务器CLI入口
pub fn run_ssh_server(port: Option<u16>) -> Result<()> {
pub fn run_ssh_server(port: Option<u16>, pg_conn: Option<&str>) -> Result<()> {
let config = SshServerConfig {
port: port.unwrap_or(2024),
bind_address: "127.0.0.1".to_string(),
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置
pg_conn: pg_conn.map(|s| s.to_string()),
};
let server = SshServer::new(config);

View File

@@ -2,14 +2,16 @@
// 参考OpenSSH sftp-server.c和draft-ietf-secsh-filexfer-02.txt
use crate::ssh_server::packet::{SshPacket, PacketType};
use crate::vfs::{VfsBackend, VfsFile, VfsDirEntry};
use crate::vfs::open_flags::OpenFlags;
use anyhow::{Result, anyhow, Context};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use log::{info, warn, debug};
use std::path::{Path, PathBuf};
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Write, Seek, SeekFrom};
use std::os::unix::fs::PermissionsExt; // 导入PermissionsExt traitUnix标准
use std::os::unix::fs::MetadataExt; // ⭐⭐⭐⭐⭐ Phase 2.2: 导入MetadataExt trait获取uid/gid
use std::fs;
use std::io::{SeekFrom, Write};
use std::os::unix::fs::PermissionsExt;
use std::os::unix::fs::MetadataExt;
/// SFTP packet类型参考draft-ietf-secsh-filexfer-02.txt
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -178,6 +180,30 @@ impl SftpAttrs {
attrs
}
pub fn from_vfs_stat(stat: &crate::vfs::VfsStat) -> Self {
let mut attrs = Self::new();
attrs.flags = SftpAttrFlags::SSH_FILEXFER_ATTR_SIZE
| SftpAttrFlags::SSH_FILEXFER_ATTR_UIDGID
| SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS
| SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME;
attrs.size = Some(stat.size);
attrs.permissions = Some(stat.mode);
attrs.uid = Some(stat.uid);
attrs.gid = Some(stat.gid);
if let Ok(d) = stat.atime.duration_since(std::time::UNIX_EPOCH) {
attrs.atime = Some(d.as_secs() as u32);
}
if let Ok(d) = stat.mtime.duration_since(std::time::UNIX_EPOCH) {
attrs.mtime = Some(d.as_secs() as u32);
}
attrs
}
pub fn serialize(&self) -> Result<Vec<u8>> {
debug!("Serializing SftpAttrs: flags=0x{:08x}, size={:?}, uid={:?}, gid={:?}, permissions=0x{:08x}, atime={:?}, mtime={:?}",
self.flags, self.size, self.uid, self.gid,
@@ -242,13 +268,12 @@ impl SftpAttrs {
}
/// SFTP handle文件或目录句柄
#[derive(Debug)] // 移除CloneFile/DirEntry不支持Clone
pub struct SftpHandle {
pub id: u32,
pub path: PathBuf,
pub handle_type: SftpHandleType,
pub file: Option<File>,
pub dir_entries: Option<Vec<fs::DirEntry>>,
pub file: Option<Box<dyn VfsFile>>,
pub dir_entries: Option<Vec<VfsDirEntry>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -260,6 +285,7 @@ pub enum SftpHandleType {
/// SFTP处理管理器参考OpenSSH sftp-server.c
pub struct SftpHandler {
root_dir: PathBuf,
vfs: Box<dyn VfsBackend>,
next_handle_id: u32,
handles: std::collections::HashMap<u32, SftpHandle>,
// ⭐⭐⭐⭐⭐ Phase 4: 添加 client maxpack 限制参考OpenSSH sftp-server.c
@@ -277,14 +303,15 @@ impl SftpHandler {
const MAX_HASH_SIZE: u64 = 268_435_456;
// ⭐⭐⭐⭐⭐ Phase 4: 修改 new() 方法,接受 maxpack 参数
pub fn new(root_dir: PathBuf, maxpacket: u32) -> Self {
pub fn new(root_dir: PathBuf, vfs: Box<dyn VfsBackend>, maxpacket: u32) -> Self {
let canonical_root = root_dir.canonicalize().unwrap_or(root_dir);
Self {
root_dir: canonical_root,
vfs,
next_handle_id: 0,
handles: std::collections::HashMap::new(),
maxpacket,
restrict_absolute: false, // 默认允许绝对路径
restrict_absolute: false,
}
}
@@ -360,30 +387,9 @@ impl SftpHandler {
info!("SSH_FXP_OPEN: id={}, path={}, pflags={:#x}", id, path, pflags);
let full_path = self.resolve_path(&path)?;
let flags = OpenFlags::from_sftp_pflags(pflags);
let file_result = if pflags & SftpFileFlags::SSH_FXF_READ != 0 {
OpenOptions::new().read(true).open(&full_path)
} else if pflags & SftpFileFlags::SSH_FXF_WRITE != 0 {
let mut opts = OpenOptions::new();
opts.write(true);
if pflags & SftpFileFlags::SSH_FXF_APPEND != 0 {
opts.append(true);
}
if pflags & SftpFileFlags::SSH_FXF_CREAT != 0 {
opts.create(true);
}
if pflags & SftpFileFlags::SSH_FXF_TRUNC != 0 {
opts.truncate(true);
}
if pflags & SftpFileFlags::SSH_FXF_EXCL != 0 {
opts.create_new(true);
}
opts.open(&full_path)
} else {
return self.build_status_response(id, SftpStatus::SSH_FX_OP_UNSUPPORTED, "Unsupported open flags");
};
match file_result {
match self.vfs.open_file(&full_path, &flags) {
Ok(file) => {
if self.handles.len() >= Self::MAX_HANDLES {
warn!("SSH_FXP_OPEN: handle limit reached ({})", Self::MAX_HANDLES);
@@ -405,7 +411,7 @@ impl SftpHandler {
self.build_handle_response(id, &handle_id.to_be_bytes())
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -447,9 +453,8 @@ impl SftpHandler {
if let Some(handle) = self.handles.get_mut(&handle_id) {
if let Some(ref mut file) = handle.file {
file.seek(SeekFrom::Start(offset))?;
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
// ⭐⭐⭐⭐⭐ Phase 4: 限制数据大小,不超过 maxpacket - 1024 和 MAX_XFER_SIZE
let max_data_size = std::cmp::min(self.maxpacket.saturating_sub(1024), Self::MAX_XFER_SIZE);
let actual_length = std::cmp::min(length, max_data_size);
@@ -465,7 +470,7 @@ impl SftpHandler {
self.build_data_response(id, &buffer)
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
} else {
@@ -491,7 +496,6 @@ impl SftpHandler {
info!("SSH_FXP_WRITE: id={}, handle={}, offset={}, length={}", id, handle_id, offset, write_data.len());
// ⭐⭐⭐⭐⭐ Phase 1.2: 添加 data preview显示前 20 字节)
if write_data.len() > 0 {
let preview_len = std::cmp::min(20, write_data.len());
let preview = &write_data[0..preview_len];
@@ -500,14 +504,15 @@ impl SftpHandler {
if let Some(handle) = self.handles.get_mut(&handle_id) {
if let Some(ref mut file) = handle.file {
file.seek(SeekFrom::Start(offset))?;
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
match file.write_all(&write_data) {
Ok(_) => {
file.flush().ok();
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Write successful")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
} else {
@@ -532,13 +537,13 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match fs::symlink_metadata(&full_path) {
Ok(metadata) => {
let attrs = SftpAttrs::from_metadata(&metadata);
match self.vfs.lstat(&full_path) {
Ok(stat) => {
let attrs = SftpAttrs::from_vfs_stat(&stat);
self.build_attrs_response(id, &attrs)
}
Err(e) => {
self.build_status_response(id, SftpStatus::SSH_FX_NO_SUCH_FILE, &format!("Stat error: {}", e))
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -556,14 +561,26 @@ impl SftpHandler {
info!("SSH_FXP_FSTAT: id={}, handle={}", id, handle_id);
if let Some(handle) = self.handles.get(&handle_id) {
match fs::metadata(&handle.path) {
Ok(metadata) => {
let attrs = SftpAttrs::from_metadata(&metadata);
self.build_attrs_response(id, &attrs)
if let Some(handle) = self.handles.get_mut(&handle_id) {
if let Some(ref mut file) = handle.file {
match file.stat() {
Ok(stat) => {
let attrs = SftpAttrs::from_vfs_stat(&stat);
self.build_attrs_response(id, &attrs)
}
Err(e) => {
self.build_status_from_vfs_error(id, &e)
}
}
Err(e) => {
self.build_status_from_io_error(id, &e)
} else {
match self.vfs.stat(&handle.path) {
Ok(stat) => {
let attrs = SftpAttrs::from_vfs_stat(&stat);
self.build_attrs_response(id, &attrs)
}
Err(e) => {
self.build_status_from_vfs_error(id, &e)
}
}
}
} else {
@@ -585,7 +602,7 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match fs::read_dir(&full_path) {
match self.vfs.read_dir(&full_path) {
Ok(entries) => {
if self.handles.len() >= Self::MAX_HANDLES {
warn!("SSH_FXP_OPENDIR: handle limit reached ({})", Self::MAX_HANDLES);
@@ -594,14 +611,12 @@ impl SftpHandler {
let handle_id = self.next_handle_id;
self.next_handle_id += 1;
let dir_entries: Vec<fs::DirEntry> = entries.filter_map(|e| e.ok()).collect();
let handle = SftpHandle {
id: handle_id,
path: full_path,
handle_type: SftpHandleType::Directory,
file: None,
dir_entries: Some(dir_entries),
dir_entries: Some(entries),
};
self.handles.insert(handle_id, handle);
@@ -609,7 +624,7 @@ impl SftpHandler {
self.build_handle_response(id, &handle_id.to_be_bytes())
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -635,11 +650,9 @@ impl SftpHandler {
} else {
let entries: Vec<(String, SftpAttrs)> = dir_entries
.drain(..std::cmp::min(100, dir_entries.len()))
.filter_map(|entry| {
let name = entry.file_name().to_string_lossy().to_string();
let attrs = entry.metadata().ok()?;
let sftp_attrs = SftpAttrs::from_metadata(&attrs);
Some((name, sftp_attrs))
.map(|entry| {
let attrs = SftpAttrs::from_vfs_stat(&entry.stat);
(entry.name, attrs)
})
.collect();
@@ -670,12 +683,12 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match fs::remove_file(&full_path) {
match self.vfs.remove_file(&full_path) {
Ok(_) => {
self.build_status_response(id, SftpStatus::SSH_FX_OK, "File removed")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -695,12 +708,12 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match fs::create_dir(&full_path) {
match self.vfs.create_dir(&full_path, 0o755) {
Ok(_) => {
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Directory created")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -719,12 +732,12 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match fs::remove_dir(&full_path) {
match self.vfs.remove_dir(&full_path) {
Ok(_) => {
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Directory removed")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -765,13 +778,13 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match fs::metadata(&full_path) {
Ok(metadata) => {
let attrs = SftpAttrs::from_metadata(&metadata);
match self.vfs.stat(&full_path) {
Ok(stat) => {
let attrs = SftpAttrs::from_vfs_stat(&stat);
self.build_attrs_response(id, &attrs)
}
Err(e) => {
self.build_status_response(id, SftpStatus::SSH_FX_NO_SUCH_FILE, &format!("Stat error: {}", e))
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -792,12 +805,12 @@ impl SftpHandler {
let old_full_path = self.resolve_path(&old_path)?;
let new_full_path = self.resolve_path(&new_path)?;
match fs::rename(&old_full_path, &new_full_path) {
match self.vfs.rename(&old_full_path, &new_full_path) {
Ok(_) => {
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Rename successful")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -832,7 +845,7 @@ impl SftpHandler {
info!("SSH_FXP_FSETSTAT: id={}, handle={}, attrs.flags={}", id, handle_id, attrs.flags);
let handle = self.handles.get(&handle_id);
let handle = self.handles.get_mut(&handle_id);
if handle.is_none() {
return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid handle");
}
@@ -847,25 +860,35 @@ impl SftpHandler {
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_SIZE != 0 {
if let Some(size) = attrs.size {
info!("FSETSTAT: setting file size to {}", size);
let file = OpenOptions::new().write(true).open(&path)?;
file.set_len(size)?;
if let Some(ref mut file) = handle.file {
file.set_len(size).map_err(|e| anyhow!("set_len error: {}", e))?;
} else {
let flags = OpenFlags::new().write();
if let Ok(mut f) = self.vfs.open_file(&path, &flags) {
f.set_len(size).ok();
}
}
}
}
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0 {
if let Some(permissions) = attrs.permissions {
info!("FSETSTAT: setting permissions to {:o}", permissions);
fs::set_permissions(&path, fs::Permissions::from_mode(permissions))?;
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0
|| attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0
{
let mut vfs_stat = crate::vfs::VfsStat::new();
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0 {
vfs_stat.mode = attrs.permissions.unwrap_or(0);
} else {
if let Ok(s) = self.vfs.lstat(&path) {
vfs_stat.mode = s.mode;
}
}
}
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0 {
if let (Some(atime), Some(mtime)) = (attrs.atime, attrs.mtime) {
info!("FSETSTAT: setting atime={}, mtime={}", atime, mtime);
let atime_filetime = filetime::FileTime::from_unix_time(atime as i64, 0);
let mtime_filetime = filetime::FileTime::from_unix_time(mtime as i64, 0);
filetime::set_file_times(&path, atime_filetime, mtime_filetime)?;
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0 {
if let (Some(atime), Some(mtime)) = (attrs.atime, attrs.mtime) {
vfs_stat.atime = std::time::UNIX_EPOCH + std::time::Duration::from_secs(atime as u64);
vfs_stat.mtime = std::time::UNIX_EPOCH + std::time::Duration::from_secs(mtime as u64);
}
}
self.vfs.set_stat(&path, &vfs_stat).ok();
}
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Fsetstat successful")
@@ -885,13 +908,13 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match fs::read_link(&full_path) {
match self.vfs.read_link(&full_path) {
Ok(link_target) => {
let target = link_target.to_string_lossy().to_string();
self.build_name_response(id, vec![(target, SftpAttrs::default())])
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -912,18 +935,14 @@ impl SftpHandler {
let full_linkpath = self.resolve_path(&linkpath)?;
let full_targetpath = self.resolve_path(&targetpath)?;
#[cfg(unix)]
match std::os::unix::fs::symlink(&full_targetpath, &full_linkpath) {
match self.vfs.create_symlink(&full_targetpath, &full_linkpath) {
Ok(_) => {
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Symlink created")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
#[cfg(not(unix))]
self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Symlink not supported on non-Unix systems")
}
/// 处理SSH_FXP_EXTENDEDPhase 10参考OpenSSH sftp-server.c: process_extended())
@@ -984,50 +1003,30 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
#[cfg(unix)]
{
use std::os::unix::fs::MetadataExt;
match fs::metadata(&full_path) {
Ok(metadata) => {
// 构建statvfs response参考OpenSSH sftp-server.c
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
// f_bsize文件系统块大小
response.write_u64::<BigEndian>(4096)?;
// f_frsize基本块大小
response.write_u64::<BigEndian>(4096)?;
// f_blocks总块数
response.write_u64::<BigEndian>(1000000)?;
// f_bfree空闲块数
response.write_u64::<BigEndian>(500000)?;
// f_bavail可用块数
response.write_u64::<BigEndian>(500000)?;
// f_files总文件数
response.write_u64::<BigEndian>(100000)?;
// f_ffree空闲文件数
response.write_u64::<BigEndian>(50000)?;
// f_favail可用文件数
response.write_u64::<BigEndian>(50000)?;
// f_fsid文件系统ID
response.write_u64::<BigEndian>(0)?;
// f_flag标志
response.write_u64::<BigEndian>(0)?;
// f_namemax文件名最大长度
response.write_u64::<BigEndian>(255)?;
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_from_io_error(id, &e)
}
match self.vfs.stat(&full_path) {
Ok(_) => {
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
response.write_u64::<BigEndian>(4096)?;
response.write_u64::<BigEndian>(4096)?;
response.write_u64::<BigEndian>(1000000)?;
response.write_u64::<BigEndian>(500000)?;
response.write_u64::<BigEndian>(500000)?;
response.write_u64::<BigEndian>(100000)?;
response.write_u64::<BigEndian>(50000)?;
response.write_u64::<BigEndian>(50000)?;
response.write_u64::<BigEndian>(0)?;
response.write_u64::<BigEndian>(0)?;
response.write_u64::<BigEndian>(255)?;
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_from_vfs_error(id, &e)
}
}
#[cfg(not(unix))]
self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "statvfs not supported on non-Unix systems")
}
/// 处理fstatvfs@openssh.com扩展文件句柄统计
@@ -1073,18 +1072,14 @@ impl SftpHandler {
let full_oldpath = self.resolve_path(&oldpath)?;
let full_newpath = self.resolve_path(&newpath)?;
#[cfg(unix)]
match fs::hard_link(&full_oldpath, &full_newpath) {
match self.vfs.hard_link(&full_oldpath, &full_newpath) {
Ok(_) => {
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Hardlink created")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
#[cfg(not(unix))]
self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Hardlink not supported on non-Unix systems")
}
/// 处理posix-rename@openssh.com扩展POSIX语义重命名
@@ -1097,12 +1092,12 @@ impl SftpHandler {
let full_oldpath = self.resolve_path(&oldpath)?;
let full_newpath = self.resolve_path(&newpath)?;
match fs::rename(&full_oldpath, &full_newpath) {
match self.vfs.rename(&full_oldpath, &full_newpath) {
Ok(_) => {
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Posix rename successful")
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -1122,34 +1117,31 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match File::open(&full_path) {
let flags = OpenFlags::new().read();
match self.vfs.open_file(&full_path, &flags) {
Ok(mut file) => {
file.seek(SeekFrom::Start(offset))?;
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
let mut buffer = vec![0u8; actual_length as usize];
file.read_exact(&mut buffer)?;
file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
// 计算MD5哈希
let hash = md5::compute(&buffer);
let hash_hex = format!("{:x}", hash);
// 构建响应
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
// hash-algorithm (SSH string)
response.write_u32::<BigEndian>(4)?;
response.write_all("md5".as_bytes())?;
// hash-value (SSH string)
response.write_u32::<BigEndian>(hash_hex.len() as u32)?;
response.write_all(hash_hex.as_bytes())?;
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -1169,37 +1161,34 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match File::open(&full_path) {
let flags = OpenFlags::new().read();
match self.vfs.open_file(&full_path, &flags) {
Ok(mut file) => {
file.seek(SeekFrom::Start(offset))?;
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
let mut buffer = vec![0u8; actual_length as usize];
file.read_exact(&mut buffer)?;
file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
// 计算SHA256哈希使用sha2 crate
use sha2::{Sha256, Digest};
let mut hasher = Sha256::new();
hasher.update(&buffer);
let hash = hasher.finalize();
let hash_hex = format!("{:x}", hash);
// 构建响应
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
// hash-algorithm (SSH string)
response.write_u32::<BigEndian>(6)?;
response.write_all("sha256".as_bytes())?;
// hash-value (SSH string)
response.write_u32::<BigEndian>(hash_hex.len() as u32)?;
response.write_all(hash_hex.as_bytes())?;
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -1219,21 +1208,20 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match File::open(&full_path) {
let flags = OpenFlags::new().read();
match self.vfs.open_file(&full_path, &flags) {
Ok(mut file) => {
file.seek(SeekFrom::Start(offset))?;
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
let mut buffer = vec![0u8; actual_length as usize];
file.read_exact(&mut buffer)?;
file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
// 计算SHA384哈希
use sha2::{Sha384, Digest};
let mut hasher = Sha384::new();
hasher.update(&buffer);
let hash = hasher.finalize();
let hash_hex = format!("{:x}", hash);
// 构建响应
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
@@ -1247,7 +1235,7 @@ impl SftpHandler {
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -1267,21 +1255,20 @@ impl SftpHandler {
let full_path = self.resolve_path(&path)?;
match File::open(&full_path) {
let flags = OpenFlags::new().read();
match self.vfs.open_file(&full_path, &flags) {
Ok(mut file) => {
file.seek(SeekFrom::Start(offset))?;
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
let mut buffer = vec![0u8; actual_length as usize];
file.read_exact(&mut buffer)?;
file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
// 计算SHA512哈希
use sha2::{Sha512, Digest};
let mut hasher = Sha512::new();
hasher.update(&buffer);
let hash = hasher.finalize();
let hash_hex = format!("{:x}", hash);
// 构建响应
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
@@ -1295,7 +1282,7 @@ impl SftpHandler {
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_from_io_error(id, &e)
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -1303,30 +1290,28 @@ impl SftpHandler {
/// 处理check-file@openssh.com扩展Phase 12文件检查
fn handle_check_file(&self, cursor: &mut std::io::Cursor<&[u8]>, id: u32) -> Result<Vec<u8>> {
let path = read_sftp_string(cursor)?;
let check_flags = cursor.read_u32::<BigEndian>()?;
let _check_flags = cursor.read_u32::<BigEndian>()?;
info!("check-file: path={}, flags={:#x}", path, check_flags);
info!("check-file: path={}", path);
let full_path = self.resolve_path(&path)?;
match fs::metadata(&full_path) {
Ok(metadata) => {
// 构建响应
match self.vfs.stat(&full_path) {
Ok(stat) => {
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
// 返回文件存在和基本信息
response.write_u32::<BigEndian>(1)?; // result: 1 = file exists
response.write_u32::<BigEndian>(1)?;
let msg = format!("File exists, size: {}", metadata.len());
let msg = format!("File exists, size: {}", stat.size);
response.write_u32::<BigEndian>(msg.len() as u32)?;
response.write_all(msg.as_bytes())?;
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_response(id, SftpStatus::SSH_FX_NO_SUCH_FILE, &format!("Check file error: {}", e))
self.build_status_from_vfs_error(id, &e)
}
}
}
@@ -1339,11 +1324,8 @@ impl SftpHandler {
let write_handle_bytes = read_sftp_string_bytes(cursor)?;
let write_offset = cursor.read_u64::<BigEndian>()?;
info!("copy-data: read_handle={}, read_offset={}, read_length={}, write_handle={}, write_offset={}",
u32::from_be_bytes([read_handle_bytes[0], read_handle_bytes[1], read_handle_bytes[2], read_handle_bytes[3]]),
read_offset, read_length,
u32::from_be_bytes([write_handle_bytes[0], write_handle_bytes[1], write_handle_bytes[2], write_handle_bytes[3]]),
write_offset);
info!("copy-data: read_handle={:?}, read_offset={}, read_length={}, write_handle={:?}, write_offset={}",
read_handle_bytes, read_offset, read_length, write_handle_bytes, write_offset);
let actual_length = std::cmp::min(read_length, Self::MAX_XFER_SIZE as u64);
if actual_length < read_length {
@@ -1353,52 +1335,44 @@ impl SftpHandler {
let read_handle_id = u32::from_be_bytes([read_handle_bytes[0], read_handle_bytes[1], read_handle_bytes[2], read_handle_bytes[3]]);
let write_handle_id = u32::from_be_bytes([write_handle_bytes[0], write_handle_bytes[1], write_handle_bytes[2], write_handle_bytes[3]]);
// 获取read handle的path不可变引用
let read_path = if let Some(read_handle) = self.handles.get(&read_handle_id) {
read_handle.path.clone()
} else {
return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid read handle");
};
// 获取write handle的path不可变引用
let write_path = if let Some(write_handle) = self.handles.get(&write_handle_id) {
write_handle.path.clone()
} else {
return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid write handle");
};
// 从read_path读取数据
match File::open(&read_path) {
Ok(mut read_file) => {
read_file.seek(SeekFrom::Start(read_offset))?;
let mut buffer = vec![0u8; actual_length as usize];
read_file.read_exact(&mut buffer)?;
// 写入到write_path
match OpenOptions::new().write(true).open(&write_path) {
Ok(mut write_file) => {
write_file.seek(SeekFrom::Start(write_offset))?;
write_file.write_all(&buffer)?;
// 构建响应
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
// 返回复制的字节数
response.write_u64::<BigEndian>(actual_length)?;
self.wrap_sftp_packet(&response)
}
Err(e) => {
self.build_status_from_io_error(id, &e)
}
}
}
Err(e) => {
self.build_status_from_io_error(id, &e)
}
}
let read_flags = OpenFlags::new().read();
let write_flags = OpenFlags::new().write();
let mut read_file = match self.vfs.open_file(&read_path, &read_flags) {
Ok(f) => f,
Err(e) => return self.build_status_from_vfs_error(id, &e),
};
let mut write_file = match self.vfs.open_file(&write_path, &write_flags) {
Ok(f) => f,
Err(e) => return self.build_status_from_vfs_error(id, &e),
};
read_file.seek(SeekFrom::Start(read_offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
let mut buffer = vec![0u8; actual_length as usize];
read_file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
write_file.seek(SeekFrom::Start(write_offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
write_file.write_all(&buffer).map_err(|e| anyhow!("Write error: {}", e))?;
write_file.flush().ok();
let mut response = Vec::new();
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
response.write_u32::<BigEndian>(id)?;
response.write_u64::<BigEndian>(actual_length)?;
self.wrap_sftp_packet(&response)
}
/// 解析路径安全性检查参考OpenSSH sftp-server.c: path_resolve()
@@ -1608,6 +1582,24 @@ impl SftpHandler {
let msg = format!("{}", err);
self.build_status_response(id, status, &msg)
}
/// 根据 VfsError 构建状态响应(自动映射错误类型)
fn build_status_from_vfs_error(&self, id: u32, err: &crate::vfs::VfsError) -> Result<Vec<u8>> {
use crate::vfs::VfsError;
let status = match err {
VfsError::NotFound(_) => SftpStatus::SSH_FX_NO_SUCH_FILE,
VfsError::PermissionDenied(_) => SftpStatus::SSH_FX_PERMISSION_DENIED,
VfsError::AlreadyExists(_) => SftpStatus::SSH_FX_FAILURE,
VfsError::NotEmpty(_) => SftpStatus::SSH_FX_FAILURE,
VfsError::NotADirectory(_) => SftpStatus::SSH_FX_FAILURE,
VfsError::IsADirectory(_) => SftpStatus::SSH_FX_FAILURE,
VfsError::Unsupported(_) => SftpStatus::SSH_FX_OP_UNSUPPORTED,
VfsError::Io(_) => SftpStatus::SSH_FX_FAILURE,
VfsError::UnexpectedEof => SftpStatus::SSH_FX_EOF,
};
let msg = format!("{}", err);
self.build_status_response(id, status, &msg)
}
}
/// 读取SFTP字符串参考draft-ietf-secsh-filexfer-02.txt
@@ -1665,8 +1657,14 @@ fn read_sftp_attrs<R: std::io::Read>(reader: &mut R) -> Result<SftpAttrs> {
#[cfg(test)]
mod tests {
use super::*;
use crate::vfs::local_fs::LocalFs;
use std::fs::File;
use tempfile::TempDir;
fn make_handler(root_dir: PathBuf) -> SftpHandler {
SftpHandler::new(root_dir, Box::new(LocalFs::new()), 32768)
}
#[test]
fn test_sftp_packet_type_conversion() {
assert_eq!(SftpPacketType::try_from(1).unwrap(), SftpPacketType::SSH_FXP_INIT);
@@ -1677,7 +1675,7 @@ mod tests {
#[test]
fn test_sftp_handler_creation() {
let temp_dir = TempDir::new().unwrap();
let handler = SftpHandler::new(temp_dir.path().to_path_buf(), 32768);
let handler = make_handler(temp_dir.path().to_path_buf());
assert_eq!(handler.next_handle_id, 0);
}
@@ -1697,7 +1695,7 @@ mod tests {
#[test]
fn test_sftp_handle_init() {
let temp_dir = TempDir::new().unwrap();
let mut handler = SftpHandler::new(temp_dir.path().to_path_buf(), 32768);
let mut handler = make_handler(temp_dir.path().to_path_buf());
let init_packet = vec![1, 0, 0, 0, 3];
let response = handler.handle_request(&init_packet).unwrap();

View File

@@ -0,0 +1,212 @@
use super::util;
use super::open_flags::OpenFlags;
use super::{VfsBackend, VfsDirEntry, VfsError, VfsFile, VfsStat};
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::os::unix::fs::{MetadataExt, PermissionsExt};
/// 本地文件系统实现(直接包装 std::fs不做路径解析
/// 路径解析由上层SftpHandler负责
pub struct LocalFs;
impl LocalFs {
pub fn new() -> Self {
Self
}
}
struct LocalFile {
file: File,
}
impl VfsFile for LocalFile {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, VfsError> {
self.file.read(buf).map_err(|e| VfsError::Io(e.to_string()))
}
fn write(&mut self, buf: &[u8]) -> Result<usize, VfsError> {
self.file.write(buf).map_err(|e| VfsError::Io(e.to_string()))
}
fn seek(&mut self, pos: SeekFrom) -> Result<u64, VfsError> {
self.file.seek(pos).map_err(|e| VfsError::Io(e.to_string()))
}
fn flush(&mut self) -> Result<(), VfsError> {
self.file.flush().map_err(|e| VfsError::Io(e.to_string()))
}
fn stat(&mut self) -> Result<VfsStat, VfsError> {
let meta = self.file.metadata().map_err(|e| VfsError::Io(e.to_string()))?;
Ok(util::stat_from_metadata(&meta, false))
}
fn set_len(&mut self, size: u64) -> Result<(), VfsError> {
self.file.set_len(size).map_err(|e| VfsError::Io(e.to_string()))
}
}
impl VfsBackend for LocalFs {
fn read_dir(&self, path: &Path) -> Result<Vec<VfsDirEntry>, VfsError> {
let dir = fs::read_dir(path).map_err(|e| util::map_io_error(path, e))?;
let mut entries = Vec::new();
for entry in dir {
let entry = entry.map_err(|e| util::map_io_error(path, e))?;
let name = entry.file_name().to_string_lossy().to_string();
let file_type = entry.file_type().map_err(|e| util::map_io_error(path, e))?;
let meta = entry.metadata().map_err(|e| util::map_io_error(path, e))?;
let stat = util::stat_from_metadata(&meta, file_type.is_symlink());
let long_name = util::build_long_name(&stat, &name);
entries.push(VfsDirEntry {
name,
long_name,
stat,
});
}
entries.sort_by(|a, b| a.name.cmp(&b.name));
Ok(entries)
}
fn open_file(&self, path: &Path, flags: &OpenFlags) -> Result<Box<dyn VfsFile>, VfsError> {
let mut opts = OpenOptions::new();
opts.read(flags.read);
opts.write(flags.write);
opts.append(flags.append);
opts.create(flags.create);
opts.truncate(flags.truncate);
opts.create_new(flags.exclusive);
let file = opts.open(path).map_err(|e| util::map_io_error(path, e))?;
#[cfg(unix)]
if flags.create && !flags.exclusive {
if let Ok(meta) = file.metadata() {
if flags.mode != 0 && meta.permissions().mode() != flags.mode {
fs::set_permissions(path, std::fs::Permissions::from_mode(flags.mode))
.ok();
}
}
}
Ok(Box::new(LocalFile { file }))
}
fn stat(&self, path: &Path) -> Result<VfsStat, VfsError> {
let meta = fs::metadata(path).map_err(|e| util::map_io_error(path, e))?;
Ok(util::stat_from_metadata(&meta, false))
}
fn lstat(&self, path: &Path) -> Result<VfsStat, VfsError> {
let meta = fs::symlink_metadata(path).map_err(|e| util::map_io_error(path, e))?;
let is_symlink = path.is_symlink() || meta.file_type().is_symlink();
Ok(util::stat_from_metadata(&meta, is_symlink))
}
fn create_dir(&self, path: &Path, mode: u32) -> Result<(), VfsError> {
fs::create_dir(path).map_err(|e| util::map_io_error(path, e))?;
#[cfg(unix)]
{
fs::set_permissions(path, std::fs::Permissions::from_mode(mode))
.map_err(|e| util::map_io_error(path, e))?;
}
Ok(())
}
fn create_dir_all(&self, path: &Path, mode: u32) -> Result<(), VfsError> {
fs::create_dir_all(path).map_err(|e| util::map_io_error(path, e))?;
#[cfg(unix)]
{
if mode != 0 {
fs::set_permissions(path, std::fs::Permissions::from_mode(mode))
.map_err(|e| util::map_io_error(path, e))?;
}
}
Ok(())
}
fn remove_dir(&self, path: &Path) -> Result<(), VfsError> {
fs::remove_dir(path).map_err(|e| util::map_io_error(path, e))
}
fn remove_file(&self, path: &Path) -> Result<(), VfsError> {
fs::remove_file(path).map_err(|e| util::map_io_error(path, e))
}
fn rename(&self, from: &Path, to: &Path) -> Result<(), VfsError> {
fs::rename(from, to).map_err(|e| util::map_io_error(from, e))
}
fn set_stat(&self, path: &Path, stat: &VfsStat) -> Result<(), VfsError> {
#[cfg(unix)]
{
if stat.mode != 0 {
fs::set_permissions(path, std::fs::Permissions::from_mode(stat.mode))
.map_err(|e| util::map_io_error(path, e))?;
}
}
if let (Some(atime), Some(mtime)) = (
stat.atime.duration_since(std::time::UNIX_EPOCH).ok(),
stat.mtime.duration_since(std::time::UNIX_EPOCH).ok(),
) {
filetime::set_file_times(path,
filetime::FileTime::from_unix_time(atime.as_secs() as i64, 0),
filetime::FileTime::from_unix_time(mtime.as_secs() as i64, 0),
).map_err(|e| util::map_io_error(path, e))?;
}
Ok(())
}
fn read_link(&self, path: &Path) -> Result<PathBuf, VfsError> {
let target = fs::read_link(path).map_err(|e| util::map_io_error(path, e))?;
Ok(target)
}
fn create_symlink(&self, target: &Path, link: &Path) -> Result<(), VfsError> {
#[cfg(unix)]
{
std::os::unix::fs::symlink(target, link)
.map_err(|e| util::map_io_error(link, e))?;
}
#[cfg(not(unix))]
{
std::os::windows::fs::symlink_file(target, link)
.map_err(|e| util::map_io_error(link, e))?;
}
Ok(())
}
fn real_path(&self, path: &Path) -> Result<PathBuf, VfsError> {
let canonical = path.canonicalize().map_err(|e| util::map_io_error(path, e))?;
Ok(canonical)
}
fn exists(&self, path: &Path) -> bool {
path.exists()
}
fn hard_link(&self, original: &Path, link: &Path) -> Result<(), VfsError> {
#[cfg(unix)]
{
fs::hard_link(original, link).map_err(|e| util::map_io_error(original, e))?;
}
#[cfg(not(unix))]
{
return Err(VfsError::Unsupported("hard_link not supported on non-Unix systems".to_string()));
}
Ok(())
}
}

View File

@@ -0,0 +1,160 @@
pub mod open_flags;
pub mod local_fs;
pub mod util;
use std::path::{Path, PathBuf};
use std::time::SystemTime;
/// VFS 错误类型
#[derive(Debug, Clone)]
pub enum VfsError {
NotFound(String),
PermissionDenied(String),
AlreadyExists(String),
NotEmpty(String),
NotADirectory(String),
IsADirectory(String),
Unsupported(String),
Io(String),
UnexpectedEof,
}
impl std::fmt::Display for VfsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VfsError::NotFound(p) => write!(f, "No such file or directory: {}", p),
VfsError::PermissionDenied(p) => write!(f, "Permission denied: {}", p),
VfsError::AlreadyExists(p) => write!(f, "File already exists: {}", p),
VfsError::NotEmpty(p) => write!(f, "Directory not empty: {}", p),
VfsError::NotADirectory(p) => write!(f, "Not a directory: {}", p),
VfsError::IsADirectory(p) => write!(f, "Is a directory: {}", p),
VfsError::Unsupported(msg) => write!(f, "Unsupported: {}", msg),
VfsError::Io(msg) => write!(f, "IO error: {}", msg),
VfsError::UnexpectedEof => write!(f, "Unexpected end of file"),
}
}
}
impl std::error::Error for VfsError {}
/// 文件统计信息(类似 libc::stat
#[derive(Debug, Clone)]
pub struct VfsStat {
pub size: u64,
pub mode: u32,
pub uid: u32,
pub gid: u32,
pub atime: SystemTime,
pub mtime: SystemTime,
pub is_dir: bool,
pub is_symlink: bool,
}
impl VfsStat {
pub fn new() -> Self {
Self {
size: 0,
mode: 0,
uid: 0,
gid: 0,
atime: SystemTime::UNIX_EPOCH,
mtime: SystemTime::UNIX_EPOCH,
is_dir: false,
is_symlink: false,
}
}
}
impl Default for VfsStat {
fn default() -> Self {
Self::new()
}
}
/// 目录条目
#[derive(Debug, Clone)]
pub struct VfsDirEntry {
pub name: String,
pub long_name: String,
pub stat: VfsStat,
}
/// 打开文件的抽象
pub trait VfsFile {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, VfsError>;
fn write(&mut self, buf: &[u8]) -> Result<usize, VfsError>;
fn seek(&mut self, pos: std::io::SeekFrom) -> Result<u64, VfsError>;
fn flush(&mut self) -> Result<(), VfsError>;
fn stat(&mut self) -> Result<VfsStat, VfsError>;
fn set_len(&mut self, size: u64) -> Result<(), VfsError>;
/// Write all bytes (convenience, default loops write() until done)
fn write_all(&mut self, mut buf: &[u8]) -> Result<(), VfsError> {
while !buf.is_empty() {
let n = self.write(buf)?;
if n == 0 {
return Err(VfsError::Io("write returned 0".to_string()));
}
buf = &buf[n..];
}
Ok(())
}
/// Read exactly `buf.len()` bytes (convenience, loops read() until done)
fn read_exact(&mut self, mut buf: &mut [u8]) -> Result<(), VfsError> {
while !buf.is_empty() {
let n = self.read(buf)?;
if n == 0 {
return Err(VfsError::UnexpectedEof);
}
buf = &mut buf[n..];
}
Ok(())
}
}
/// VFS 后端 trait所有文件系统操作
pub trait VfsBackend: Send {
/// 读取目录内容
fn read_dir(&self, path: &Path) -> Result<Vec<VfsDirEntry>, VfsError>;
/// 打开文件(读/写)
fn open_file(&self, path: &Path, flags: &open_flags::OpenFlags) -> Result<Box<dyn VfsFile>, VfsError>;
/// 获取文件/目录元数据
fn stat(&self, path: &Path) -> Result<VfsStat, VfsError>;
fn lstat(&self, path: &Path) -> Result<VfsStat, VfsError>;
/// 创建目录
fn create_dir(&self, path: &Path, mode: u32) -> Result<(), VfsError>;
/// 递归创建目录
fn create_dir_all(&self, path: &Path, mode: u32) -> Result<(), VfsError>;
/// 删除空目录
fn remove_dir(&self, path: &Path) -> Result<(), VfsError>;
/// 删除文件
fn remove_file(&self, path: &Path) -> Result<(), VfsError>;
/// 重命名
fn rename(&self, from: &Path, to: &Path) -> Result<(), VfsError>;
/// 设置文件属性
fn set_stat(&self, path: &Path, stat: &VfsStat) -> Result<(), VfsError>;
/// 读取符号链接目标
fn read_link(&self, path: &Path) -> Result<PathBuf, VfsError>;
/// 创建符号链接
fn create_symlink(&self, target: &Path, link: &Path) -> Result<(), VfsError>;
/// 规范化路径
fn real_path(&self, path: &Path) -> Result<PathBuf, VfsError>;
/// 检查路径是否存在
fn exists(&self, path: &Path) -> bool;
/// 创建硬链接
fn hard_link(&self, original: &Path, link: &Path) -> Result<(), VfsError>;
}

View File

@@ -0,0 +1,75 @@
/// 文件打开标志(映射 SSH_FXF_* 和 POSIX open flags
#[derive(Debug, Clone, Default)]
pub struct OpenFlags {
pub read: bool,
pub write: bool,
pub append: bool,
pub create: bool,
pub truncate: bool,
pub exclusive: bool,
pub mode: u32,
}
impl OpenFlags {
pub fn new() -> Self {
Self::default()
}
pub fn read(mut self) -> Self {
self.read = true;
self
}
pub fn write(mut self) -> Self {
self.write = true;
self
}
pub fn append(mut self) -> Self {
self.append = true;
self.write = true;
self
}
pub fn create(mut self) -> Self {
self.create = true;
self.write = true;
self
}
pub fn truncate(mut self) -> Self {
self.truncate = true;
self.write = true;
self
}
pub fn exclusive(mut self) -> Self {
self.exclusive = true;
self
}
pub fn mode(mut self, mode: u32) -> Self {
self.mode = mode;
self
}
/// 从 SFTP 的 pflagsSSH_FXF_*)构建 OpenFlags
pub fn from_sftp_pflags(pflags: u32) -> Self {
let read = pflags & 0x00000001 != 0;
let write = pflags & 0x00000002 != 0;
let append = pflags & 0x00000004 != 0;
let create = pflags & 0x00000008 != 0;
let truncate = pflags & 0x00000010 != 0;
let exclusive = pflags & 0x00000020 != 0;
Self {
read,
write,
append,
create,
truncate,
exclusive,
mode: 0o644,
}
}
}

View File

@@ -0,0 +1,105 @@
use super::{VfsError, VfsStat};
use chrono::Datelike;
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
/// 从 std::io::ErrorKind 映射 VfsError
pub fn map_io_error(path: &Path, e: std::io::Error) -> VfsError {
match e.kind() {
std::io::ErrorKind::NotFound => VfsError::NotFound(path.display().to_string()),
std::io::ErrorKind::PermissionDenied => VfsError::PermissionDenied(path.display().to_string()),
std::io::ErrorKind::AlreadyExists => VfsError::AlreadyExists(path.display().to_string()),
std::io::ErrorKind::DirectoryNotEmpty => VfsError::NotEmpty(path.display().to_string()),
std::io::ErrorKind::NotADirectory => VfsError::NotADirectory(path.display().to_string()),
std::io::ErrorKind::IsADirectory => VfsError::IsADirectory(path.display().to_string()),
std::io::ErrorKind::UnexpectedEof => VfsError::UnexpectedEof,
other => VfsError::Io(format!("{}: {}", other, path.display())),
}
}
/// 从 std::fs::Metadata 构建 VfsStat
pub fn stat_from_metadata(meta: &std::fs::Metadata, is_symlink: bool) -> VfsStat {
#[cfg(unix)]
use std::os::unix::fs::MetadataExt;
let mut stat = VfsStat::new();
stat.size = meta.len();
stat.is_dir = meta.is_dir();
stat.is_symlink = is_symlink;
#[cfg(unix)]
{
stat.mode = meta.permissions().mode();
stat.uid = meta.uid();
stat.gid = meta.gid();
}
#[cfg(not(unix))]
{
stat.mode = if meta.is_dir() { 0o40755 } else { 0o100644 };
}
if let Ok(t) = meta.accessed() {
stat.atime = t;
}
if let Ok(t) = meta.modified() {
stat.mtime = t;
}
stat
}
/// 构建目录条目的 long_name类似 ls -l 格式)
pub fn build_long_name(stat: &VfsStat, name: &str) -> String {
let file_type = if stat.is_dir { 'd' } else { '-' };
let perms = format_permissions(stat.mode & 0o777);
let link_count = if stat.is_dir { 3 } else { 1 };
let size = stat.size;
let mtime = match stat.mtime.duration_since(std::time::UNIX_EPOCH) {
Ok(d) => {
let secs = d.as_secs();
format_timestamp(secs)
}
Err(_) => "Jan 1 1970".to_string(),
};
format!(
"{}{} {} {} {} {} {} {}",
file_type, perms,
link_count,
stat.uid,
stat.gid,
size,
mtime,
name
)
}
fn format_permissions(mode: u32) -> String {
let rwx = |n: u32| -> String {
let r = if n & 4 != 0 { 'r' } else { '-' };
let w = if n & 2 != 0 { 'w' } else { '-' };
let x = if n & 1 != 0 { 'x' } else { '-' };
format!("{}{}{}", r, w, x)
};
format!(
"{}{}{}",
rwx((mode >> 6) & 7),
rwx((mode >> 3) & 7),
rwx(mode & 7)
)
}
fn format_timestamp(secs: u64) -> String {
let datetime = match chrono::DateTime::from_timestamp(secs as i64, 0) {
Some(dt) => dt,
None => return "Jan 1 1970".to_string(),
};
let now = chrono::Utc::now();
if datetime.year() == now.year() {
datetime.format("%b %e %H:%M").to_string()
} else {
datetime.format("%b %e %Y").to_string()
}
}