VFS/DataProvider/Config refactoring + SSH public key authentication
Phase 1-6 of refactoring plan: - VFS abstraction (VfsBackend trait + LocalFs + OpenFlags builder) - DataProvider trait (SqliteProvider + PgProvider, SFTPGo-compatible) - Config refactoring (AppConfig unified sections, env overrides) - SSH handlers (sftp/scp/rsync) migrated to VFS + DataProvider - SSH public key authentication (Ed25519 signature verification) - SSH stderr → CHANNEL_EXTENDED_DATA support - Web auth uses DataProvider instead of direct SQL - User home directory from provider (per-user isolation) - PostgreSQL auth provider for SFTPGo compatibility
This commit is contained in:
@@ -38,6 +38,7 @@ filetime = "0.2"
|
||||
base64 = "0.22"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio-postgres = "0.7"
|
||||
postgres = "0.19"
|
||||
russh = "0.61.2"
|
||||
russh-keys = "0.50.0-beta.7"
|
||||
russh-sftp = "2.3.0"
|
||||
|
||||
@@ -5,6 +5,8 @@ use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::provider::{DataProvider, ProviderError};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct User {
|
||||
pub user_id: String,
|
||||
@@ -66,13 +68,13 @@ pub struct AuthState {
|
||||
pub users: Arc<Mutex<HashMap<String, User>>>,
|
||||
pub auth_db: Option<crate::sync::AuthDb>,
|
||||
pub admin_sessions: Arc<Mutex<HashMap<String, AdminSession>>>,
|
||||
pub provider: Option<Arc<dyn DataProvider>>,
|
||||
}
|
||||
|
||||
impl AuthState {
|
||||
pub fn new() -> Self {
|
||||
let mut users = HashMap::new();
|
||||
|
||||
// Create default demo user
|
||||
let password_hash = hash("demo123", DEFAULT_COST).unwrap();
|
||||
users.insert(
|
||||
"demo".to_string(),
|
||||
@@ -89,6 +91,7 @@ impl AuthState {
|
||||
users: Arc::new(Mutex::new(users)),
|
||||
auth_db: None,
|
||||
admin_sessions: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,6 +103,17 @@ impl AuthState {
|
||||
users: Arc::new(Mutex::new(HashMap::new())),
|
||||
auth_db,
|
||||
admin_sessions: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_provider(provider: Box<dyn DataProvider>) -> Self {
|
||||
AuthState {
|
||||
sessions: Arc::new(Mutex::new(HashMap::new())),
|
||||
users: Arc::new(Mutex::new(HashMap::new())),
|
||||
auth_db: None,
|
||||
admin_sessions: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider: Some(Arc::from(provider)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,8 +222,12 @@ impl AuthState {
|
||||
}
|
||||
|
||||
pub fn login_with_sync(&self, username: &str, password: &str) -> Option<LoginResponse> {
|
||||
// Prefer provider over auth_db
|
||||
if let Some(provider) = &self.provider {
|
||||
return self.login_with_provider(&**provider, username, password);
|
||||
}
|
||||
|
||||
if let Some(auth_db) = &self.auth_db {
|
||||
// Get user from auth.sqlite
|
||||
let user = match auth_db.get_user(username) {
|
||||
Ok(Some(user)) => user,
|
||||
Ok(None) => {
|
||||
@@ -266,11 +284,70 @@ impl AuthState {
|
||||
}
|
||||
}
|
||||
|
||||
fn login_with_provider(&self, provider: &dyn DataProvider, username: &str, password: &str) -> Option<LoginResponse> {
|
||||
match provider.get_user(username) {
|
||||
Ok(Some(user)) => {
|
||||
if user.status != 1 {
|
||||
log::warn!("User {} is disabled or not found", username);
|
||||
return None;
|
||||
}
|
||||
|
||||
match provider.check_password(username, password) {
|
||||
Ok(true) => {
|
||||
let groups = provider.get_user_groups(username).unwrap_or_default();
|
||||
|
||||
let token = Uuid::new_v4().to_string();
|
||||
let now = Utc::now();
|
||||
let expires_at = now + Duration::hours(24);
|
||||
|
||||
let session = Session {
|
||||
token: token.clone(),
|
||||
user_id: username.to_string(),
|
||||
username: username.to_string(),
|
||||
created_at: now.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
|
||||
expires_at: expires_at.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
|
||||
groups: groups.clone(),
|
||||
permissions: user.permissions.clone(),
|
||||
};
|
||||
|
||||
let mut sessions = self.sessions.lock().unwrap();
|
||||
sessions.insert(token.clone(), session);
|
||||
|
||||
log::info!("User {} logged in via DataProvider", username);
|
||||
|
||||
Some(LoginResponse {
|
||||
token,
|
||||
expires_at: expires_at.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
|
||||
user_id: username.to_string(),
|
||||
groups,
|
||||
permissions: user.permissions,
|
||||
})
|
||||
}
|
||||
Ok(false) => {
|
||||
log::warn!("Invalid password for user {}", username);
|
||||
None
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Password check error for {}: {}", username, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
log::warn!("User {} not found", username);
|
||||
None
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Provider error for {}: {}", username, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn verify_token(&self, token: &str) -> Option<Session> {
|
||||
let sessions = self.sessions.lock().unwrap();
|
||||
let session = sessions.get(token)?;
|
||||
|
||||
// Check expiration
|
||||
let expires_at = chrono::DateTime::parse_from_rfc3339(&session.expires_at)
|
||||
.ok()?
|
||||
.with_timezone(&Utc);
|
||||
|
||||
@@ -6,20 +6,29 @@ pub enum SshCommand {
|
||||
Start {
|
||||
#[arg(short, long, default_value = "2024")]
|
||||
port: u16,
|
||||
|
||||
/// PostgreSQL connection string for SFTPGo-compatible auth (e.g. "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026")
|
||||
#[arg(long)]
|
||||
pg_conn: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
pub async fn handle_ssh_command(cmd: SshCommand) -> anyhow::Result<()> {
|
||||
match cmd {
|
||||
SshCommand::Start { port } => {
|
||||
SshCommand::Start { port, pg_conn } => {
|
||||
println!("=== MarkBase SSH Server (Hand-written Implementation) ===");
|
||||
println!("Port: {}", port);
|
||||
println!("Implementation: SSH-2.0-MarkBaseSSH_1.0");
|
||||
println!("Features: SSH + SFTP + SCP + rsync");
|
||||
if pg_conn.is_some() {
|
||||
println!("Auth Provider: PostgreSQL (SFTPGo-compatible)");
|
||||
} else {
|
||||
println!("Auth Provider: SQLite");
|
||||
}
|
||||
println!("Security: ⭐⭐⭐⭐⭐ (RustCrypto authoritative libraries)");
|
||||
println!();
|
||||
|
||||
crate::ssh_server::server::run_ssh_server(Some(port))?;
|
||||
|
||||
crate::ssh_server::server::run_ssh_server(Some(port), pg_conn.as_deref())?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
||||
233
markbase-core/src/config/mod.rs
Normal file
233
markbase-core/src/config/mod.rs
Normal file
@@ -0,0 +1,233 @@
|
||||
pub mod web;
|
||||
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// Re-export web config for backward compatibility
|
||||
pub use web::*;
|
||||
|
||||
/// Unified application configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AppConfig {
|
||||
#[serde(default)]
|
||||
pub web: WebSection,
|
||||
#[serde(default)]
|
||||
pub s3: S3Section,
|
||||
#[serde(default)]
|
||||
pub sftp: SftpSection,
|
||||
#[serde(default)]
|
||||
pub ssh: SshSection,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WebSection {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub log_level: String,
|
||||
pub auth_db_path: String,
|
||||
pub users_db_dir: String,
|
||||
}
|
||||
|
||||
impl Default for WebSection {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 11438,
|
||||
log_level: "info".to_string(),
|
||||
auth_db_path: "data/auth.sqlite".to_string(),
|
||||
users_db_dir: "data/users".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct S3Section {
|
||||
pub enabled: bool,
|
||||
pub endpoint: String,
|
||||
pub region: String,
|
||||
pub require_auth: bool,
|
||||
pub default_access_key: String,
|
||||
pub default_secret_key: String,
|
||||
pub keys_db_path: String,
|
||||
}
|
||||
|
||||
impl Default for S3Section {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
endpoint: "http://localhost:11438/s3".to_string(),
|
||||
region: "us-east-1".to_string(),
|
||||
require_auth: false,
|
||||
default_access_key: "markbase_access_key_001".to_string(),
|
||||
default_secret_key: "markbase_secret_key_xyz123".to_string(),
|
||||
keys_db_path: "data/s3_keys.json".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SftpSection {
|
||||
pub enabled: bool,
|
||||
pub port: u16,
|
||||
pub base_path: String,
|
||||
pub auth_db_path: String,
|
||||
pub max_connections: usize,
|
||||
pub chunk_size: usize,
|
||||
pub require_path_validation: bool,
|
||||
pub audit_logging: bool,
|
||||
pub path_traversal_protection: bool,
|
||||
pub symlink_check: bool,
|
||||
}
|
||||
|
||||
impl Default for SftpSection {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
port: 2023,
|
||||
base_path: "/Users/accusys/momentry/var/sftpgo/data".to_string(),
|
||||
auth_db_path: "data/auth.sqlite".to_string(),
|
||||
max_connections: 100,
|
||||
chunk_size: 65536,
|
||||
require_path_validation: true,
|
||||
audit_logging: true,
|
||||
path_traversal_protection: true,
|
||||
symlink_check: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SshSection {
|
||||
pub enabled: bool,
|
||||
pub port: u16,
|
||||
pub bind_address: String,
|
||||
pub security_config_path: String,
|
||||
}
|
||||
|
||||
impl Default for SshSection {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
port: 2024,
|
||||
bind_address: "127.0.0.1".to_string(),
|
||||
security_config_path: "data/ssh_config.json".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
pub fn load(path: &str) -> Result<Self> {
|
||||
let config_path = std::path::PathBuf::from(path);
|
||||
|
||||
if !config_path.exists() {
|
||||
log::warn!("Config file not found: {}, using defaults", path);
|
||||
return Ok(Self::default());
|
||||
}
|
||||
|
||||
let content = std::fs::read_to_string(&config_path)?;
|
||||
let config: AppConfig = toml::from_str(&content)?;
|
||||
log::info!("App config loaded from: {}", path);
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub fn load_default() -> Result<Self> {
|
||||
Self::load("config/app.toml")
|
||||
}
|
||||
|
||||
pub fn save(&self, path: &str) -> Result<()> {
|
||||
let config_path = std::path::PathBuf::from(path);
|
||||
|
||||
if config_path.exists() {
|
||||
let backup_path = config_path.with_extension("toml.bak");
|
||||
std::fs::copy(&config_path, &backup_path)?;
|
||||
log::info!("Backup created: {}", backup_path.display());
|
||||
}
|
||||
|
||||
let content = toml::to_string_pretty(self)?;
|
||||
std::fs::write(&config_path, content)?;
|
||||
log::info!("App config saved to: {}", path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn merge_env(&mut self) {
|
||||
if let Ok(v) = std::env::var("MB_WEB_HOST") {
|
||||
self.web.host = v;
|
||||
}
|
||||
if let Ok(v) = std::env::var("MB_WEB_PORT") {
|
||||
if let Ok(p) = v.parse() { self.web.port = p; }
|
||||
}
|
||||
if let Ok(v) = std::env::var("MB_SSH_PORT") {
|
||||
if let Ok(p) = v.parse() { self.ssh.port = p; }
|
||||
}
|
||||
if let Ok(v) = std::env::var("MB_SFTP_PORT") {
|
||||
if let Ok(p) = v.parse() { self.sftp.port = p; }
|
||||
}
|
||||
if let Ok(v) = std::env::var("MB_S3_ENABLED") {
|
||||
self.s3.enabled = v == "true" || v == "1";
|
||||
}
|
||||
if let Ok(v) = std::env::var("MB_AUTH_DB") {
|
||||
self.web.auth_db_path = v.clone();
|
||||
self.sftp.auth_db_path = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
web: WebSection::default(),
|
||||
s3: S3Section::default(),
|
||||
sftp: SftpSection::default(),
|
||||
ssh: SshSection::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = AppConfig::default();
|
||||
assert_eq!(config.web.port, 11438);
|
||||
assert_eq!(config.ssh.port, 2024);
|
||||
assert_eq!(config.sftp.port, 2023);
|
||||
assert_eq!(config.s3.region, "us-east-1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_missing() {
|
||||
let config = AppConfig::load("/tmp/nonexistent/config.toml").unwrap();
|
||||
assert_eq!(config.web.port, 11438);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_env() {
|
||||
std::env::set_var("MB_WEB_PORT", "9090");
|
||||
std::env::set_var("MB_SSH_PORT", "2222");
|
||||
|
||||
let mut config = AppConfig::default();
|
||||
config.merge_env();
|
||||
|
||||
assert_eq!(config.web.port, 9090);
|
||||
assert_eq!(config.ssh.port, 2222);
|
||||
|
||||
std::env::remove_var("MB_WEB_PORT");
|
||||
std::env::remove_var("MB_SSH_PORT");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let path = dir.path().join("test.toml");
|
||||
let path_str = path.to_string_lossy().to_string();
|
||||
|
||||
let config = AppConfig::default();
|
||||
config.save(&path_str).unwrap();
|
||||
|
||||
let loaded = AppConfig::load(&path_str).unwrap();
|
||||
assert_eq!(loaded.web.port, 11438);
|
||||
}
|
||||
}
|
||||
@@ -67,13 +67,12 @@ impl MarkBaseConfig {
|
||||
}
|
||||
|
||||
pub fn save(&self, path: &Path) -> Result<()> {
|
||||
// Create backup before saving
|
||||
if path.exists() {
|
||||
let backup_path = path.with_extension("toml.bak");
|
||||
std::fs::copy(path, &backup_path)?;
|
||||
log::info!("Backup created: {}", backup_path.display());
|
||||
}
|
||||
|
||||
|
||||
let content = toml::to_string_pretty(self)?;
|
||||
std::fs::write(path, content)?;
|
||||
log::info!("Configuration saved to: {}", path.display());
|
||||
@@ -23,6 +23,8 @@ pub mod import_markdown; // Category View Module - 双视图管理(Phase 1)
|
||||
// pub mod ssh2_mod; // ssh2辅助模块(已禁用)
|
||||
pub mod ssh_server; // SSH服务器(Phase 1-9完成,正在修复编译错误)⭐⭐⭐⭐⭐
|
||||
pub mod sync;
|
||||
pub mod provider; // DataProvider抽象层(Phase 5)
|
||||
pub mod vfs; // VFS抽象层(Phase 1-6重构计划)
|
||||
|
||||
// Re-export from external filetree crate
|
||||
pub use filetree::node::FileNode;
|
||||
|
||||
65
markbase-core/src/provider/mod.rs
Normal file
65
markbase-core/src/provider/mod.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
pub mod sqlite;
|
||||
pub mod pg;
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// 用户信息
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct User {
|
||||
pub username: String,
|
||||
pub password_hash: String,
|
||||
pub home_dir: PathBuf,
|
||||
pub uid: u32,
|
||||
pub gid: u32,
|
||||
pub permissions: String,
|
||||
pub status: i32,
|
||||
}
|
||||
|
||||
/// Provider 错误类型
|
||||
#[derive(Debug)]
|
||||
pub enum ProviderError {
|
||||
NotFound(String),
|
||||
AuthFailed(String),
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ProviderError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ProviderError::NotFound(msg) => write!(f, "Not found: {}", msg),
|
||||
ProviderError::AuthFailed(msg) => write!(f, "Authentication failed: {}", msg),
|
||||
ProviderError::Internal(msg) => write!(f, "Internal error: {}", msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ProviderError {}
|
||||
|
||||
/// 数据提供者 trait(用户认证和配置)
|
||||
pub trait DataProvider: Send + Sync {
|
||||
/// 获取用户信息
|
||||
fn get_user(&self, username: &str) -> Result<Option<User>, ProviderError>;
|
||||
|
||||
/// 验证用户密码
|
||||
fn check_password(&self, username: &str, password: &str) -> Result<bool, ProviderError>;
|
||||
|
||||
/// 获取用户主目录
|
||||
fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError>;
|
||||
|
||||
/// 获取用户组列表
|
||||
fn get_user_groups(&self, username: &str) -> Result<Vec<String>, ProviderError> {
|
||||
let _ = username;
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// 检查用户是否存在且启用
|
||||
fn user_exists(&self, username: &str) -> Result<bool, ProviderError> {
|
||||
Ok(self.get_user(username)?.map(|u| u.status == 1).unwrap_or(false))
|
||||
}
|
||||
|
||||
/// 获取用户的公开密钥列表(OpenSSH authorized_keys格式)
|
||||
fn get_public_keys(&self, username: &str) -> Result<Vec<String>, ProviderError> {
|
||||
let _ = username;
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
184
markbase-core/src/provider/pg.rs
Normal file
184
markbase-core/src/provider/pg.rs
Normal file
@@ -0,0 +1,184 @@
|
||||
use std::path::PathBuf;
|
||||
use postgres::{Client, NoTls};
|
||||
use bcrypt::verify;
|
||||
use super::{DataProvider, ProviderError, User};
|
||||
|
||||
/// PostgreSQL 数据提供者(兼容 SFTPGo 的 users 表)
|
||||
pub struct PgProvider {
|
||||
conn_str: String,
|
||||
}
|
||||
|
||||
impl PgProvider {
|
||||
/// 从连接字符串创建 PgProvider
|
||||
///
|
||||
/// 连接字符串格式:host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026
|
||||
pub fn new(conn_str: &str) -> Result<Self, ProviderError> {
|
||||
Ok(Self { conn_str: conn_str.to_string() })
|
||||
}
|
||||
|
||||
pub fn from_params(
|
||||
host: &str,
|
||||
port: u16,
|
||||
dbname: &str,
|
||||
user: &str,
|
||||
password: &str,
|
||||
) -> Result<Self, ProviderError> {
|
||||
let conn_str = format!(
|
||||
"host={} port={} dbname={} user={} password={}",
|
||||
host, port, dbname, user, password
|
||||
);
|
||||
Ok(Self { conn_str })
|
||||
}
|
||||
|
||||
fn open_conn(&self) -> Result<Client, ProviderError> {
|
||||
Client::connect(&self.conn_str, NoTls)
|
||||
.map_err(|e| ProviderError::Internal(format!("PostgreSQL connect failed: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
impl DataProvider for PgProvider {
|
||||
fn get_user(&self, username: &str) -> Result<Option<User>, ProviderError> {
|
||||
let mut conn = self.open_conn()?;
|
||||
|
||||
let result = conn.query_opt(
|
||||
"SELECT username, password, home_dir, permissions, uid, gid, status
|
||||
FROM users WHERE username = $1 AND status = 1",
|
||||
&[&username],
|
||||
).map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
|
||||
|
||||
match result {
|
||||
Some(row) => Ok(Some(User {
|
||||
username: row.get(0),
|
||||
password_hash: row.get::<_, Option<String>>(1).unwrap_or_default(),
|
||||
home_dir: PathBuf::from(row.get::<_, String>(2)),
|
||||
permissions: row.get::<_, Option<String>>(3).unwrap_or_else(|| "*".to_string()),
|
||||
uid: row.get::<_, i64>(4) as u32,
|
||||
gid: row.get::<_, i64>(5) as u32,
|
||||
status: row.get(6),
|
||||
})),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_password(&self, username: &str, password: &str) -> Result<bool, ProviderError> {
|
||||
let hash = match self.get_user(username)? {
|
||||
Some(user) => user.password_hash,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
if hash.is_empty() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
verify(password, &hash)
|
||||
.map_err(|e| ProviderError::Internal(format!("bcrypt verify error: {}", e)))
|
||||
}
|
||||
|
||||
fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
|
||||
Ok(self.get_user(username)?.map(|u| u.home_dir.to_string_lossy().to_string()))
|
||||
}
|
||||
|
||||
fn get_public_keys(&self, username: &str) -> Result<Vec<String>, ProviderError> {
|
||||
let mut conn = self.open_conn()?;
|
||||
let result = conn.query_opt(
|
||||
"SELECT public_keys FROM users WHERE username = $1 AND status = 1",
|
||||
&[&username],
|
||||
).map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
|
||||
|
||||
match result {
|
||||
Some(row) => {
|
||||
let json_str: Option<String> = row.get(0);
|
||||
match json_str {
|
||||
Some(s) if !s.is_empty() => {
|
||||
let keys: Vec<serde_json::Value> = serde_json::from_str(&s)
|
||||
.map_err(|e| ProviderError::Internal(format!("JSON parse error: {}", e)))?;
|
||||
Ok(keys.iter()
|
||||
.filter_map(|v| v.get("public_key")?.as_str().map(|s| s.to_string()))
|
||||
.collect())
|
||||
}
|
||||
_ => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
None => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pg_provider_connection() {
|
||||
// 仅当 SFTPGo PostgreSQL 可用时运行
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
);
|
||||
assert!(provider.is_ok(), "Should connect to SFTPGo PostgreSQL");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_get_user_demo() {
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
).unwrap();
|
||||
let user = provider.get_user("demo").unwrap();
|
||||
assert!(user.is_some(), "Demo user should exist");
|
||||
assert_eq!(user.unwrap().username, "demo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_get_user_momentry() {
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
).unwrap();
|
||||
let user = provider.get_user("momentry").unwrap();
|
||||
assert!(user.is_some(), "Momentry user should exist");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_get_user_warren() {
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
).unwrap();
|
||||
let user = provider.get_user("warren").unwrap();
|
||||
assert!(user.is_some(), "Warren user should exist");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_check_password_demo() {
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
).unwrap();
|
||||
let valid = provider.check_password("demo", "demo123").unwrap();
|
||||
assert!(valid, "Password should be valid");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_check_password_invalid() {
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
).unwrap();
|
||||
let valid = provider.check_password("demo", "wrong").unwrap();
|
||||
assert!(!valid, "Wrong password should fail");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_get_home_dir() {
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
).unwrap();
|
||||
let dir = provider.get_home_dir("demo").unwrap();
|
||||
assert!(dir.is_some());
|
||||
assert!(dir.unwrap().contains("momentry"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_nonexistent_user() {
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
).unwrap();
|
||||
let user = provider.get_user("__nonexistent__").unwrap();
|
||||
assert!(user.is_none());
|
||||
}
|
||||
}
|
||||
135
markbase-core/src/provider/sqlite.rs
Normal file
135
markbase-core/src/provider/sqlite.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
use std::path::PathBuf;
|
||||
use rusqlite::{Connection, params};
|
||||
use bcrypt::verify;
|
||||
use super::{DataProvider, ProviderError, User};
|
||||
|
||||
/// SQLite 数据提供者
|
||||
pub struct SqliteProvider {
|
||||
db_path: PathBuf,
|
||||
}
|
||||
|
||||
impl SqliteProvider {
|
||||
pub fn new(db_path: &str) -> Result<Self, ProviderError> {
|
||||
let path = PathBuf::from(db_path);
|
||||
if !path.exists() {
|
||||
return Err(ProviderError::NotFound(format!(
|
||||
"Database not found: {}", db_path
|
||||
)));
|
||||
}
|
||||
Ok(Self { db_path: path })
|
||||
}
|
||||
|
||||
fn open_conn(&self) -> Result<Connection, ProviderError> {
|
||||
Connection::open(&self.db_path)
|
||||
.map_err(|e| ProviderError::Internal(format!("Failed to open database: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
impl DataProvider for SqliteProvider {
|
||||
fn get_user(&self, username: &str) -> Result<Option<User>, ProviderError> {
|
||||
let conn = self.open_conn()?;
|
||||
|
||||
let result = conn.query_row(
|
||||
"SELECT username, password_hash, home_dir, permissions, uid, gid, status
|
||||
FROM sftpgo_users WHERE username = ?1 AND status = 1",
|
||||
params![username],
|
||||
|row| {
|
||||
Ok(User {
|
||||
username: row.get(0)?,
|
||||
password_hash: row.get(1)?,
|
||||
home_dir: PathBuf::from(row.get::<_, String>(2)?),
|
||||
permissions: row.get(3)?,
|
||||
uid: row.get::<_, i64>(4)? as u32,
|
||||
gid: row.get::<_, i64>(5)? as u32,
|
||||
status: row.get(6)?,
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
match result {
|
||||
Ok(user) => Ok(Some(user)),
|
||||
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
|
||||
Err(e) => Err(ProviderError::Internal(format!(
|
||||
"Database query error: {}", e
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_password(&self, username: &str, password: &str) -> Result<bool, ProviderError> {
|
||||
let hash = match self.get_user(username)? {
|
||||
Some(user) => user.password_hash,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
verify(password, &hash)
|
||||
.map_err(|e| ProviderError::Internal(format!("bcrypt verify error: {}", e)))
|
||||
}
|
||||
|
||||
fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
|
||||
Ok(self.get_user(username)?.map(|u| u.home_dir.to_string_lossy().to_string()))
|
||||
}
|
||||
|
||||
fn get_public_keys(&self, username: &str) -> Result<Vec<String>, ProviderError> {
|
||||
let _ = username;
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
fn get_user_groups(&self, username: &str) -> Result<Vec<String>, ProviderError> {
|
||||
let conn = self.open_conn()?;
|
||||
let groups: Vec<String> = conn
|
||||
.prepare("SELECT group_name FROM users_groups_mapping WHERE username = ?1")
|
||||
.map_err(|e| ProviderError::Internal(format!("Query prepare error: {}", e)))?
|
||||
.query_map(params![username], |row| row.get(0))
|
||||
.map_err(|e| ProviderError::Internal(format!("Query map error: {}", e)))?
|
||||
.filter_map(|r| r.ok())
|
||||
.collect();
|
||||
Ok(groups)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_provider_not_found() {
|
||||
let provider = SqliteProvider::new("/tmp/nonexistent.db");
|
||||
assert!(provider.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_get_user() {
|
||||
let provider = SqliteProvider::new("data/auth.sqlite").unwrap();
|
||||
let user = provider.get_user("demo").unwrap();
|
||||
assert!(user.is_some());
|
||||
assert_eq!(user.unwrap().username, "demo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_nonexistent_user() {
|
||||
let provider = SqliteProvider::new("data/auth.sqlite").unwrap();
|
||||
let user = provider.get_user("__nonexistent__").unwrap();
|
||||
assert!(user.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_password_valid() {
|
||||
let provider = SqliteProvider::new("data/auth.sqlite").unwrap();
|
||||
let valid = provider.check_password("demo", "demo123").unwrap();
|
||||
assert!(valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_password_invalid() {
|
||||
let provider = SqliteProvider::new("data/auth.sqlite").unwrap();
|
||||
let valid = provider.check_password("demo", "wrong").unwrap();
|
||||
assert!(!valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_home_dir() {
|
||||
let provider = SqliteProvider::new("data/auth.sqlite").unwrap();
|
||||
let dir = provider.get_home_dir("demo").unwrap();
|
||||
assert!(dir.is_some());
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::audio;
|
||||
use crate::auth::{AuthState, LoginRequest};
|
||||
use crate::provider::sqlite::SqliteProvider;
|
||||
use crate::render;
|
||||
use crate::download;
|
||||
use crate::archive::{self, ArchiveFormat, ArchiveProcessor, FormatDetector, ArchiveConfig, ProcessorRegistry};
|
||||
@@ -57,7 +58,10 @@ pub async fn run(port: u16, file: Option<String>) -> anyhow::Result<()> {
|
||||
}))),
|
||||
labels: Arc::new(Mutex::new(vec![])),
|
||||
db_dir: "data/users".to_string(),
|
||||
auth: AuthState::with_sync("data/auth.sqlite"),
|
||||
auth: AuthState::with_provider(Box::new(
|
||||
SqliteProvider::new("data/auth.sqlite")
|
||||
.map_err(|e| anyhow::anyhow!("Failed to init SqliteProvider: {}", e))?
|
||||
)),
|
||||
auth_db_path: "data/auth.sqlite".to_string(),
|
||||
s3_keys: Arc::new(Mutex::new(load_s3_keys())),
|
||||
};
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
// ssh2 Server核心实现
|
||||
// 替代russh,提供完整的SSH/SFTP/SCP/rsync支持
|
||||
|
||||
use crate::sftp::auth::SftpAuth;
|
||||
use crate::provider::sqlite::SqliteProvider;
|
||||
use crate::sftp::config::SftpConfig;
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{Result, anyhow, Context};
|
||||
use log::{info, warn, error};
|
||||
use ssh2::Session;
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
@@ -106,9 +106,10 @@ fn authenticate_client(session: &Session, config: &Arc<SftpConfig>) -> Result<St
|
||||
let user = "warren";
|
||||
let password = "demo123";
|
||||
|
||||
// 使用SftpAuth验证(复用现有认证系统)
|
||||
let auth = SftpAuth::new(&config.auth_db_path)?;
|
||||
if auth.verify_password(user, password)? {
|
||||
// 使用SqliteProvider验证(复用现有认证系统)
|
||||
let provider = SqliteProvider::new(&config.auth_db_path)
|
||||
.context("Failed to init SqliteProvider for ssh2_server")?;
|
||||
if provider.check_password(user, password)? {
|
||||
info!("Password auth successful for user: {}", user);
|
||||
session.userauth_password(user, password)?;
|
||||
return Ok(user.to_string());
|
||||
|
||||
@@ -1,70 +1,58 @@
|
||||
// SSH认证协议实现(Phase 5)
|
||||
// 参考OpenSSH auth.c, auth-passwd.c
|
||||
|
||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
||||
use std::io::{Read, Write}; // 导入Write trait(OpenSSH标准)
|
||||
use std::io::Write;
|
||||
use anyhow::{Result, anyhow};
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use log::{info, warn, debug};
|
||||
use rusqlite::{Connection, params};
|
||||
use bcrypt::{verify, DEFAULT_COST};
|
||||
use base64::{Engine as _, engine::general_purpose}; // Phase 9: Base64 for authorized_keys
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
|
||||
use ed25519_dalek::{VerifyingKey, Signature};
|
||||
|
||||
use crate::provider::{DataProvider, ProviderError};
|
||||
|
||||
/// SSH认证处理器(参考OpenSSH auth2.c)
|
||||
pub struct AuthHandler {
|
||||
db_path: String, // SQLite数据库路径
|
||||
provider: Box<dyn DataProvider>,
|
||||
}
|
||||
|
||||
impl AuthHandler {
|
||||
/// 创建认证处理器
|
||||
pub fn new() -> Result<Self> {
|
||||
let db_path = "data/auth.sqlite".to_string();
|
||||
|
||||
// 验证数据库是否存在
|
||||
let conn = Connection::open(&db_path)?;
|
||||
drop(conn); // rusqlite会自动关闭
|
||||
|
||||
info!("AuthHandler initialized with database: {}", db_path);
|
||||
Ok(Self { db_path })
|
||||
pub fn new(provider: Box<dyn DataProvider>) -> Self {
|
||||
info!("AuthHandler initialized with DataProvider");
|
||||
Self { provider }
|
||||
}
|
||||
|
||||
|
||||
/// 获取用户home目录(SFTPGo兼容)
|
||||
pub fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
|
||||
self.provider.get_home_dir(username)
|
||||
}
|
||||
|
||||
/// 处理SSH_MSG_USERAUTH_REQUEST(参考OpenSSH auth2.c: userauth_request())
|
||||
pub fn handle_userauth_request(&mut self, packet: &SshPacket) -> Result<AuthResult> {
|
||||
pub fn handle_userauth_request(&mut self, packet: &SshPacket, session_id: &[u8]) -> Result<AuthResult> {
|
||||
info!("Processing SSH_MSG_USERAUTH_REQUEST");
|
||||
|
||||
|
||||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
|
||||
|
||||
// Packet type
|
||||
|
||||
let packet_type = cursor.read_u8()?;
|
||||
if packet_type != PacketType::SSH_MSG_USERAUTH_REQUEST as u8 {
|
||||
return Err(anyhow!("Invalid packet type for USERAUTH_REQUEST"));
|
||||
}
|
||||
|
||||
// 读取用户名(SSH string)
|
||||
|
||||
let user = read_ssh_string(&mut cursor)?;
|
||||
|
||||
// 读取服务名称(SSH string)
|
||||
let service = read_ssh_string(&mut cursor)?;
|
||||
|
||||
// 读取认证方法名称(SSH string)
|
||||
let method = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
info!("Auth request: user={}, service={}, method={}", user, service, method);
|
||||
|
||||
// 检查服务名称(OpenSSH要求:ssh-connection)
|
||||
|
||||
if service != "ssh-connection" {
|
||||
warn!("Unsupported service: {}", service);
|
||||
return Ok(AuthResult::Failure("Unsupported service".to_string()));
|
||||
}
|
||||
|
||||
// 根据认证方法处理(参考OpenSSH auth2.c)
|
||||
|
||||
if method == "password" {
|
||||
self.handle_password_auth(&mut cursor, &user)
|
||||
} else if method == "publickey" {
|
||||
self.handle_publickey_auth(&mut cursor, &user)
|
||||
self.handle_publickey_auth(&mut cursor, &user, &service, session_id)
|
||||
} else if method == "none" {
|
||||
// OpenSSH:none认证总是失败(用于查询支持的认证方法)
|
||||
// 返回支持的认证方法列表:password, publickey
|
||||
warn!("None auth request - returning supported methods");
|
||||
Ok(AuthResult::Failure("password,publickey".to_string()))
|
||||
} else {
|
||||
@@ -72,203 +60,254 @@ impl AuthHandler {
|
||||
Ok(AuthResult::Failure("Unsupported auth method".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 处理password认证(参考OpenSSH auth-passwd.c)
|
||||
fn handle_password_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result<AuthResult> {
|
||||
info!("Handling password auth for user: {}", user);
|
||||
|
||||
// 读取是否修改密码标志(boolean,OpenSSH password认证格式)
|
||||
|
||||
let change_password = cursor.read_u8()? != 0;
|
||||
|
||||
if change_password {
|
||||
warn!("Password change not supported");
|
||||
return Ok(AuthResult::Failure("Password change not supported".to_string()));
|
||||
}
|
||||
|
||||
// 读取密码(SSH string)
|
||||
|
||||
let password = read_ssh_string(cursor)?;
|
||||
|
||||
|
||||
debug!("Password auth attempt: user={}, password length={}", user, password.len());
|
||||
|
||||
// 查询数据库获取password_hash
|
||||
let conn = Connection::open(&self.db_path)?;
|
||||
|
||||
let password_hash_result = conn.query_row(
|
||||
"SELECT password_hash FROM sftpgo_users WHERE username = ?1 AND status = 1",
|
||||
params![user],
|
||||
|row| row.get::<_, String>(0)
|
||||
);
|
||||
|
||||
// 关闭连接(rusqlite会自动关闭)
|
||||
drop(conn);
|
||||
|
||||
// 验证用户是否存在
|
||||
let password_hash = match password_hash_result {
|
||||
Ok(hash) => Some(hash),
|
||||
Err(rusqlite::Error::QueryReturnedNoRows) => None,
|
||||
Err(e) => return Err(anyhow!("Database query error: {}", e)),
|
||||
};
|
||||
|
||||
if password_hash.is_none() {
|
||||
warn!("User not found or disabled: {}", user);
|
||||
// SSH_MSG_USERAUTH_FAILURE必须返回可继续使用的认证方法列表(RFC 4253)
|
||||
return Ok(AuthResult::Failure("password,publickey".to_string()));
|
||||
}
|
||||
|
||||
// 使用bcrypt验证密码
|
||||
let stored_hash = password_hash.unwrap();
|
||||
info!("Attempting bcrypt verify: password='{}', hash='{}'", password, stored_hash);
|
||||
let valid = verify(&password, &stored_hash)?;
|
||||
info!("bcrypt verify result: {}", valid);
|
||||
|
||||
if valid {
|
||||
info!("Password auth successful for user: {}", user);
|
||||
Ok(AuthResult::Success)
|
||||
} else {
|
||||
warn!("Password auth failed for user: {}", user);
|
||||
// SSH_MSG_USERAUTH_FAILURE必须返回可继续使用的认证方法列表(RFC 4253)
|
||||
Ok(AuthResult::Failure("password,publickey".to_string()))
|
||||
|
||||
match self.provider.check_password(user, &password) {
|
||||
Ok(true) => {
|
||||
info!("Password auth successful for user: {}", user);
|
||||
Ok(AuthResult::Success)
|
||||
}
|
||||
Ok(false) => {
|
||||
warn!("Password auth failed for user: {}", user);
|
||||
Ok(AuthResult::Failure("password,publickey".to_string()))
|
||||
}
|
||||
Err(ProviderError::NotFound(msg)) => {
|
||||
warn!("User not found: {}", msg);
|
||||
Ok(AuthResult::Failure("password,publickey".to_string()))
|
||||
}
|
||||
Err(e) => {
|
||||
Err(anyhow!("Password auth error: {}", e))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 构建SSH_MSG_USERAUTH_SUCCESS packet(参考OpenSSH auth2.c)
|
||||
|
||||
/// 构建SSH_MSG_USERAUTH_SUCCESS packet
|
||||
pub fn build_userauth_success() -> Result<SshPacket> {
|
||||
let payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8];
|
||||
Ok(SshPacket::new(payload))
|
||||
}
|
||||
|
||||
/// 构建SSH_MSG_USERAUTH_FAILURE packet(参考OpenSSH auth2.c)
|
||||
|
||||
/// 构建SSH_MSG_USERAUTH_FAILURE packet
|
||||
pub fn build_userauth_failure(methods: &[String], partial_success: bool) -> Result<SshPacket> {
|
||||
let mut payload = Vec::new();
|
||||
|
||||
// Packet type
|
||||
|
||||
payload.write_u8(PacketType::SSH_MSG_USERAUTH_FAILURE as u8)?;
|
||||
|
||||
// 认证方法列表(SSH string,逗号分隔)
|
||||
|
||||
let methods_str = methods.join(",");
|
||||
payload.write_u32::<BigEndian>(methods_str.len() as u32)?;
|
||||
payload.write_all(methods_str.as_bytes())?;
|
||||
|
||||
// partial_success标志(boolean)
|
||||
|
||||
payload.write_u8(if partial_success { 1 } else { 0 })?;
|
||||
|
||||
|
||||
Ok(SshPacket::new(payload))
|
||||
}
|
||||
|
||||
/// 构建SSH_MSG_USERAUTH_BANNER packet(可选,参考OpenSSH auth2.c)
|
||||
|
||||
/// 构建SSH_MSG_USERAUTH_BANNER packet
|
||||
pub fn build_userauth_banner(message: &str, language: &str) -> Result<SshPacket> {
|
||||
let mut payload = Vec::new();
|
||||
|
||||
// Packet type
|
||||
|
||||
payload.write_u8(PacketType::SSH_MSG_USERAUTH_BANNER as u8)?;
|
||||
|
||||
// Banner message(SSH string)
|
||||
|
||||
payload.write_u32::<BigEndian>(message.len() as u32)?;
|
||||
payload.write_all(message.as_bytes())?;
|
||||
|
||||
// Language tag(SSH string)
|
||||
|
||||
payload.write_u32::<BigEndian>(language.len() as u32)?;
|
||||
payload.write_all(language.as_bytes())?;
|
||||
|
||||
|
||||
Ok(SshPacket::new(payload))
|
||||
}
|
||||
|
||||
/// 处理publickey认证(Phase 9:参考OpenSSH auth2-pubkey.c)
|
||||
fn handle_publickey_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result<AuthResult> {
|
||||
|
||||
/// 处理publickey认证(RFC 4252 §7)
|
||||
/// 支持Ed25519签名验证 + 数据库/filesystem密钥查找
|
||||
fn handle_publickey_auth(
|
||||
&mut self,
|
||||
cursor: &mut std::io::Cursor<&[u8]>,
|
||||
user: &str,
|
||||
service: &str,
|
||||
session_id: &[u8],
|
||||
) -> Result<AuthResult> {
|
||||
info!("Handling publickey auth for user: {}", user);
|
||||
|
||||
// 读取是否签名的标志(boolean)
|
||||
|
||||
let is_signed = cursor.read_u8()? != 0;
|
||||
|
||||
// 读取public key algorithm(SSH string)
|
||||
let algorithm = read_ssh_string(cursor)?;
|
||||
|
||||
// 读取public key blob(SSH string)
|
||||
let public_key_blob = read_ssh_string_bytes(cursor)?;
|
||||
|
||||
|
||||
info!("Publickey auth: algorithm={}, blob_len={}, is_signed={}", algorithm, public_key_blob.len(), is_signed);
|
||||
|
||||
// Phase 9:简化实现 - 从authorized_keys文件验证
|
||||
let authorized_keys_path = format!("data/{}/authorized_keys", user);
|
||||
let authorized_keys = match std::fs::read_to_string(&authorized_keys_path) {
|
||||
Ok(content) => content,
|
||||
Err(_) => {
|
||||
// 尝试默认路径
|
||||
let default_path = "data/authorized_keys";
|
||||
match std::fs::read_to_string(default_path) {
|
||||
Ok(content) => content,
|
||||
Err(_) => {
|
||||
warn!("No authorized_keys file found for user: {}", user);
|
||||
return Ok(AuthResult::Failure("password,publickey".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// 解析authorized_keys,查找匹配的public key
|
||||
let public_key_matches = authorized_keys.lines().any(|line| {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
return false;
|
||||
}
|
||||
|
||||
// SSH authorized_keys格式:algorithm base64-key comment
|
||||
let parts: Vec<&str> = line.split_whitespace().collect();
|
||||
if parts.len() < 2 {
|
||||
return false;
|
||||
}
|
||||
|
||||
let key_algorithm = parts[0];
|
||||
let key_base64 = parts[1];
|
||||
|
||||
// 匹配algorithm
|
||||
if key_algorithm != algorithm {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 匹配public key blob(base64解码对比)
|
||||
match base64_decode(key_base64) {
|
||||
Ok(decoded_key) => decoded_key == public_key_blob,
|
||||
Err(_) => false,
|
||||
}
|
||||
});
|
||||
|
||||
if !public_key_matches {
|
||||
|
||||
if !self.is_key_authorized(user, &algorithm, &public_key_blob)? {
|
||||
warn!("Public key not authorized for user: {}", user);
|
||||
return Ok(AuthResult::Failure("password,publickey".to_string()));
|
||||
}
|
||||
|
||||
|
||||
info!("Public key authorized for user: {}", user);
|
||||
|
||||
// 如果没有签名,返回PK_OK(query阶段)
|
||||
|
||||
if !is_signed {
|
||||
// SSH_MSG_USERAUTH_PK_OK:表示public key可接受,client需要发送签名
|
||||
return Ok(AuthResult::PublicKeyOk(algorithm, public_key_blob));
|
||||
}
|
||||
|
||||
// 读取signature(SSH string)
|
||||
let signature = read_ssh_string_bytes(cursor)?;
|
||||
|
||||
info!("Verifying signature for user: {}", user);
|
||||
|
||||
// Phase 9:简化签名验证 - 信任authorized_keys
|
||||
// 完整实现需要:提取session_id, 构建signed_data, verify signature
|
||||
// 这里简化处理:只要public key匹配authorized_keys就接受
|
||||
|
||||
|
||||
let signature_blob = read_ssh_string_bytes(cursor)?;
|
||||
|
||||
self.verify_signature(&algorithm, &public_key_blob, &signature_blob, user, service, session_id)?;
|
||||
|
||||
info!("Publickey auth successful for user: {}", user);
|
||||
Ok(AuthResult::Success)
|
||||
}
|
||||
|
||||
/// 检查public key是否在授权列表中(数据库优先,fallback到filesystem)
|
||||
fn is_key_authorized(&self, user: &str, algorithm: &str, public_key_blob: &[u8]) -> Result<bool> {
|
||||
// 1. 先检查数据库
|
||||
match self.provider.get_public_keys(user) {
|
||||
Ok(keys) => {
|
||||
for key_line in &keys {
|
||||
if public_key_matches_line(key_line, algorithm, public_key_blob) {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("Failed to get public keys from provider: {}", e),
|
||||
}
|
||||
|
||||
// 2. Fallback到filesystem
|
||||
let authorized_keys_path = format!("data/{}/authorized_keys", user);
|
||||
let content = match std::fs::read_to_string(&authorized_keys_path) {
|
||||
Ok(c) => c,
|
||||
Err(_) => match std::fs::read_to_string("data/authorized_keys") {
|
||||
Ok(c) => c,
|
||||
Err(_) => return Ok(false),
|
||||
}
|
||||
};
|
||||
|
||||
Ok(content.lines().any(|line| public_key_matches_line(line, algorithm, public_key_blob)))
|
||||
}
|
||||
|
||||
/// 验证Ed25519签名(RFC 4252 §7)
|
||||
fn verify_signature(
|
||||
&self,
|
||||
algorithm: &str,
|
||||
public_key_blob: &[u8],
|
||||
signature_blob: &[u8],
|
||||
user: &str,
|
||||
service: &str,
|
||||
session_id: &[u8],
|
||||
) -> Result<()> {
|
||||
// 目前只支援Ed25519
|
||||
if algorithm != "ssh-ed25519" {
|
||||
return Err(anyhow!("Unsupported public key algorithm: {}", algorithm));
|
||||
}
|
||||
|
||||
let verifying_key = parse_ed25519_verifying_key(public_key_blob)?;
|
||||
let signature = parse_ed25519_signature(signature_blob)?;
|
||||
|
||||
// 建立签名验证数据(RFC 4252 §7)
|
||||
let mut signed_data = Vec::new();
|
||||
|
||||
// string session identifier
|
||||
signed_data.write_u32::<BigEndian>(session_id.len() as u32)?;
|
||||
signed_data.write_all(session_id)?;
|
||||
|
||||
// byte SSH_MSG_USERAUTH_REQUEST
|
||||
signed_data.write_u8(PacketType::SSH_MSG_USERAUTH_REQUEST as u8)?;
|
||||
|
||||
// string user name
|
||||
signed_data.write_u32::<BigEndian>(user.len() as u32)?;
|
||||
signed_data.write_all(user.as_bytes())?;
|
||||
|
||||
// string service name
|
||||
signed_data.write_u32::<BigEndian>(service.len() as u32)?;
|
||||
signed_data.write_all(service.as_bytes())?;
|
||||
|
||||
// string "publickey"
|
||||
const PUBKEY_STR: &str = "publickey";
|
||||
signed_data.write_u32::<BigEndian>(PUBKEY_STR.len() as u32)?;
|
||||
signed_data.write_all(PUBKEY_STR.as_bytes())?;
|
||||
|
||||
// boolean TRUE
|
||||
signed_data.write_u8(1)?;
|
||||
|
||||
// string public key algorithm name
|
||||
signed_data.write_u32::<BigEndian>(algorithm.len() as u32)?;
|
||||
signed_data.write_all(algorithm.as_bytes())?;
|
||||
|
||||
// string public key blob
|
||||
signed_data.write_u32::<BigEndian>(public_key_blob.len() as u32)?;
|
||||
signed_data.write_all(public_key_blob)?;
|
||||
|
||||
// 验证签名
|
||||
verifying_key.verify_strict(&signed_data, &signature)
|
||||
.map_err(|e| anyhow!("Ed25519 signature verification failed: {}", e))
|
||||
}
|
||||
}
|
||||
|
||||
/// SSH认证结果(参考OpenSSH auth2.c)
|
||||
/// SSH认证结果
|
||||
pub enum AuthResult {
|
||||
Success,
|
||||
Failure(String), // 失败原因
|
||||
PartialSuccess, // 部分成功(多步骤认证)
|
||||
PublicKeyOk(String, Vec<u8>), // Public key acceptable (algorithm, blob)
|
||||
Failure(String),
|
||||
PartialSuccess,
|
||||
PublicKeyOk(String, Vec<u8>),
|
||||
}
|
||||
|
||||
/// 解析Ed25519公钥blob(SSH格式 -> VerifyingKey)
|
||||
fn parse_ed25519_verifying_key(public_key_blob: &[u8]) -> Result<VerifyingKey> {
|
||||
let mut cursor = std::io::Cursor::new(public_key_blob);
|
||||
let algorithm = read_ssh_string(&mut cursor)?;
|
||||
if algorithm != "ssh-ed25519" {
|
||||
return Err(anyhow!("Unsupported algorithm: {}", algorithm));
|
||||
}
|
||||
let key_bytes = read_ssh_string_bytes(&mut cursor)?;
|
||||
if key_bytes.len() != 32 {
|
||||
return Err(anyhow!("Invalid Ed25519 key length: {}", key_bytes.len()));
|
||||
}
|
||||
let key_array: [u8; 32] = key_bytes.try_into()
|
||||
.map_err(|_| anyhow!("Invalid Ed25519 key data"))?;
|
||||
VerifyingKey::from_bytes(&key_array)
|
||||
.map_err(|e| anyhow!("Invalid Ed25519 key: {}", e))
|
||||
}
|
||||
|
||||
/// 解析Ed25519签名blob(SSH格式 -> Signature)
|
||||
fn parse_ed25519_signature(signature_blob: &[u8]) -> Result<Signature> {
|
||||
let mut cursor = std::io::Cursor::new(signature_blob);
|
||||
let algorithm = read_ssh_string(&mut cursor)?;
|
||||
if algorithm != "ssh-ed25519" {
|
||||
return Err(anyhow!("Unsupported signature algorithm: {}", algorithm));
|
||||
}
|
||||
let sig_bytes = read_ssh_string_bytes(&mut cursor)?;
|
||||
if sig_bytes.len() != 64 {
|
||||
return Err(anyhow!("Invalid Ed25519 signature length: {}", sig_bytes.len()));
|
||||
}
|
||||
let sig_array: [u8; 64] = sig_bytes.try_into()
|
||||
.map_err(|_| anyhow!("Invalid Ed25519 signature data"))?;
|
||||
Ok(Signature::from_bytes(&sig_array))
|
||||
}
|
||||
|
||||
/// 检查一行authorized_keys格式的密钥是否匹配
|
||||
fn public_key_matches_line(line: &str, algorithm: &str, public_key_blob: &[u8]) -> bool {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
return false;
|
||||
}
|
||||
let parts: Vec<&str> = line.split_whitespace().collect();
|
||||
if parts.len() < 2 {
|
||||
return false;
|
||||
}
|
||||
if parts[0] != algorithm {
|
||||
return false;
|
||||
}
|
||||
base64_decode(parts[1]).map(|decoded| decoded == public_key_blob).unwrap_or(false)
|
||||
}
|
||||
|
||||
/// SSH string读取辅助函数
|
||||
fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
|
||||
let length = reader.read_u32::<BigEndian>()?;
|
||||
let mut buffer = vec![0u8; length as usize];
|
||||
@@ -276,7 +315,6 @@ fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
|
||||
Ok(String::from_utf8(buffer)?)
|
||||
}
|
||||
|
||||
/// SSH string读取辅助函数(bytes版本)
|
||||
fn read_ssh_string_bytes<R: std::io::Read>(reader: &mut R) -> Result<Vec<u8>> {
|
||||
let length = reader.read_u32::<BigEndian>()?;
|
||||
let mut buffer = vec![0u8; length as usize];
|
||||
@@ -284,9 +322,7 @@ fn read_ssh_string_bytes<R: std::io::Read>(reader: &mut R) -> Result<Vec<u8>> {
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
/// Base64解码辅助函数(Phase 9)
|
||||
fn base64_decode(input: &str) -> Result<Vec<u8>> {
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
general_purpose::STANDARD.decode(input)
|
||||
.map_err(|e| anyhow!("Base64 decode error: {}", e))
|
||||
}
|
||||
@@ -294,18 +330,19 @@ fn base64_decode(input: &str) -> Result<Vec<u8>> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::provider::sqlite::SqliteProvider;
|
||||
|
||||
#[test]
|
||||
fn test_userauth_success_packet() {
|
||||
let packet = AuthHandler::build_userauth_success().unwrap();
|
||||
assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_SUCCESS as u8);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_userauth_failure_packet() {
|
||||
let methods = vec!["password".to_string(), "publickey".to_string()];
|
||||
let packet = AuthHandler::build_userauth_failure(&methods, false).unwrap();
|
||||
|
||||
|
||||
assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_FAILURE as u8);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,8 @@ pub struct ChannelManager {
|
||||
next_channel_id: u32,
|
||||
/// ⭐⭐⭐⭐⭐ Phase 15.1: 待发送packet队列(用于同时发送WINDOW_ADJUST和SFTP响应)
|
||||
pub pending_packets: VecDeque<SshPacket>,
|
||||
/// 用户home目录(SFTP/SCP/rsync根目录,SFTPGo兼容)
|
||||
pub home_dir: PathBuf,
|
||||
}
|
||||
|
||||
/// Phase 14: 交互式Exec进程管理(参考OpenSSH session.c: do_exec_no_pty)
|
||||
@@ -40,11 +42,12 @@ pub struct ExecProcess {
|
||||
}
|
||||
|
||||
impl ChannelManager {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(home_dir: PathBuf) -> Self {
|
||||
Self {
|
||||
channels: HashMap::new(),
|
||||
next_channel_id: 0,
|
||||
pending_packets: VecDeque::new(),
|
||||
home_dir,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -371,9 +374,12 @@ impl ChannelManager {
|
||||
info!("⭐⭐⭐⭐⭐ [{}_EXEC_START] Starting interactive process: {}", process_type, command);
|
||||
|
||||
// 启动子进程(相当于OpenSSH fork)
|
||||
// ⭐⭐⭐⭐⭐ Phase 17: 设置工作目录为用户home_dir(SFTPGo兼容)
|
||||
let home_dir = self.home_dir.clone();
|
||||
let mut child = Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg(command)
|
||||
.current_dir(&home_dir)
|
||||
.stdin(Stdio::piped()) // ← 创建stdin管道(相当于pipe(pin))
|
||||
.stdout(Stdio::piped()) // ← 创建stdout管道(相当于pipe(pout))
|
||||
.stderr(Stdio::piped()) // ← 创建stderr管道(相当于pipe(perr))
|
||||
@@ -446,8 +452,8 @@ impl ChannelManager {
|
||||
if subsystem == "sftp" {
|
||||
info!("SFTP subsystem requested");
|
||||
|
||||
// Phase 7: 初始化SFTP handler
|
||||
let root_dir = PathBuf::from("/Users/accusys/markbase"); // 默认root目录
|
||||
// Phase 7: 初始化SFTP handler(使用用户home目录,SFTPGo兼容)
|
||||
let root_dir = self.home_dir.clone();
|
||||
|
||||
// ⭐⭐⭐⭐⭐ Phase 4: 获取 client maxpack 限制(从 Channel 中获取)
|
||||
let maxpacket = if let Some(ch) = self.channels.get(&channel) {
|
||||
@@ -456,7 +462,8 @@ impl ChannelManager {
|
||||
32768 // OpenSSH 默认值(32KB)
|
||||
};
|
||||
|
||||
let sftp_handler = SftpHandler::new(root_dir, maxpacket); // ⭐⭐⭐⭐⭐ Phase 4: 传入 maxpack
|
||||
let vfs = Box::new(crate::vfs::local_fs::LocalFs::new());
|
||||
let sftp_handler = SftpHandler::new(root_dir, vfs, maxpacket); // ⭐⭐⭐⭐⭐ Phase 4: 传入 maxpack
|
||||
|
||||
// 存储到channel
|
||||
if let Some(ch) = self.channels.get_mut(&channel) {
|
||||
@@ -952,6 +959,22 @@ impl ChannelManager {
|
||||
false
|
||||
}
|
||||
|
||||
/// Phase 17: 关闭所有子进程stdin(收到CHANNEL_EOF时调用)
|
||||
/// SCP upload需要:scp -t 等待EOF on stdin才知道数据传输完毕
|
||||
pub fn close_child_stdin(&mut self) {
|
||||
let channel_ids: Vec<u32> = self.channels.keys().copied().collect();
|
||||
for id in channel_ids {
|
||||
if let Some(channel) = self.channels.get_mut(&id) {
|
||||
if let Some(exec) = &mut channel.exec_process {
|
||||
if let Some(stdin) = exec.stdin.take() {
|
||||
drop(stdin);
|
||||
info!("⭐⭐⭐⭐⭐ [CHANNEL_EOF] Closed child stdin (channel {})", id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取channel输出(Phase 6新增)
|
||||
pub fn get_channel_output(&mut self, channel_id: u32) -> Option<Vec<u8>> {
|
||||
if let Some(channel) = self.channels.get_mut(&channel_id) {
|
||||
@@ -1283,6 +1306,7 @@ impl ChannelManager {
|
||||
|
||||
// 4. 检查stdout/stderr fd是否有数据
|
||||
let mut packets_data: Vec<(u32, Vec<u8>)> = Vec::new();
|
||||
let mut stderr_packets: Vec<(u32, Vec<u8>)> = Vec::new(); // Phase 17: stderr → CHANNEL_EXTENDED_DATA
|
||||
|
||||
for (channel_id, (stdout_idx, stderr_idx)) in channel_fds_map {
|
||||
if let Some(channel) = self.channels.get_mut(&channel_id) {
|
||||
@@ -1325,7 +1349,8 @@ impl ChannelManager {
|
||||
Ok(n) if n > 0 => {
|
||||
info!("⭐⭐⭐⭐⭐ [AFTER stderr.read] Read {} bytes from stderr (channel {})", n, channel_id);
|
||||
info!("⭐⭐⭐⭐⭐ stderr content: {:?}", &buffer[..std::cmp::min(50, n)]);
|
||||
packets_data.push((channel_id, buffer[..n].to_vec()));
|
||||
// ⭐⭐⭐⭐⭐ Phase 17: stderr → SSH_MSG_CHANNEL_EXTENDED_DATA (data_type=1)
|
||||
stderr_packets.push((channel_id, buffer[..n].to_vec()));
|
||||
}
|
||||
Ok(0) => {
|
||||
info!("stderr EOF (channel {}), closing stderr pipe", channel_id);
|
||||
@@ -1351,12 +1376,17 @@ impl ChannelManager {
|
||||
}
|
||||
|
||||
// 构建packets
|
||||
if !packets_data.is_empty() {
|
||||
if !packets_data.is_empty() || !stderr_packets.is_empty() {
|
||||
let mut packets = Vec::new();
|
||||
for (channel_id, data) in packets_data {
|
||||
let packet = self.build_channel_data(channel_id, &data)?;
|
||||
packets.push(packet);
|
||||
}
|
||||
// Phase 17: stderr → SSH_MSG_CHANNEL_EXTENDED_DATA (data_type=1)
|
||||
for (channel_id, data) in stderr_packets {
|
||||
let packet = self.build_channel_extended_data(channel_id, 1, &data)?;
|
||||
packets.push(packet);
|
||||
}
|
||||
info!("⭐⭐⭐⭐⭐ Returning {} packets (stdout/stderr data)", packets.len());
|
||||
return Ok((Some(packets), client_has_data, child_exited));
|
||||
}
|
||||
@@ -1689,13 +1719,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_channel_manager_creation() {
|
||||
let manager = ChannelManager::new();
|
||||
let manager = ChannelManager::new(PathBuf::from("/tmp"));
|
||||
assert_eq!(manager.next_channel_id, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_channel_open_confirmation() {
|
||||
let manager = ChannelManager::new();
|
||||
let manager = ChannelManager::new(PathBuf::from("/tmp"));
|
||||
let packet = manager.build_channel_open_confirmation(0, 100, 2097152, 32768).unwrap();
|
||||
|
||||
assert_eq!(packet.payload[0], PacketType::SSH_MSG_CHANNEL_OPEN_CONFIRMATION as u8);
|
||||
@@ -1703,7 +1733,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_channel_success() {
|
||||
let manager = ChannelManager::new();
|
||||
let manager = ChannelManager::new(PathBuf::from("/tmp"));
|
||||
let packet = manager.build_channel_success(0).unwrap();
|
||||
|
||||
assert_eq!(packet.payload[0], PacketType::SSH_MSG_CHANNEL_SUCCESS as u8);
|
||||
|
||||
@@ -17,6 +17,7 @@ type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// SSH加密通道管理器(参考OpenSSH struct sshcipher_ctx)
|
||||
pub struct EncryptionContext {
|
||||
pub session_id: Vec<u8>, // session identifier (exchange hash)
|
||||
pub encryption_key_ctos: Vec<u8>, // 客户端→服务器加密密钥
|
||||
pub encryption_key_stoc: Vec<u8>, // 服务器→客户端加密密钥
|
||||
pub mac_key_ctos: Vec<u8>, // 客户端→服务器MAC密钥
|
||||
@@ -32,6 +33,7 @@ pub struct EncryptionContext {
|
||||
impl Default for EncryptionContext {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
session_id: vec![0u8; 32],
|
||||
encryption_key_ctos: vec![0u8; 32],
|
||||
encryption_key_stoc: vec![0u8; 32],
|
||||
mac_key_ctos: vec![0u8; 32],
|
||||
@@ -73,6 +75,7 @@ impl EncryptionContext {
|
||||
info!("Ciphers initialized successfully");
|
||||
|
||||
Self {
|
||||
session_id: keys.session_id.clone(),
|
||||
encryption_key_ctos: keys.encryption_key_ctos.clone(),
|
||||
encryption_key_stoc: keys.encryption_key_stoc.clone(),
|
||||
mac_key_ctos: keys.mac_key_ctos.clone(),
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::path::PathBuf;
|
||||
use std::fs::{self, File};
|
||||
use std::io::Write;
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, debug, warn};
|
||||
use crate::vfs::{VfsBackend, VfsFile, VfsError};
|
||||
use crate::vfs::open_flags::OpenFlags;
|
||||
|
||||
/// MPLEX_BASE from rsync io.h
|
||||
const MPLEX_BASE: u32 = 7;
|
||||
@@ -27,23 +27,21 @@ pub(crate) enum RsyncState {
|
||||
|
||||
pub struct RsyncHandler {
|
||||
state: RsyncState,
|
||||
/// Raw input from SSH (multiplexed after version exchange)
|
||||
raw_input: Vec<u8>,
|
||||
/// Decoded rsync protocol data (after stripping multiplex)
|
||||
rsync_input: Vec<u8>,
|
||||
/// Raw rsync data to send (multiplex wrapping applied in drain_output)
|
||||
output_raw: Vec<u8>,
|
||||
dest_path: PathBuf,
|
||||
output_file: Option<File>,
|
||||
output_file: Option<Box<dyn VfsFile>>,
|
||||
total_written: u64,
|
||||
file_entries: Vec<String>,
|
||||
current_file: usize,
|
||||
protocol_version: u32,
|
||||
multiplex: bool,
|
||||
vfs: Box<dyn VfsBackend>,
|
||||
}
|
||||
|
||||
impl RsyncHandler {
|
||||
pub fn parse_rsync_command(command: &str) -> Result<Self> {
|
||||
pub fn parse_rsync_command(command: &str, vfs: Box<dyn VfsBackend>) -> Result<Self> {
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
if parts.len() < 3 || parts[0] != "rsync" {
|
||||
return Err(anyhow!("Invalid rsync command: {}", command));
|
||||
@@ -83,9 +81,9 @@ impl RsyncHandler {
|
||||
current_file: 0,
|
||||
protocol_version: 30,
|
||||
multiplex: false,
|
||||
vfs,
|
||||
};
|
||||
|
||||
// Send protocol version (4-byte LE int, no multiplex)
|
||||
handler.output_raw.extend_from_slice(&30u32.to_le_bytes());
|
||||
handler.state = RsyncState::WaitVersion;
|
||||
|
||||
@@ -129,7 +127,6 @@ impl RsyncHandler {
|
||||
}
|
||||
MSG_DONE => {
|
||||
info!("rsync: MSG_DONE received (file complete)");
|
||||
// Signal file completion by appending a sentinel to rsync_input
|
||||
self.rsync_input.extend_from_slice(b"RSYNCDONE");
|
||||
}
|
||||
9 => {
|
||||
@@ -147,7 +144,6 @@ impl RsyncHandler {
|
||||
if data.is_empty() || !self.multiplex {
|
||||
return data;
|
||||
}
|
||||
// Wrap with multiplex header (MSG_DATA)
|
||||
let header = (MPLEX_BASE << 24) | (data.len() as u32);
|
||||
let mut wrapped = Vec::with_capacity(4 + data.len());
|
||||
wrapped.extend_from_slice(&header.to_le_bytes());
|
||||
@@ -180,7 +176,6 @@ impl RsyncHandler {
|
||||
loop {
|
||||
match self.state.clone() {
|
||||
RsyncState::SendVersion => {
|
||||
// Version already sent in constructor
|
||||
self.transition(RsyncState::WaitVersion);
|
||||
}
|
||||
|
||||
@@ -206,7 +201,6 @@ impl RsyncHandler {
|
||||
|
||||
let flags = self.rsync_input[0];
|
||||
if flags == 0 {
|
||||
// End of file list
|
||||
self.rsync_input.drain(..1);
|
||||
info!("rsync: file list end ({} entries)", self.file_entries.len());
|
||||
|
||||
@@ -214,14 +208,12 @@ impl RsyncHandler {
|
||||
self.file_entries.push("file".to_string());
|
||||
}
|
||||
self.current_file = 0;
|
||||
// Enter sum head reading state
|
||||
self.transition(RsyncState::ReadSumHead { need: 20 });
|
||||
break;
|
||||
}
|
||||
|
||||
let mut pos = 1;
|
||||
|
||||
// Extended flags
|
||||
let _more_flags = if flags & 0x80 != 0 {
|
||||
if self.rsync_input.len() <= pos { break; }
|
||||
let ef = self.rsync_input[pos];
|
||||
@@ -249,7 +241,6 @@ impl RsyncHandler {
|
||||
self.file_entries.push(name);
|
||||
}
|
||||
|
||||
// Skip metadata varints
|
||||
let skip_count = if flags & 0x10 == 0 { 1 } else { 0 }
|
||||
+ if flags & 0x20 == 0 { 1 } else { 0 }
|
||||
+ if flags & 0x40 == 0 { 1 } else { 0 }
|
||||
@@ -277,9 +268,6 @@ impl RsyncHandler {
|
||||
|
||||
RsyncState::ReadSumHead { need } => {
|
||||
if self.rsync_input.len() >= need {
|
||||
// Read sum head: count, blength, s2length, remainder (4 × LE int)
|
||||
// + checksum seed (1 × LE int)
|
||||
// = 5 × 4 = 20 bytes
|
||||
let sum_count = i32::from_le_bytes([
|
||||
self.rsync_input[0], self.rsync_input[1],
|
||||
self.rsync_input[2], self.rsync_input[3],
|
||||
@@ -312,7 +300,6 @@ impl RsyncHandler {
|
||||
RsyncState::SendSumCount => {
|
||||
self.open_current_file()?;
|
||||
|
||||
// Send sum_count = 0 (4-byte LE int = we have no existing data)
|
||||
self.output_raw.extend_from_slice(&0u32.to_le_bytes());
|
||||
info!("rsync: sent sum_count=0, ready to receive file data");
|
||||
|
||||
@@ -320,22 +307,17 @@ impl RsyncHandler {
|
||||
}
|
||||
|
||||
RsyncState::ReadFileData => {
|
||||
// Data comes as raw bytes inside MSG_DATA multiplex packets.
|
||||
// MSG_DONE appends b"RSYNCDONE" to rsync_input.
|
||||
let done_marker = b"RSYNCDONE";
|
||||
if let Some(pos) = self.rsync_input.windows(done_marker.len())
|
||||
.position(|w| w == done_marker)
|
||||
{
|
||||
// Data before the marker
|
||||
if pos > 0 {
|
||||
let data = self.rsync_input[..pos].to_vec();
|
||||
self.rsync_input.drain(..pos);
|
||||
self.write_to_file(&data)?;
|
||||
}
|
||||
// Remove marker
|
||||
self.rsync_input.drain(..done_marker.len());
|
||||
|
||||
// Close file
|
||||
if let Some(mut file) = self.output_file.take() {
|
||||
if let Err(e) = file.flush() {
|
||||
warn!("rsync flush error: {}", e);
|
||||
@@ -353,11 +335,9 @@ impl RsyncHandler {
|
||||
info!("rsync ALL DONE: {} bytes written to {}",
|
||||
self.total_written, self.dest_path.display());
|
||||
} else {
|
||||
// Next file sum head
|
||||
self.transition(RsyncState::ReadSumHead { need: 20 });
|
||||
}
|
||||
} else if !self.rsync_input.is_empty() {
|
||||
// Partial data, keep it in buffer for more
|
||||
let data = self.rsync_input.clone();
|
||||
self.rsync_input.clear();
|
||||
self.write_to_file(&data)?;
|
||||
@@ -377,9 +357,11 @@ impl RsyncHandler {
|
||||
|
||||
fn open_current_file(&mut self) -> Result<()> {
|
||||
if let Some(parent) = self.dest_path.parent() {
|
||||
fs::create_dir_all(parent).ok();
|
||||
self.vfs.create_dir_all(parent, 0o755).ok();
|
||||
}
|
||||
let file = File::create(&self.dest_path)?;
|
||||
let flags = OpenFlags::new().write().create().truncate();
|
||||
let file = self.vfs.open_file(&self.dest_path, &flags)
|
||||
.map_err(|e| anyhow!("open error: {}", e))?;
|
||||
self.output_file = Some(file);
|
||||
info!("rsync: opened {} for writing", self.dest_path.display());
|
||||
Ok(())
|
||||
@@ -387,7 +369,8 @@ impl RsyncHandler {
|
||||
|
||||
fn write_to_file(&mut self, data: &[u8]) -> Result<()> {
|
||||
if let Some(file) = &mut self.output_file {
|
||||
file.write_all(data)?;
|
||||
file.write_all(data)
|
||||
.map_err(|e| anyhow!("write error: {}", e))?;
|
||||
self.total_written += data.len() as u64;
|
||||
}
|
||||
Ok(())
|
||||
@@ -426,28 +409,37 @@ fn read_varint(buf: &[u8]) -> Option<(i32, usize)> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::vfs::local_fs::LocalFs;
|
||||
|
||||
fn make_vfs() -> Box<dyn VfsBackend> {
|
||||
Box::new(LocalFs::new())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_command() {
|
||||
let h = RsyncHandler::parse_rsync_command("rsync --server -g -l -o -p -D -r -t -v --dirs . /tmp/upload.bin").unwrap();
|
||||
let h = RsyncHandler::parse_rsync_command(
|
||||
"rsync --server -g -l -o -p -D -r -t -v --dirs . /tmp/upload.bin",
|
||||
make_vfs()
|
||||
).unwrap();
|
||||
assert_eq!(h.dest_path, PathBuf::from("/tmp/upload.bin"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_command_sender() {
|
||||
let h = RsyncHandler::parse_rsync_command("rsync --server --sender -vlogDtprz . /home/user/file.txt").unwrap();
|
||||
let h = RsyncHandler::parse_rsync_command(
|
||||
"rsync --server --sender -vlogDtprz . /home/user/file.txt",
|
||||
make_vfs()
|
||||
).unwrap();
|
||||
assert_eq!(h.dest_path, PathBuf::from("/home/user/file.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_version_exchange() {
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin").unwrap();
|
||||
// Initial output: protocol version (30 as LE int)
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap();
|
||||
let output = h.drain_output();
|
||||
assert_eq!(output, b"\x1e\x00\x00\x00");
|
||||
assert_eq!(h.state, RsyncState::WaitVersion);
|
||||
|
||||
// Client sends its version (30 = 0x1E)
|
||||
h.feed(b"\x1e\x00\x00\x00").unwrap();
|
||||
assert_eq!(h.state, RsyncState::ReadFileList);
|
||||
assert!(h.multiplex);
|
||||
@@ -455,9 +447,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_version_negotiate_down() {
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin").unwrap();
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap();
|
||||
let _ = h.drain_output();
|
||||
// Client has lower version (29)
|
||||
h.feed(b"\x1d\x00\x00\x00").unwrap();
|
||||
assert_eq!(h.protocol_version, 29);
|
||||
assert_eq!(h.state, RsyncState::ReadFileList);
|
||||
@@ -471,24 +462,14 @@ mod tests {
|
||||
buf
|
||||
}
|
||||
|
||||
fn build_multiplex_done() -> Vec<u8> {
|
||||
let header = (MPLEX_BASE << 24) | 0u32; // MSG_DONE (tag=1 → raw_tag=8)
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&header.to_le_bytes());
|
||||
buf
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_list_multiplex() {
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin").unwrap();
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap();
|
||||
let _ = h.drain_output();
|
||||
// Version exchange
|
||||
h.feed(b"\x1e\x00\x00\x00").unwrap();
|
||||
assert!(h.multiplex);
|
||||
|
||||
// Build file list with multiplex wrapping
|
||||
let mut flist = Vec::new();
|
||||
// Entry: flags=0, name="test.txt\0", + 6 varints
|
||||
flist.push(0);
|
||||
flist.extend_from_slice(b"test.txt");
|
||||
flist.push(0);
|
||||
@@ -514,46 +495,40 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
write_varint(&mut flist, 33188); // mode
|
||||
write_varint(&mut flist, 501); // uid
|
||||
write_varint(&mut flist, 20); // gid
|
||||
write_varint(&mut flist, 1700000000); // time
|
||||
write_varint(&mut flist, 100); // size
|
||||
write_varint(&mut flist, 0); // checksum seed
|
||||
// End marker
|
||||
write_varint(&mut flist, 33188);
|
||||
write_varint(&mut flist, 501);
|
||||
write_varint(&mut flist, 20);
|
||||
write_varint(&mut flist, 1700000000);
|
||||
write_varint(&mut flist, 100);
|
||||
write_varint(&mut flist, 0);
|
||||
flist.push(0);
|
||||
|
||||
// Sum head (5 ints = 20 bytes) as separate multiplex packet
|
||||
let mut sum_head = Vec::new();
|
||||
sum_head.extend_from_slice(&0i32.to_le_bytes()); // count
|
||||
sum_head.extend_from_slice(&7000i32.to_le_bytes()); // blength
|
||||
sum_head.extend_from_slice(&2i32.to_le_bytes()); // s2length
|
||||
sum_head.extend_from_slice(&100i32.to_le_bytes()); // remainder
|
||||
sum_head.extend_from_slice(&42i32.to_le_bytes()); // checksum_seed
|
||||
sum_head.extend_from_slice(&0i32.to_le_bytes());
|
||||
sum_head.extend_from_slice(&7000i32.to_le_bytes());
|
||||
sum_head.extend_from_slice(&2i32.to_le_bytes());
|
||||
sum_head.extend_from_slice(&100i32.to_le_bytes());
|
||||
sum_head.extend_from_slice(&42i32.to_le_bytes());
|
||||
|
||||
// Feed file list
|
||||
h.feed(&build_multiplex(&flist)).unwrap();
|
||||
assert_eq!(h.state, RsyncState::ReadFileList); // Still reading, 0x00 end marker triggered transition
|
||||
assert_eq!(h.state, RsyncState::ReadFileList);
|
||||
assert_eq!(h.file_entries.len(), 1);
|
||||
|
||||
// Now feed sum head
|
||||
h.feed(&build_multiplex(&sum_head)).unwrap();
|
||||
assert_eq!(h.state, RsyncState::SendSumCount);
|
||||
|
||||
// Send sum count response
|
||||
let sum_resp = h.drain_output();
|
||||
assert_eq!(sum_resp.len(), 8); // 4-byte header + 4-byte int
|
||||
assert_eq!(sum_resp.len(), 8);
|
||||
assert_eq!(&sum_resp[4..8], &0u32.to_le_bytes());
|
||||
assert_eq!(h.state, RsyncState::ReadFileData);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_data_multiplex() {
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin").unwrap();
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap();
|
||||
let _ = h.drain_output();
|
||||
h.feed(b"\x1e\x00\x00\x00").unwrap(); // version
|
||||
h.feed(b"\x1e\x00\x00\x00").unwrap();
|
||||
|
||||
// Simple file list
|
||||
let mut flist = Vec::new();
|
||||
flist.push(0);
|
||||
flist.extend_from_slice(b"test.bin");
|
||||
@@ -568,7 +543,6 @@ mod tests {
|
||||
flist.push(0);
|
||||
h.feed(&build_multiplex(&flist)).unwrap();
|
||||
|
||||
// Sum head
|
||||
let mut sh = Vec::new();
|
||||
sh.extend_from_slice(&0i32.to_le_bytes());
|
||||
sh.extend_from_slice(&7000i32.to_le_bytes());
|
||||
@@ -576,16 +550,13 @@ mod tests {
|
||||
sh.extend_from_slice(&100i32.to_le_bytes());
|
||||
sh.extend_from_slice(&42i32.to_le_bytes());
|
||||
h.feed(&build_multiplex(&sh)).unwrap();
|
||||
let _ = h.drain_output(); // sum count response
|
||||
let _ = h.drain_output();
|
||||
|
||||
// File data + MSG_DONE
|
||||
let file_data = b"Hello, rsync protocol!";
|
||||
h.feed(&build_multiplex(file_data)).unwrap();
|
||||
assert_eq!(h.state, RsyncState::ReadFileData);
|
||||
|
||||
// MSG_DONE
|
||||
// MSG_DONE has tag=1, so raw_tag = MPLEX_BASE + 1 = 8
|
||||
let done_header = (MPLEX_BASE + 1) << 24; // raw_tag = 8, len = 0
|
||||
let done_header = (MPLEX_BASE + 1) << 24;
|
||||
let done_bytes = done_header.to_le_bytes();
|
||||
h.feed(&done_bytes).unwrap();
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
// SCP协议实现(Phase 8)
|
||||
// 参考OpenSSH scp.c源码
|
||||
|
||||
use crate::vfs::{VfsBackend, VfsFile, VfsError, VfsStat};
|
||||
use crate::vfs::open_flags::OpenFlags;
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, warn, debug};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{Read, Write, BufReader, BufWriter, BufRead}; // 导入BufRead trait(OpenSSH标准)
|
||||
use chrono::{DateTime, Utc};
|
||||
use std::io::{Read, Write, BufRead};
|
||||
use std::time::SystemTime;
|
||||
|
||||
/// SCP Handler(参考OpenSSH scp.c)
|
||||
pub struct ScpHandler {
|
||||
@@ -14,6 +15,7 @@ pub struct ScpHandler {
|
||||
mode: ScpMode,
|
||||
recursive: bool,
|
||||
preserve_times: bool,
|
||||
vfs: Box<dyn VfsBackend>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -23,24 +25,25 @@ pub enum ScpMode {
|
||||
}
|
||||
|
||||
impl ScpHandler {
|
||||
pub fn new(root_dir: PathBuf) -> Self {
|
||||
pub fn new(root_dir: PathBuf, vfs: Box<dyn VfsBackend>) -> Self {
|
||||
Self {
|
||||
root_dir,
|
||||
mode: ScpMode::Destination,
|
||||
recursive: false,
|
||||
preserve_times: false,
|
||||
vfs,
|
||||
}
|
||||
}
|
||||
|
||||
/// 解析SCP命令(参考OpenSSH scp.c: parse_command())
|
||||
pub fn parse_scp_command(command: &str) -> Result<Self> {
|
||||
pub fn parse_scp_command(command: &str, vfs: Box<dyn VfsBackend>) -> Result<Self> {
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
|
||||
if parts.len() < 2 || parts[0] != "scp" {
|
||||
return Err(anyhow!("Invalid SCP command: {}", command));
|
||||
}
|
||||
|
||||
let mut handler = ScpHandler::new(PathBuf::from("/tmp"));
|
||||
let mut handler = ScpHandler::new(PathBuf::from("/tmp"), vfs);
|
||||
|
||||
for part in &parts[1..] {
|
||||
match part {
|
||||
@@ -68,19 +71,19 @@ impl ScpHandler {
|
||||
|
||||
/// SCP Source Mode(scp -f,发送文件)
|
||||
fn handle_source_mode(&self, channel: &mut dyn ReadWrite) -> Result<()> {
|
||||
info!("SCP source mode: sending files from {}", self.root_dir.display()); // 使用display()(Rust标准)
|
||||
info!("SCP source mode: sending files from {}", self.root_dir.display());
|
||||
|
||||
let full_path = self.resolve_path(&self.root_dir.to_string_lossy())?;
|
||||
let stat = self.vfs.stat(&full_path)
|
||||
.map_err(|e| anyhow!("stat error: {}", e))?;
|
||||
|
||||
if full_path.is_file() {
|
||||
self.send_file(channel, &full_path)?;
|
||||
} else if full_path.is_dir() {
|
||||
if stat.is_dir {
|
||||
if !self.recursive {
|
||||
return Err(anyhow!("Directory detected but -r flag not specified"));
|
||||
}
|
||||
self.send_directory(channel, &full_path)?;
|
||||
} else {
|
||||
return Err(anyhow!("Path does not exist: {}", full_path.display()));
|
||||
self.send_file(channel, &full_path)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -88,9 +91,8 @@ impl ScpHandler {
|
||||
|
||||
/// SCP Destination Mode(scp -t,接收文件)
|
||||
fn handle_destination_mode(&mut self, channel: &mut dyn ReadWrite) -> Result<()> {
|
||||
info!("SCP destination mode: receiving files to {}", self.root_dir.display()); // 使用display()(Rust标准)
|
||||
info!("SCP destination mode: receiving files to {}", self.root_dir.display());
|
||||
|
||||
// 发送确认('\0')
|
||||
channel.write_all(&[0])?;
|
||||
channel.flush()?;
|
||||
|
||||
@@ -99,10 +101,9 @@ impl ScpHandler {
|
||||
loop {
|
||||
buffer.clear();
|
||||
|
||||
// 每次循环创建新的reader(避免borrow冲突)- OpenSSH标准
|
||||
let mut reader = BufReader::new(&mut *channel);
|
||||
let mut reader = std::io::BufReader::new(&mut *channel);
|
||||
match reader.read_line(&mut buffer)? {
|
||||
0 => break, // EOF
|
||||
0 => break,
|
||||
_ => {
|
||||
let command = buffer.trim();
|
||||
debug!("SCP command: {}", command);
|
||||
@@ -113,7 +114,6 @@ impl ScpHandler {
|
||||
Some('E') => self.handle_end_directory(channel)?,
|
||||
Some('T') => self.handle_time_command(channel, command)?,
|
||||
Some('\0') => {
|
||||
// 确认信号,继续
|
||||
continue;
|
||||
}
|
||||
_ => {
|
||||
@@ -130,28 +130,30 @@ impl ScpHandler {
|
||||
|
||||
/// 发送文件(参考OpenSSH scp.c: source())
|
||||
fn send_file(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
|
||||
let metadata = fs::metadata(path)?;
|
||||
let size = metadata.len();
|
||||
let stat = self.vfs.stat(path)
|
||||
.map_err(|e| anyhow!("stat error: {}", e))?;
|
||||
let size = stat.size;
|
||||
let filename = path.file_name().unwrap().to_string_lossy();
|
||||
|
||||
// 发送文件命令:C0644 size filename
|
||||
let command = format!("C0644 {} {}\n", size, filename);
|
||||
channel.write_all(command.as_bytes())?;
|
||||
channel.flush()?;
|
||||
|
||||
// 等待确认('\0')
|
||||
let mut ack = [0u8; 1];
|
||||
channel.read_exact(&mut ack)?;
|
||||
if ack[0] != 0 {
|
||||
return Err(anyhow!("SCP file command rejected"));
|
||||
}
|
||||
|
||||
// 发送文件内容
|
||||
let file = File::open(path)?;
|
||||
let mut reader = BufReader::new(file);
|
||||
let flags = OpenFlags::new().read();
|
||||
let mut file = self.vfs.open_file(path, &flags)
|
||||
.map_err(|e| anyhow!("open error: {}", e))?;
|
||||
|
||||
let mut buffer = vec![0u8; 8192];
|
||||
|
||||
while let Ok(n) = reader.read(&mut buffer) {
|
||||
loop {
|
||||
let n = file.read(&mut buffer)
|
||||
.map_err(|e| anyhow!("read error: {}", e))?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
@@ -160,11 +162,9 @@ impl ScpHandler {
|
||||
|
||||
channel.flush()?;
|
||||
|
||||
// 发送结束确认('\0')
|
||||
channel.write_all(&[0])?;
|
||||
channel.flush()?;
|
||||
|
||||
// 等待确认('\0')
|
||||
channel.read_exact(&mut ack)?;
|
||||
if ack[0] != 0 {
|
||||
return Err(anyhow!("SCP file transfer rejected"));
|
||||
@@ -178,35 +178,34 @@ impl ScpHandler {
|
||||
fn send_directory(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
|
||||
let dirname = path.file_name().unwrap().to_string_lossy();
|
||||
|
||||
// 发送目录命令:D0755 0 dirname
|
||||
let command = format!("D0755 0 {}\n", dirname);
|
||||
channel.write_all(command.as_bytes())?;
|
||||
channel.flush()?;
|
||||
|
||||
// 等待确认('\0')
|
||||
let mut ack = [0u8; 1];
|
||||
channel.read_exact(&mut ack)?;
|
||||
if ack[0] != 0 {
|
||||
return Err(anyhow!("SCP directory command rejected"));
|
||||
}
|
||||
|
||||
// 递归发送目录内容
|
||||
for entry in fs::read_dir(path)? {
|
||||
let entry = entry?;
|
||||
let full_path = entry.path();
|
||||
let entries = self.vfs.read_dir(path)
|
||||
.map_err(|e| anyhow!("read_dir error: {}", e))?;
|
||||
|
||||
if full_path.is_file() {
|
||||
self.send_file(channel, &full_path)?;
|
||||
} else if full_path.is_dir() && self.recursive {
|
||||
self.send_directory(channel, &full_path)?;
|
||||
for entry in &entries {
|
||||
let entry_path = path.join(&entry.name);
|
||||
|
||||
if entry.stat.is_dir {
|
||||
if self.recursive {
|
||||
self.send_directory(channel, &entry_path)?;
|
||||
}
|
||||
} else {
|
||||
self.send_file(channel, &entry_path)?;
|
||||
}
|
||||
}
|
||||
|
||||
// 发送结束目录命令:E
|
||||
channel.write_all("E\n".as_bytes())?;
|
||||
channel.flush()?;
|
||||
|
||||
// 等待确认('\0')
|
||||
channel.read_exact(&mut ack)?;
|
||||
if ack[0] != 0 {
|
||||
return Err(anyhow!("SCP end directory rejected"));
|
||||
@@ -224,31 +223,25 @@ impl ScpHandler {
|
||||
return self.send_error(channel, "Invalid file command format");
|
||||
}
|
||||
|
||||
let mode = parts[0].trim_start_matches('C');
|
||||
let mode_str = parts[0].trim_start_matches('C');
|
||||
let size: u64 = parts[1].parse()?;
|
||||
let filename = parts[2];
|
||||
|
||||
debug!("SCP receive file: mode={}, size={}, name={}", mode, size, filename);
|
||||
debug!("SCP receive file: mode={}, size={}, name={}", mode_str, size, filename);
|
||||
|
||||
// 安全性检查:文件大小限制(防止DoS)
|
||||
if size > 1024 * 1024 * 1024 { // 1GB限制
|
||||
if size > 1024 * 1024 * 1024 {
|
||||
return self.send_error(channel, "File too large (max 1GB)");
|
||||
}
|
||||
|
||||
// 创建文件
|
||||
let full_path = self.resolve_path(filename)?;
|
||||
let file = OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.open(&full_path)?;
|
||||
|
||||
// 发送确认('\0')
|
||||
let flags = OpenFlags::new().write().create().truncate();
|
||||
let mut file = self.vfs.open_file(&full_path, &flags)
|
||||
.map_err(|e| anyhow!("open error: {}", e))?;
|
||||
|
||||
channel.write_all(&[0])?;
|
||||
channel.flush()?;
|
||||
|
||||
// 接收文件内容
|
||||
let mut writer = BufWriter::new(file);
|
||||
let mut buffer = vec![0u8; 8192];
|
||||
let mut remaining = size;
|
||||
|
||||
@@ -258,25 +251,25 @@ impl ScpHandler {
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
writer.write_all(&buffer[..n])?;
|
||||
file.write_all(&buffer[..n])
|
||||
.map_err(|e| anyhow!("write error: {}", e))?;
|
||||
remaining -= n as u64;
|
||||
}
|
||||
|
||||
writer.flush()?;
|
||||
file.flush().map_err(|e| anyhow!("flush error: {}", e))?;
|
||||
|
||||
// 设置文件权限
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let mode_int: u32 = mode.parse()?;
|
||||
fs::set_permissions(&full_path, fs::Permissions::from_mode(mode_int))?;
|
||||
let mode_int: u32 = mode_str.parse()?;
|
||||
if mode_int != 0 {
|
||||
let mut set_stat = VfsStat::new();
|
||||
set_stat.mode = mode_int;
|
||||
self.vfs.set_stat(&full_path, &set_stat)
|
||||
.map_err(|e| anyhow!("set_stat error: {}", e))?;
|
||||
}
|
||||
|
||||
// 接收结束确认('\0')
|
||||
let mut ack = [0u8; 1];
|
||||
channel.read_exact(&mut ack)?;
|
||||
|
||||
// 发送确认('\0')
|
||||
channel.write_all(&[0])?;
|
||||
channel.flush()?;
|
||||
|
||||
@@ -296,24 +289,17 @@ impl ScpHandler {
|
||||
return self.send_error(channel, "Recursive flag not specified");
|
||||
}
|
||||
|
||||
let mode = parts[0].trim_start_matches('D');
|
||||
let mode_str = parts[0].trim_start_matches('D');
|
||||
let dirname = parts[2];
|
||||
|
||||
debug!("SCP receive directory: mode={}, name={}", mode, dirname);
|
||||
debug!("SCP receive directory: mode={}, name={}", mode_str, dirname);
|
||||
|
||||
// 创建目录
|
||||
let full_path = self.resolve_path(dirname)?;
|
||||
fs::create_dir_all(&full_path)?;
|
||||
|
||||
// 设置目录权限
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let mode_int: u32 = mode.parse()?;
|
||||
fs::set_permissions(&full_path, fs::Permissions::from_mode(mode_int))?;
|
||||
}
|
||||
let mode_int: u32 = mode_str.parse()?;
|
||||
self.vfs.create_dir_all(&full_path, mode_int)
|
||||
.map_err(|e| anyhow!("create_dir_all error: {}", e))?;
|
||||
|
||||
// 发送确认('\0')
|
||||
channel.write_all(&[0])?;
|
||||
channel.flush()?;
|
||||
|
||||
@@ -325,7 +311,6 @@ impl ScpHandler {
|
||||
fn handle_end_directory(&self, channel: &mut dyn ReadWrite) -> Result<()> {
|
||||
debug!("SCP end directory");
|
||||
|
||||
// 发送确认('\0')
|
||||
channel.write_all(&[0])?;
|
||||
channel.flush()?;
|
||||
|
||||
@@ -335,7 +320,6 @@ impl ScpHandler {
|
||||
/// 处理时间命令(T mtime atime)
|
||||
fn handle_time_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
|
||||
if !self.preserve_times {
|
||||
// 发送确认('\0'),但不设置时间
|
||||
channel.write_all(&[0])?;
|
||||
channel.flush()?;
|
||||
return Ok(());
|
||||
@@ -347,18 +331,14 @@ impl ScpHandler {
|
||||
return self.send_error(channel, "Invalid time command format");
|
||||
}
|
||||
|
||||
let mtime: i64 = parts[1].parse()?;
|
||||
let atime: i64 = parts[2].parse()?;
|
||||
let mtime_secs: i64 = parts[1].parse()?;
|
||||
let atime_secs: i64 = parts[2].parse()?;
|
||||
|
||||
debug!("SCP set times: mtime={}, atime={}", mtime, atime);
|
||||
debug!("SCP set times: mtime={}, atime={}", mtime_secs, atime_secs);
|
||||
|
||||
// 发送确认('\0')
|
||||
channel.write_all(&[0])?;
|
||||
channel.flush()?;
|
||||
|
||||
// 时间设置将在文件接收完成后进行
|
||||
// (这里仅记录,实际设置在handle_file_command中)
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -374,10 +354,13 @@ impl ScpHandler {
|
||||
fn resolve_path(&self, path: &str) -> Result<PathBuf> {
|
||||
let full_path = self.root_dir.join(path);
|
||||
|
||||
let canonical_path = full_path.canonicalize()
|
||||
let canonical_path = self.vfs.real_path(&full_path)
|
||||
.map_err(|e| anyhow!("Path resolution error: {}", e))?;
|
||||
|
||||
if !canonical_path.starts_with(&self.root_dir.canonicalize()?) {
|
||||
let root_canonical = self.vfs.real_path(&self.root_dir)
|
||||
.map_err(|e| anyhow!("Root path resolution error: {}", e))?;
|
||||
|
||||
if !canonical_path.starts_with(&root_canonical) {
|
||||
return Err(anyhow!("Path traversal attempt detected"));
|
||||
}
|
||||
|
||||
@@ -392,23 +375,28 @@ impl<T: Read + Write> ReadWrite for T {}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::vfs::local_fs::LocalFs;
|
||||
|
||||
fn make_handler() -> ScpHandler {
|
||||
ScpHandler::new(PathBuf::from("/tmp"), Box::new(LocalFs::new()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scp_command_parse() {
|
||||
let handler = ScpHandler::parse_scp_command("scp -t /tmp").unwrap();
|
||||
let handler = ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
assert_eq!(handler.mode, ScpMode::Destination);
|
||||
assert_eq!(handler.root_dir, PathBuf::from("/tmp"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scp_recursive_parse() {
|
||||
let handler = ScpHandler::parse_scp_command("scp -r -t /tmp").unwrap();
|
||||
let handler = ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
assert!(handler.recursive);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scp_source_parse() {
|
||||
let handler = ScpHandler::parse_scp_command("scp -f /tmp").unwrap();
|
||||
let handler = ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
assert_eq!(handler.mode, ScpMode::Source);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,9 @@ use crate::ssh_server::packet::{SshPacket, PacketType};
|
||||
use crate::ssh_server::kex::{KexResult, KexProposal};
|
||||
use crate::ssh_server::kex_complete::{KexState};
|
||||
use crate::ssh_server::auth::{AuthHandler, AuthResult};
|
||||
use crate::provider::sqlite::SqliteProvider;
|
||||
use crate::provider::pg::PgProvider;
|
||||
use crate::provider::DataProvider;
|
||||
use crate::ssh_server::channel::{ChannelManager};
|
||||
use crate::ssh_server::cipher::{EncryptionContext, EncryptedPacket};
|
||||
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1
|
||||
@@ -13,6 +16,7 @@ use crate::ssh_server::port_forward::PortForwardManager; // Phase 13
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, warn, error, debug};
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::path::PathBuf;
|
||||
use std::thread;
|
||||
use std::io::{Read, Write};
|
||||
use std::sync::{Arc, Mutex}; // Phase 13: 端口转发线程同步
|
||||
@@ -22,6 +26,7 @@ pub struct SshServerConfig {
|
||||
pub port: u16,
|
||||
pub bind_address: String,
|
||||
pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置
|
||||
pub pg_conn: Option<String>, // PostgreSQL连接字符串(SFTPGo兼容认证)
|
||||
}
|
||||
|
||||
impl Default for SshServerConfig {
|
||||
@@ -30,6 +35,7 @@ impl Default for SshServerConfig {
|
||||
port: 2024,
|
||||
bind_address: "127.0.0.1".to_string(),
|
||||
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1
|
||||
pg_conn: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -42,6 +48,7 @@ impl SshServerConfig {
|
||||
port: 2024,
|
||||
bind_address: "127.0.0.1".to_string(),
|
||||
security_config: config,
|
||||
pg_conn: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -73,6 +80,7 @@ impl SshServer {
|
||||
self.config.security_config.max_sessions);
|
||||
|
||||
let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置
|
||||
let pg_conn = self.config.pg_conn.clone();
|
||||
|
||||
for stream in listener.incoming() {
|
||||
match stream {
|
||||
@@ -81,9 +89,10 @@ impl SshServer {
|
||||
info!("New SSH connection from {}", client_addr);
|
||||
|
||||
let security_config_clone = security_config.clone(); // Phase 13.1
|
||||
let pg_conn_clone = pg_conn.clone();
|
||||
|
||||
thread::spawn(move || {
|
||||
if let Err(e) = handle_connection_complete(stream, security_config_clone) { // Phase 13.1
|
||||
if let Err(e) = handle_connection_complete(stream, security_config_clone, pg_conn_clone) { // Phase 13.1
|
||||
error!("Connection error: {}", e);
|
||||
}
|
||||
});
|
||||
@@ -99,7 +108,7 @@ impl SshServer {
|
||||
}
|
||||
|
||||
/// 处理完整SSH连接(Phase 1-13完整流程)
|
||||
fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshSecurityConfig>>) -> Result<()> {
|
||||
fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshSecurityConfig>>, pg_conn: Option<String>) -> Result<()> {
|
||||
info!("Handling client connection (Phase 1-13 complete flow with port forwarding)");
|
||||
|
||||
// Phase 13.1: 增加活动会话数
|
||||
@@ -122,13 +131,22 @@ fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshS
|
||||
let mut encryption_ctx = perform_complete_kex_exchange(&mut stream, client_version.clone(), kex_result, server_kexinit, client_kexinit)?;
|
||||
info!("Key exchange completed, encryption channel ready");
|
||||
|
||||
// Phase 5: SSH认证(参考OpenSSH auth2.c)
|
||||
let mut auth_handler = AuthHandler::new()?;
|
||||
// Phase 5: SSH认证(SFTPGo兼容 — PostgreSQL或SQLite)
|
||||
let provider: Box<dyn DataProvider> = if let Some(ref conn_str) = pg_conn {
|
||||
info!("Using PostgreSQL auth provider (SFTPGo-compatible): {}", conn_str);
|
||||
Box::new(PgProvider::new(conn_str)
|
||||
.map_err(|e| anyhow!("Failed to init PgProvider: {}", e))?)
|
||||
} else {
|
||||
info!("Using SQLite auth provider");
|
||||
Box::new(SqliteProvider::new("data/auth.sqlite")
|
||||
.map_err(|e| anyhow!("Failed to init SqliteProvider: {}", e))?)
|
||||
};
|
||||
let mut auth_handler = AuthHandler::new(provider);
|
||||
let auth_user = perform_ssh_auth(&mut stream, &mut auth_handler, &mut encryption_ctx)?;
|
||||
info!("SSH authentication succeeded: user={}", auth_user);
|
||||
info!("SSH authentication succeeded: user={}", auth_user.username);
|
||||
|
||||
// Phase 6: SSH Channel管理(参考OpenSSH channel.c)
|
||||
let mut channel_manager = ChannelManager::new();
|
||||
let mut channel_manager = ChannelManager::new(auth_user.home_dir.clone());
|
||||
|
||||
// Phase 13: PortForwardManager初始化
|
||||
let mut port_forward_manager = PortForwardManager::new();
|
||||
@@ -226,11 +244,16 @@ fn perform_complete_kex_exchange(
|
||||
}
|
||||
|
||||
/// SSH认证流程(Phase 5)
|
||||
pub struct AuthUser {
|
||||
pub username: String,
|
||||
pub home_dir: PathBuf,
|
||||
}
|
||||
|
||||
fn perform_ssh_auth(
|
||||
stream: &mut TcpStream,
|
||||
auth_handler: &mut AuthHandler,
|
||||
encryption_ctx: &mut EncryptionContext,
|
||||
) -> Result<String> {
|
||||
) -> Result<AuthUser> {
|
||||
info!("Starting SSH authentication");
|
||||
info!("Encryption context: key_ctos_len={}, key_stoc_len={}, iv_ctos_len={}, iv_stoc_len={}",
|
||||
encryption_ctx.encryption_key_ctos.len(),
|
||||
@@ -279,6 +302,8 @@ fn perform_ssh_auth(
|
||||
encrypted_accept.write(stream)?;
|
||||
info!("Sent encrypted SSH_MSG_SERVICE_ACCEPT");
|
||||
|
||||
let session_id = encryption_ctx.session_id.clone();
|
||||
|
||||
loop {
|
||||
let auth_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; // Reading from client, use cipher_ctos
|
||||
let auth_payload = auth_packet.payload();
|
||||
@@ -286,7 +311,7 @@ fn perform_ssh_auth(
|
||||
|
||||
let auth_request = SshPacket::new(auth_payload.to_vec());
|
||||
|
||||
match auth_handler.handle_userauth_request(&auth_request)? {
|
||||
match auth_handler.handle_userauth_request(&auth_request, &session_id)? {
|
||||
AuthResult::Success => {
|
||||
let success_payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8];
|
||||
let encrypted_success = EncryptedPacket::new(
|
||||
@@ -297,7 +322,16 @@ fn perform_ssh_auth(
|
||||
encrypted_success.write(stream)?;
|
||||
info!("Sent encrypted SSH_MSG_USERAUTH_SUCCESS");
|
||||
|
||||
return Ok("demo".to_string());
|
||||
// Extract username from auth request
|
||||
let user = extract_username_from_auth_request(&auth_request)
|
||||
.unwrap_or_else(|_| "unknown".to_string());
|
||||
let home_dir = auth_handler.get_home_dir(&user)
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from("/Users/accusys/markbase"));
|
||||
info!("Auth success: user={}, home_dir={:?}", user, home_dir);
|
||||
return Ok(AuthUser { username: user, home_dir });
|
||||
}
|
||||
AuthResult::Failure(message) => {
|
||||
// message包含可用的认证方法列表(如"password,publickey")
|
||||
@@ -519,7 +553,9 @@ fn handle_ssh_service_loop(
|
||||
}
|
||||
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_EOF as u8 => {
|
||||
info!("Received SSH_MSG_CHANNEL_EOF");
|
||||
// EOF means client won't send more data, just acknowledge and continue
|
||||
// Phase 17: EOF means client won't send more data → close child stdin
|
||||
// (Essential for SCP upload where scp -t waits for EOF on stdin)
|
||||
channel_manager.close_child_stdin();
|
||||
}
|
||||
Some(&pt) if pt == PacketType::SSH_MSG_DISCONNECT as u8 => {
|
||||
info!("Received SSH_MSG_DISCONNECT");
|
||||
@@ -543,12 +579,27 @@ fn handle_ssh_service_loop(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 从SSH_MSG_USERAUTH_REQUEST payload中提取用户名
|
||||
fn extract_username_from_auth_request(packet: &crate::ssh_server::packet::SshPacket) -> Result<String> {
|
||||
let payload = &packet.payload;
|
||||
if payload.len() < 5 {
|
||||
return Err(anyhow!("Auth request too short"));
|
||||
}
|
||||
let name_len = u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]) as usize;
|
||||
if payload.len() < 5 + name_len {
|
||||
return Err(anyhow!("Auth request truncated"));
|
||||
}
|
||||
let username = String::from_utf8_lossy(&payload[5..5 + name_len]).to_string();
|
||||
Ok(username)
|
||||
}
|
||||
|
||||
/// SSH服务器CLI入口
|
||||
pub fn run_ssh_server(port: Option<u16>) -> Result<()> {
|
||||
pub fn run_ssh_server(port: Option<u16>, pg_conn: Option<&str>) -> Result<()> {
|
||||
let config = SshServerConfig {
|
||||
port: port.unwrap_or(2024),
|
||||
bind_address: "127.0.0.1".to_string(),
|
||||
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置
|
||||
pg_conn: pg_conn.map(|s| s.to_string()),
|
||||
};
|
||||
|
||||
let server = SshServer::new(config);
|
||||
|
||||
@@ -2,14 +2,16 @@
|
||||
// 参考OpenSSH sftp-server.c和draft-ietf-secsh-filexfer-02.txt
|
||||
|
||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
||||
use crate::vfs::{VfsBackend, VfsFile, VfsDirEntry};
|
||||
use crate::vfs::open_flags::OpenFlags;
|
||||
use anyhow::{Result, anyhow, Context};
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use log::{info, warn, debug};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{Read, Write, Seek, SeekFrom};
|
||||
use std::os::unix::fs::PermissionsExt; // 导入PermissionsExt trait(Unix标准)
|
||||
use std::os::unix::fs::MetadataExt; // ⭐⭐⭐⭐⭐ Phase 2.2: 导入MetadataExt trait(获取uid/gid)
|
||||
use std::fs;
|
||||
use std::io::{SeekFrom, Write};
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::os::unix::fs::MetadataExt;
|
||||
|
||||
/// SFTP packet类型(参考draft-ietf-secsh-filexfer-02.txt)
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -178,6 +180,30 @@ impl SftpAttrs {
|
||||
attrs
|
||||
}
|
||||
|
||||
pub fn from_vfs_stat(stat: &crate::vfs::VfsStat) -> Self {
|
||||
let mut attrs = Self::new();
|
||||
|
||||
attrs.flags = SftpAttrFlags::SSH_FILEXFER_ATTR_SIZE
|
||||
| SftpAttrFlags::SSH_FILEXFER_ATTR_UIDGID
|
||||
| SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS
|
||||
| SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME;
|
||||
|
||||
attrs.size = Some(stat.size);
|
||||
attrs.permissions = Some(stat.mode);
|
||||
attrs.uid = Some(stat.uid);
|
||||
attrs.gid = Some(stat.gid);
|
||||
|
||||
if let Ok(d) = stat.atime.duration_since(std::time::UNIX_EPOCH) {
|
||||
attrs.atime = Some(d.as_secs() as u32);
|
||||
}
|
||||
|
||||
if let Ok(d) = stat.mtime.duration_since(std::time::UNIX_EPOCH) {
|
||||
attrs.mtime = Some(d.as_secs() as u32);
|
||||
}
|
||||
|
||||
attrs
|
||||
}
|
||||
|
||||
pub fn serialize(&self) -> Result<Vec<u8>> {
|
||||
debug!("Serializing SftpAttrs: flags=0x{:08x}, size={:?}, uid={:?}, gid={:?}, permissions=0x{:08x}, atime={:?}, mtime={:?}",
|
||||
self.flags, self.size, self.uid, self.gid,
|
||||
@@ -242,13 +268,12 @@ impl SftpAttrs {
|
||||
}
|
||||
|
||||
/// SFTP handle(文件或目录句柄)
|
||||
#[derive(Debug)] // 移除Clone(File/DirEntry不支持Clone)
|
||||
pub struct SftpHandle {
|
||||
pub id: u32,
|
||||
pub path: PathBuf,
|
||||
pub handle_type: SftpHandleType,
|
||||
pub file: Option<File>,
|
||||
pub dir_entries: Option<Vec<fs::DirEntry>>,
|
||||
pub file: Option<Box<dyn VfsFile>>,
|
||||
pub dir_entries: Option<Vec<VfsDirEntry>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -260,6 +285,7 @@ pub enum SftpHandleType {
|
||||
/// SFTP处理管理器(参考OpenSSH sftp-server.c)
|
||||
pub struct SftpHandler {
|
||||
root_dir: PathBuf,
|
||||
vfs: Box<dyn VfsBackend>,
|
||||
next_handle_id: u32,
|
||||
handles: std::collections::HashMap<u32, SftpHandle>,
|
||||
// ⭐⭐⭐⭐⭐ Phase 4: 添加 client maxpack 限制(参考OpenSSH sftp-server.c)
|
||||
@@ -277,14 +303,15 @@ impl SftpHandler {
|
||||
const MAX_HASH_SIZE: u64 = 268_435_456;
|
||||
|
||||
// ⭐⭐⭐⭐⭐ Phase 4: 修改 new() 方法,接受 maxpack 参数
|
||||
pub fn new(root_dir: PathBuf, maxpacket: u32) -> Self {
|
||||
pub fn new(root_dir: PathBuf, vfs: Box<dyn VfsBackend>, maxpacket: u32) -> Self {
|
||||
let canonical_root = root_dir.canonicalize().unwrap_or(root_dir);
|
||||
Self {
|
||||
root_dir: canonical_root,
|
||||
vfs,
|
||||
next_handle_id: 0,
|
||||
handles: std::collections::HashMap::new(),
|
||||
maxpacket,
|
||||
restrict_absolute: false, // 默认允许绝对路径
|
||||
restrict_absolute: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -360,30 +387,9 @@ impl SftpHandler {
|
||||
info!("SSH_FXP_OPEN: id={}, path={}, pflags={:#x}", id, path, pflags);
|
||||
|
||||
let full_path = self.resolve_path(&path)?;
|
||||
let flags = OpenFlags::from_sftp_pflags(pflags);
|
||||
|
||||
let file_result = if pflags & SftpFileFlags::SSH_FXF_READ != 0 {
|
||||
OpenOptions::new().read(true).open(&full_path)
|
||||
} else if pflags & SftpFileFlags::SSH_FXF_WRITE != 0 {
|
||||
let mut opts = OpenOptions::new();
|
||||
opts.write(true);
|
||||
if pflags & SftpFileFlags::SSH_FXF_APPEND != 0 {
|
||||
opts.append(true);
|
||||
}
|
||||
if pflags & SftpFileFlags::SSH_FXF_CREAT != 0 {
|
||||
opts.create(true);
|
||||
}
|
||||
if pflags & SftpFileFlags::SSH_FXF_TRUNC != 0 {
|
||||
opts.truncate(true);
|
||||
}
|
||||
if pflags & SftpFileFlags::SSH_FXF_EXCL != 0 {
|
||||
opts.create_new(true);
|
||||
}
|
||||
opts.open(&full_path)
|
||||
} else {
|
||||
return self.build_status_response(id, SftpStatus::SSH_FX_OP_UNSUPPORTED, "Unsupported open flags");
|
||||
};
|
||||
|
||||
match file_result {
|
||||
match self.vfs.open_file(&full_path, &flags) {
|
||||
Ok(file) => {
|
||||
if self.handles.len() >= Self::MAX_HANDLES {
|
||||
warn!("SSH_FXP_OPEN: handle limit reached ({})", Self::MAX_HANDLES);
|
||||
@@ -405,7 +411,7 @@ impl SftpHandler {
|
||||
self.build_handle_response(id, &handle_id.to_be_bytes())
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -447,9 +453,8 @@ impl SftpHandler {
|
||||
|
||||
if let Some(handle) = self.handles.get_mut(&handle_id) {
|
||||
if let Some(ref mut file) = handle.file {
|
||||
file.seek(SeekFrom::Start(offset))?;
|
||||
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
|
||||
|
||||
// ⭐⭐⭐⭐⭐ Phase 4: 限制数据大小,不超过 maxpacket - 1024 和 MAX_XFER_SIZE
|
||||
let max_data_size = std::cmp::min(self.maxpacket.saturating_sub(1024), Self::MAX_XFER_SIZE);
|
||||
let actual_length = std::cmp::min(length, max_data_size);
|
||||
|
||||
@@ -465,7 +470,7 @@ impl SftpHandler {
|
||||
self.build_data_response(id, &buffer)
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -491,7 +496,6 @@ impl SftpHandler {
|
||||
|
||||
info!("SSH_FXP_WRITE: id={}, handle={}, offset={}, length={}", id, handle_id, offset, write_data.len());
|
||||
|
||||
// ⭐⭐⭐⭐⭐ Phase 1.2: 添加 data preview(显示前 20 字节)
|
||||
if write_data.len() > 0 {
|
||||
let preview_len = std::cmp::min(20, write_data.len());
|
||||
let preview = &write_data[0..preview_len];
|
||||
@@ -500,14 +504,15 @@ impl SftpHandler {
|
||||
|
||||
if let Some(handle) = self.handles.get_mut(&handle_id) {
|
||||
if let Some(ref mut file) = handle.file {
|
||||
file.seek(SeekFrom::Start(offset))?;
|
||||
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
|
||||
|
||||
match file.write_all(&write_data) {
|
||||
Ok(_) => {
|
||||
file.flush().ok();
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Write successful")
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -532,13 +537,13 @@ impl SftpHandler {
|
||||
|
||||
let full_path = self.resolve_path(&path)?;
|
||||
|
||||
match fs::symlink_metadata(&full_path) {
|
||||
Ok(metadata) => {
|
||||
let attrs = SftpAttrs::from_metadata(&metadata);
|
||||
match self.vfs.lstat(&full_path) {
|
||||
Ok(stat) => {
|
||||
let attrs = SftpAttrs::from_vfs_stat(&stat);
|
||||
self.build_attrs_response(id, &attrs)
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_NO_SUCH_FILE, &format!("Stat error: {}", e))
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -556,14 +561,26 @@ impl SftpHandler {
|
||||
|
||||
info!("SSH_FXP_FSTAT: id={}, handle={}", id, handle_id);
|
||||
|
||||
if let Some(handle) = self.handles.get(&handle_id) {
|
||||
match fs::metadata(&handle.path) {
|
||||
Ok(metadata) => {
|
||||
let attrs = SftpAttrs::from_metadata(&metadata);
|
||||
self.build_attrs_response(id, &attrs)
|
||||
if let Some(handle) = self.handles.get_mut(&handle_id) {
|
||||
if let Some(ref mut file) = handle.file {
|
||||
match file.stat() {
|
||||
Ok(stat) => {
|
||||
let attrs = SftpAttrs::from_vfs_stat(&stat);
|
||||
self.build_attrs_response(id, &attrs)
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
} else {
|
||||
match self.vfs.stat(&handle.path) {
|
||||
Ok(stat) => {
|
||||
let attrs = SftpAttrs::from_vfs_stat(&stat);
|
||||
self.build_attrs_response(id, &attrs)
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -585,7 +602,7 @@ impl SftpHandler {
|
||||
|
||||
let full_path = self.resolve_path(&path)?;
|
||||
|
||||
match fs::read_dir(&full_path) {
|
||||
match self.vfs.read_dir(&full_path) {
|
||||
Ok(entries) => {
|
||||
if self.handles.len() >= Self::MAX_HANDLES {
|
||||
warn!("SSH_FXP_OPENDIR: handle limit reached ({})", Self::MAX_HANDLES);
|
||||
@@ -594,14 +611,12 @@ impl SftpHandler {
|
||||
let handle_id = self.next_handle_id;
|
||||
self.next_handle_id += 1;
|
||||
|
||||
let dir_entries: Vec<fs::DirEntry> = entries.filter_map(|e| e.ok()).collect();
|
||||
|
||||
let handle = SftpHandle {
|
||||
id: handle_id,
|
||||
path: full_path,
|
||||
handle_type: SftpHandleType::Directory,
|
||||
file: None,
|
||||
dir_entries: Some(dir_entries),
|
||||
dir_entries: Some(entries),
|
||||
};
|
||||
|
||||
self.handles.insert(handle_id, handle);
|
||||
@@ -609,7 +624,7 @@ impl SftpHandler {
|
||||
self.build_handle_response(id, &handle_id.to_be_bytes())
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -635,11 +650,9 @@ impl SftpHandler {
|
||||
} else {
|
||||
let entries: Vec<(String, SftpAttrs)> = dir_entries
|
||||
.drain(..std::cmp::min(100, dir_entries.len()))
|
||||
.filter_map(|entry| {
|
||||
let name = entry.file_name().to_string_lossy().to_string();
|
||||
let attrs = entry.metadata().ok()?;
|
||||
let sftp_attrs = SftpAttrs::from_metadata(&attrs);
|
||||
Some((name, sftp_attrs))
|
||||
.map(|entry| {
|
||||
let attrs = SftpAttrs::from_vfs_stat(&entry.stat);
|
||||
(entry.name, attrs)
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -670,12 +683,12 @@ impl SftpHandler {
|
||||
|
||||
let full_path = self.resolve_path(&path)?;
|
||||
|
||||
match fs::remove_file(&full_path) {
|
||||
match self.vfs.remove_file(&full_path) {
|
||||
Ok(_) => {
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_OK, "File removed")
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -695,12 +708,12 @@ impl SftpHandler {
|
||||
|
||||
let full_path = self.resolve_path(&path)?;
|
||||
|
||||
match fs::create_dir(&full_path) {
|
||||
match self.vfs.create_dir(&full_path, 0o755) {
|
||||
Ok(_) => {
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Directory created")
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -719,12 +732,12 @@ impl SftpHandler {
|
||||
|
||||
let full_path = self.resolve_path(&path)?;
|
||||
|
||||
match fs::remove_dir(&full_path) {
|
||||
match self.vfs.remove_dir(&full_path) {
|
||||
Ok(_) => {
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Directory removed")
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -765,13 +778,13 @@ impl SftpHandler {
|
||||
|
||||
let full_path = self.resolve_path(&path)?;
|
||||
|
||||
match fs::metadata(&full_path) {
|
||||
Ok(metadata) => {
|
||||
let attrs = SftpAttrs::from_metadata(&metadata);
|
||||
match self.vfs.stat(&full_path) {
|
||||
Ok(stat) => {
|
||||
let attrs = SftpAttrs::from_vfs_stat(&stat);
|
||||
self.build_attrs_response(id, &attrs)
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_NO_SUCH_FILE, &format!("Stat error: {}", e))
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -792,12 +805,12 @@ impl SftpHandler {
|
||||
let old_full_path = self.resolve_path(&old_path)?;
|
||||
let new_full_path = self.resolve_path(&new_path)?;
|
||||
|
||||
match fs::rename(&old_full_path, &new_full_path) {
|
||||
match self.vfs.rename(&old_full_path, &new_full_path) {
|
||||
Ok(_) => {
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Rename successful")
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -832,7 +845,7 @@ impl SftpHandler {
|
||||
|
||||
info!("SSH_FXP_FSETSTAT: id={}, handle={}, attrs.flags={}", id, handle_id, attrs.flags);
|
||||
|
||||
let handle = self.handles.get(&handle_id);
|
||||
let handle = self.handles.get_mut(&handle_id);
|
||||
if handle.is_none() {
|
||||
return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid handle");
|
||||
}
|
||||
@@ -847,25 +860,35 @@ impl SftpHandler {
|
||||
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_SIZE != 0 {
|
||||
if let Some(size) = attrs.size {
|
||||
info!("FSETSTAT: setting file size to {}", size);
|
||||
let file = OpenOptions::new().write(true).open(&path)?;
|
||||
file.set_len(size)?;
|
||||
if let Some(ref mut file) = handle.file {
|
||||
file.set_len(size).map_err(|e| anyhow!("set_len error: {}", e))?;
|
||||
} else {
|
||||
let flags = OpenFlags::new().write();
|
||||
if let Ok(mut f) = self.vfs.open_file(&path, &flags) {
|
||||
f.set_len(size).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0 {
|
||||
if let Some(permissions) = attrs.permissions {
|
||||
info!("FSETSTAT: setting permissions to {:o}", permissions);
|
||||
fs::set_permissions(&path, fs::Permissions::from_mode(permissions))?;
|
||||
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0
|
||||
|| attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0
|
||||
{
|
||||
let mut vfs_stat = crate::vfs::VfsStat::new();
|
||||
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0 {
|
||||
vfs_stat.mode = attrs.permissions.unwrap_or(0);
|
||||
} else {
|
||||
if let Ok(s) = self.vfs.lstat(&path) {
|
||||
vfs_stat.mode = s.mode;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0 {
|
||||
if let (Some(atime), Some(mtime)) = (attrs.atime, attrs.mtime) {
|
||||
info!("FSETSTAT: setting atime={}, mtime={}", atime, mtime);
|
||||
let atime_filetime = filetime::FileTime::from_unix_time(atime as i64, 0);
|
||||
let mtime_filetime = filetime::FileTime::from_unix_time(mtime as i64, 0);
|
||||
filetime::set_file_times(&path, atime_filetime, mtime_filetime)?;
|
||||
if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0 {
|
||||
if let (Some(atime), Some(mtime)) = (attrs.atime, attrs.mtime) {
|
||||
vfs_stat.atime = std::time::UNIX_EPOCH + std::time::Duration::from_secs(atime as u64);
|
||||
vfs_stat.mtime = std::time::UNIX_EPOCH + std::time::Duration::from_secs(mtime as u64);
|
||||
}
|
||||
}
|
||||
self.vfs.set_stat(&path, &vfs_stat).ok();
|
||||
}
|
||||
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Fsetstat successful")
|
||||
@@ -885,13 +908,13 @@ impl SftpHandler {
|
||||
|
||||
let full_path = self.resolve_path(&path)?;
|
||||
|
||||
match fs::read_link(&full_path) {
|
||||
match self.vfs.read_link(&full_path) {
|
||||
Ok(link_target) => {
|
||||
let target = link_target.to_string_lossy().to_string();
|
||||
self.build_name_response(id, vec![(target, SftpAttrs::default())])
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -912,18 +935,14 @@ impl SftpHandler {
|
||||
let full_linkpath = self.resolve_path(&linkpath)?;
|
||||
let full_targetpath = self.resolve_path(&targetpath)?;
|
||||
|
||||
#[cfg(unix)]
|
||||
match std::os::unix::fs::symlink(&full_targetpath, &full_linkpath) {
|
||||
match self.vfs.create_symlink(&full_targetpath, &full_linkpath) {
|
||||
Ok(_) => {
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Symlink created")
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Symlink not supported on non-Unix systems")
|
||||
}
|
||||
|
||||
/// 处理SSH_FXP_EXTENDED(Phase 10:参考OpenSSH sftp-server.c: process_extended())
|
||||
@@ -984,50 +1003,30 @@ impl SftpHandler {
|
||||
|
||||
let full_path = self.resolve_path(&path)?;
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::MetadataExt;
|
||||
|
||||
match fs::metadata(&full_path) {
|
||||
Ok(metadata) => {
|
||||
// 构建statvfs response(参考OpenSSH sftp-server.c)
|
||||
let mut response = Vec::new();
|
||||
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
|
||||
response.write_u32::<BigEndian>(id)?;
|
||||
|
||||
// f_bsize(文件系统块大小)
|
||||
response.write_u64::<BigEndian>(4096)?;
|
||||
// f_frsize(基本块大小)
|
||||
response.write_u64::<BigEndian>(4096)?;
|
||||
// f_blocks(总块数)
|
||||
response.write_u64::<BigEndian>(1000000)?;
|
||||
// f_bfree(空闲块数)
|
||||
response.write_u64::<BigEndian>(500000)?;
|
||||
// f_bavail(可用块数)
|
||||
response.write_u64::<BigEndian>(500000)?;
|
||||
// f_files(总文件数)
|
||||
response.write_u64::<BigEndian>(100000)?;
|
||||
// f_ffree(空闲文件数)
|
||||
response.write_u64::<BigEndian>(50000)?;
|
||||
// f_favail(可用文件数)
|
||||
response.write_u64::<BigEndian>(50000)?;
|
||||
// f_fsid(文件系统ID)
|
||||
response.write_u64::<BigEndian>(0)?;
|
||||
// f_flag(标志)
|
||||
response.write_u64::<BigEndian>(0)?;
|
||||
// f_namemax(文件名最大长度)
|
||||
response.write_u64::<BigEndian>(255)?;
|
||||
|
||||
self.wrap_sftp_packet(&response)
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
}
|
||||
match self.vfs.stat(&full_path) {
|
||||
Ok(_) => {
|
||||
let mut response = Vec::new();
|
||||
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
|
||||
response.write_u32::<BigEndian>(id)?;
|
||||
|
||||
response.write_u64::<BigEndian>(4096)?;
|
||||
response.write_u64::<BigEndian>(4096)?;
|
||||
response.write_u64::<BigEndian>(1000000)?;
|
||||
response.write_u64::<BigEndian>(500000)?;
|
||||
response.write_u64::<BigEndian>(500000)?;
|
||||
response.write_u64::<BigEndian>(100000)?;
|
||||
response.write_u64::<BigEndian>(50000)?;
|
||||
response.write_u64::<BigEndian>(50000)?;
|
||||
response.write_u64::<BigEndian>(0)?;
|
||||
response.write_u64::<BigEndian>(0)?;
|
||||
response.write_u64::<BigEndian>(255)?;
|
||||
|
||||
self.wrap_sftp_packet(&response)
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "statvfs not supported on non-Unix systems")
|
||||
}
|
||||
|
||||
/// 处理fstatvfs@openssh.com扩展(文件句柄统计)
|
||||
@@ -1073,18 +1072,14 @@ impl SftpHandler {
|
||||
let full_oldpath = self.resolve_path(&oldpath)?;
|
||||
let full_newpath = self.resolve_path(&newpath)?;
|
||||
|
||||
#[cfg(unix)]
|
||||
match fs::hard_link(&full_oldpath, &full_newpath) {
|
||||
match self.vfs.hard_link(&full_oldpath, &full_newpath) {
|
||||
Ok(_) => {
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Hardlink created")
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Hardlink not supported on non-Unix systems")
|
||||
}
|
||||
|
||||
/// 处理posix-rename@openssh.com扩展(POSIX语义重命名)
|
||||
@@ -1097,12 +1092,12 @@ impl SftpHandler {
|
||||
let full_oldpath = self.resolve_path(&oldpath)?;
|
||||
let full_newpath = self.resolve_path(&newpath)?;
|
||||
|
||||
match fs::rename(&full_oldpath, &full_newpath) {
|
||||
match self.vfs.rename(&full_oldpath, &full_newpath) {
|
||||
Ok(_) => {
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_OK, "Posix rename successful")
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1122,34 +1117,31 @@ impl SftpHandler {
|
||||
|
||||
let full_path = self.resolve_path(&path)?;
|
||||
|
||||
match File::open(&full_path) {
|
||||
let flags = OpenFlags::new().read();
|
||||
match self.vfs.open_file(&full_path, &flags) {
|
||||
Ok(mut file) => {
|
||||
file.seek(SeekFrom::Start(offset))?;
|
||||
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
|
||||
|
||||
let mut buffer = vec![0u8; actual_length as usize];
|
||||
file.read_exact(&mut buffer)?;
|
||||
file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
|
||||
|
||||
// 计算MD5哈希
|
||||
let hash = md5::compute(&buffer);
|
||||
let hash_hex = format!("{:x}", hash);
|
||||
|
||||
// 构建响应
|
||||
let mut response = Vec::new();
|
||||
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
|
||||
response.write_u32::<BigEndian>(id)?;
|
||||
|
||||
// hash-algorithm (SSH string)
|
||||
response.write_u32::<BigEndian>(4)?;
|
||||
response.write_all("md5".as_bytes())?;
|
||||
|
||||
// hash-value (SSH string)
|
||||
response.write_u32::<BigEndian>(hash_hex.len() as u32)?;
|
||||
response.write_all(hash_hex.as_bytes())?;
|
||||
|
||||
self.wrap_sftp_packet(&response)
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1169,37 +1161,34 @@ impl SftpHandler {
|
||||
|
||||
let full_path = self.resolve_path(&path)?;
|
||||
|
||||
match File::open(&full_path) {
|
||||
let flags = OpenFlags::new().read();
|
||||
match self.vfs.open_file(&full_path, &flags) {
|
||||
Ok(mut file) => {
|
||||
file.seek(SeekFrom::Start(offset))?;
|
||||
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
|
||||
|
||||
let mut buffer = vec![0u8; actual_length as usize];
|
||||
file.read_exact(&mut buffer)?;
|
||||
file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
|
||||
|
||||
// 计算SHA256哈希(使用sha2 crate)
|
||||
use sha2::{Sha256, Digest};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(&buffer);
|
||||
let hash = hasher.finalize();
|
||||
let hash_hex = format!("{:x}", hash);
|
||||
|
||||
// 构建响应
|
||||
let mut response = Vec::new();
|
||||
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
|
||||
response.write_u32::<BigEndian>(id)?;
|
||||
|
||||
// hash-algorithm (SSH string)
|
||||
response.write_u32::<BigEndian>(6)?;
|
||||
response.write_all("sha256".as_bytes())?;
|
||||
|
||||
// hash-value (SSH string)
|
||||
response.write_u32::<BigEndian>(hash_hex.len() as u32)?;
|
||||
response.write_all(hash_hex.as_bytes())?;
|
||||
|
||||
self.wrap_sftp_packet(&response)
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1219,21 +1208,20 @@ impl SftpHandler {
|
||||
|
||||
let full_path = self.resolve_path(&path)?;
|
||||
|
||||
match File::open(&full_path) {
|
||||
let flags = OpenFlags::new().read();
|
||||
match self.vfs.open_file(&full_path, &flags) {
|
||||
Ok(mut file) => {
|
||||
file.seek(SeekFrom::Start(offset))?;
|
||||
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
|
||||
|
||||
let mut buffer = vec![0u8; actual_length as usize];
|
||||
file.read_exact(&mut buffer)?;
|
||||
file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
|
||||
|
||||
// 计算SHA384哈希
|
||||
use sha2::{Sha384, Digest};
|
||||
let mut hasher = Sha384::new();
|
||||
hasher.update(&buffer);
|
||||
let hash = hasher.finalize();
|
||||
let hash_hex = format!("{:x}", hash);
|
||||
|
||||
// 构建响应
|
||||
let mut response = Vec::new();
|
||||
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
|
||||
response.write_u32::<BigEndian>(id)?;
|
||||
@@ -1247,7 +1235,7 @@ impl SftpHandler {
|
||||
self.wrap_sftp_packet(&response)
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1267,21 +1255,20 @@ impl SftpHandler {
|
||||
|
||||
let full_path = self.resolve_path(&path)?;
|
||||
|
||||
match File::open(&full_path) {
|
||||
let flags = OpenFlags::new().read();
|
||||
match self.vfs.open_file(&full_path, &flags) {
|
||||
Ok(mut file) => {
|
||||
file.seek(SeekFrom::Start(offset))?;
|
||||
file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
|
||||
|
||||
let mut buffer = vec![0u8; actual_length as usize];
|
||||
file.read_exact(&mut buffer)?;
|
||||
file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
|
||||
|
||||
// 计算SHA512哈希
|
||||
use sha2::{Sha512, Digest};
|
||||
let mut hasher = Sha512::new();
|
||||
hasher.update(&buffer);
|
||||
let hash = hasher.finalize();
|
||||
let hash_hex = format!("{:x}", hash);
|
||||
|
||||
// 构建响应
|
||||
let mut response = Vec::new();
|
||||
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
|
||||
response.write_u32::<BigEndian>(id)?;
|
||||
@@ -1295,7 +1282,7 @@ impl SftpHandler {
|
||||
self.wrap_sftp_packet(&response)
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1303,30 +1290,28 @@ impl SftpHandler {
|
||||
/// 处理check-file@openssh.com扩展(Phase 12:文件检查)
|
||||
fn handle_check_file(&self, cursor: &mut std::io::Cursor<&[u8]>, id: u32) -> Result<Vec<u8>> {
|
||||
let path = read_sftp_string(cursor)?;
|
||||
let check_flags = cursor.read_u32::<BigEndian>()?;
|
||||
let _check_flags = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
info!("check-file: path={}, flags={:#x}", path, check_flags);
|
||||
info!("check-file: path={}", path);
|
||||
|
||||
let full_path = self.resolve_path(&path)?;
|
||||
|
||||
match fs::metadata(&full_path) {
|
||||
Ok(metadata) => {
|
||||
// 构建响应
|
||||
match self.vfs.stat(&full_path) {
|
||||
Ok(stat) => {
|
||||
let mut response = Vec::new();
|
||||
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
|
||||
response.write_u32::<BigEndian>(id)?;
|
||||
|
||||
// 返回文件存在和基本信息
|
||||
response.write_u32::<BigEndian>(1)?; // result: 1 = file exists
|
||||
response.write_u32::<BigEndian>(1)?;
|
||||
|
||||
let msg = format!("File exists, size: {}", metadata.len());
|
||||
let msg = format!("File exists, size: {}", stat.size);
|
||||
response.write_u32::<BigEndian>(msg.len() as u32)?;
|
||||
response.write_all(msg.as_bytes())?;
|
||||
|
||||
self.wrap_sftp_packet(&response)
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_response(id, SftpStatus::SSH_FX_NO_SUCH_FILE, &format!("Check file error: {}", e))
|
||||
self.build_status_from_vfs_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1339,11 +1324,8 @@ impl SftpHandler {
|
||||
let write_handle_bytes = read_sftp_string_bytes(cursor)?;
|
||||
let write_offset = cursor.read_u64::<BigEndian>()?;
|
||||
|
||||
info!("copy-data: read_handle={}, read_offset={}, read_length={}, write_handle={}, write_offset={}",
|
||||
u32::from_be_bytes([read_handle_bytes[0], read_handle_bytes[1], read_handle_bytes[2], read_handle_bytes[3]]),
|
||||
read_offset, read_length,
|
||||
u32::from_be_bytes([write_handle_bytes[0], write_handle_bytes[1], write_handle_bytes[2], write_handle_bytes[3]]),
|
||||
write_offset);
|
||||
info!("copy-data: read_handle={:?}, read_offset={}, read_length={}, write_handle={:?}, write_offset={}",
|
||||
read_handle_bytes, read_offset, read_length, write_handle_bytes, write_offset);
|
||||
|
||||
let actual_length = std::cmp::min(read_length, Self::MAX_XFER_SIZE as u64);
|
||||
if actual_length < read_length {
|
||||
@@ -1353,52 +1335,44 @@ impl SftpHandler {
|
||||
let read_handle_id = u32::from_be_bytes([read_handle_bytes[0], read_handle_bytes[1], read_handle_bytes[2], read_handle_bytes[3]]);
|
||||
let write_handle_id = u32::from_be_bytes([write_handle_bytes[0], write_handle_bytes[1], write_handle_bytes[2], write_handle_bytes[3]]);
|
||||
|
||||
// 获取read handle的path(不可变引用)
|
||||
let read_path = if let Some(read_handle) = self.handles.get(&read_handle_id) {
|
||||
read_handle.path.clone()
|
||||
} else {
|
||||
return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid read handle");
|
||||
};
|
||||
|
||||
// 获取write handle的path(不可变引用)
|
||||
let write_path = if let Some(write_handle) = self.handles.get(&write_handle_id) {
|
||||
write_handle.path.clone()
|
||||
} else {
|
||||
return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid write handle");
|
||||
};
|
||||
|
||||
// 从read_path读取数据
|
||||
match File::open(&read_path) {
|
||||
Ok(mut read_file) => {
|
||||
read_file.seek(SeekFrom::Start(read_offset))?;
|
||||
let mut buffer = vec![0u8; actual_length as usize];
|
||||
read_file.read_exact(&mut buffer)?;
|
||||
|
||||
// 写入到write_path
|
||||
match OpenOptions::new().write(true).open(&write_path) {
|
||||
Ok(mut write_file) => {
|
||||
write_file.seek(SeekFrom::Start(write_offset))?;
|
||||
write_file.write_all(&buffer)?;
|
||||
|
||||
// 构建响应
|
||||
let mut response = Vec::new();
|
||||
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
|
||||
response.write_u32::<BigEndian>(id)?;
|
||||
|
||||
// 返回复制的字节数
|
||||
response.write_u64::<BigEndian>(actual_length)?;
|
||||
|
||||
self.wrap_sftp_packet(&response)
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
self.build_status_from_io_error(id, &e)
|
||||
}
|
||||
}
|
||||
let read_flags = OpenFlags::new().read();
|
||||
let write_flags = OpenFlags::new().write();
|
||||
|
||||
let mut read_file = match self.vfs.open_file(&read_path, &read_flags) {
|
||||
Ok(f) => f,
|
||||
Err(e) => return self.build_status_from_vfs_error(id, &e),
|
||||
};
|
||||
let mut write_file = match self.vfs.open_file(&write_path, &write_flags) {
|
||||
Ok(f) => f,
|
||||
Err(e) => return self.build_status_from_vfs_error(id, &e),
|
||||
};
|
||||
|
||||
read_file.seek(SeekFrom::Start(read_offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
|
||||
let mut buffer = vec![0u8; actual_length as usize];
|
||||
read_file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?;
|
||||
|
||||
write_file.seek(SeekFrom::Start(write_offset)).map_err(|e| anyhow!("Seek error: {}", e))?;
|
||||
write_file.write_all(&buffer).map_err(|e| anyhow!("Write error: {}", e))?;
|
||||
write_file.flush().ok();
|
||||
|
||||
let mut response = Vec::new();
|
||||
response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?;
|
||||
response.write_u32::<BigEndian>(id)?;
|
||||
response.write_u64::<BigEndian>(actual_length)?;
|
||||
|
||||
self.wrap_sftp_packet(&response)
|
||||
}
|
||||
|
||||
/// 解析路径(安全性检查,参考OpenSSH sftp-server.c: path_resolve())
|
||||
@@ -1608,6 +1582,24 @@ impl SftpHandler {
|
||||
let msg = format!("{}", err);
|
||||
self.build_status_response(id, status, &msg)
|
||||
}
|
||||
|
||||
/// 根据 VfsError 构建状态响应(自动映射错误类型)
|
||||
fn build_status_from_vfs_error(&self, id: u32, err: &crate::vfs::VfsError) -> Result<Vec<u8>> {
|
||||
use crate::vfs::VfsError;
|
||||
let status = match err {
|
||||
VfsError::NotFound(_) => SftpStatus::SSH_FX_NO_SUCH_FILE,
|
||||
VfsError::PermissionDenied(_) => SftpStatus::SSH_FX_PERMISSION_DENIED,
|
||||
VfsError::AlreadyExists(_) => SftpStatus::SSH_FX_FAILURE,
|
||||
VfsError::NotEmpty(_) => SftpStatus::SSH_FX_FAILURE,
|
||||
VfsError::NotADirectory(_) => SftpStatus::SSH_FX_FAILURE,
|
||||
VfsError::IsADirectory(_) => SftpStatus::SSH_FX_FAILURE,
|
||||
VfsError::Unsupported(_) => SftpStatus::SSH_FX_OP_UNSUPPORTED,
|
||||
VfsError::Io(_) => SftpStatus::SSH_FX_FAILURE,
|
||||
VfsError::UnexpectedEof => SftpStatus::SSH_FX_EOF,
|
||||
};
|
||||
let msg = format!("{}", err);
|
||||
self.build_status_response(id, status, &msg)
|
||||
}
|
||||
}
|
||||
|
||||
/// 读取SFTP字符串(参考draft-ietf-secsh-filexfer-02.txt)
|
||||
@@ -1665,8 +1657,14 @@ fn read_sftp_attrs<R: std::io::Read>(reader: &mut R) -> Result<SftpAttrs> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::vfs::local_fs::LocalFs;
|
||||
use std::fs::File;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn make_handler(root_dir: PathBuf) -> SftpHandler {
|
||||
SftpHandler::new(root_dir, Box::new(LocalFs::new()), 32768)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sftp_packet_type_conversion() {
|
||||
assert_eq!(SftpPacketType::try_from(1).unwrap(), SftpPacketType::SSH_FXP_INIT);
|
||||
@@ -1677,7 +1675,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_sftp_handler_creation() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let handler = SftpHandler::new(temp_dir.path().to_path_buf(), 32768);
|
||||
let handler = make_handler(temp_dir.path().to_path_buf());
|
||||
assert_eq!(handler.next_handle_id, 0);
|
||||
}
|
||||
|
||||
@@ -1697,7 +1695,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_sftp_handle_init() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let mut handler = SftpHandler::new(temp_dir.path().to_path_buf(), 32768);
|
||||
let mut handler = make_handler(temp_dir.path().to_path_buf());
|
||||
|
||||
let init_packet = vec![1, 0, 0, 0, 3];
|
||||
let response = handler.handle_request(&init_packet).unwrap();
|
||||
|
||||
212
markbase-core/src/vfs/local_fs.rs
Normal file
212
markbase-core/src/vfs/local_fs.rs
Normal file
@@ -0,0 +1,212 @@
|
||||
use super::util;
|
||||
use super::open_flags::OpenFlags;
|
||||
use super::{VfsBackend, VfsDirEntry, VfsError, VfsFile, VfsStat};
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{Read, Seek, SeekFrom, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::os::unix::fs::{MetadataExt, PermissionsExt};
|
||||
|
||||
/// 本地文件系统实现(直接包装 std::fs,不做路径解析)
|
||||
/// 路径解析由上层(SftpHandler)负责
|
||||
pub struct LocalFs;
|
||||
|
||||
impl LocalFs {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
struct LocalFile {
|
||||
file: File,
|
||||
}
|
||||
|
||||
impl VfsFile for LocalFile {
|
||||
fn read(&mut self, buf: &mut [u8]) -> Result<usize, VfsError> {
|
||||
self.file.read(buf).map_err(|e| VfsError::Io(e.to_string()))
|
||||
}
|
||||
|
||||
fn write(&mut self, buf: &[u8]) -> Result<usize, VfsError> {
|
||||
self.file.write(buf).map_err(|e| VfsError::Io(e.to_string()))
|
||||
}
|
||||
|
||||
fn seek(&mut self, pos: SeekFrom) -> Result<u64, VfsError> {
|
||||
self.file.seek(pos).map_err(|e| VfsError::Io(e.to_string()))
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), VfsError> {
|
||||
self.file.flush().map_err(|e| VfsError::Io(e.to_string()))
|
||||
}
|
||||
|
||||
fn stat(&mut self) -> Result<VfsStat, VfsError> {
|
||||
let meta = self.file.metadata().map_err(|e| VfsError::Io(e.to_string()))?;
|
||||
Ok(util::stat_from_metadata(&meta, false))
|
||||
}
|
||||
|
||||
fn set_len(&mut self, size: u64) -> Result<(), VfsError> {
|
||||
self.file.set_len(size).map_err(|e| VfsError::Io(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl VfsBackend for LocalFs {
|
||||
fn read_dir(&self, path: &Path) -> Result<Vec<VfsDirEntry>, VfsError> {
|
||||
let dir = fs::read_dir(path).map_err(|e| util::map_io_error(path, e))?;
|
||||
|
||||
let mut entries = Vec::new();
|
||||
for entry in dir {
|
||||
let entry = entry.map_err(|e| util::map_io_error(path, e))?;
|
||||
let name = entry.file_name().to_string_lossy().to_string();
|
||||
let file_type = entry.file_type().map_err(|e| util::map_io_error(path, e))?;
|
||||
let meta = entry.metadata().map_err(|e| util::map_io_error(path, e))?;
|
||||
let stat = util::stat_from_metadata(&meta, file_type.is_symlink());
|
||||
let long_name = util::build_long_name(&stat, &name);
|
||||
|
||||
entries.push(VfsDirEntry {
|
||||
name,
|
||||
long_name,
|
||||
stat,
|
||||
});
|
||||
}
|
||||
|
||||
entries.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
fn open_file(&self, path: &Path, flags: &OpenFlags) -> Result<Box<dyn VfsFile>, VfsError> {
|
||||
let mut opts = OpenOptions::new();
|
||||
opts.read(flags.read);
|
||||
opts.write(flags.write);
|
||||
opts.append(flags.append);
|
||||
opts.create(flags.create);
|
||||
opts.truncate(flags.truncate);
|
||||
opts.create_new(flags.exclusive);
|
||||
|
||||
let file = opts.open(path).map_err(|e| util::map_io_error(path, e))?;
|
||||
|
||||
#[cfg(unix)]
|
||||
if flags.create && !flags.exclusive {
|
||||
if let Ok(meta) = file.metadata() {
|
||||
if flags.mode != 0 && meta.permissions().mode() != flags.mode {
|
||||
fs::set_permissions(path, std::fs::Permissions::from_mode(flags.mode))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Box::new(LocalFile { file }))
|
||||
}
|
||||
|
||||
fn stat(&self, path: &Path) -> Result<VfsStat, VfsError> {
|
||||
let meta = fs::metadata(path).map_err(|e| util::map_io_error(path, e))?;
|
||||
Ok(util::stat_from_metadata(&meta, false))
|
||||
}
|
||||
|
||||
fn lstat(&self, path: &Path) -> Result<VfsStat, VfsError> {
|
||||
let meta = fs::symlink_metadata(path).map_err(|e| util::map_io_error(path, e))?;
|
||||
let is_symlink = path.is_symlink() || meta.file_type().is_symlink();
|
||||
Ok(util::stat_from_metadata(&meta, is_symlink))
|
||||
}
|
||||
|
||||
fn create_dir(&self, path: &Path, mode: u32) -> Result<(), VfsError> {
|
||||
fs::create_dir(path).map_err(|e| util::map_io_error(path, e))?;
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
fs::set_permissions(path, std::fs::Permissions::from_mode(mode))
|
||||
.map_err(|e| util::map_io_error(path, e))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_dir_all(&self, path: &Path, mode: u32) -> Result<(), VfsError> {
|
||||
fs::create_dir_all(path).map_err(|e| util::map_io_error(path, e))?;
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
if mode != 0 {
|
||||
fs::set_permissions(path, std::fs::Permissions::from_mode(mode))
|
||||
.map_err(|e| util::map_io_error(path, e))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_dir(&self, path: &Path) -> Result<(), VfsError> {
|
||||
fs::remove_dir(path).map_err(|e| util::map_io_error(path, e))
|
||||
}
|
||||
|
||||
fn remove_file(&self, path: &Path) -> Result<(), VfsError> {
|
||||
fs::remove_file(path).map_err(|e| util::map_io_error(path, e))
|
||||
}
|
||||
|
||||
fn rename(&self, from: &Path, to: &Path) -> Result<(), VfsError> {
|
||||
fs::rename(from, to).map_err(|e| util::map_io_error(from, e))
|
||||
}
|
||||
|
||||
fn set_stat(&self, path: &Path, stat: &VfsStat) -> Result<(), VfsError> {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
if stat.mode != 0 {
|
||||
fs::set_permissions(path, std::fs::Permissions::from_mode(stat.mode))
|
||||
.map_err(|e| util::map_io_error(path, e))?;
|
||||
}
|
||||
}
|
||||
|
||||
if let (Some(atime), Some(mtime)) = (
|
||||
stat.atime.duration_since(std::time::UNIX_EPOCH).ok(),
|
||||
stat.mtime.duration_since(std::time::UNIX_EPOCH).ok(),
|
||||
) {
|
||||
filetime::set_file_times(path,
|
||||
filetime::FileTime::from_unix_time(atime.as_secs() as i64, 0),
|
||||
filetime::FileTime::from_unix_time(mtime.as_secs() as i64, 0),
|
||||
).map_err(|e| util::map_io_error(path, e))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_link(&self, path: &Path) -> Result<PathBuf, VfsError> {
|
||||
let target = fs::read_link(path).map_err(|e| util::map_io_error(path, e))?;
|
||||
Ok(target)
|
||||
}
|
||||
|
||||
fn create_symlink(&self, target: &Path, link: &Path) -> Result<(), VfsError> {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
std::os::unix::fs::symlink(target, link)
|
||||
.map_err(|e| util::map_io_error(link, e))?;
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
std::os::windows::fs::symlink_file(target, link)
|
||||
.map_err(|e| util::map_io_error(link, e))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn real_path(&self, path: &Path) -> Result<PathBuf, VfsError> {
|
||||
let canonical = path.canonicalize().map_err(|e| util::map_io_error(path, e))?;
|
||||
Ok(canonical)
|
||||
}
|
||||
|
||||
fn exists(&self, path: &Path) -> bool {
|
||||
path.exists()
|
||||
}
|
||||
|
||||
fn hard_link(&self, original: &Path, link: &Path) -> Result<(), VfsError> {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
fs::hard_link(original, link).map_err(|e| util::map_io_error(original, e))?;
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
return Err(VfsError::Unsupported("hard_link not supported on non-Unix systems".to_string()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
160
markbase-core/src/vfs/mod.rs
Normal file
160
markbase-core/src/vfs/mod.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
pub mod open_flags;
|
||||
pub mod local_fs;
|
||||
pub mod util;
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::SystemTime;
|
||||
|
||||
/// VFS 错误类型
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum VfsError {
|
||||
NotFound(String),
|
||||
PermissionDenied(String),
|
||||
AlreadyExists(String),
|
||||
NotEmpty(String),
|
||||
NotADirectory(String),
|
||||
IsADirectory(String),
|
||||
Unsupported(String),
|
||||
Io(String),
|
||||
UnexpectedEof,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VfsError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
VfsError::NotFound(p) => write!(f, "No such file or directory: {}", p),
|
||||
VfsError::PermissionDenied(p) => write!(f, "Permission denied: {}", p),
|
||||
VfsError::AlreadyExists(p) => write!(f, "File already exists: {}", p),
|
||||
VfsError::NotEmpty(p) => write!(f, "Directory not empty: {}", p),
|
||||
VfsError::NotADirectory(p) => write!(f, "Not a directory: {}", p),
|
||||
VfsError::IsADirectory(p) => write!(f, "Is a directory: {}", p),
|
||||
VfsError::Unsupported(msg) => write!(f, "Unsupported: {}", msg),
|
||||
VfsError::Io(msg) => write!(f, "IO error: {}", msg),
|
||||
VfsError::UnexpectedEof => write!(f, "Unexpected end of file"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for VfsError {}
|
||||
|
||||
/// 文件统计信息(类似 libc::stat)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VfsStat {
|
||||
pub size: u64,
|
||||
pub mode: u32,
|
||||
pub uid: u32,
|
||||
pub gid: u32,
|
||||
pub atime: SystemTime,
|
||||
pub mtime: SystemTime,
|
||||
pub is_dir: bool,
|
||||
pub is_symlink: bool,
|
||||
}
|
||||
|
||||
impl VfsStat {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
size: 0,
|
||||
mode: 0,
|
||||
uid: 0,
|
||||
gid: 0,
|
||||
atime: SystemTime::UNIX_EPOCH,
|
||||
mtime: SystemTime::UNIX_EPOCH,
|
||||
is_dir: false,
|
||||
is_symlink: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for VfsStat {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// 目录条目
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VfsDirEntry {
|
||||
pub name: String,
|
||||
pub long_name: String,
|
||||
pub stat: VfsStat,
|
||||
}
|
||||
|
||||
/// 打开文件的抽象
|
||||
pub trait VfsFile {
|
||||
fn read(&mut self, buf: &mut [u8]) -> Result<usize, VfsError>;
|
||||
fn write(&mut self, buf: &[u8]) -> Result<usize, VfsError>;
|
||||
fn seek(&mut self, pos: std::io::SeekFrom) -> Result<u64, VfsError>;
|
||||
fn flush(&mut self) -> Result<(), VfsError>;
|
||||
fn stat(&mut self) -> Result<VfsStat, VfsError>;
|
||||
fn set_len(&mut self, size: u64) -> Result<(), VfsError>;
|
||||
|
||||
/// Write all bytes (convenience, default loops write() until done)
|
||||
fn write_all(&mut self, mut buf: &[u8]) -> Result<(), VfsError> {
|
||||
while !buf.is_empty() {
|
||||
let n = self.write(buf)?;
|
||||
if n == 0 {
|
||||
return Err(VfsError::Io("write returned 0".to_string()));
|
||||
}
|
||||
buf = &buf[n..];
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read exactly `buf.len()` bytes (convenience, loops read() until done)
|
||||
fn read_exact(&mut self, mut buf: &mut [u8]) -> Result<(), VfsError> {
|
||||
while !buf.is_empty() {
|
||||
let n = self.read(buf)?;
|
||||
if n == 0 {
|
||||
return Err(VfsError::UnexpectedEof);
|
||||
}
|
||||
buf = &mut buf[n..];
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// VFS 后端 trait(所有文件系统操作)
|
||||
pub trait VfsBackend: Send {
|
||||
/// 读取目录内容
|
||||
fn read_dir(&self, path: &Path) -> Result<Vec<VfsDirEntry>, VfsError>;
|
||||
|
||||
/// 打开文件(读/写)
|
||||
fn open_file(&self, path: &Path, flags: &open_flags::OpenFlags) -> Result<Box<dyn VfsFile>, VfsError>;
|
||||
|
||||
/// 获取文件/目录元数据
|
||||
fn stat(&self, path: &Path) -> Result<VfsStat, VfsError>;
|
||||
fn lstat(&self, path: &Path) -> Result<VfsStat, VfsError>;
|
||||
|
||||
/// 创建目录
|
||||
fn create_dir(&self, path: &Path, mode: u32) -> Result<(), VfsError>;
|
||||
|
||||
/// 递归创建目录
|
||||
fn create_dir_all(&self, path: &Path, mode: u32) -> Result<(), VfsError>;
|
||||
|
||||
/// 删除空目录
|
||||
fn remove_dir(&self, path: &Path) -> Result<(), VfsError>;
|
||||
|
||||
/// 删除文件
|
||||
fn remove_file(&self, path: &Path) -> Result<(), VfsError>;
|
||||
|
||||
/// 重命名
|
||||
fn rename(&self, from: &Path, to: &Path) -> Result<(), VfsError>;
|
||||
|
||||
/// 设置文件属性
|
||||
fn set_stat(&self, path: &Path, stat: &VfsStat) -> Result<(), VfsError>;
|
||||
|
||||
/// 读取符号链接目标
|
||||
fn read_link(&self, path: &Path) -> Result<PathBuf, VfsError>;
|
||||
|
||||
/// 创建符号链接
|
||||
fn create_symlink(&self, target: &Path, link: &Path) -> Result<(), VfsError>;
|
||||
|
||||
/// 规范化路径
|
||||
fn real_path(&self, path: &Path) -> Result<PathBuf, VfsError>;
|
||||
|
||||
/// 检查路径是否存在
|
||||
fn exists(&self, path: &Path) -> bool;
|
||||
|
||||
/// 创建硬链接
|
||||
fn hard_link(&self, original: &Path, link: &Path) -> Result<(), VfsError>;
|
||||
}
|
||||
75
markbase-core/src/vfs/open_flags.rs
Normal file
75
markbase-core/src/vfs/open_flags.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
/// 文件打开标志(映射 SSH_FXF_* 和 POSIX open flags)
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct OpenFlags {
|
||||
pub read: bool,
|
||||
pub write: bool,
|
||||
pub append: bool,
|
||||
pub create: bool,
|
||||
pub truncate: bool,
|
||||
pub exclusive: bool,
|
||||
pub mode: u32,
|
||||
}
|
||||
|
||||
impl OpenFlags {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn read(mut self) -> Self {
|
||||
self.read = true;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn write(mut self) -> Self {
|
||||
self.write = true;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn append(mut self) -> Self {
|
||||
self.append = true;
|
||||
self.write = true;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn create(mut self) -> Self {
|
||||
self.create = true;
|
||||
self.write = true;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn truncate(mut self) -> Self {
|
||||
self.truncate = true;
|
||||
self.write = true;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn exclusive(mut self) -> Self {
|
||||
self.exclusive = true;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn mode(mut self, mode: u32) -> Self {
|
||||
self.mode = mode;
|
||||
self
|
||||
}
|
||||
|
||||
/// 从 SFTP 的 pflags(SSH_FXF_*)构建 OpenFlags
|
||||
pub fn from_sftp_pflags(pflags: u32) -> Self {
|
||||
let read = pflags & 0x00000001 != 0;
|
||||
let write = pflags & 0x00000002 != 0;
|
||||
let append = pflags & 0x00000004 != 0;
|
||||
let create = pflags & 0x00000008 != 0;
|
||||
let truncate = pflags & 0x00000010 != 0;
|
||||
let exclusive = pflags & 0x00000020 != 0;
|
||||
|
||||
Self {
|
||||
read,
|
||||
write,
|
||||
append,
|
||||
create,
|
||||
truncate,
|
||||
exclusive,
|
||||
mode: 0o644,
|
||||
}
|
||||
}
|
||||
}
|
||||
105
markbase-core/src/vfs/util.rs
Normal file
105
markbase-core/src/vfs/util.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
use super::{VfsError, VfsStat};
|
||||
use chrono::Datelike;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::path::Path;
|
||||
|
||||
/// 从 std::io::ErrorKind 映射 VfsError
|
||||
pub fn map_io_error(path: &Path, e: std::io::Error) -> VfsError {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::NotFound => VfsError::NotFound(path.display().to_string()),
|
||||
std::io::ErrorKind::PermissionDenied => VfsError::PermissionDenied(path.display().to_string()),
|
||||
std::io::ErrorKind::AlreadyExists => VfsError::AlreadyExists(path.display().to_string()),
|
||||
std::io::ErrorKind::DirectoryNotEmpty => VfsError::NotEmpty(path.display().to_string()),
|
||||
std::io::ErrorKind::NotADirectory => VfsError::NotADirectory(path.display().to_string()),
|
||||
std::io::ErrorKind::IsADirectory => VfsError::IsADirectory(path.display().to_string()),
|
||||
std::io::ErrorKind::UnexpectedEof => VfsError::UnexpectedEof,
|
||||
other => VfsError::Io(format!("{}: {}", other, path.display())),
|
||||
}
|
||||
}
|
||||
|
||||
/// 从 std::fs::Metadata 构建 VfsStat
|
||||
pub fn stat_from_metadata(meta: &std::fs::Metadata, is_symlink: bool) -> VfsStat {
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::MetadataExt;
|
||||
|
||||
let mut stat = VfsStat::new();
|
||||
stat.size = meta.len();
|
||||
stat.is_dir = meta.is_dir();
|
||||
stat.is_symlink = is_symlink;
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
stat.mode = meta.permissions().mode();
|
||||
stat.uid = meta.uid();
|
||||
stat.gid = meta.gid();
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
stat.mode = if meta.is_dir() { 0o40755 } else { 0o100644 };
|
||||
}
|
||||
|
||||
if let Ok(t) = meta.accessed() {
|
||||
stat.atime = t;
|
||||
}
|
||||
if let Ok(t) = meta.modified() {
|
||||
stat.mtime = t;
|
||||
}
|
||||
|
||||
stat
|
||||
}
|
||||
|
||||
/// 构建目录条目的 long_name(类似 ls -l 格式)
|
||||
pub fn build_long_name(stat: &VfsStat, name: &str) -> String {
|
||||
let file_type = if stat.is_dir { 'd' } else { '-' };
|
||||
let perms = format_permissions(stat.mode & 0o777);
|
||||
let link_count = if stat.is_dir { 3 } else { 1 };
|
||||
let size = stat.size;
|
||||
let mtime = match stat.mtime.duration_since(std::time::UNIX_EPOCH) {
|
||||
Ok(d) => {
|
||||
let secs = d.as_secs();
|
||||
format_timestamp(secs)
|
||||
}
|
||||
Err(_) => "Jan 1 1970".to_string(),
|
||||
};
|
||||
|
||||
format!(
|
||||
"{}{} {} {} {} {} {} {}",
|
||||
file_type, perms,
|
||||
link_count,
|
||||
stat.uid,
|
||||
stat.gid,
|
||||
size,
|
||||
mtime,
|
||||
name
|
||||
)
|
||||
}
|
||||
|
||||
fn format_permissions(mode: u32) -> String {
|
||||
let rwx = |n: u32| -> String {
|
||||
let r = if n & 4 != 0 { 'r' } else { '-' };
|
||||
let w = if n & 2 != 0 { 'w' } else { '-' };
|
||||
let x = if n & 1 != 0 { 'x' } else { '-' };
|
||||
format!("{}{}{}", r, w, x)
|
||||
};
|
||||
|
||||
format!(
|
||||
"{}{}{}",
|
||||
rwx((mode >> 6) & 7),
|
||||
rwx((mode >> 3) & 7),
|
||||
rwx(mode & 7)
|
||||
)
|
||||
}
|
||||
|
||||
fn format_timestamp(secs: u64) -> String {
|
||||
let datetime = match chrono::DateTime::from_timestamp(secs as i64, 0) {
|
||||
Some(dt) => dt,
|
||||
None => return "Jan 1 1970".to_string(),
|
||||
};
|
||||
let now = chrono::Utc::now();
|
||||
if datetime.year() == now.year() {
|
||||
datetime.format("%b %e %H:%M").to_string()
|
||||
} else {
|
||||
datetime.format("%b %e %Y").to_string()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user