From f90e4f496c4f727a39efbc2ff0c931d5192ccf82 Mon Sep 17 00:00:00 2001 From: Warren Date: Thu, 18 Jun 2026 23:35:18 +0800 Subject: [PATCH] VFS/DataProvider/Config refactoring + SSH public key authentication MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 1-6 of refactoring plan: - VFS abstraction (VfsBackend trait + LocalFs + OpenFlags builder) - DataProvider trait (SqliteProvider + PgProvider, SFTPGo-compatible) - Config refactoring (AppConfig unified sections, env overrides) - SSH handlers (sftp/scp/rsync) migrated to VFS + DataProvider - SSH public key authentication (Ed25519 signature verification) - SSH stderr → CHANNEL_EXTENDED_DATA support - Web auth uses DataProvider instead of direct SQL - User home directory from provider (per-user isolation) - PostgreSQL auth provider for SFTPGo compatibility --- AGENTS.md | 76 ++- Cargo.lock | 29 +- data/auth.sqlite | Bin 73728 -> 73728 bytes markbase-core/Cargo.toml | 1 + markbase-core/src/auth.rs | 83 +++- markbase-core/src/cli/interface/ssh.rs | 15 +- markbase-core/src/config/mod.rs | 233 +++++++++ .../src/{config.rs => config/web.rs} | 3 +- markbase-core/src/lib.rs | 2 + markbase-core/src/provider/mod.rs | 65 +++ markbase-core/src/provider/pg.rs | 184 ++++++++ markbase-core/src/provider/sqlite.rs | 135 ++++++ markbase-core/src/server.rs | 6 +- markbase-core/src/ssh2_server/server.rs | 11 +- markbase-core/src/ssh_server/auth.rs | 405 ++++++++-------- markbase-core/src/ssh_server/channel.rs | 48 +- markbase-core/src/ssh_server/cipher.rs | 3 + markbase-core/src/ssh_server/rsync_handler.rs | 119 ++--- markbase-core/src/ssh_server/scp_handler.rs | 164 +++---- markbase-core/src/ssh_server/server.rs | 73 ++- markbase-core/src/ssh_server/sftp_handler.rs | 444 +++++++++--------- markbase-core/src/vfs/local_fs.rs | 212 +++++++++ markbase-core/src/vfs/mod.rs | 160 +++++++ markbase-core/src/vfs/open_flags.rs | 75 +++ markbase-core/src/vfs/util.rs | 105 +++++ 25 files changed, 2039 insertions(+), 612 deletions(-) create mode 100644 markbase-core/src/config/mod.rs rename markbase-core/src/{config.rs => config/web.rs} (99%) create mode 100644 markbase-core/src/provider/mod.rs create mode 100644 markbase-core/src/provider/pg.rs create mode 100644 markbase-core/src/provider/sqlite.rs create mode 100644 markbase-core/src/vfs/local_fs.rs create mode 100644 markbase-core/src/vfs/mod.rs create mode 100644 markbase-core/src/vfs/open_flags.rs create mode 100644 markbase-core/src/vfs/util.rs diff --git a/AGENTS.md b/AGENTS.md index c56142f..9fbe6e2 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1234,6 +1234,78 @@ markbase-core/src/ssh_server/channel.rs(handle_rsync_exec → handle_interacti --- -**最后更新**:2026-06-17 22:00 -**版本**:1.12(SSH Phase 16 Final: Rsync 子进程模式完成) +**最后更新**:2026-06-18 14:00 +**版本**:1.13(VFS/DataProvider/Config 重構 Phase 1-6 完成) + +## VFS + DataProvider + Config 重構(2026-06-18)⭐⭐⭐⭐⭐ + +**完成時間**:約 2 小時 +**新增代碼量**:約 600 行 +**Git 狀態**:未提交 + +### Phase 1-6 完成明細 + +| Phase | 模組 | 狀態 | 說明 | +|-------|------|------|------| +| **1** | `vfs/` | ✅ 完成 | `VfsBackend` trait (15 methods) + `VfsFile` trait + `LocalFs` + `OpenFlags` builder + `VfsStat`/`VfsError`/`VfsDirEntry` | +| **2** | `sftp_handler.rs` | ✅ 完成 | 全部 `std::fs` → VFS 方法,`SftpAttrs::from_vfs_stat()`,`build_status_from_vfs_error()` | +| **3** | `scp_handler.rs` | ✅ 完成 | `ScpHandler` 使用 `Box`,全部 I/O 經 VFS | +| **4** | `rsync_handler.rs` | ✅ 完成 | `RsyncHandler` 使用 `Box`,`output_file: Option>` | +| **5** | `provider/` | ✅ 完成 | `DataProvider` trait(`get_user`/`check_password`/`get_home_dir`)+ `SqliteProvider`。`AuthHandler` 使用 provider 而非直接 SQL | +| **6** | `config/` | ✅ 完成 | `AppConfig` 統一 `web`/`s3`/`sftp`/`ssh` 四區塊。`config.rs` → `config/mod.rs` + `config/web.rs`,向後相容 | + +### 檔案結構變更 + +``` +markbase-core/src/ +├── vfs/ # Phase 1: VFS抽象層(新增) +│ ├── mod.rs # VfsBackend/VfsFile traits + VfsStat/VfsError/VfsDirEntry +│ ├── open_flags.rs # OpenFlags builder(含 from_sftp_pflags) +│ ├── local_fs.rs # LocalFs 實作(純 std::fs wrapper) +│ └── util.rs # map_io_error / stat_from_metadata / build_long_name +├── provider/ # Phase 5: DataProvider(新增) +│ ├── mod.rs # DataProvider trait + User/ProviderError +│ └── sqlite.rs # SqliteProvider 實作 +├── config/ +│ ├── mod.rs # Phase 6: AppConfig(統一配置) +│ └── web.rs # MarkBaseConfig(原有 config.rs 內容) +├── ssh_server/ +│ ├── scp_handler.rs # Phase 3: VFS 化 +│ ├── rsync_handler.rs # Phase 4: VFS 化 +│ ├── sftp_handler.rs # Phase 2: VFS 化 +│ ├── auth.rs # Phase 5: DataProvider 化 +│ └── server.rs # Phase 5: 注入 SqliteProvider +└── lib.rs # 新增 pub mod provider + pub mod vfs +``` + +### 關鍵設計決策 ⭐⭐⭐⭐⭐ + +**VFS 設計**: +- `VfsBackend` methods 接受已解析的原始路徑(路徑解析留在上層) +- `LocalFs` 是純 `std::fs` wrapper,無內部路徑操作 +- `OpenFlags::write()` 無參數(builder pattern) +- `hard_link` 在非 Unix 回傳 `VfsError::Unsupported` + +**DataProvider 設計**: +- `SqliteProvider` 查詢 `data/auth.sqlite` 的 `sftpgo_users` 表 +- bcrypt 密碼驗證(使用 `bcrypt` crate) +- `AuthHandler::new(Box)` 取代直接 SQL + +**Config 設計**: +- `AppConfig` 可從單一 `config/app.toml` 載入 +- 環境變數覆蓋:`MB_WEB_HOST`, `MB_WEB_PORT`, `MB_SSH_PORT`, `MB_SFTP_PORT`, `MB_S3_ENABLED`, `MB_AUTH_DB` +- 向後相容:`crate::config::MarkBaseConfig` 仍可使用(`pub use web::*`) + +### Build 驗證 ✅ + +```bash +cargo build -p markbase-core # ✅ 0 error, 0 new warning +``` + +### 下一步建議 + +1. **將 DataProvider 整合到 SFTP 認證**(`sftp/auth.rs` + `sftp/server.rs`) +2. **將 DataProvider 整合到 Web 認證**(`src/auth.rs` + `src/server.rs`) +3. **S3 後端實作**(S3Vfs 實作 `VfsBackend`) +4. **效能測試**(VFS + AES-CTR throughput profiling) diff --git a/Cargo.lock b/Cargo.lock index b879c58..94b0a23 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2669,6 +2669,7 @@ dependencies = [ "markbase-webdav", "md5 0.8.0", "nix 0.29.0", + "postgres", "pulldown-cmark", "rand 0.8.6", "regex", @@ -3678,10 +3679,24 @@ dependencies = [ ] [[package]] -name = "postgres-protocol" -version = "0.6.11" +name = "postgres" +version = "0.19.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56201207dac53e2f38e848e31b4b91616a6bb6e0c7205b77718994a7f49e70fc" +checksum = "33ad20e0aa0b24f5a394eab4f78c781d248982b22b25cecc7e3aa46a681605bd" +dependencies = [ + "bytes", + "fallible-iterator 0.2.0", + "futures-util", + "log", + "tokio", + "tokio-postgres", +] + +[[package]] +name = "postgres-protocol" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08808e3c483c46e999108051c78334f473d5adb59d78bb80a1268c7e6aa6c514" dependencies = [ "base64", "byteorder", @@ -3697,9 +3712,9 @@ dependencies = [ [[package]] name = "postgres-types" -version = "0.2.13" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8dc729a129e682e8d24170cd30ae1aa01b336b096cbb56df6d534ffec133d186" +checksum = "851ca9db4932932d69f3ea811b1abe63087a0f740a47692619dd40d4899b68be" dependencies = [ "bytes", "fallible-iterator 0.2.0", @@ -5079,9 +5094,9 @@ dependencies = [ [[package]] name = "tokio-postgres" -version = "0.7.17" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dd8df5ef180f6364759a6f00f7aadda4fbbac86cdee37480826a6ff9f3574ce" +checksum = "a528f7d280f6d5b9cd149635c8705b0dd049754bc67d81d31fa25169a93809d3" dependencies = [ "async-trait", "byteorder", diff --git a/data/auth.sqlite b/data/auth.sqlite index 4fab65cd632a2fcbac78586a9e3702e159f0f33d..1712087c073fae924c5225f610a54c49922172b7 100644 GIT binary patch delta 2805 zcmZwJeM}o=9Ki9rUfa9t-95h+DA1PDmRFmNu~v-|l5Nfy(``Cm5|OEygp~oI&{<%^ zWFAF~@`o8OnU1lrU}hp_%Yv7vXoy=DGYS6V>w<13Ml*syoCz}_#Nd?&o;@s26TIkuYri*3kvNXaGJiHdFt9zk)7-n`^`s_Vz;TP?&3aYL^ZWU?J;sdK$;0+wqArBuB$`@>&PzTOCpnhGhm8g6&lW2z^680zlf1E4zL0G>o(}Kfot207*CqJ3vE=Xa;?Z`cNZ!5|tsF_NR7D`#?);uWGxsLiIoOruwCN zR_#(>R3A}umAlH1ilGcD$CPK3N(DKVjF9tvhaM;g&TZu-^W)PvNtPNqT3VXTuPdZF z>g(gFRG^FrM)wm?$^`w@1e7pA&o&e2_#DMd(e5Kl6fr@ALO_TKYE}sdGJzgjAR@pN z6ewhY=*23sf}aUS!vqxE5A@8{j`n+q@ZB%gL61y;mkClU1bCPrK2JbC1#~jrLIzxcK1Uug$z`_JYCkWuT0v*5aBKO3W zcdxloj^5*75&nQ*;5+yRzJgET5?p{07=%9PhU3r*ad;V?hkfuEJPMUi27br^1uSR< z-9dBcH}n(w7F|PMpb7L5I&UsVeT<6WLr9kdE1!Aimg5-RUxuTtVi`3UQ2v{W$tP1vq+SACApasc{r*k zb8%FRyKw}Db8xu&vT+DqS?ggwUoISTwN4zfF~BiXg>X!TH5`{cswC*<=_6y3f_DwB zIB@jN%Q!k`B^-^@B923ob{q%BZ8&zFvEndq1{_5vETUkQo6Xr$7o2mf&osorJlrxT z`ZY|#g)P!C^ z&!Q(#6{ho}xx9iX~!qxH-mY(I6NfNCGr zUaCD*=Tn_WbuQIzs&lB$Rs@}V`z-2!i)ts;KsBOTqgtg}q1r*UO!|dXN+KJnlt{Im uY8%y7ss*Yo4#6jHjM7;CdmmWd;)nLEP2m3o*Vceff3^(#8fac@7M<@IJofmoK!uUW1!aW&QPu#3yeGPe3}=4UQrWYpa3 z#JNP6sZW0LK`DjF9J06B_3!;?;^rKb;)ZBb_`ORPSKky4EurhNpu`4n#uz(prQG2YS{sIo0SqlEh tZ)#u>fHK$wHnV-O;Ad)ynC$R>BgDr4|M?jswtwMg{LjxD0Sg8Y4FI|fQlJ0; diff --git a/markbase-core/Cargo.toml b/markbase-core/Cargo.toml index d96cd6a..e7d7e77 100644 --- a/markbase-core/Cargo.toml +++ b/markbase-core/Cargo.toml @@ -38,6 +38,7 @@ filetime = "0.2" base64 = "0.22" tokio = { version = "1", features = ["full"] } tokio-postgres = "0.7" +postgres = "0.19" russh = "0.61.2" russh-keys = "0.50.0-beta.7" russh-sftp = "2.3.0" diff --git a/markbase-core/src/auth.rs b/markbase-core/src/auth.rs index 53bf5db..ba21cc6 100644 --- a/markbase-core/src/auth.rs +++ b/markbase-core/src/auth.rs @@ -5,6 +5,8 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; use uuid::Uuid; +use crate::provider::{DataProvider, ProviderError}; + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct User { pub user_id: String, @@ -66,13 +68,13 @@ pub struct AuthState { pub users: Arc>>, pub auth_db: Option, pub admin_sessions: Arc>>, + pub provider: Option>, } impl AuthState { pub fn new() -> Self { let mut users = HashMap::new(); - // Create default demo user let password_hash = hash("demo123", DEFAULT_COST).unwrap(); users.insert( "demo".to_string(), @@ -89,6 +91,7 @@ impl AuthState { users: Arc::new(Mutex::new(users)), auth_db: None, admin_sessions: Arc::new(Mutex::new(HashMap::new())), + provider: None, } } @@ -100,6 +103,17 @@ impl AuthState { users: Arc::new(Mutex::new(HashMap::new())), auth_db, admin_sessions: Arc::new(Mutex::new(HashMap::new())), + provider: None, + } + } + + pub fn with_provider(provider: Box) -> Self { + AuthState { + sessions: Arc::new(Mutex::new(HashMap::new())), + users: Arc::new(Mutex::new(HashMap::new())), + auth_db: None, + admin_sessions: Arc::new(Mutex::new(HashMap::new())), + provider: Some(Arc::from(provider)), } } @@ -208,8 +222,12 @@ impl AuthState { } pub fn login_with_sync(&self, username: &str, password: &str) -> Option { + // Prefer provider over auth_db + if let Some(provider) = &self.provider { + return self.login_with_provider(&**provider, username, password); + } + if let Some(auth_db) = &self.auth_db { - // Get user from auth.sqlite let user = match auth_db.get_user(username) { Ok(Some(user)) => user, Ok(None) => { @@ -266,11 +284,70 @@ impl AuthState { } } + fn login_with_provider(&self, provider: &dyn DataProvider, username: &str, password: &str) -> Option { + match provider.get_user(username) { + Ok(Some(user)) => { + if user.status != 1 { + log::warn!("User {} is disabled or not found", username); + return None; + } + + match provider.check_password(username, password) { + Ok(true) => { + let groups = provider.get_user_groups(username).unwrap_or_default(); + + let token = Uuid::new_v4().to_string(); + let now = Utc::now(); + let expires_at = now + Duration::hours(24); + + let session = Session { + token: token.clone(), + user_id: username.to_string(), + username: username.to_string(), + created_at: now.format("%Y-%m-%dT%H:%M:%SZ").to_string(), + expires_at: expires_at.format("%Y-%m-%dT%H:%M:%SZ").to_string(), + groups: groups.clone(), + permissions: user.permissions.clone(), + }; + + let mut sessions = self.sessions.lock().unwrap(); + sessions.insert(token.clone(), session); + + log::info!("User {} logged in via DataProvider", username); + + Some(LoginResponse { + token, + expires_at: expires_at.format("%Y-%m-%dT%H:%M:%SZ").to_string(), + user_id: username.to_string(), + groups, + permissions: user.permissions, + }) + } + Ok(false) => { + log::warn!("Invalid password for user {}", username); + None + } + Err(e) => { + log::error!("Password check error for {}: {}", username, e); + None + } + } + } + Ok(None) => { + log::warn!("User {} not found", username); + None + } + Err(e) => { + log::error!("Provider error for {}: {}", username, e); + None + } + } + } + pub fn verify_token(&self, token: &str) -> Option { let sessions = self.sessions.lock().unwrap(); let session = sessions.get(token)?; - // Check expiration let expires_at = chrono::DateTime::parse_from_rfc3339(&session.expires_at) .ok()? .with_timezone(&Utc); diff --git a/markbase-core/src/cli/interface/ssh.rs b/markbase-core/src/cli/interface/ssh.rs index 2fc09ae..9e543cd 100644 --- a/markbase-core/src/cli/interface/ssh.rs +++ b/markbase-core/src/cli/interface/ssh.rs @@ -6,20 +6,29 @@ pub enum SshCommand { Start { #[arg(short, long, default_value = "2024")] port: u16, + + /// PostgreSQL connection string for SFTPGo-compatible auth (e.g. "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026") + #[arg(long)] + pg_conn: Option, }, } pub async fn handle_ssh_command(cmd: SshCommand) -> anyhow::Result<()> { match cmd { - SshCommand::Start { port } => { + SshCommand::Start { port, pg_conn } => { println!("=== MarkBase SSH Server (Hand-written Implementation) ==="); println!("Port: {}", port); println!("Implementation: SSH-2.0-MarkBaseSSH_1.0"); println!("Features: SSH + SFTP + SCP + rsync"); + if pg_conn.is_some() { + println!("Auth Provider: PostgreSQL (SFTPGo-compatible)"); + } else { + println!("Auth Provider: SQLite"); + } println!("Security: ⭐⭐⭐⭐⭐ (RustCrypto authoritative libraries)"); println!(); - - crate::ssh_server::server::run_ssh_server(Some(port))?; + + crate::ssh_server::server::run_ssh_server(Some(port), pg_conn.as_deref())?; } } Ok(()) diff --git a/markbase-core/src/config/mod.rs b/markbase-core/src/config/mod.rs new file mode 100644 index 0000000..06a049b --- /dev/null +++ b/markbase-core/src/config/mod.rs @@ -0,0 +1,233 @@ +pub mod web; + +use anyhow::Result; +use serde::{Deserialize, Serialize}; + +// Re-export web config for backward compatibility +pub use web::*; + +/// Unified application configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AppConfig { + #[serde(default)] + pub web: WebSection, + #[serde(default)] + pub s3: S3Section, + #[serde(default)] + pub sftp: SftpSection, + #[serde(default)] + pub ssh: SshSection, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WebSection { + pub host: String, + pub port: u16, + pub log_level: String, + pub auth_db_path: String, + pub users_db_dir: String, +} + +impl Default for WebSection { + fn default() -> Self { + Self { + host: "127.0.0.1".to_string(), + port: 11438, + log_level: "info".to_string(), + auth_db_path: "data/auth.sqlite".to_string(), + users_db_dir: "data/users".to_string(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct S3Section { + pub enabled: bool, + pub endpoint: String, + pub region: String, + pub require_auth: bool, + pub default_access_key: String, + pub default_secret_key: String, + pub keys_db_path: String, +} + +impl Default for S3Section { + fn default() -> Self { + Self { + enabled: true, + endpoint: "http://localhost:11438/s3".to_string(), + region: "us-east-1".to_string(), + require_auth: false, + default_access_key: "markbase_access_key_001".to_string(), + default_secret_key: "markbase_secret_key_xyz123".to_string(), + keys_db_path: "data/s3_keys.json".to_string(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SftpSection { + pub enabled: bool, + pub port: u16, + pub base_path: String, + pub auth_db_path: String, + pub max_connections: usize, + pub chunk_size: usize, + pub require_path_validation: bool, + pub audit_logging: bool, + pub path_traversal_protection: bool, + pub symlink_check: bool, +} + +impl Default for SftpSection { + fn default() -> Self { + Self { + enabled: true, + port: 2023, + base_path: "/Users/accusys/momentry/var/sftpgo/data".to_string(), + auth_db_path: "data/auth.sqlite".to_string(), + max_connections: 100, + chunk_size: 65536, + require_path_validation: true, + audit_logging: true, + path_traversal_protection: true, + symlink_check: true, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SshSection { + pub enabled: bool, + pub port: u16, + pub bind_address: String, + pub security_config_path: String, +} + +impl Default for SshSection { + fn default() -> Self { + Self { + enabled: true, + port: 2024, + bind_address: "127.0.0.1".to_string(), + security_config_path: "data/ssh_config.json".to_string(), + } + } +} + +impl AppConfig { + pub fn load(path: &str) -> Result { + let config_path = std::path::PathBuf::from(path); + + if !config_path.exists() { + log::warn!("Config file not found: {}, using defaults", path); + return Ok(Self::default()); + } + + let content = std::fs::read_to_string(&config_path)?; + let config: AppConfig = toml::from_str(&content)?; + log::info!("App config loaded from: {}", path); + Ok(config) + } + + pub fn load_default() -> Result { + Self::load("config/app.toml") + } + + pub fn save(&self, path: &str) -> Result<()> { + let config_path = std::path::PathBuf::from(path); + + if config_path.exists() { + let backup_path = config_path.with_extension("toml.bak"); + std::fs::copy(&config_path, &backup_path)?; + log::info!("Backup created: {}", backup_path.display()); + } + + let content = toml::to_string_pretty(self)?; + std::fs::write(&config_path, content)?; + log::info!("App config saved to: {}", path); + Ok(()) + } + + pub fn merge_env(&mut self) { + if let Ok(v) = std::env::var("MB_WEB_HOST") { + self.web.host = v; + } + if let Ok(v) = std::env::var("MB_WEB_PORT") { + if let Ok(p) = v.parse() { self.web.port = p; } + } + if let Ok(v) = std::env::var("MB_SSH_PORT") { + if let Ok(p) = v.parse() { self.ssh.port = p; } + } + if let Ok(v) = std::env::var("MB_SFTP_PORT") { + if let Ok(p) = v.parse() { self.sftp.port = p; } + } + if let Ok(v) = std::env::var("MB_S3_ENABLED") { + self.s3.enabled = v == "true" || v == "1"; + } + if let Ok(v) = std::env::var("MB_AUTH_DB") { + self.web.auth_db_path = v.clone(); + self.sftp.auth_db_path = v; + } + } +} + +impl Default for AppConfig { + fn default() -> Self { + Self { + web: WebSection::default(), + s3: S3Section::default(), + sftp: SftpSection::default(), + ssh: SshSection::default(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[test] + fn test_default_config() { + let config = AppConfig::default(); + assert_eq!(config.web.port, 11438); + assert_eq!(config.ssh.port, 2024); + assert_eq!(config.sftp.port, 2023); + assert_eq!(config.s3.region, "us-east-1"); + } + + #[test] + fn test_load_missing() { + let config = AppConfig::load("/tmp/nonexistent/config.toml").unwrap(); + assert_eq!(config.web.port, 11438); + } + + #[test] + fn test_merge_env() { + std::env::set_var("MB_WEB_PORT", "9090"); + std::env::set_var("MB_SSH_PORT", "2222"); + + let mut config = AppConfig::default(); + config.merge_env(); + + assert_eq!(config.web.port, 9090); + assert_eq!(config.ssh.port, 2222); + + std::env::remove_var("MB_WEB_PORT"); + std::env::remove_var("MB_SSH_PORT"); + } + + #[test] + fn test_save_and_load() { + let dir = TempDir::new().unwrap(); + let path = dir.path().join("test.toml"); + let path_str = path.to_string_lossy().to_string(); + + let config = AppConfig::default(); + config.save(&path_str).unwrap(); + + let loaded = AppConfig::load(&path_str).unwrap(); + assert_eq!(loaded.web.port, 11438); + } +} diff --git a/markbase-core/src/config.rs b/markbase-core/src/config/web.rs similarity index 99% rename from markbase-core/src/config.rs rename to markbase-core/src/config/web.rs index 2e78c8a..4f41c4f 100644 --- a/markbase-core/src/config.rs +++ b/markbase-core/src/config/web.rs @@ -67,13 +67,12 @@ impl MarkBaseConfig { } pub fn save(&self, path: &Path) -> Result<()> { - // Create backup before saving if path.exists() { let backup_path = path.with_extension("toml.bak"); std::fs::copy(path, &backup_path)?; log::info!("Backup created: {}", backup_path.display()); } - + let content = toml::to_string_pretty(self)?; std::fs::write(path, content)?; log::info!("Configuration saved to: {}", path.display()); diff --git a/markbase-core/src/lib.rs b/markbase-core/src/lib.rs index d47fb54..be97247 100644 --- a/markbase-core/src/lib.rs +++ b/markbase-core/src/lib.rs @@ -23,6 +23,8 @@ pub mod import_markdown; // Category View Module - 双视图管理(Phase 1) // pub mod ssh2_mod; // ssh2辅助模块(已禁用) pub mod ssh_server; // SSH服务器(Phase 1-9完成,正在修复编译错误)⭐⭐⭐⭐⭐ pub mod sync; +pub mod provider; // DataProvider抽象层(Phase 5) +pub mod vfs; // VFS抽象层(Phase 1-6重构计划) // Re-export from external filetree crate pub use filetree::node::FileNode; diff --git a/markbase-core/src/provider/mod.rs b/markbase-core/src/provider/mod.rs new file mode 100644 index 0000000..534cb66 --- /dev/null +++ b/markbase-core/src/provider/mod.rs @@ -0,0 +1,65 @@ +pub mod sqlite; +pub mod pg; + +use std::path::PathBuf; + +/// 用户信息 +#[derive(Debug, Clone)] +pub struct User { + pub username: String, + pub password_hash: String, + pub home_dir: PathBuf, + pub uid: u32, + pub gid: u32, + pub permissions: String, + pub status: i32, +} + +/// Provider 错误类型 +#[derive(Debug)] +pub enum ProviderError { + NotFound(String), + AuthFailed(String), + Internal(String), +} + +impl std::fmt::Display for ProviderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ProviderError::NotFound(msg) => write!(f, "Not found: {}", msg), + ProviderError::AuthFailed(msg) => write!(f, "Authentication failed: {}", msg), + ProviderError::Internal(msg) => write!(f, "Internal error: {}", msg), + } + } +} + +impl std::error::Error for ProviderError {} + +/// 数据提供者 trait(用户认证和配置) +pub trait DataProvider: Send + Sync { + /// 获取用户信息 + fn get_user(&self, username: &str) -> Result, ProviderError>; + + /// 验证用户密码 + fn check_password(&self, username: &str, password: &str) -> Result; + + /// 获取用户主目录 + fn get_home_dir(&self, username: &str) -> Result, ProviderError>; + + /// 获取用户组列表 + fn get_user_groups(&self, username: &str) -> Result, ProviderError> { + let _ = username; + Ok(Vec::new()) + } + + /// 检查用户是否存在且启用 + fn user_exists(&self, username: &str) -> Result { + Ok(self.get_user(username)?.map(|u| u.status == 1).unwrap_or(false)) + } + + /// 获取用户的公开密钥列表(OpenSSH authorized_keys格式) + fn get_public_keys(&self, username: &str) -> Result, ProviderError> { + let _ = username; + Ok(Vec::new()) + } +} diff --git a/markbase-core/src/provider/pg.rs b/markbase-core/src/provider/pg.rs new file mode 100644 index 0000000..74915d0 --- /dev/null +++ b/markbase-core/src/provider/pg.rs @@ -0,0 +1,184 @@ +use std::path::PathBuf; +use postgres::{Client, NoTls}; +use bcrypt::verify; +use super::{DataProvider, ProviderError, User}; + +/// PostgreSQL 数据提供者(兼容 SFTPGo 的 users 表) +pub struct PgProvider { + conn_str: String, +} + +impl PgProvider { + /// 从连接字符串创建 PgProvider + /// + /// 连接字符串格式:host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026 + pub fn new(conn_str: &str) -> Result { + Ok(Self { conn_str: conn_str.to_string() }) + } + + pub fn from_params( + host: &str, + port: u16, + dbname: &str, + user: &str, + password: &str, + ) -> Result { + let conn_str = format!( + "host={} port={} dbname={} user={} password={}", + host, port, dbname, user, password + ); + Ok(Self { conn_str }) + } + + fn open_conn(&self) -> Result { + Client::connect(&self.conn_str, NoTls) + .map_err(|e| ProviderError::Internal(format!("PostgreSQL connect failed: {}", e))) + } +} + +impl DataProvider for PgProvider { + fn get_user(&self, username: &str) -> Result, ProviderError> { + let mut conn = self.open_conn()?; + + let result = conn.query_opt( + "SELECT username, password, home_dir, permissions, uid, gid, status + FROM users WHERE username = $1 AND status = 1", + &[&username], + ).map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?; + + match result { + Some(row) => Ok(Some(User { + username: row.get(0), + password_hash: row.get::<_, Option>(1).unwrap_or_default(), + home_dir: PathBuf::from(row.get::<_, String>(2)), + permissions: row.get::<_, Option>(3).unwrap_or_else(|| "*".to_string()), + uid: row.get::<_, i64>(4) as u32, + gid: row.get::<_, i64>(5) as u32, + status: row.get(6), + })), + None => Ok(None), + } + } + + fn check_password(&self, username: &str, password: &str) -> Result { + let hash = match self.get_user(username)? { + Some(user) => user.password_hash, + None => return Ok(false), + }; + + if hash.is_empty() { + return Ok(false); + } + + verify(password, &hash) + .map_err(|e| ProviderError::Internal(format!("bcrypt verify error: {}", e))) + } + + fn get_home_dir(&self, username: &str) -> Result, ProviderError> { + Ok(self.get_user(username)?.map(|u| u.home_dir.to_string_lossy().to_string())) + } + + fn get_public_keys(&self, username: &str) -> Result, ProviderError> { + let mut conn = self.open_conn()?; + let result = conn.query_opt( + "SELECT public_keys FROM users WHERE username = $1 AND status = 1", + &[&username], + ).map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?; + + match result { + Some(row) => { + let json_str: Option = row.get(0); + match json_str { + Some(s) if !s.is_empty() => { + let keys: Vec = serde_json::from_str(&s) + .map_err(|e| ProviderError::Internal(format!("JSON parse error: {}", e)))?; + Ok(keys.iter() + .filter_map(|v| v.get("public_key")?.as_str().map(|s| s.to_string())) + .collect()) + } + _ => Ok(Vec::new()), + } + } + None => Ok(Vec::new()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pg_provider_connection() { + // 仅当 SFTPGo PostgreSQL 可用时运行 + let provider = PgProvider::new( + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" + ); + assert!(provider.is_ok(), "Should connect to SFTPGo PostgreSQL"); + } + + #[test] + fn test_pg_get_user_demo() { + let provider = PgProvider::new( + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" + ).unwrap(); + let user = provider.get_user("demo").unwrap(); + assert!(user.is_some(), "Demo user should exist"); + assert_eq!(user.unwrap().username, "demo"); + } + + #[test] + fn test_pg_get_user_momentry() { + let provider = PgProvider::new( + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" + ).unwrap(); + let user = provider.get_user("momentry").unwrap(); + assert!(user.is_some(), "Momentry user should exist"); + } + + #[test] + fn test_pg_get_user_warren() { + let provider = PgProvider::new( + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" + ).unwrap(); + let user = provider.get_user("warren").unwrap(); + assert!(user.is_some(), "Warren user should exist"); + } + + #[test] + fn test_pg_check_password_demo() { + let provider = PgProvider::new( + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" + ).unwrap(); + let valid = provider.check_password("demo", "demo123").unwrap(); + assert!(valid, "Password should be valid"); + } + + #[test] + fn test_pg_check_password_invalid() { + let provider = PgProvider::new( + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" + ).unwrap(); + let valid = provider.check_password("demo", "wrong").unwrap(); + assert!(!valid, "Wrong password should fail"); + } + + #[test] + fn test_pg_get_home_dir() { + let provider = PgProvider::new( + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" + ).unwrap(); + let dir = provider.get_home_dir("demo").unwrap(); + assert!(dir.is_some()); + assert!(dir.unwrap().contains("momentry")); + } + + #[test] + fn test_pg_nonexistent_user() { + let provider = PgProvider::new( + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" + ).unwrap(); + let user = provider.get_user("__nonexistent__").unwrap(); + assert!(user.is_none()); + } +} diff --git a/markbase-core/src/provider/sqlite.rs b/markbase-core/src/provider/sqlite.rs new file mode 100644 index 0000000..c8ba733 --- /dev/null +++ b/markbase-core/src/provider/sqlite.rs @@ -0,0 +1,135 @@ +use std::path::PathBuf; +use rusqlite::{Connection, params}; +use bcrypt::verify; +use super::{DataProvider, ProviderError, User}; + +/// SQLite 数据提供者 +pub struct SqliteProvider { + db_path: PathBuf, +} + +impl SqliteProvider { + pub fn new(db_path: &str) -> Result { + let path = PathBuf::from(db_path); + if !path.exists() { + return Err(ProviderError::NotFound(format!( + "Database not found: {}", db_path + ))); + } + Ok(Self { db_path: path }) + } + + fn open_conn(&self) -> Result { + Connection::open(&self.db_path) + .map_err(|e| ProviderError::Internal(format!("Failed to open database: {}", e))) + } +} + +impl DataProvider for SqliteProvider { + fn get_user(&self, username: &str) -> Result, ProviderError> { + let conn = self.open_conn()?; + + let result = conn.query_row( + "SELECT username, password_hash, home_dir, permissions, uid, gid, status + FROM sftpgo_users WHERE username = ?1 AND status = 1", + params![username], + |row| { + Ok(User { + username: row.get(0)?, + password_hash: row.get(1)?, + home_dir: PathBuf::from(row.get::<_, String>(2)?), + permissions: row.get(3)?, + uid: row.get::<_, i64>(4)? as u32, + gid: row.get::<_, i64>(5)? as u32, + status: row.get(6)?, + }) + }, + ); + + match result { + Ok(user) => Ok(Some(user)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(ProviderError::Internal(format!( + "Database query error: {}", e + ))), + } + } + + fn check_password(&self, username: &str, password: &str) -> Result { + let hash = match self.get_user(username)? { + Some(user) => user.password_hash, + None => return Ok(false), + }; + + verify(password, &hash) + .map_err(|e| ProviderError::Internal(format!("bcrypt verify error: {}", e))) + } + + fn get_home_dir(&self, username: &str) -> Result, ProviderError> { + Ok(self.get_user(username)?.map(|u| u.home_dir.to_string_lossy().to_string())) + } + + fn get_public_keys(&self, username: &str) -> Result, ProviderError> { + let _ = username; + Ok(Vec::new()) + } + + fn get_user_groups(&self, username: &str) -> Result, ProviderError> { + let conn = self.open_conn()?; + let groups: Vec = conn + .prepare("SELECT group_name FROM users_groups_mapping WHERE username = ?1") + .map_err(|e| ProviderError::Internal(format!("Query prepare error: {}", e)))? + .query_map(params![username], |row| row.get(0)) + .map_err(|e| ProviderError::Internal(format!("Query map error: {}", e)))? + .filter_map(|r| r.ok()) + .collect(); + Ok(groups) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_provider_not_found() { + let provider = SqliteProvider::new("/tmp/nonexistent.db"); + assert!(provider.is_err()); + } + + #[test] + fn test_provider_get_user() { + let provider = SqliteProvider::new("data/auth.sqlite").unwrap(); + let user = provider.get_user("demo").unwrap(); + assert!(user.is_some()); + assert_eq!(user.unwrap().username, "demo"); + } + + #[test] + fn test_provider_nonexistent_user() { + let provider = SqliteProvider::new("data/auth.sqlite").unwrap(); + let user = provider.get_user("__nonexistent__").unwrap(); + assert!(user.is_none()); + } + + #[test] + fn test_check_password_valid() { + let provider = SqliteProvider::new("data/auth.sqlite").unwrap(); + let valid = provider.check_password("demo", "demo123").unwrap(); + assert!(valid); + } + + #[test] + fn test_check_password_invalid() { + let provider = SqliteProvider::new("data/auth.sqlite").unwrap(); + let valid = provider.check_password("demo", "wrong").unwrap(); + assert!(!valid); + } + + #[test] + fn test_get_home_dir() { + let provider = SqliteProvider::new("data/auth.sqlite").unwrap(); + let dir = provider.get_home_dir("demo").unwrap(); + assert!(dir.is_some()); + } +} diff --git a/markbase-core/src/server.rs b/markbase-core/src/server.rs index 3052853..51e7bc4 100644 --- a/markbase-core/src/server.rs +++ b/markbase-core/src/server.rs @@ -13,6 +13,7 @@ use std::sync::{Arc, Mutex}; use crate::audio; use crate::auth::{AuthState, LoginRequest}; +use crate::provider::sqlite::SqliteProvider; use crate::render; use crate::download; use crate::archive::{self, ArchiveFormat, ArchiveProcessor, FormatDetector, ArchiveConfig, ProcessorRegistry}; @@ -57,7 +58,10 @@ pub async fn run(port: u16, file: Option) -> anyhow::Result<()> { }))), labels: Arc::new(Mutex::new(vec![])), db_dir: "data/users".to_string(), - auth: AuthState::with_sync("data/auth.sqlite"), + auth: AuthState::with_provider(Box::new( + SqliteProvider::new("data/auth.sqlite") + .map_err(|e| anyhow::anyhow!("Failed to init SqliteProvider: {}", e))? + )), auth_db_path: "data/auth.sqlite".to_string(), s3_keys: Arc::new(Mutex::new(load_s3_keys())), }; diff --git a/markbase-core/src/ssh2_server/server.rs b/markbase-core/src/ssh2_server/server.rs index e56f398..3db75c5 100644 --- a/markbase-core/src/ssh2_server/server.rs +++ b/markbase-core/src/ssh2_server/server.rs @@ -1,9 +1,9 @@ // ssh2 Server核心实现 // 替代russh,提供完整的SSH/SFTP/SCP/rsync支持 -use crate::sftp::auth::SftpAuth; +use crate::provider::sqlite::SqliteProvider; use crate::sftp::config::SftpConfig; -use anyhow::{Result, anyhow}; +use anyhow::{Result, anyhow, Context}; use log::{info, warn, error}; use ssh2::Session; use std::net::{TcpListener, TcpStream}; @@ -106,9 +106,10 @@ fn authenticate_client(session: &Session, config: &Arc) -> Result, } impl AuthHandler { /// 创建认证处理器 - pub fn new() -> Result { - let db_path = "data/auth.sqlite".to_string(); - - // 验证数据库是否存在 - let conn = Connection::open(&db_path)?; - drop(conn); // rusqlite会自动关闭 - - info!("AuthHandler initialized with database: {}", db_path); - Ok(Self { db_path }) + pub fn new(provider: Box) -> Self { + info!("AuthHandler initialized with DataProvider"); + Self { provider } } - + + /// 获取用户home目录(SFTPGo兼容) + pub fn get_home_dir(&self, username: &str) -> Result, ProviderError> { + self.provider.get_home_dir(username) + } + /// 处理SSH_MSG_USERAUTH_REQUEST(参考OpenSSH auth2.c: userauth_request()) - pub fn handle_userauth_request(&mut self, packet: &SshPacket) -> Result { + pub fn handle_userauth_request(&mut self, packet: &SshPacket, session_id: &[u8]) -> Result { info!("Processing SSH_MSG_USERAUTH_REQUEST"); - + let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); - - // Packet type + let packet_type = cursor.read_u8()?; if packet_type != PacketType::SSH_MSG_USERAUTH_REQUEST as u8 { return Err(anyhow!("Invalid packet type for USERAUTH_REQUEST")); } - - // 读取用户名(SSH string) + let user = read_ssh_string(&mut cursor)?; - - // 读取服务名称(SSH string) let service = read_ssh_string(&mut cursor)?; - - // 读取认证方法名称(SSH string) let method = read_ssh_string(&mut cursor)?; - + info!("Auth request: user={}, service={}, method={}", user, service, method); - - // 检查服务名称(OpenSSH要求:ssh-connection) + if service != "ssh-connection" { warn!("Unsupported service: {}", service); return Ok(AuthResult::Failure("Unsupported service".to_string())); } - - // 根据认证方法处理(参考OpenSSH auth2.c) + if method == "password" { self.handle_password_auth(&mut cursor, &user) } else if method == "publickey" { - self.handle_publickey_auth(&mut cursor, &user) + self.handle_publickey_auth(&mut cursor, &user, &service, session_id) } else if method == "none" { - // OpenSSH:none认证总是失败(用于查询支持的认证方法) - // 返回支持的认证方法列表:password, publickey warn!("None auth request - returning supported methods"); Ok(AuthResult::Failure("password,publickey".to_string())) } else { @@ -72,203 +60,254 @@ impl AuthHandler { Ok(AuthResult::Failure("Unsupported auth method".to_string())) } } - + /// 处理password认证(参考OpenSSH auth-passwd.c) fn handle_password_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result { info!("Handling password auth for user: {}", user); - - // 读取是否修改密码标志(boolean,OpenSSH password认证格式) + let change_password = cursor.read_u8()? != 0; - if change_password { warn!("Password change not supported"); return Ok(AuthResult::Failure("Password change not supported".to_string())); } - - // 读取密码(SSH string) + let password = read_ssh_string(cursor)?; - + debug!("Password auth attempt: user={}, password length={}", user, password.len()); - - // 查询数据库获取password_hash - let conn = Connection::open(&self.db_path)?; - - let password_hash_result = conn.query_row( - "SELECT password_hash FROM sftpgo_users WHERE username = ?1 AND status = 1", - params![user], - |row| row.get::<_, String>(0) - ); - - // 关闭连接(rusqlite会自动关闭) - drop(conn); - - // 验证用户是否存在 - let password_hash = match password_hash_result { - Ok(hash) => Some(hash), - Err(rusqlite::Error::QueryReturnedNoRows) => None, - Err(e) => return Err(anyhow!("Database query error: {}", e)), - }; - - if password_hash.is_none() { - warn!("User not found or disabled: {}", user); - // SSH_MSG_USERAUTH_FAILURE必须返回可继续使用的认证方法列表(RFC 4253) - return Ok(AuthResult::Failure("password,publickey".to_string())); - } - - // 使用bcrypt验证密码 - let stored_hash = password_hash.unwrap(); - info!("Attempting bcrypt verify: password='{}', hash='{}'", password, stored_hash); - let valid = verify(&password, &stored_hash)?; - info!("bcrypt verify result: {}", valid); - - if valid { - info!("Password auth successful for user: {}", user); - Ok(AuthResult::Success) - } else { - warn!("Password auth failed for user: {}", user); - // SSH_MSG_USERAUTH_FAILURE必须返回可继续使用的认证方法列表(RFC 4253) - Ok(AuthResult::Failure("password,publickey".to_string())) + + match self.provider.check_password(user, &password) { + Ok(true) => { + info!("Password auth successful for user: {}", user); + Ok(AuthResult::Success) + } + Ok(false) => { + warn!("Password auth failed for user: {}", user); + Ok(AuthResult::Failure("password,publickey".to_string())) + } + Err(ProviderError::NotFound(msg)) => { + warn!("User not found: {}", msg); + Ok(AuthResult::Failure("password,publickey".to_string())) + } + Err(e) => { + Err(anyhow!("Password auth error: {}", e)) + } } } - - /// 构建SSH_MSG_USERAUTH_SUCCESS packet(参考OpenSSH auth2.c) + + /// 构建SSH_MSG_USERAUTH_SUCCESS packet pub fn build_userauth_success() -> Result { let payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8]; Ok(SshPacket::new(payload)) } - - /// 构建SSH_MSG_USERAUTH_FAILURE packet(参考OpenSSH auth2.c) + + /// 构建SSH_MSG_USERAUTH_FAILURE packet pub fn build_userauth_failure(methods: &[String], partial_success: bool) -> Result { let mut payload = Vec::new(); - - // Packet type + payload.write_u8(PacketType::SSH_MSG_USERAUTH_FAILURE as u8)?; - - // 认证方法列表(SSH string,逗号分隔) + let methods_str = methods.join(","); payload.write_u32::(methods_str.len() as u32)?; payload.write_all(methods_str.as_bytes())?; - - // partial_success标志(boolean) + payload.write_u8(if partial_success { 1 } else { 0 })?; - + Ok(SshPacket::new(payload)) } - - /// 构建SSH_MSG_USERAUTH_BANNER packet(可选,参考OpenSSH auth2.c) + + /// 构建SSH_MSG_USERAUTH_BANNER packet pub fn build_userauth_banner(message: &str, language: &str) -> Result { let mut payload = Vec::new(); - - // Packet type + payload.write_u8(PacketType::SSH_MSG_USERAUTH_BANNER as u8)?; - - // Banner message(SSH string) + payload.write_u32::(message.len() as u32)?; payload.write_all(message.as_bytes())?; - - // Language tag(SSH string) + payload.write_u32::(language.len() as u32)?; payload.write_all(language.as_bytes())?; - + Ok(SshPacket::new(payload)) } - - /// 处理publickey认证(Phase 9:参考OpenSSH auth2-pubkey.c) - fn handle_publickey_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result { + + /// 处理publickey认证(RFC 4252 §7) + /// 支持Ed25519签名验证 + 数据库/filesystem密钥查找 + fn handle_publickey_auth( + &mut self, + cursor: &mut std::io::Cursor<&[u8]>, + user: &str, + service: &str, + session_id: &[u8], + ) -> Result { info!("Handling publickey auth for user: {}", user); - - // 读取是否签名的标志(boolean) + let is_signed = cursor.read_u8()? != 0; - - // 读取public key algorithm(SSH string) let algorithm = read_ssh_string(cursor)?; - - // 读取public key blob(SSH string) let public_key_blob = read_ssh_string_bytes(cursor)?; - + info!("Publickey auth: algorithm={}, blob_len={}, is_signed={}", algorithm, public_key_blob.len(), is_signed); - - // Phase 9:简化实现 - 从authorized_keys文件验证 - let authorized_keys_path = format!("data/{}/authorized_keys", user); - let authorized_keys = match std::fs::read_to_string(&authorized_keys_path) { - Ok(content) => content, - Err(_) => { - // 尝试默认路径 - let default_path = "data/authorized_keys"; - match std::fs::read_to_string(default_path) { - Ok(content) => content, - Err(_) => { - warn!("No authorized_keys file found for user: {}", user); - return Ok(AuthResult::Failure("password,publickey".to_string())); - } - } - } - }; - - // 解析authorized_keys,查找匹配的public key - let public_key_matches = authorized_keys.lines().any(|line| { - let line = line.trim(); - if line.is_empty() || line.starts_with('#') { - return false; - } - - // SSH authorized_keys格式:algorithm base64-key comment - let parts: Vec<&str> = line.split_whitespace().collect(); - if parts.len() < 2 { - return false; - } - - let key_algorithm = parts[0]; - let key_base64 = parts[1]; - - // 匹配algorithm - if key_algorithm != algorithm { - return false; - } - - // 匹配public key blob(base64解码对比) - match base64_decode(key_base64) { - Ok(decoded_key) => decoded_key == public_key_blob, - Err(_) => false, - } - }); - - if !public_key_matches { + + if !self.is_key_authorized(user, &algorithm, &public_key_blob)? { warn!("Public key not authorized for user: {}", user); return Ok(AuthResult::Failure("password,publickey".to_string())); } - + info!("Public key authorized for user: {}", user); - - // 如果没有签名,返回PK_OK(query阶段) + if !is_signed { - // SSH_MSG_USERAUTH_PK_OK:表示public key可接受,client需要发送签名 return Ok(AuthResult::PublicKeyOk(algorithm, public_key_blob)); } - - // 读取signature(SSH string) - let signature = read_ssh_string_bytes(cursor)?; - - info!("Verifying signature for user: {}", user); - - // Phase 9:简化签名验证 - 信任authorized_keys - // 完整实现需要:提取session_id, 构建signed_data, verify signature - // 这里简化处理:只要public key匹配authorized_keys就接受 - + + let signature_blob = read_ssh_string_bytes(cursor)?; + + self.verify_signature(&algorithm, &public_key_blob, &signature_blob, user, service, session_id)?; + info!("Publickey auth successful for user: {}", user); Ok(AuthResult::Success) } + + /// 检查public key是否在授权列表中(数据库优先,fallback到filesystem) + fn is_key_authorized(&self, user: &str, algorithm: &str, public_key_blob: &[u8]) -> Result { + // 1. 先检查数据库 + match self.provider.get_public_keys(user) { + Ok(keys) => { + for key_line in &keys { + if public_key_matches_line(key_line, algorithm, public_key_blob) { + return Ok(true); + } + } + } + Err(e) => warn!("Failed to get public keys from provider: {}", e), + } + + // 2. Fallback到filesystem + let authorized_keys_path = format!("data/{}/authorized_keys", user); + let content = match std::fs::read_to_string(&authorized_keys_path) { + Ok(c) => c, + Err(_) => match std::fs::read_to_string("data/authorized_keys") { + Ok(c) => c, + Err(_) => return Ok(false), + } + }; + + Ok(content.lines().any(|line| public_key_matches_line(line, algorithm, public_key_blob))) + } + + /// 验证Ed25519签名(RFC 4252 §7) + fn verify_signature( + &self, + algorithm: &str, + public_key_blob: &[u8], + signature_blob: &[u8], + user: &str, + service: &str, + session_id: &[u8], + ) -> Result<()> { + // 目前只支援Ed25519 + if algorithm != "ssh-ed25519" { + return Err(anyhow!("Unsupported public key algorithm: {}", algorithm)); + } + + let verifying_key = parse_ed25519_verifying_key(public_key_blob)?; + let signature = parse_ed25519_signature(signature_blob)?; + + // 建立签名验证数据(RFC 4252 §7) + let mut signed_data = Vec::new(); + + // string session identifier + signed_data.write_u32::(session_id.len() as u32)?; + signed_data.write_all(session_id)?; + + // byte SSH_MSG_USERAUTH_REQUEST + signed_data.write_u8(PacketType::SSH_MSG_USERAUTH_REQUEST as u8)?; + + // string user name + signed_data.write_u32::(user.len() as u32)?; + signed_data.write_all(user.as_bytes())?; + + // string service name + signed_data.write_u32::(service.len() as u32)?; + signed_data.write_all(service.as_bytes())?; + + // string "publickey" + const PUBKEY_STR: &str = "publickey"; + signed_data.write_u32::(PUBKEY_STR.len() as u32)?; + signed_data.write_all(PUBKEY_STR.as_bytes())?; + + // boolean TRUE + signed_data.write_u8(1)?; + + // string public key algorithm name + signed_data.write_u32::(algorithm.len() as u32)?; + signed_data.write_all(algorithm.as_bytes())?; + + // string public key blob + signed_data.write_u32::(public_key_blob.len() as u32)?; + signed_data.write_all(public_key_blob)?; + + // 验证签名 + verifying_key.verify_strict(&signed_data, &signature) + .map_err(|e| anyhow!("Ed25519 signature verification failed: {}", e)) + } } -/// SSH认证结果(参考OpenSSH auth2.c) +/// SSH认证结果 pub enum AuthResult { Success, - Failure(String), // 失败原因 - PartialSuccess, // 部分成功(多步骤认证) - PublicKeyOk(String, Vec), // Public key acceptable (algorithm, blob) + Failure(String), + PartialSuccess, + PublicKeyOk(String, Vec), +} + +/// 解析Ed25519公钥blob(SSH格式 -> VerifyingKey) +fn parse_ed25519_verifying_key(public_key_blob: &[u8]) -> Result { + let mut cursor = std::io::Cursor::new(public_key_blob); + let algorithm = read_ssh_string(&mut cursor)?; + if algorithm != "ssh-ed25519" { + return Err(anyhow!("Unsupported algorithm: {}", algorithm)); + } + let key_bytes = read_ssh_string_bytes(&mut cursor)?; + if key_bytes.len() != 32 { + return Err(anyhow!("Invalid Ed25519 key length: {}", key_bytes.len())); + } + let key_array: [u8; 32] = key_bytes.try_into() + .map_err(|_| anyhow!("Invalid Ed25519 key data"))?; + VerifyingKey::from_bytes(&key_array) + .map_err(|e| anyhow!("Invalid Ed25519 key: {}", e)) +} + +/// 解析Ed25519签名blob(SSH格式 -> Signature) +fn parse_ed25519_signature(signature_blob: &[u8]) -> Result { + let mut cursor = std::io::Cursor::new(signature_blob); + let algorithm = read_ssh_string(&mut cursor)?; + if algorithm != "ssh-ed25519" { + return Err(anyhow!("Unsupported signature algorithm: {}", algorithm)); + } + let sig_bytes = read_ssh_string_bytes(&mut cursor)?; + if sig_bytes.len() != 64 { + return Err(anyhow!("Invalid Ed25519 signature length: {}", sig_bytes.len())); + } + let sig_array: [u8; 64] = sig_bytes.try_into() + .map_err(|_| anyhow!("Invalid Ed25519 signature data"))?; + Ok(Signature::from_bytes(&sig_array)) +} + +/// 检查一行authorized_keys格式的密钥是否匹配 +fn public_key_matches_line(line: &str, algorithm: &str, public_key_blob: &[u8]) -> bool { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + return false; + } + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() < 2 { + return false; + } + if parts[0] != algorithm { + return false; + } + base64_decode(parts[1]).map(|decoded| decoded == public_key_blob).unwrap_or(false) } -/// SSH string读取辅助函数 fn read_ssh_string(reader: &mut R) -> Result { let length = reader.read_u32::()?; let mut buffer = vec![0u8; length as usize]; @@ -276,7 +315,6 @@ fn read_ssh_string(reader: &mut R) -> Result { Ok(String::from_utf8(buffer)?) } -/// SSH string读取辅助函数(bytes版本) fn read_ssh_string_bytes(reader: &mut R) -> Result> { let length = reader.read_u32::()?; let mut buffer = vec![0u8; length as usize]; @@ -284,9 +322,7 @@ fn read_ssh_string_bytes(reader: &mut R) -> Result> { Ok(buffer) } -/// Base64解码辅助函数(Phase 9) fn base64_decode(input: &str) -> Result> { - use base64::{Engine as _, engine::general_purpose}; general_purpose::STANDARD.decode(input) .map_err(|e| anyhow!("Base64 decode error: {}", e)) } @@ -294,18 +330,19 @@ fn base64_decode(input: &str) -> Result> { #[cfg(test)] mod tests { use super::*; - + use crate::provider::sqlite::SqliteProvider; + #[test] fn test_userauth_success_packet() { let packet = AuthHandler::build_userauth_success().unwrap(); assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_SUCCESS as u8); } - + #[test] fn test_userauth_failure_packet() { let methods = vec!["password".to_string(), "publickey".to_string()]; let packet = AuthHandler::build_userauth_failure(&methods, false).unwrap(); - + assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_FAILURE as u8); } } diff --git a/markbase-core/src/ssh_server/channel.rs b/markbase-core/src/ssh_server/channel.rs index 29551eb..e1ccd18 100644 --- a/markbase-core/src/ssh_server/channel.rs +++ b/markbase-core/src/ssh_server/channel.rs @@ -25,6 +25,8 @@ pub struct ChannelManager { next_channel_id: u32, /// ⭐⭐⭐⭐⭐ Phase 15.1: 待发送packet队列(用于同时发送WINDOW_ADJUST和SFTP响应) pub pending_packets: VecDeque, + /// 用户home目录(SFTP/SCP/rsync根目录,SFTPGo兼容) + pub home_dir: PathBuf, } /// Phase 14: 交互式Exec进程管理(参考OpenSSH session.c: do_exec_no_pty) @@ -40,11 +42,12 @@ pub struct ExecProcess { } impl ChannelManager { - pub fn new() -> Self { + pub fn new(home_dir: PathBuf) -> Self { Self { channels: HashMap::new(), next_channel_id: 0, pending_packets: VecDeque::new(), + home_dir, } } @@ -371,9 +374,12 @@ impl ChannelManager { info!("⭐⭐⭐⭐⭐ [{}_EXEC_START] Starting interactive process: {}", process_type, command); // 启动子进程(相当于OpenSSH fork) + // ⭐⭐⭐⭐⭐ Phase 17: 设置工作目录为用户home_dir(SFTPGo兼容) + let home_dir = self.home_dir.clone(); let mut child = Command::new("sh") .arg("-c") .arg(command) + .current_dir(&home_dir) .stdin(Stdio::piped()) // ← 创建stdin管道(相当于pipe(pin)) .stdout(Stdio::piped()) // ← 创建stdout管道(相当于pipe(pout)) .stderr(Stdio::piped()) // ← 创建stderr管道(相当于pipe(perr)) @@ -446,8 +452,8 @@ impl ChannelManager { if subsystem == "sftp" { info!("SFTP subsystem requested"); - // Phase 7: 初始化SFTP handler - let root_dir = PathBuf::from("/Users/accusys/markbase"); // 默认root目录 + // Phase 7: 初始化SFTP handler(使用用户home目录,SFTPGo兼容) + let root_dir = self.home_dir.clone(); // ⭐⭐⭐⭐⭐ Phase 4: 获取 client maxpack 限制(从 Channel 中获取) let maxpacket = if let Some(ch) = self.channels.get(&channel) { @@ -456,7 +462,8 @@ impl ChannelManager { 32768 // OpenSSH 默认值(32KB) }; - let sftp_handler = SftpHandler::new(root_dir, maxpacket); // ⭐⭐⭐⭐⭐ Phase 4: 传入 maxpack + let vfs = Box::new(crate::vfs::local_fs::LocalFs::new()); + let sftp_handler = SftpHandler::new(root_dir, vfs, maxpacket); // ⭐⭐⭐⭐⭐ Phase 4: 传入 maxpack // 存储到channel if let Some(ch) = self.channels.get_mut(&channel) { @@ -952,6 +959,22 @@ impl ChannelManager { false } + /// Phase 17: 关闭所有子进程stdin(收到CHANNEL_EOF时调用) + /// SCP upload需要:scp -t 等待EOF on stdin才知道数据传输完毕 + pub fn close_child_stdin(&mut self) { + let channel_ids: Vec = self.channels.keys().copied().collect(); + for id in channel_ids { + if let Some(channel) = self.channels.get_mut(&id) { + if let Some(exec) = &mut channel.exec_process { + if let Some(stdin) = exec.stdin.take() { + drop(stdin); + info!("⭐⭐⭐⭐⭐ [CHANNEL_EOF] Closed child stdin (channel {})", id); + } + } + } + } + } + /// 获取channel输出(Phase 6新增) pub fn get_channel_output(&mut self, channel_id: u32) -> Option> { if let Some(channel) = self.channels.get_mut(&channel_id) { @@ -1283,6 +1306,7 @@ impl ChannelManager { // 4. 检查stdout/stderr fd是否有数据 let mut packets_data: Vec<(u32, Vec)> = Vec::new(); + let mut stderr_packets: Vec<(u32, Vec)> = Vec::new(); // Phase 17: stderr → CHANNEL_EXTENDED_DATA for (channel_id, (stdout_idx, stderr_idx)) in channel_fds_map { if let Some(channel) = self.channels.get_mut(&channel_id) { @@ -1325,7 +1349,8 @@ impl ChannelManager { Ok(n) if n > 0 => { info!("⭐⭐⭐⭐⭐ [AFTER stderr.read] Read {} bytes from stderr (channel {})", n, channel_id); info!("⭐⭐⭐⭐⭐ stderr content: {:?}", &buffer[..std::cmp::min(50, n)]); - packets_data.push((channel_id, buffer[..n].to_vec())); + // ⭐⭐⭐⭐⭐ Phase 17: stderr → SSH_MSG_CHANNEL_EXTENDED_DATA (data_type=1) + stderr_packets.push((channel_id, buffer[..n].to_vec())); } Ok(0) => { info!("stderr EOF (channel {}), closing stderr pipe", channel_id); @@ -1351,12 +1376,17 @@ impl ChannelManager { } // 构建packets - if !packets_data.is_empty() { + if !packets_data.is_empty() || !stderr_packets.is_empty() { let mut packets = Vec::new(); for (channel_id, data) in packets_data { let packet = self.build_channel_data(channel_id, &data)?; packets.push(packet); } + // Phase 17: stderr → SSH_MSG_CHANNEL_EXTENDED_DATA (data_type=1) + for (channel_id, data) in stderr_packets { + let packet = self.build_channel_extended_data(channel_id, 1, &data)?; + packets.push(packet); + } info!("⭐⭐⭐⭐⭐ Returning {} packets (stdout/stderr data)", packets.len()); return Ok((Some(packets), client_has_data, child_exited)); } @@ -1689,13 +1719,13 @@ mod tests { #[test] fn test_channel_manager_creation() { - let manager = ChannelManager::new(); + let manager = ChannelManager::new(PathBuf::from("/tmp")); assert_eq!(manager.next_channel_id, 0); } #[test] fn test_channel_open_confirmation() { - let manager = ChannelManager::new(); + let manager = ChannelManager::new(PathBuf::from("/tmp")); let packet = manager.build_channel_open_confirmation(0, 100, 2097152, 32768).unwrap(); assert_eq!(packet.payload[0], PacketType::SSH_MSG_CHANNEL_OPEN_CONFIRMATION as u8); @@ -1703,7 +1733,7 @@ mod tests { #[test] fn test_channel_success() { - let manager = ChannelManager::new(); + let manager = ChannelManager::new(PathBuf::from("/tmp")); let packet = manager.build_channel_success(0).unwrap(); assert_eq!(packet.payload[0], PacketType::SSH_MSG_CHANNEL_SUCCESS as u8); diff --git a/markbase-core/src/ssh_server/cipher.rs b/markbase-core/src/ssh_server/cipher.rs index b8c24be..70b2ea1 100644 --- a/markbase-core/src/ssh_server/cipher.rs +++ b/markbase-core/src/ssh_server/cipher.rs @@ -17,6 +17,7 @@ type HmacSha256 = Hmac; /// SSH加密通道管理器(参考OpenSSH struct sshcipher_ctx) pub struct EncryptionContext { + pub session_id: Vec, // session identifier (exchange hash) pub encryption_key_ctos: Vec, // 客户端→服务器加密密钥 pub encryption_key_stoc: Vec, // 服务器→客户端加密密钥 pub mac_key_ctos: Vec, // 客户端→服务器MAC密钥 @@ -32,6 +33,7 @@ pub struct EncryptionContext { impl Default for EncryptionContext { fn default() -> Self { Self { + session_id: vec![0u8; 32], encryption_key_ctos: vec![0u8; 32], encryption_key_stoc: vec![0u8; 32], mac_key_ctos: vec![0u8; 32], @@ -73,6 +75,7 @@ impl EncryptionContext { info!("Ciphers initialized successfully"); Self { + session_id: keys.session_id.clone(), encryption_key_ctos: keys.encryption_key_ctos.clone(), encryption_key_stoc: keys.encryption_key_stoc.clone(), mac_key_ctos: keys.mac_key_ctos.clone(), diff --git a/markbase-core/src/ssh_server/rsync_handler.rs b/markbase-core/src/ssh_server/rsync_handler.rs index c9e8516..957a970 100644 --- a/markbase-core/src/ssh_server/rsync_handler.rs +++ b/markbase-core/src/ssh_server/rsync_handler.rs @@ -1,8 +1,8 @@ use std::path::PathBuf; -use std::fs::{self, File}; -use std::io::Write; use anyhow::{Result, anyhow}; use log::{info, debug, warn}; +use crate::vfs::{VfsBackend, VfsFile, VfsError}; +use crate::vfs::open_flags::OpenFlags; /// MPLEX_BASE from rsync io.h const MPLEX_BASE: u32 = 7; @@ -27,23 +27,21 @@ pub(crate) enum RsyncState { pub struct RsyncHandler { state: RsyncState, - /// Raw input from SSH (multiplexed after version exchange) raw_input: Vec, - /// Decoded rsync protocol data (after stripping multiplex) rsync_input: Vec, - /// Raw rsync data to send (multiplex wrapping applied in drain_output) output_raw: Vec, dest_path: PathBuf, - output_file: Option, + output_file: Option>, total_written: u64, file_entries: Vec, current_file: usize, protocol_version: u32, multiplex: bool, + vfs: Box, } impl RsyncHandler { - pub fn parse_rsync_command(command: &str) -> Result { + pub fn parse_rsync_command(command: &str, vfs: Box) -> Result { let parts: Vec<&str> = command.split_whitespace().collect(); if parts.len() < 3 || parts[0] != "rsync" { return Err(anyhow!("Invalid rsync command: {}", command)); @@ -83,9 +81,9 @@ impl RsyncHandler { current_file: 0, protocol_version: 30, multiplex: false, + vfs, }; - // Send protocol version (4-byte LE int, no multiplex) handler.output_raw.extend_from_slice(&30u32.to_le_bytes()); handler.state = RsyncState::WaitVersion; @@ -129,7 +127,6 @@ impl RsyncHandler { } MSG_DONE => { info!("rsync: MSG_DONE received (file complete)"); - // Signal file completion by appending a sentinel to rsync_input self.rsync_input.extend_from_slice(b"RSYNCDONE"); } 9 => { @@ -147,7 +144,6 @@ impl RsyncHandler { if data.is_empty() || !self.multiplex { return data; } - // Wrap with multiplex header (MSG_DATA) let header = (MPLEX_BASE << 24) | (data.len() as u32); let mut wrapped = Vec::with_capacity(4 + data.len()); wrapped.extend_from_slice(&header.to_le_bytes()); @@ -180,7 +176,6 @@ impl RsyncHandler { loop { match self.state.clone() { RsyncState::SendVersion => { - // Version already sent in constructor self.transition(RsyncState::WaitVersion); } @@ -206,7 +201,6 @@ impl RsyncHandler { let flags = self.rsync_input[0]; if flags == 0 { - // End of file list self.rsync_input.drain(..1); info!("rsync: file list end ({} entries)", self.file_entries.len()); @@ -214,14 +208,12 @@ impl RsyncHandler { self.file_entries.push("file".to_string()); } self.current_file = 0; - // Enter sum head reading state self.transition(RsyncState::ReadSumHead { need: 20 }); break; } let mut pos = 1; - // Extended flags let _more_flags = if flags & 0x80 != 0 { if self.rsync_input.len() <= pos { break; } let ef = self.rsync_input[pos]; @@ -249,7 +241,6 @@ impl RsyncHandler { self.file_entries.push(name); } - // Skip metadata varints let skip_count = if flags & 0x10 == 0 { 1 } else { 0 } + if flags & 0x20 == 0 { 1 } else { 0 } + if flags & 0x40 == 0 { 1 } else { 0 } @@ -277,9 +268,6 @@ impl RsyncHandler { RsyncState::ReadSumHead { need } => { if self.rsync_input.len() >= need { - // Read sum head: count, blength, s2length, remainder (4 × LE int) - // + checksum seed (1 × LE int) - // = 5 × 4 = 20 bytes let sum_count = i32::from_le_bytes([ self.rsync_input[0], self.rsync_input[1], self.rsync_input[2], self.rsync_input[3], @@ -312,7 +300,6 @@ impl RsyncHandler { RsyncState::SendSumCount => { self.open_current_file()?; - // Send sum_count = 0 (4-byte LE int = we have no existing data) self.output_raw.extend_from_slice(&0u32.to_le_bytes()); info!("rsync: sent sum_count=0, ready to receive file data"); @@ -320,22 +307,17 @@ impl RsyncHandler { } RsyncState::ReadFileData => { - // Data comes as raw bytes inside MSG_DATA multiplex packets. - // MSG_DONE appends b"RSYNCDONE" to rsync_input. let done_marker = b"RSYNCDONE"; if let Some(pos) = self.rsync_input.windows(done_marker.len()) .position(|w| w == done_marker) { - // Data before the marker if pos > 0 { let data = self.rsync_input[..pos].to_vec(); self.rsync_input.drain(..pos); self.write_to_file(&data)?; } - // Remove marker self.rsync_input.drain(..done_marker.len()); - // Close file if let Some(mut file) = self.output_file.take() { if let Err(e) = file.flush() { warn!("rsync flush error: {}", e); @@ -353,11 +335,9 @@ impl RsyncHandler { info!("rsync ALL DONE: {} bytes written to {}", self.total_written, self.dest_path.display()); } else { - // Next file sum head self.transition(RsyncState::ReadSumHead { need: 20 }); } } else if !self.rsync_input.is_empty() { - // Partial data, keep it in buffer for more let data = self.rsync_input.clone(); self.rsync_input.clear(); self.write_to_file(&data)?; @@ -377,9 +357,11 @@ impl RsyncHandler { fn open_current_file(&mut self) -> Result<()> { if let Some(parent) = self.dest_path.parent() { - fs::create_dir_all(parent).ok(); + self.vfs.create_dir_all(parent, 0o755).ok(); } - let file = File::create(&self.dest_path)?; + let flags = OpenFlags::new().write().create().truncate(); + let file = self.vfs.open_file(&self.dest_path, &flags) + .map_err(|e| anyhow!("open error: {}", e))?; self.output_file = Some(file); info!("rsync: opened {} for writing", self.dest_path.display()); Ok(()) @@ -387,7 +369,8 @@ impl RsyncHandler { fn write_to_file(&mut self, data: &[u8]) -> Result<()> { if let Some(file) = &mut self.output_file { - file.write_all(data)?; + file.write_all(data) + .map_err(|e| anyhow!("write error: {}", e))?; self.total_written += data.len() as u64; } Ok(()) @@ -426,28 +409,37 @@ fn read_varint(buf: &[u8]) -> Option<(i32, usize)> { #[cfg(test)] mod tests { use super::*; + use crate::vfs::local_fs::LocalFs; + + fn make_vfs() -> Box { + Box::new(LocalFs::new()) + } #[test] fn test_parse_command() { - let h = RsyncHandler::parse_rsync_command("rsync --server -g -l -o -p -D -r -t -v --dirs . /tmp/upload.bin").unwrap(); + let h = RsyncHandler::parse_rsync_command( + "rsync --server -g -l -o -p -D -r -t -v --dirs . /tmp/upload.bin", + make_vfs() + ).unwrap(); assert_eq!(h.dest_path, PathBuf::from("/tmp/upload.bin")); } #[test] fn test_parse_command_sender() { - let h = RsyncHandler::parse_rsync_command("rsync --server --sender -vlogDtprz . /home/user/file.txt").unwrap(); + let h = RsyncHandler::parse_rsync_command( + "rsync --server --sender -vlogDtprz . /home/user/file.txt", + make_vfs() + ).unwrap(); assert_eq!(h.dest_path, PathBuf::from("/home/user/file.txt")); } #[test] fn test_version_exchange() { - let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin").unwrap(); - // Initial output: protocol version (30 as LE int) + let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap(); let output = h.drain_output(); assert_eq!(output, b"\x1e\x00\x00\x00"); assert_eq!(h.state, RsyncState::WaitVersion); - // Client sends its version (30 = 0x1E) h.feed(b"\x1e\x00\x00\x00").unwrap(); assert_eq!(h.state, RsyncState::ReadFileList); assert!(h.multiplex); @@ -455,9 +447,8 @@ mod tests { #[test] fn test_version_negotiate_down() { - let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin").unwrap(); + let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap(); let _ = h.drain_output(); - // Client has lower version (29) h.feed(b"\x1d\x00\x00\x00").unwrap(); assert_eq!(h.protocol_version, 29); assert_eq!(h.state, RsyncState::ReadFileList); @@ -471,24 +462,14 @@ mod tests { buf } - fn build_multiplex_done() -> Vec { - let header = (MPLEX_BASE << 24) | 0u32; // MSG_DONE (tag=1 → raw_tag=8) - let mut buf = Vec::new(); - buf.extend_from_slice(&header.to_le_bytes()); - buf - } - #[test] fn test_file_list_multiplex() { - let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin").unwrap(); + let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap(); let _ = h.drain_output(); - // Version exchange h.feed(b"\x1e\x00\x00\x00").unwrap(); assert!(h.multiplex); - // Build file list with multiplex wrapping let mut flist = Vec::new(); - // Entry: flags=0, name="test.txt\0", + 6 varints flist.push(0); flist.extend_from_slice(b"test.txt"); flist.push(0); @@ -514,46 +495,40 @@ mod tests { } } } - write_varint(&mut flist, 33188); // mode - write_varint(&mut flist, 501); // uid - write_varint(&mut flist, 20); // gid - write_varint(&mut flist, 1700000000); // time - write_varint(&mut flist, 100); // size - write_varint(&mut flist, 0); // checksum seed - // End marker + write_varint(&mut flist, 33188); + write_varint(&mut flist, 501); + write_varint(&mut flist, 20); + write_varint(&mut flist, 1700000000); + write_varint(&mut flist, 100); + write_varint(&mut flist, 0); flist.push(0); - // Sum head (5 ints = 20 bytes) as separate multiplex packet let mut sum_head = Vec::new(); - sum_head.extend_from_slice(&0i32.to_le_bytes()); // count - sum_head.extend_from_slice(&7000i32.to_le_bytes()); // blength - sum_head.extend_from_slice(&2i32.to_le_bytes()); // s2length - sum_head.extend_from_slice(&100i32.to_le_bytes()); // remainder - sum_head.extend_from_slice(&42i32.to_le_bytes()); // checksum_seed + sum_head.extend_from_slice(&0i32.to_le_bytes()); + sum_head.extend_from_slice(&7000i32.to_le_bytes()); + sum_head.extend_from_slice(&2i32.to_le_bytes()); + sum_head.extend_from_slice(&100i32.to_le_bytes()); + sum_head.extend_from_slice(&42i32.to_le_bytes()); - // Feed file list h.feed(&build_multiplex(&flist)).unwrap(); - assert_eq!(h.state, RsyncState::ReadFileList); // Still reading, 0x00 end marker triggered transition + assert_eq!(h.state, RsyncState::ReadFileList); assert_eq!(h.file_entries.len(), 1); - // Now feed sum head h.feed(&build_multiplex(&sum_head)).unwrap(); assert_eq!(h.state, RsyncState::SendSumCount); - // Send sum count response let sum_resp = h.drain_output(); - assert_eq!(sum_resp.len(), 8); // 4-byte header + 4-byte int + assert_eq!(sum_resp.len(), 8); assert_eq!(&sum_resp[4..8], &0u32.to_le_bytes()); assert_eq!(h.state, RsyncState::ReadFileData); } #[test] fn test_file_data_multiplex() { - let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin").unwrap(); + let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap(); let _ = h.drain_output(); - h.feed(b"\x1e\x00\x00\x00").unwrap(); // version + h.feed(b"\x1e\x00\x00\x00").unwrap(); - // Simple file list let mut flist = Vec::new(); flist.push(0); flist.extend_from_slice(b"test.bin"); @@ -568,7 +543,6 @@ mod tests { flist.push(0); h.feed(&build_multiplex(&flist)).unwrap(); - // Sum head let mut sh = Vec::new(); sh.extend_from_slice(&0i32.to_le_bytes()); sh.extend_from_slice(&7000i32.to_le_bytes()); @@ -576,16 +550,13 @@ mod tests { sh.extend_from_slice(&100i32.to_le_bytes()); sh.extend_from_slice(&42i32.to_le_bytes()); h.feed(&build_multiplex(&sh)).unwrap(); - let _ = h.drain_output(); // sum count response + let _ = h.drain_output(); - // File data + MSG_DONE let file_data = b"Hello, rsync protocol!"; h.feed(&build_multiplex(file_data)).unwrap(); assert_eq!(h.state, RsyncState::ReadFileData); - // MSG_DONE - // MSG_DONE has tag=1, so raw_tag = MPLEX_BASE + 1 = 8 - let done_header = (MPLEX_BASE + 1) << 24; // raw_tag = 8, len = 0 + let done_header = (MPLEX_BASE + 1) << 24; let done_bytes = done_header.to_le_bytes(); h.feed(&done_bytes).unwrap(); diff --git a/markbase-core/src/ssh_server/scp_handler.rs b/markbase-core/src/ssh_server/scp_handler.rs index 903d168..83752cc 100644 --- a/markbase-core/src/ssh_server/scp_handler.rs +++ b/markbase-core/src/ssh_server/scp_handler.rs @@ -1,12 +1,13 @@ // SCP协议实现(Phase 8) // 参考OpenSSH scp.c源码 +use crate::vfs::{VfsBackend, VfsFile, VfsError, VfsStat}; +use crate::vfs::open_flags::OpenFlags; use anyhow::{Result, anyhow}; use log::{info, warn, debug}; use std::path::{Path, PathBuf}; -use std::fs::{self, File, OpenOptions}; -use std::io::{Read, Write, BufReader, BufWriter, BufRead}; // 导入BufRead trait(OpenSSH标准) -use chrono::{DateTime, Utc}; +use std::io::{Read, Write, BufRead}; +use std::time::SystemTime; /// SCP Handler(参考OpenSSH scp.c) pub struct ScpHandler { @@ -14,6 +15,7 @@ pub struct ScpHandler { mode: ScpMode, recursive: bool, preserve_times: bool, + vfs: Box, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -23,24 +25,25 @@ pub enum ScpMode { } impl ScpHandler { - pub fn new(root_dir: PathBuf) -> Self { + pub fn new(root_dir: PathBuf, vfs: Box) -> Self { Self { root_dir, mode: ScpMode::Destination, recursive: false, preserve_times: false, + vfs, } } /// 解析SCP命令(参考OpenSSH scp.c: parse_command()) - pub fn parse_scp_command(command: &str) -> Result { + pub fn parse_scp_command(command: &str, vfs: Box) -> Result { let parts: Vec<&str> = command.split_whitespace().collect(); if parts.len() < 2 || parts[0] != "scp" { return Err(anyhow!("Invalid SCP command: {}", command)); } - let mut handler = ScpHandler::new(PathBuf::from("/tmp")); + let mut handler = ScpHandler::new(PathBuf::from("/tmp"), vfs); for part in &parts[1..] { match part { @@ -68,19 +71,19 @@ impl ScpHandler { /// SCP Source Mode(scp -f,发送文件) fn handle_source_mode(&self, channel: &mut dyn ReadWrite) -> Result<()> { - info!("SCP source mode: sending files from {}", self.root_dir.display()); // 使用display()(Rust标准) + info!("SCP source mode: sending files from {}", self.root_dir.display()); let full_path = self.resolve_path(&self.root_dir.to_string_lossy())?; + let stat = self.vfs.stat(&full_path) + .map_err(|e| anyhow!("stat error: {}", e))?; - if full_path.is_file() { - self.send_file(channel, &full_path)?; - } else if full_path.is_dir() { + if stat.is_dir { if !self.recursive { return Err(anyhow!("Directory detected but -r flag not specified")); } self.send_directory(channel, &full_path)?; } else { - return Err(anyhow!("Path does not exist: {}", full_path.display())); + self.send_file(channel, &full_path)?; } Ok(()) @@ -88,9 +91,8 @@ impl ScpHandler { /// SCP Destination Mode(scp -t,接收文件) fn handle_destination_mode(&mut self, channel: &mut dyn ReadWrite) -> Result<()> { - info!("SCP destination mode: receiving files to {}", self.root_dir.display()); // 使用display()(Rust标准) + info!("SCP destination mode: receiving files to {}", self.root_dir.display()); -// 发送确认('\0') channel.write_all(&[0])?; channel.flush()?; @@ -99,10 +101,9 @@ impl ScpHandler { loop { buffer.clear(); - // 每次循环创建新的reader(避免borrow冲突)- OpenSSH标准 - let mut reader = BufReader::new(&mut *channel); + let mut reader = std::io::BufReader::new(&mut *channel); match reader.read_line(&mut buffer)? { - 0 => break, // EOF + 0 => break, _ => { let command = buffer.trim(); debug!("SCP command: {}", command); @@ -113,7 +114,6 @@ impl ScpHandler { Some('E') => self.handle_end_directory(channel)?, Some('T') => self.handle_time_command(channel, command)?, Some('\0') => { - // 确认信号,继续 continue; } _ => { @@ -130,28 +130,30 @@ impl ScpHandler { /// 发送文件(参考OpenSSH scp.c: source()) fn send_file(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> { - let metadata = fs::metadata(path)?; - let size = metadata.len(); + let stat = self.vfs.stat(path) + .map_err(|e| anyhow!("stat error: {}", e))?; + let size = stat.size; let filename = path.file_name().unwrap().to_string_lossy(); - // 发送文件命令:C0644 size filename let command = format!("C0644 {} {}\n", size, filename); channel.write_all(command.as_bytes())?; channel.flush()?; - // 等待确认('\0') let mut ack = [0u8; 1]; channel.read_exact(&mut ack)?; if ack[0] != 0 { return Err(anyhow!("SCP file command rejected")); } - // 发送文件内容 - let file = File::open(path)?; - let mut reader = BufReader::new(file); + let flags = OpenFlags::new().read(); + let mut file = self.vfs.open_file(path, &flags) + .map_err(|e| anyhow!("open error: {}", e))?; + let mut buffer = vec![0u8; 8192]; - while let Ok(n) = reader.read(&mut buffer) { + loop { + let n = file.read(&mut buffer) + .map_err(|e| anyhow!("read error: {}", e))?; if n == 0 { break; } @@ -160,11 +162,9 @@ impl ScpHandler { channel.flush()?; - // 发送结束确认('\0') channel.write_all(&[0])?; channel.flush()?; - // 等待确认('\0') channel.read_exact(&mut ack)?; if ack[0] != 0 { return Err(anyhow!("SCP file transfer rejected")); @@ -178,35 +178,34 @@ impl ScpHandler { fn send_directory(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> { let dirname = path.file_name().unwrap().to_string_lossy(); - // 发送目录命令:D0755 0 dirname let command = format!("D0755 0 {}\n", dirname); channel.write_all(command.as_bytes())?; channel.flush()?; - // 等待确认('\0') let mut ack = [0u8; 1]; channel.read_exact(&mut ack)?; if ack[0] != 0 { return Err(anyhow!("SCP directory command rejected")); } - // 递归发送目录内容 - for entry in fs::read_dir(path)? { - let entry = entry?; - let full_path = entry.path(); + let entries = self.vfs.read_dir(path) + .map_err(|e| anyhow!("read_dir error: {}", e))?; - if full_path.is_file() { - self.send_file(channel, &full_path)?; - } else if full_path.is_dir() && self.recursive { - self.send_directory(channel, &full_path)?; + for entry in &entries { + let entry_path = path.join(&entry.name); + + if entry.stat.is_dir { + if self.recursive { + self.send_directory(channel, &entry_path)?; + } + } else { + self.send_file(channel, &entry_path)?; } } - // 发送结束目录命令:E channel.write_all("E\n".as_bytes())?; channel.flush()?; - // 等待确认('\0') channel.read_exact(&mut ack)?; if ack[0] != 0 { return Err(anyhow!("SCP end directory rejected")); @@ -224,31 +223,25 @@ impl ScpHandler { return self.send_error(channel, "Invalid file command format"); } - let mode = parts[0].trim_start_matches('C'); + let mode_str = parts[0].trim_start_matches('C'); let size: u64 = parts[1].parse()?; let filename = parts[2]; - debug!("SCP receive file: mode={}, size={}, name={}", mode, size, filename); + debug!("SCP receive file: mode={}, size={}, name={}", mode_str, size, filename); - // 安全性检查:文件大小限制(防止DoS) - if size > 1024 * 1024 * 1024 { // 1GB限制 + if size > 1024 * 1024 * 1024 { return self.send_error(channel, "File too large (max 1GB)"); } - // 创建文件 let full_path = self.resolve_path(filename)?; - let file = OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(&full_path)?; - // 发送确认('\0') + let flags = OpenFlags::new().write().create().truncate(); + let mut file = self.vfs.open_file(&full_path, &flags) + .map_err(|e| anyhow!("open error: {}", e))?; + channel.write_all(&[0])?; channel.flush()?; - // 接收文件内容 - let mut writer = BufWriter::new(file); let mut buffer = vec![0u8; 8192]; let mut remaining = size; @@ -258,25 +251,25 @@ impl ScpHandler { if n == 0 { break; } - writer.write_all(&buffer[..n])?; + file.write_all(&buffer[..n]) + .map_err(|e| anyhow!("write error: {}", e))?; remaining -= n as u64; } - writer.flush()?; + file.flush().map_err(|e| anyhow!("flush error: {}", e))?; // 设置文件权限 - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - let mode_int: u32 = mode.parse()?; - fs::set_permissions(&full_path, fs::Permissions::from_mode(mode_int))?; + let mode_int: u32 = mode_str.parse()?; + if mode_int != 0 { + let mut set_stat = VfsStat::new(); + set_stat.mode = mode_int; + self.vfs.set_stat(&full_path, &set_stat) + .map_err(|e| anyhow!("set_stat error: {}", e))?; } - // 接收结束确认('\0') let mut ack = [0u8; 1]; channel.read_exact(&mut ack)?; - // 发送确认('\0') channel.write_all(&[0])?; channel.flush()?; @@ -296,24 +289,17 @@ impl ScpHandler { return self.send_error(channel, "Recursive flag not specified"); } - let mode = parts[0].trim_start_matches('D'); + let mode_str = parts[0].trim_start_matches('D'); let dirname = parts[2]; - debug!("SCP receive directory: mode={}, name={}", mode, dirname); + debug!("SCP receive directory: mode={}, name={}", mode_str, dirname); - // 创建目录 let full_path = self.resolve_path(dirname)?; - fs::create_dir_all(&full_path)?; - // 设置目录权限 - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - let mode_int: u32 = mode.parse()?; - fs::set_permissions(&full_path, fs::Permissions::from_mode(mode_int))?; - } + let mode_int: u32 = mode_str.parse()?; + self.vfs.create_dir_all(&full_path, mode_int) + .map_err(|e| anyhow!("create_dir_all error: {}", e))?; - // 发送确认('\0') channel.write_all(&[0])?; channel.flush()?; @@ -325,7 +311,6 @@ impl ScpHandler { fn handle_end_directory(&self, channel: &mut dyn ReadWrite) -> Result<()> { debug!("SCP end directory"); - // 发送确认('\0') channel.write_all(&[0])?; channel.flush()?; @@ -335,7 +320,6 @@ impl ScpHandler { /// 处理时间命令(T mtime atime) fn handle_time_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> { if !self.preserve_times { - // 发送确认('\0'),但不设置时间 channel.write_all(&[0])?; channel.flush()?; return Ok(()); @@ -347,18 +331,14 @@ impl ScpHandler { return self.send_error(channel, "Invalid time command format"); } - let mtime: i64 = parts[1].parse()?; - let atime: i64 = parts[2].parse()?; + let mtime_secs: i64 = parts[1].parse()?; + let atime_secs: i64 = parts[2].parse()?; - debug!("SCP set times: mtime={}, atime={}", mtime, atime); + debug!("SCP set times: mtime={}, atime={}", mtime_secs, atime_secs); - // 发送确认('\0') channel.write_all(&[0])?; channel.flush()?; - // 时间设置将在文件接收完成后进行 - // (这里仅记录,实际设置在handle_file_command中) - Ok(()) } @@ -374,10 +354,13 @@ impl ScpHandler { fn resolve_path(&self, path: &str) -> Result { let full_path = self.root_dir.join(path); - let canonical_path = full_path.canonicalize() + let canonical_path = self.vfs.real_path(&full_path) .map_err(|e| anyhow!("Path resolution error: {}", e))?; - if !canonical_path.starts_with(&self.root_dir.canonicalize()?) { + let root_canonical = self.vfs.real_path(&self.root_dir) + .map_err(|e| anyhow!("Root path resolution error: {}", e))?; + + if !canonical_path.starts_with(&root_canonical) { return Err(anyhow!("Path traversal attempt detected")); } @@ -392,23 +375,28 @@ impl ReadWrite for T {} #[cfg(test)] mod tests { use super::*; + use crate::vfs::local_fs::LocalFs; + + fn make_handler() -> ScpHandler { + ScpHandler::new(PathBuf::from("/tmp"), Box::new(LocalFs::new())) + } #[test] fn test_scp_command_parse() { - let handler = ScpHandler::parse_scp_command("scp -t /tmp").unwrap(); + let handler = ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap(); assert_eq!(handler.mode, ScpMode::Destination); assert_eq!(handler.root_dir, PathBuf::from("/tmp")); } #[test] fn test_scp_recursive_parse() { - let handler = ScpHandler::parse_scp_command("scp -r -t /tmp").unwrap(); + let handler = ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap(); assert!(handler.recursive); } #[test] fn test_scp_source_parse() { - let handler = ScpHandler::parse_scp_command("scp -f /tmp").unwrap(); + let handler = ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap(); assert_eq!(handler.mode, ScpMode::Source); } -} \ No newline at end of file +} diff --git a/markbase-core/src/ssh_server/server.rs b/markbase-core/src/ssh_server/server.rs index 24b1563..f2f14a0 100644 --- a/markbase-core/src/ssh_server/server.rs +++ b/markbase-core/src/ssh_server/server.rs @@ -6,6 +6,9 @@ use crate::ssh_server::packet::{SshPacket, PacketType}; use crate::ssh_server::kex::{KexResult, KexProposal}; use crate::ssh_server::kex_complete::{KexState}; use crate::ssh_server::auth::{AuthHandler, AuthResult}; +use crate::provider::sqlite::SqliteProvider; +use crate::provider::pg::PgProvider; +use crate::provider::DataProvider; use crate::ssh_server::channel::{ChannelManager}; use crate::ssh_server::cipher::{EncryptionContext, EncryptedPacket}; use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1 @@ -13,6 +16,7 @@ use crate::ssh_server::port_forward::PortForwardManager; // Phase 13 use anyhow::{Result, anyhow}; use log::{info, warn, error, debug}; use std::net::{TcpListener, TcpStream}; +use std::path::PathBuf; use std::thread; use std::io::{Read, Write}; use std::sync::{Arc, Mutex}; // Phase 13: 端口转发线程同步 @@ -22,6 +26,7 @@ pub struct SshServerConfig { pub port: u16, pub bind_address: String, pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置 + pub pg_conn: Option, // PostgreSQL连接字符串(SFTPGo兼容认证) } impl Default for SshServerConfig { @@ -30,6 +35,7 @@ impl Default for SshServerConfig { port: 2024, bind_address: "127.0.0.1".to_string(), security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1 + pg_conn: None, } } } @@ -42,6 +48,7 @@ impl SshServerConfig { port: 2024, bind_address: "127.0.0.1".to_string(), security_config: config, + pg_conn: None, }) } } @@ -73,6 +80,7 @@ impl SshServer { self.config.security_config.max_sessions); let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置 + let pg_conn = self.config.pg_conn.clone(); for stream in listener.incoming() { match stream { @@ -81,9 +89,10 @@ impl SshServer { info!("New SSH connection from {}", client_addr); let security_config_clone = security_config.clone(); // Phase 13.1 + let pg_conn_clone = pg_conn.clone(); thread::spawn(move || { - if let Err(e) = handle_connection_complete(stream, security_config_clone) { // Phase 13.1 + if let Err(e) = handle_connection_complete(stream, security_config_clone, pg_conn_clone) { // Phase 13.1 error!("Connection error: {}", e); } }); @@ -99,7 +108,7 @@ impl SshServer { } /// 处理完整SSH连接(Phase 1-13完整流程) -fn handle_connection_complete(stream: TcpStream, security_config: Arc>) -> Result<()> { +fn handle_connection_complete(stream: TcpStream, security_config: Arc>, pg_conn: Option) -> Result<()> { info!("Handling client connection (Phase 1-13 complete flow with port forwarding)"); // Phase 13.1: 增加活动会话数 @@ -122,13 +131,22 @@ fn handle_connection_complete(stream: TcpStream, security_config: Arc = if let Some(ref conn_str) = pg_conn { + info!("Using PostgreSQL auth provider (SFTPGo-compatible): {}", conn_str); + Box::new(PgProvider::new(conn_str) + .map_err(|e| anyhow!("Failed to init PgProvider: {}", e))?) + } else { + info!("Using SQLite auth provider"); + Box::new(SqliteProvider::new("data/auth.sqlite") + .map_err(|e| anyhow!("Failed to init SqliteProvider: {}", e))?) + }; + let mut auth_handler = AuthHandler::new(provider); let auth_user = perform_ssh_auth(&mut stream, &mut auth_handler, &mut encryption_ctx)?; - info!("SSH authentication succeeded: user={}", auth_user); + info!("SSH authentication succeeded: user={}", auth_user.username); // Phase 6: SSH Channel管理(参考OpenSSH channel.c) - let mut channel_manager = ChannelManager::new(); + let mut channel_manager = ChannelManager::new(auth_user.home_dir.clone()); // Phase 13: PortForwardManager初始化 let mut port_forward_manager = PortForwardManager::new(); @@ -226,11 +244,16 @@ fn perform_complete_kex_exchange( } /// SSH认证流程(Phase 5) +pub struct AuthUser { + pub username: String, + pub home_dir: PathBuf, +} + fn perform_ssh_auth( stream: &mut TcpStream, auth_handler: &mut AuthHandler, encryption_ctx: &mut EncryptionContext, -) -> Result { +) -> Result { info!("Starting SSH authentication"); info!("Encryption context: key_ctos_len={}, key_stoc_len={}, iv_ctos_len={}, iv_stoc_len={}", encryption_ctx.encryption_key_ctos.len(), @@ -279,6 +302,8 @@ fn perform_ssh_auth( encrypted_accept.write(stream)?; info!("Sent encrypted SSH_MSG_SERVICE_ACCEPT"); + let session_id = encryption_ctx.session_id.clone(); + loop { let auth_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; // Reading from client, use cipher_ctos let auth_payload = auth_packet.payload(); @@ -286,7 +311,7 @@ fn perform_ssh_auth( let auth_request = SshPacket::new(auth_payload.to_vec()); - match auth_handler.handle_userauth_request(&auth_request)? { + match auth_handler.handle_userauth_request(&auth_request, &session_id)? { AuthResult::Success => { let success_payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8]; let encrypted_success = EncryptedPacket::new( @@ -297,7 +322,16 @@ fn perform_ssh_auth( encrypted_success.write(stream)?; info!("Sent encrypted SSH_MSG_USERAUTH_SUCCESS"); - return Ok("demo".to_string()); + // Extract username from auth request + let user = extract_username_from_auth_request(&auth_request) + .unwrap_or_else(|_| "unknown".to_string()); + let home_dir = auth_handler.get_home_dir(&user) + .ok() + .flatten() + .map(PathBuf::from) + .unwrap_or_else(|| PathBuf::from("/Users/accusys/markbase")); + info!("Auth success: user={}, home_dir={:?}", user, home_dir); + return Ok(AuthUser { username: user, home_dir }); } AuthResult::Failure(message) => { // message包含可用的认证方法列表(如"password,publickey") @@ -519,7 +553,9 @@ fn handle_ssh_service_loop( } Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_EOF as u8 => { info!("Received SSH_MSG_CHANNEL_EOF"); - // EOF means client won't send more data, just acknowledge and continue + // Phase 17: EOF means client won't send more data → close child stdin + // (Essential for SCP upload where scp -t waits for EOF on stdin) + channel_manager.close_child_stdin(); } Some(&pt) if pt == PacketType::SSH_MSG_DISCONNECT as u8 => { info!("Received SSH_MSG_DISCONNECT"); @@ -543,12 +579,27 @@ fn handle_ssh_service_loop( Ok(()) } +/// 从SSH_MSG_USERAUTH_REQUEST payload中提取用户名 +fn extract_username_from_auth_request(packet: &crate::ssh_server::packet::SshPacket) -> Result { + let payload = &packet.payload; + if payload.len() < 5 { + return Err(anyhow!("Auth request too short")); + } + let name_len = u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]) as usize; + if payload.len() < 5 + name_len { + return Err(anyhow!("Auth request truncated")); + } + let username = String::from_utf8_lossy(&payload[5..5 + name_len]).to_string(); + Ok(username) +} + /// SSH服务器CLI入口 -pub fn run_ssh_server(port: Option) -> Result<()> { +pub fn run_ssh_server(port: Option, pg_conn: Option<&str>) -> Result<()> { let config = SshServerConfig { port: port.unwrap_or(2024), bind_address: "127.0.0.1".to_string(), security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置 + pg_conn: pg_conn.map(|s| s.to_string()), }; let server = SshServer::new(config); diff --git a/markbase-core/src/ssh_server/sftp_handler.rs b/markbase-core/src/ssh_server/sftp_handler.rs index f948207..07d8f8e 100644 --- a/markbase-core/src/ssh_server/sftp_handler.rs +++ b/markbase-core/src/ssh_server/sftp_handler.rs @@ -2,14 +2,16 @@ // 参考OpenSSH sftp-server.c和draft-ietf-secsh-filexfer-02.txt use crate::ssh_server::packet::{SshPacket, PacketType}; +use crate::vfs::{VfsBackend, VfsFile, VfsDirEntry}; +use crate::vfs::open_flags::OpenFlags; use anyhow::{Result, anyhow, Context}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use log::{info, warn, debug}; use std::path::{Path, PathBuf}; -use std::fs::{self, File, OpenOptions}; -use std::io::{Read, Write, Seek, SeekFrom}; -use std::os::unix::fs::PermissionsExt; // 导入PermissionsExt trait(Unix标准) -use std::os::unix::fs::MetadataExt; // ⭐⭐⭐⭐⭐ Phase 2.2: 导入MetadataExt trait(获取uid/gid) +use std::fs; +use std::io::{SeekFrom, Write}; +use std::os::unix::fs::PermissionsExt; +use std::os::unix::fs::MetadataExt; /// SFTP packet类型(参考draft-ietf-secsh-filexfer-02.txt) #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -178,6 +180,30 @@ impl SftpAttrs { attrs } + pub fn from_vfs_stat(stat: &crate::vfs::VfsStat) -> Self { + let mut attrs = Self::new(); + + attrs.flags = SftpAttrFlags::SSH_FILEXFER_ATTR_SIZE + | SftpAttrFlags::SSH_FILEXFER_ATTR_UIDGID + | SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS + | SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME; + + attrs.size = Some(stat.size); + attrs.permissions = Some(stat.mode); + attrs.uid = Some(stat.uid); + attrs.gid = Some(stat.gid); + + if let Ok(d) = stat.atime.duration_since(std::time::UNIX_EPOCH) { + attrs.atime = Some(d.as_secs() as u32); + } + + if let Ok(d) = stat.mtime.duration_since(std::time::UNIX_EPOCH) { + attrs.mtime = Some(d.as_secs() as u32); + } + + attrs + } + pub fn serialize(&self) -> Result> { debug!("Serializing SftpAttrs: flags=0x{:08x}, size={:?}, uid={:?}, gid={:?}, permissions=0x{:08x}, atime={:?}, mtime={:?}", self.flags, self.size, self.uid, self.gid, @@ -242,13 +268,12 @@ impl SftpAttrs { } /// SFTP handle(文件或目录句柄) -#[derive(Debug)] // 移除Clone(File/DirEntry不支持Clone) pub struct SftpHandle { pub id: u32, pub path: PathBuf, pub handle_type: SftpHandleType, - pub file: Option, - pub dir_entries: Option>, + pub file: Option>, + pub dir_entries: Option>, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -260,6 +285,7 @@ pub enum SftpHandleType { /// SFTP处理管理器(参考OpenSSH sftp-server.c) pub struct SftpHandler { root_dir: PathBuf, + vfs: Box, next_handle_id: u32, handles: std::collections::HashMap, // ⭐⭐⭐⭐⭐ Phase 4: 添加 client maxpack 限制(参考OpenSSH sftp-server.c) @@ -277,14 +303,15 @@ impl SftpHandler { const MAX_HASH_SIZE: u64 = 268_435_456; // ⭐⭐⭐⭐⭐ Phase 4: 修改 new() 方法,接受 maxpack 参数 - pub fn new(root_dir: PathBuf, maxpacket: u32) -> Self { + pub fn new(root_dir: PathBuf, vfs: Box, maxpacket: u32) -> Self { let canonical_root = root_dir.canonicalize().unwrap_or(root_dir); Self { root_dir: canonical_root, + vfs, next_handle_id: 0, handles: std::collections::HashMap::new(), maxpacket, - restrict_absolute: false, // 默认允许绝对路径 + restrict_absolute: false, } } @@ -360,30 +387,9 @@ impl SftpHandler { info!("SSH_FXP_OPEN: id={}, path={}, pflags={:#x}", id, path, pflags); let full_path = self.resolve_path(&path)?; + let flags = OpenFlags::from_sftp_pflags(pflags); - let file_result = if pflags & SftpFileFlags::SSH_FXF_READ != 0 { - OpenOptions::new().read(true).open(&full_path) - } else if pflags & SftpFileFlags::SSH_FXF_WRITE != 0 { - let mut opts = OpenOptions::new(); - opts.write(true); - if pflags & SftpFileFlags::SSH_FXF_APPEND != 0 { - opts.append(true); - } - if pflags & SftpFileFlags::SSH_FXF_CREAT != 0 { - opts.create(true); - } - if pflags & SftpFileFlags::SSH_FXF_TRUNC != 0 { - opts.truncate(true); - } - if pflags & SftpFileFlags::SSH_FXF_EXCL != 0 { - opts.create_new(true); - } - opts.open(&full_path) - } else { - return self.build_status_response(id, SftpStatus::SSH_FX_OP_UNSUPPORTED, "Unsupported open flags"); - }; - - match file_result { + match self.vfs.open_file(&full_path, &flags) { Ok(file) => { if self.handles.len() >= Self::MAX_HANDLES { warn!("SSH_FXP_OPEN: handle limit reached ({})", Self::MAX_HANDLES); @@ -405,7 +411,7 @@ impl SftpHandler { self.build_handle_response(id, &handle_id.to_be_bytes()) } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } } @@ -447,9 +453,8 @@ impl SftpHandler { if let Some(handle) = self.handles.get_mut(&handle_id) { if let Some(ref mut file) = handle.file { - file.seek(SeekFrom::Start(offset))?; + file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?; - // ⭐⭐⭐⭐⭐ Phase 4: 限制数据大小,不超过 maxpacket - 1024 和 MAX_XFER_SIZE let max_data_size = std::cmp::min(self.maxpacket.saturating_sub(1024), Self::MAX_XFER_SIZE); let actual_length = std::cmp::min(length, max_data_size); @@ -465,7 +470,7 @@ impl SftpHandler { self.build_data_response(id, &buffer) } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } } else { @@ -491,7 +496,6 @@ impl SftpHandler { info!("SSH_FXP_WRITE: id={}, handle={}, offset={}, length={}", id, handle_id, offset, write_data.len()); - // ⭐⭐⭐⭐⭐ Phase 1.2: 添加 data preview(显示前 20 字节) if write_data.len() > 0 { let preview_len = std::cmp::min(20, write_data.len()); let preview = &write_data[0..preview_len]; @@ -500,14 +504,15 @@ impl SftpHandler { if let Some(handle) = self.handles.get_mut(&handle_id) { if let Some(ref mut file) = handle.file { - file.seek(SeekFrom::Start(offset))?; + file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?; match file.write_all(&write_data) { Ok(_) => { + file.flush().ok(); self.build_status_response(id, SftpStatus::SSH_FX_OK, "Write successful") } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } } else { @@ -532,13 +537,13 @@ impl SftpHandler { let full_path = self.resolve_path(&path)?; - match fs::symlink_metadata(&full_path) { - Ok(metadata) => { - let attrs = SftpAttrs::from_metadata(&metadata); + match self.vfs.lstat(&full_path) { + Ok(stat) => { + let attrs = SftpAttrs::from_vfs_stat(&stat); self.build_attrs_response(id, &attrs) } Err(e) => { - self.build_status_response(id, SftpStatus::SSH_FX_NO_SUCH_FILE, &format!("Stat error: {}", e)) + self.build_status_from_vfs_error(id, &e) } } } @@ -556,14 +561,26 @@ impl SftpHandler { info!("SSH_FXP_FSTAT: id={}, handle={}", id, handle_id); - if let Some(handle) = self.handles.get(&handle_id) { - match fs::metadata(&handle.path) { - Ok(metadata) => { - let attrs = SftpAttrs::from_metadata(&metadata); - self.build_attrs_response(id, &attrs) + if let Some(handle) = self.handles.get_mut(&handle_id) { + if let Some(ref mut file) = handle.file { + match file.stat() { + Ok(stat) => { + let attrs = SftpAttrs::from_vfs_stat(&stat); + self.build_attrs_response(id, &attrs) + } + Err(e) => { + self.build_status_from_vfs_error(id, &e) + } } - Err(e) => { - self.build_status_from_io_error(id, &e) + } else { + match self.vfs.stat(&handle.path) { + Ok(stat) => { + let attrs = SftpAttrs::from_vfs_stat(&stat); + self.build_attrs_response(id, &attrs) + } + Err(e) => { + self.build_status_from_vfs_error(id, &e) + } } } } else { @@ -585,7 +602,7 @@ impl SftpHandler { let full_path = self.resolve_path(&path)?; - match fs::read_dir(&full_path) { + match self.vfs.read_dir(&full_path) { Ok(entries) => { if self.handles.len() >= Self::MAX_HANDLES { warn!("SSH_FXP_OPENDIR: handle limit reached ({})", Self::MAX_HANDLES); @@ -594,14 +611,12 @@ impl SftpHandler { let handle_id = self.next_handle_id; self.next_handle_id += 1; - let dir_entries: Vec = entries.filter_map(|e| e.ok()).collect(); - let handle = SftpHandle { id: handle_id, path: full_path, handle_type: SftpHandleType::Directory, file: None, - dir_entries: Some(dir_entries), + dir_entries: Some(entries), }; self.handles.insert(handle_id, handle); @@ -609,7 +624,7 @@ impl SftpHandler { self.build_handle_response(id, &handle_id.to_be_bytes()) } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } } @@ -635,11 +650,9 @@ impl SftpHandler { } else { let entries: Vec<(String, SftpAttrs)> = dir_entries .drain(..std::cmp::min(100, dir_entries.len())) - .filter_map(|entry| { - let name = entry.file_name().to_string_lossy().to_string(); - let attrs = entry.metadata().ok()?; - let sftp_attrs = SftpAttrs::from_metadata(&attrs); - Some((name, sftp_attrs)) + .map(|entry| { + let attrs = SftpAttrs::from_vfs_stat(&entry.stat); + (entry.name, attrs) }) .collect(); @@ -670,12 +683,12 @@ impl SftpHandler { let full_path = self.resolve_path(&path)?; - match fs::remove_file(&full_path) { + match self.vfs.remove_file(&full_path) { Ok(_) => { self.build_status_response(id, SftpStatus::SSH_FX_OK, "File removed") } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } } @@ -695,12 +708,12 @@ impl SftpHandler { let full_path = self.resolve_path(&path)?; - match fs::create_dir(&full_path) { + match self.vfs.create_dir(&full_path, 0o755) { Ok(_) => { self.build_status_response(id, SftpStatus::SSH_FX_OK, "Directory created") } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } } @@ -719,12 +732,12 @@ impl SftpHandler { let full_path = self.resolve_path(&path)?; - match fs::remove_dir(&full_path) { + match self.vfs.remove_dir(&full_path) { Ok(_) => { self.build_status_response(id, SftpStatus::SSH_FX_OK, "Directory removed") } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } } @@ -765,13 +778,13 @@ impl SftpHandler { let full_path = self.resolve_path(&path)?; - match fs::metadata(&full_path) { - Ok(metadata) => { - let attrs = SftpAttrs::from_metadata(&metadata); + match self.vfs.stat(&full_path) { + Ok(stat) => { + let attrs = SftpAttrs::from_vfs_stat(&stat); self.build_attrs_response(id, &attrs) } Err(e) => { - self.build_status_response(id, SftpStatus::SSH_FX_NO_SUCH_FILE, &format!("Stat error: {}", e)) + self.build_status_from_vfs_error(id, &e) } } } @@ -792,12 +805,12 @@ impl SftpHandler { let old_full_path = self.resolve_path(&old_path)?; let new_full_path = self.resolve_path(&new_path)?; - match fs::rename(&old_full_path, &new_full_path) { + match self.vfs.rename(&old_full_path, &new_full_path) { Ok(_) => { self.build_status_response(id, SftpStatus::SSH_FX_OK, "Rename successful") } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } } @@ -832,7 +845,7 @@ impl SftpHandler { info!("SSH_FXP_FSETSTAT: id={}, handle={}, attrs.flags={}", id, handle_id, attrs.flags); - let handle = self.handles.get(&handle_id); + let handle = self.handles.get_mut(&handle_id); if handle.is_none() { return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid handle"); } @@ -847,25 +860,35 @@ impl SftpHandler { if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_SIZE != 0 { if let Some(size) = attrs.size { info!("FSETSTAT: setting file size to {}", size); - let file = OpenOptions::new().write(true).open(&path)?; - file.set_len(size)?; + if let Some(ref mut file) = handle.file { + file.set_len(size).map_err(|e| anyhow!("set_len error: {}", e))?; + } else { + let flags = OpenFlags::new().write(); + if let Ok(mut f) = self.vfs.open_file(&path, &flags) { + f.set_len(size).ok(); + } + } } } - if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0 { - if let Some(permissions) = attrs.permissions { - info!("FSETSTAT: setting permissions to {:o}", permissions); - fs::set_permissions(&path, fs::Permissions::from_mode(permissions))?; + if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0 + || attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0 + { + let mut vfs_stat = crate::vfs::VfsStat::new(); + if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0 { + vfs_stat.mode = attrs.permissions.unwrap_or(0); + } else { + if let Ok(s) = self.vfs.lstat(&path) { + vfs_stat.mode = s.mode; + } } - } - - if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0 { - if let (Some(atime), Some(mtime)) = (attrs.atime, attrs.mtime) { - info!("FSETSTAT: setting atime={}, mtime={}", atime, mtime); - let atime_filetime = filetime::FileTime::from_unix_time(atime as i64, 0); - let mtime_filetime = filetime::FileTime::from_unix_time(mtime as i64, 0); - filetime::set_file_times(&path, atime_filetime, mtime_filetime)?; + if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0 { + if let (Some(atime), Some(mtime)) = (attrs.atime, attrs.mtime) { + vfs_stat.atime = std::time::UNIX_EPOCH + std::time::Duration::from_secs(atime as u64); + vfs_stat.mtime = std::time::UNIX_EPOCH + std::time::Duration::from_secs(mtime as u64); + } } + self.vfs.set_stat(&path, &vfs_stat).ok(); } self.build_status_response(id, SftpStatus::SSH_FX_OK, "Fsetstat successful") @@ -885,13 +908,13 @@ impl SftpHandler { let full_path = self.resolve_path(&path)?; - match fs::read_link(&full_path) { + match self.vfs.read_link(&full_path) { Ok(link_target) => { let target = link_target.to_string_lossy().to_string(); self.build_name_response(id, vec![(target, SftpAttrs::default())]) } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } } @@ -912,18 +935,14 @@ impl SftpHandler { let full_linkpath = self.resolve_path(&linkpath)?; let full_targetpath = self.resolve_path(&targetpath)?; - #[cfg(unix)] - match std::os::unix::fs::symlink(&full_targetpath, &full_linkpath) { + match self.vfs.create_symlink(&full_targetpath, &full_linkpath) { Ok(_) => { self.build_status_response(id, SftpStatus::SSH_FX_OK, "Symlink created") } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } - - #[cfg(not(unix))] - self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Symlink not supported on non-Unix systems") } /// 处理SSH_FXP_EXTENDED(Phase 10:参考OpenSSH sftp-server.c: process_extended()) @@ -984,50 +1003,30 @@ impl SftpHandler { let full_path = self.resolve_path(&path)?; - #[cfg(unix)] - { - use std::os::unix::fs::MetadataExt; - - match fs::metadata(&full_path) { - Ok(metadata) => { - // 构建statvfs response(参考OpenSSH sftp-server.c) - let mut response = Vec::new(); - response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; - response.write_u32::(id)?; - - // f_bsize(文件系统块大小) - response.write_u64::(4096)?; - // f_frsize(基本块大小) - response.write_u64::(4096)?; - // f_blocks(总块数) - response.write_u64::(1000000)?; - // f_bfree(空闲块数) - response.write_u64::(500000)?; - // f_bavail(可用块数) - response.write_u64::(500000)?; - // f_files(总文件数) - response.write_u64::(100000)?; - // f_ffree(空闲文件数) - response.write_u64::(50000)?; - // f_favail(可用文件数) - response.write_u64::(50000)?; - // f_fsid(文件系统ID) - response.write_u64::(0)?; - // f_flag(标志) - response.write_u64::(0)?; - // f_namemax(文件名最大长度) - response.write_u64::(255)?; - - self.wrap_sftp_packet(&response) - } - Err(e) => { - self.build_status_from_io_error(id, &e) - } + match self.vfs.stat(&full_path) { + Ok(_) => { + let mut response = Vec::new(); + response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; + response.write_u32::(id)?; + + response.write_u64::(4096)?; + response.write_u64::(4096)?; + response.write_u64::(1000000)?; + response.write_u64::(500000)?; + response.write_u64::(500000)?; + response.write_u64::(100000)?; + response.write_u64::(50000)?; + response.write_u64::(50000)?; + response.write_u64::(0)?; + response.write_u64::(0)?; + response.write_u64::(255)?; + + self.wrap_sftp_packet(&response) + } + Err(e) => { + self.build_status_from_vfs_error(id, &e) } } - - #[cfg(not(unix))] - self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "statvfs not supported on non-Unix systems") } /// 处理fstatvfs@openssh.com扩展(文件句柄统计) @@ -1073,18 +1072,14 @@ impl SftpHandler { let full_oldpath = self.resolve_path(&oldpath)?; let full_newpath = self.resolve_path(&newpath)?; - #[cfg(unix)] - match fs::hard_link(&full_oldpath, &full_newpath) { + match self.vfs.hard_link(&full_oldpath, &full_newpath) { Ok(_) => { self.build_status_response(id, SftpStatus::SSH_FX_OK, "Hardlink created") } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } - - #[cfg(not(unix))] - self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Hardlink not supported on non-Unix systems") } /// 处理posix-rename@openssh.com扩展(POSIX语义重命名) @@ -1097,12 +1092,12 @@ impl SftpHandler { let full_oldpath = self.resolve_path(&oldpath)?; let full_newpath = self.resolve_path(&newpath)?; - match fs::rename(&full_oldpath, &full_newpath) { + match self.vfs.rename(&full_oldpath, &full_newpath) { Ok(_) => { self.build_status_response(id, SftpStatus::SSH_FX_OK, "Posix rename successful") } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } } @@ -1122,34 +1117,31 @@ impl SftpHandler { let full_path = self.resolve_path(&path)?; - match File::open(&full_path) { + let flags = OpenFlags::new().read(); + match self.vfs.open_file(&full_path, &flags) { Ok(mut file) => { - file.seek(SeekFrom::Start(offset))?; + file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?; let mut buffer = vec![0u8; actual_length as usize]; - file.read_exact(&mut buffer)?; + file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?; - // 计算MD5哈希 let hash = md5::compute(&buffer); let hash_hex = format!("{:x}", hash); - // 构建响应 let mut response = Vec::new(); response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; response.write_u32::(id)?; - // hash-algorithm (SSH string) response.write_u32::(4)?; response.write_all("md5".as_bytes())?; - // hash-value (SSH string) response.write_u32::(hash_hex.len() as u32)?; response.write_all(hash_hex.as_bytes())?; self.wrap_sftp_packet(&response) } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } } @@ -1169,37 +1161,34 @@ impl SftpHandler { let full_path = self.resolve_path(&path)?; - match File::open(&full_path) { + let flags = OpenFlags::new().read(); + match self.vfs.open_file(&full_path, &flags) { Ok(mut file) => { - file.seek(SeekFrom::Start(offset))?; + file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?; let mut buffer = vec![0u8; actual_length as usize]; - file.read_exact(&mut buffer)?; + file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?; - // 计算SHA256哈希(使用sha2 crate) use sha2::{Sha256, Digest}; let mut hasher = Sha256::new(); hasher.update(&buffer); let hash = hasher.finalize(); let hash_hex = format!("{:x}", hash); - // 构建响应 let mut response = Vec::new(); response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; response.write_u32::(id)?; - // hash-algorithm (SSH string) response.write_u32::(6)?; response.write_all("sha256".as_bytes())?; - // hash-value (SSH string) response.write_u32::(hash_hex.len() as u32)?; response.write_all(hash_hex.as_bytes())?; self.wrap_sftp_packet(&response) } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } } @@ -1219,21 +1208,20 @@ impl SftpHandler { let full_path = self.resolve_path(&path)?; - match File::open(&full_path) { + let flags = OpenFlags::new().read(); + match self.vfs.open_file(&full_path, &flags) { Ok(mut file) => { - file.seek(SeekFrom::Start(offset))?; + file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?; let mut buffer = vec![0u8; actual_length as usize]; - file.read_exact(&mut buffer)?; + file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?; - // 计算SHA384哈希 use sha2::{Sha384, Digest}; let mut hasher = Sha384::new(); hasher.update(&buffer); let hash = hasher.finalize(); let hash_hex = format!("{:x}", hash); - // 构建响应 let mut response = Vec::new(); response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; response.write_u32::(id)?; @@ -1247,7 +1235,7 @@ impl SftpHandler { self.wrap_sftp_packet(&response) } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } } @@ -1267,21 +1255,20 @@ impl SftpHandler { let full_path = self.resolve_path(&path)?; - match File::open(&full_path) { + let flags = OpenFlags::new().read(); + match self.vfs.open_file(&full_path, &flags) { Ok(mut file) => { - file.seek(SeekFrom::Start(offset))?; + file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?; let mut buffer = vec![0u8; actual_length as usize]; - file.read_exact(&mut buffer)?; + file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?; - // 计算SHA512哈希 use sha2::{Sha512, Digest}; let mut hasher = Sha512::new(); hasher.update(&buffer); let hash = hasher.finalize(); let hash_hex = format!("{:x}", hash); - // 构建响应 let mut response = Vec::new(); response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; response.write_u32::(id)?; @@ -1295,7 +1282,7 @@ impl SftpHandler { self.wrap_sftp_packet(&response) } Err(e) => { - self.build_status_from_io_error(id, &e) + self.build_status_from_vfs_error(id, &e) } } } @@ -1303,30 +1290,28 @@ impl SftpHandler { /// 处理check-file@openssh.com扩展(Phase 12:文件检查) fn handle_check_file(&self, cursor: &mut std::io::Cursor<&[u8]>, id: u32) -> Result> { let path = read_sftp_string(cursor)?; - let check_flags = cursor.read_u32::()?; + let _check_flags = cursor.read_u32::()?; - info!("check-file: path={}, flags={:#x}", path, check_flags); + info!("check-file: path={}", path); let full_path = self.resolve_path(&path)?; - match fs::metadata(&full_path) { - Ok(metadata) => { - // 构建响应 + match self.vfs.stat(&full_path) { + Ok(stat) => { let mut response = Vec::new(); response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; response.write_u32::(id)?; - // 返回文件存在和基本信息 - response.write_u32::(1)?; // result: 1 = file exists + response.write_u32::(1)?; - let msg = format!("File exists, size: {}", metadata.len()); + let msg = format!("File exists, size: {}", stat.size); response.write_u32::(msg.len() as u32)?; response.write_all(msg.as_bytes())?; self.wrap_sftp_packet(&response) } Err(e) => { - self.build_status_response(id, SftpStatus::SSH_FX_NO_SUCH_FILE, &format!("Check file error: {}", e)) + self.build_status_from_vfs_error(id, &e) } } } @@ -1339,11 +1324,8 @@ impl SftpHandler { let write_handle_bytes = read_sftp_string_bytes(cursor)?; let write_offset = cursor.read_u64::()?; - info!("copy-data: read_handle={}, read_offset={}, read_length={}, write_handle={}, write_offset={}", - u32::from_be_bytes([read_handle_bytes[0], read_handle_bytes[1], read_handle_bytes[2], read_handle_bytes[3]]), - read_offset, read_length, - u32::from_be_bytes([write_handle_bytes[0], write_handle_bytes[1], write_handle_bytes[2], write_handle_bytes[3]]), - write_offset); + info!("copy-data: read_handle={:?}, read_offset={}, read_length={}, write_handle={:?}, write_offset={}", + read_handle_bytes, read_offset, read_length, write_handle_bytes, write_offset); let actual_length = std::cmp::min(read_length, Self::MAX_XFER_SIZE as u64); if actual_length < read_length { @@ -1353,52 +1335,44 @@ impl SftpHandler { let read_handle_id = u32::from_be_bytes([read_handle_bytes[0], read_handle_bytes[1], read_handle_bytes[2], read_handle_bytes[3]]); let write_handle_id = u32::from_be_bytes([write_handle_bytes[0], write_handle_bytes[1], write_handle_bytes[2], write_handle_bytes[3]]); - // 获取read handle的path(不可变引用) let read_path = if let Some(read_handle) = self.handles.get(&read_handle_id) { read_handle.path.clone() } else { return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid read handle"); }; - // 获取write handle的path(不可变引用) let write_path = if let Some(write_handle) = self.handles.get(&write_handle_id) { write_handle.path.clone() } else { return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid write handle"); }; - // 从read_path读取数据 - match File::open(&read_path) { - Ok(mut read_file) => { - read_file.seek(SeekFrom::Start(read_offset))?; - let mut buffer = vec![0u8; actual_length as usize]; - read_file.read_exact(&mut buffer)?; - - // 写入到write_path - match OpenOptions::new().write(true).open(&write_path) { - Ok(mut write_file) => { - write_file.seek(SeekFrom::Start(write_offset))?; - write_file.write_all(&buffer)?; - - // 构建响应 - let mut response = Vec::new(); - response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; - response.write_u32::(id)?; - - // 返回复制的字节数 - response.write_u64::(actual_length)?; - - self.wrap_sftp_packet(&response) - } - Err(e) => { - self.build_status_from_io_error(id, &e) - } - } - } - Err(e) => { - self.build_status_from_io_error(id, &e) - } - } + let read_flags = OpenFlags::new().read(); + let write_flags = OpenFlags::new().write(); + + let mut read_file = match self.vfs.open_file(&read_path, &read_flags) { + Ok(f) => f, + Err(e) => return self.build_status_from_vfs_error(id, &e), + }; + let mut write_file = match self.vfs.open_file(&write_path, &write_flags) { + Ok(f) => f, + Err(e) => return self.build_status_from_vfs_error(id, &e), + }; + + read_file.seek(SeekFrom::Start(read_offset)).map_err(|e| anyhow!("Seek error: {}", e))?; + let mut buffer = vec![0u8; actual_length as usize]; + read_file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?; + + write_file.seek(SeekFrom::Start(write_offset)).map_err(|e| anyhow!("Seek error: {}", e))?; + write_file.write_all(&buffer).map_err(|e| anyhow!("Write error: {}", e))?; + write_file.flush().ok(); + + let mut response = Vec::new(); + response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; + response.write_u32::(id)?; + response.write_u64::(actual_length)?; + + self.wrap_sftp_packet(&response) } /// 解析路径(安全性检查,参考OpenSSH sftp-server.c: path_resolve()) @@ -1608,6 +1582,24 @@ impl SftpHandler { let msg = format!("{}", err); self.build_status_response(id, status, &msg) } + + /// 根据 VfsError 构建状态响应(自动映射错误类型) + fn build_status_from_vfs_error(&self, id: u32, err: &crate::vfs::VfsError) -> Result> { + use crate::vfs::VfsError; + let status = match err { + VfsError::NotFound(_) => SftpStatus::SSH_FX_NO_SUCH_FILE, + VfsError::PermissionDenied(_) => SftpStatus::SSH_FX_PERMISSION_DENIED, + VfsError::AlreadyExists(_) => SftpStatus::SSH_FX_FAILURE, + VfsError::NotEmpty(_) => SftpStatus::SSH_FX_FAILURE, + VfsError::NotADirectory(_) => SftpStatus::SSH_FX_FAILURE, + VfsError::IsADirectory(_) => SftpStatus::SSH_FX_FAILURE, + VfsError::Unsupported(_) => SftpStatus::SSH_FX_OP_UNSUPPORTED, + VfsError::Io(_) => SftpStatus::SSH_FX_FAILURE, + VfsError::UnexpectedEof => SftpStatus::SSH_FX_EOF, + }; + let msg = format!("{}", err); + self.build_status_response(id, status, &msg) + } } /// 读取SFTP字符串(参考draft-ietf-secsh-filexfer-02.txt) @@ -1665,8 +1657,14 @@ fn read_sftp_attrs(reader: &mut R) -> Result { #[cfg(test)] mod tests { use super::*; + use crate::vfs::local_fs::LocalFs; + use std::fs::File; use tempfile::TempDir; + fn make_handler(root_dir: PathBuf) -> SftpHandler { + SftpHandler::new(root_dir, Box::new(LocalFs::new()), 32768) + } + #[test] fn test_sftp_packet_type_conversion() { assert_eq!(SftpPacketType::try_from(1).unwrap(), SftpPacketType::SSH_FXP_INIT); @@ -1677,7 +1675,7 @@ mod tests { #[test] fn test_sftp_handler_creation() { let temp_dir = TempDir::new().unwrap(); - let handler = SftpHandler::new(temp_dir.path().to_path_buf(), 32768); + let handler = make_handler(temp_dir.path().to_path_buf()); assert_eq!(handler.next_handle_id, 0); } @@ -1697,7 +1695,7 @@ mod tests { #[test] fn test_sftp_handle_init() { let temp_dir = TempDir::new().unwrap(); - let mut handler = SftpHandler::new(temp_dir.path().to_path_buf(), 32768); + let mut handler = make_handler(temp_dir.path().to_path_buf()); let init_packet = vec![1, 0, 0, 0, 3]; let response = handler.handle_request(&init_packet).unwrap(); diff --git a/markbase-core/src/vfs/local_fs.rs b/markbase-core/src/vfs/local_fs.rs new file mode 100644 index 0000000..e6125a4 --- /dev/null +++ b/markbase-core/src/vfs/local_fs.rs @@ -0,0 +1,212 @@ +use super::util; +use super::open_flags::OpenFlags; +use super::{VfsBackend, VfsDirEntry, VfsError, VfsFile, VfsStat}; +use std::fs::{self, File, OpenOptions}; +use std::io::{Read, Seek, SeekFrom, Write}; +use std::path::{Path, PathBuf}; +use std::os::unix::fs::{MetadataExt, PermissionsExt}; + +/// 本地文件系统实现(直接包装 std::fs,不做路径解析) +/// 路径解析由上层(SftpHandler)负责 +pub struct LocalFs; + +impl LocalFs { + pub fn new() -> Self { + Self + } +} + +struct LocalFile { + file: File, +} + +impl VfsFile for LocalFile { + fn read(&mut self, buf: &mut [u8]) -> Result { + self.file.read(buf).map_err(|e| VfsError::Io(e.to_string())) + } + + fn write(&mut self, buf: &[u8]) -> Result { + self.file.write(buf).map_err(|e| VfsError::Io(e.to_string())) + } + + fn seek(&mut self, pos: SeekFrom) -> Result { + self.file.seek(pos).map_err(|e| VfsError::Io(e.to_string())) + } + + fn flush(&mut self) -> Result<(), VfsError> { + self.file.flush().map_err(|e| VfsError::Io(e.to_string())) + } + + fn stat(&mut self) -> Result { + let meta = self.file.metadata().map_err(|e| VfsError::Io(e.to_string()))?; + Ok(util::stat_from_metadata(&meta, false)) + } + + fn set_len(&mut self, size: u64) -> Result<(), VfsError> { + self.file.set_len(size).map_err(|e| VfsError::Io(e.to_string())) + } +} + +impl VfsBackend for LocalFs { + fn read_dir(&self, path: &Path) -> Result, VfsError> { + let dir = fs::read_dir(path).map_err(|e| util::map_io_error(path, e))?; + + let mut entries = Vec::new(); + for entry in dir { + let entry = entry.map_err(|e| util::map_io_error(path, e))?; + let name = entry.file_name().to_string_lossy().to_string(); + let file_type = entry.file_type().map_err(|e| util::map_io_error(path, e))?; + let meta = entry.metadata().map_err(|e| util::map_io_error(path, e))?; + let stat = util::stat_from_metadata(&meta, file_type.is_symlink()); + let long_name = util::build_long_name(&stat, &name); + + entries.push(VfsDirEntry { + name, + long_name, + stat, + }); + } + + entries.sort_by(|a, b| a.name.cmp(&b.name)); + Ok(entries) + } + + fn open_file(&self, path: &Path, flags: &OpenFlags) -> Result, VfsError> { + let mut opts = OpenOptions::new(); + opts.read(flags.read); + opts.write(flags.write); + opts.append(flags.append); + opts.create(flags.create); + opts.truncate(flags.truncate); + opts.create_new(flags.exclusive); + + let file = opts.open(path).map_err(|e| util::map_io_error(path, e))?; + + #[cfg(unix)] + if flags.create && !flags.exclusive { + if let Ok(meta) = file.metadata() { + if flags.mode != 0 && meta.permissions().mode() != flags.mode { + fs::set_permissions(path, std::fs::Permissions::from_mode(flags.mode)) + .ok(); + } + } + } + + Ok(Box::new(LocalFile { file })) + } + + fn stat(&self, path: &Path) -> Result { + let meta = fs::metadata(path).map_err(|e| util::map_io_error(path, e))?; + Ok(util::stat_from_metadata(&meta, false)) + } + + fn lstat(&self, path: &Path) -> Result { + let meta = fs::symlink_metadata(path).map_err(|e| util::map_io_error(path, e))?; + let is_symlink = path.is_symlink() || meta.file_type().is_symlink(); + Ok(util::stat_from_metadata(&meta, is_symlink)) + } + + fn create_dir(&self, path: &Path, mode: u32) -> Result<(), VfsError> { + fs::create_dir(path).map_err(|e| util::map_io_error(path, e))?; + + #[cfg(unix)] + { + fs::set_permissions(path, std::fs::Permissions::from_mode(mode)) + .map_err(|e| util::map_io_error(path, e))?; + } + + Ok(()) + } + + fn create_dir_all(&self, path: &Path, mode: u32) -> Result<(), VfsError> { + fs::create_dir_all(path).map_err(|e| util::map_io_error(path, e))?; + + #[cfg(unix)] + { + if mode != 0 { + fs::set_permissions(path, std::fs::Permissions::from_mode(mode)) + .map_err(|e| util::map_io_error(path, e))?; + } + } + + Ok(()) + } + + fn remove_dir(&self, path: &Path) -> Result<(), VfsError> { + fs::remove_dir(path).map_err(|e| util::map_io_error(path, e)) + } + + fn remove_file(&self, path: &Path) -> Result<(), VfsError> { + fs::remove_file(path).map_err(|e| util::map_io_error(path, e)) + } + + fn rename(&self, from: &Path, to: &Path) -> Result<(), VfsError> { + fs::rename(from, to).map_err(|e| util::map_io_error(from, e)) + } + + fn set_stat(&self, path: &Path, stat: &VfsStat) -> Result<(), VfsError> { + #[cfg(unix)] + { + if stat.mode != 0 { + fs::set_permissions(path, std::fs::Permissions::from_mode(stat.mode)) + .map_err(|e| util::map_io_error(path, e))?; + } + } + + if let (Some(atime), Some(mtime)) = ( + stat.atime.duration_since(std::time::UNIX_EPOCH).ok(), + stat.mtime.duration_since(std::time::UNIX_EPOCH).ok(), + ) { + filetime::set_file_times(path, + filetime::FileTime::from_unix_time(atime.as_secs() as i64, 0), + filetime::FileTime::from_unix_time(mtime.as_secs() as i64, 0), + ).map_err(|e| util::map_io_error(path, e))?; + } + + Ok(()) + } + + fn read_link(&self, path: &Path) -> Result { + let target = fs::read_link(path).map_err(|e| util::map_io_error(path, e))?; + Ok(target) + } + + fn create_symlink(&self, target: &Path, link: &Path) -> Result<(), VfsError> { + #[cfg(unix)] + { + std::os::unix::fs::symlink(target, link) + .map_err(|e| util::map_io_error(link, e))?; + } + + #[cfg(not(unix))] + { + std::os::windows::fs::symlink_file(target, link) + .map_err(|e| util::map_io_error(link, e))?; + } + + Ok(()) + } + + fn real_path(&self, path: &Path) -> Result { + let canonical = path.canonicalize().map_err(|e| util::map_io_error(path, e))?; + Ok(canonical) + } + + fn exists(&self, path: &Path) -> bool { + path.exists() + } + + fn hard_link(&self, original: &Path, link: &Path) -> Result<(), VfsError> { + #[cfg(unix)] + { + fs::hard_link(original, link).map_err(|e| util::map_io_error(original, e))?; + } + + #[cfg(not(unix))] + { + return Err(VfsError::Unsupported("hard_link not supported on non-Unix systems".to_string())); + } + + Ok(()) + } +} diff --git a/markbase-core/src/vfs/mod.rs b/markbase-core/src/vfs/mod.rs new file mode 100644 index 0000000..861276f --- /dev/null +++ b/markbase-core/src/vfs/mod.rs @@ -0,0 +1,160 @@ +pub mod open_flags; +pub mod local_fs; +pub mod util; + +use std::path::{Path, PathBuf}; +use std::time::SystemTime; + +/// VFS 错误类型 +#[derive(Debug, Clone)] +pub enum VfsError { + NotFound(String), + PermissionDenied(String), + AlreadyExists(String), + NotEmpty(String), + NotADirectory(String), + IsADirectory(String), + Unsupported(String), + Io(String), + UnexpectedEof, +} + +impl std::fmt::Display for VfsError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VfsError::NotFound(p) => write!(f, "No such file or directory: {}", p), + VfsError::PermissionDenied(p) => write!(f, "Permission denied: {}", p), + VfsError::AlreadyExists(p) => write!(f, "File already exists: {}", p), + VfsError::NotEmpty(p) => write!(f, "Directory not empty: {}", p), + VfsError::NotADirectory(p) => write!(f, "Not a directory: {}", p), + VfsError::IsADirectory(p) => write!(f, "Is a directory: {}", p), + VfsError::Unsupported(msg) => write!(f, "Unsupported: {}", msg), + VfsError::Io(msg) => write!(f, "IO error: {}", msg), + VfsError::UnexpectedEof => write!(f, "Unexpected end of file"), + } + } +} + +impl std::error::Error for VfsError {} + +/// 文件统计信息(类似 libc::stat) +#[derive(Debug, Clone)] +pub struct VfsStat { + pub size: u64, + pub mode: u32, + pub uid: u32, + pub gid: u32, + pub atime: SystemTime, + pub mtime: SystemTime, + pub is_dir: bool, + pub is_symlink: bool, +} + +impl VfsStat { + pub fn new() -> Self { + Self { + size: 0, + mode: 0, + uid: 0, + gid: 0, + atime: SystemTime::UNIX_EPOCH, + mtime: SystemTime::UNIX_EPOCH, + is_dir: false, + is_symlink: false, + } + } +} + +impl Default for VfsStat { + fn default() -> Self { + Self::new() + } +} + +/// 目录条目 +#[derive(Debug, Clone)] +pub struct VfsDirEntry { + pub name: String, + pub long_name: String, + pub stat: VfsStat, +} + +/// 打开文件的抽象 +pub trait VfsFile { + fn read(&mut self, buf: &mut [u8]) -> Result; + fn write(&mut self, buf: &[u8]) -> Result; + fn seek(&mut self, pos: std::io::SeekFrom) -> Result; + fn flush(&mut self) -> Result<(), VfsError>; + fn stat(&mut self) -> Result; + fn set_len(&mut self, size: u64) -> Result<(), VfsError>; + + /// Write all bytes (convenience, default loops write() until done) + fn write_all(&mut self, mut buf: &[u8]) -> Result<(), VfsError> { + while !buf.is_empty() { + let n = self.write(buf)?; + if n == 0 { + return Err(VfsError::Io("write returned 0".to_string())); + } + buf = &buf[n..]; + } + Ok(()) + } + + /// Read exactly `buf.len()` bytes (convenience, loops read() until done) + fn read_exact(&mut self, mut buf: &mut [u8]) -> Result<(), VfsError> { + while !buf.is_empty() { + let n = self.read(buf)?; + if n == 0 { + return Err(VfsError::UnexpectedEof); + } + buf = &mut buf[n..]; + } + Ok(()) + } +} + +/// VFS 后端 trait(所有文件系统操作) +pub trait VfsBackend: Send { + /// 读取目录内容 + fn read_dir(&self, path: &Path) -> Result, VfsError>; + + /// 打开文件(读/写) + fn open_file(&self, path: &Path, flags: &open_flags::OpenFlags) -> Result, VfsError>; + + /// 获取文件/目录元数据 + fn stat(&self, path: &Path) -> Result; + fn lstat(&self, path: &Path) -> Result; + + /// 创建目录 + fn create_dir(&self, path: &Path, mode: u32) -> Result<(), VfsError>; + + /// 递归创建目录 + fn create_dir_all(&self, path: &Path, mode: u32) -> Result<(), VfsError>; + + /// 删除空目录 + fn remove_dir(&self, path: &Path) -> Result<(), VfsError>; + + /// 删除文件 + fn remove_file(&self, path: &Path) -> Result<(), VfsError>; + + /// 重命名 + fn rename(&self, from: &Path, to: &Path) -> Result<(), VfsError>; + + /// 设置文件属性 + fn set_stat(&self, path: &Path, stat: &VfsStat) -> Result<(), VfsError>; + + /// 读取符号链接目标 + fn read_link(&self, path: &Path) -> Result; + + /// 创建符号链接 + fn create_symlink(&self, target: &Path, link: &Path) -> Result<(), VfsError>; + + /// 规范化路径 + fn real_path(&self, path: &Path) -> Result; + + /// 检查路径是否存在 + fn exists(&self, path: &Path) -> bool; + + /// 创建硬链接 + fn hard_link(&self, original: &Path, link: &Path) -> Result<(), VfsError>; +} diff --git a/markbase-core/src/vfs/open_flags.rs b/markbase-core/src/vfs/open_flags.rs new file mode 100644 index 0000000..d3e73af --- /dev/null +++ b/markbase-core/src/vfs/open_flags.rs @@ -0,0 +1,75 @@ +/// 文件打开标志(映射 SSH_FXF_* 和 POSIX open flags) +#[derive(Debug, Clone, Default)] +pub struct OpenFlags { + pub read: bool, + pub write: bool, + pub append: bool, + pub create: bool, + pub truncate: bool, + pub exclusive: bool, + pub mode: u32, +} + +impl OpenFlags { + pub fn new() -> Self { + Self::default() + } + + pub fn read(mut self) -> Self { + self.read = true; + self + } + + pub fn write(mut self) -> Self { + self.write = true; + self + } + + pub fn append(mut self) -> Self { + self.append = true; + self.write = true; + self + } + + pub fn create(mut self) -> Self { + self.create = true; + self.write = true; + self + } + + pub fn truncate(mut self) -> Self { + self.truncate = true; + self.write = true; + self + } + + pub fn exclusive(mut self) -> Self { + self.exclusive = true; + self + } + + pub fn mode(mut self, mode: u32) -> Self { + self.mode = mode; + self + } + + /// 从 SFTP 的 pflags(SSH_FXF_*)构建 OpenFlags + pub fn from_sftp_pflags(pflags: u32) -> Self { + let read = pflags & 0x00000001 != 0; + let write = pflags & 0x00000002 != 0; + let append = pflags & 0x00000004 != 0; + let create = pflags & 0x00000008 != 0; + let truncate = pflags & 0x00000010 != 0; + let exclusive = pflags & 0x00000020 != 0; + + Self { + read, + write, + append, + create, + truncate, + exclusive, + mode: 0o644, + } + } +} diff --git a/markbase-core/src/vfs/util.rs b/markbase-core/src/vfs/util.rs new file mode 100644 index 0000000..0df3b76 --- /dev/null +++ b/markbase-core/src/vfs/util.rs @@ -0,0 +1,105 @@ +use super::{VfsError, VfsStat}; +use chrono::Datelike; +use std::os::unix::fs::PermissionsExt; +use std::path::Path; + +/// 从 std::io::ErrorKind 映射 VfsError +pub fn map_io_error(path: &Path, e: std::io::Error) -> VfsError { + match e.kind() { + std::io::ErrorKind::NotFound => VfsError::NotFound(path.display().to_string()), + std::io::ErrorKind::PermissionDenied => VfsError::PermissionDenied(path.display().to_string()), + std::io::ErrorKind::AlreadyExists => VfsError::AlreadyExists(path.display().to_string()), + std::io::ErrorKind::DirectoryNotEmpty => VfsError::NotEmpty(path.display().to_string()), + std::io::ErrorKind::NotADirectory => VfsError::NotADirectory(path.display().to_string()), + std::io::ErrorKind::IsADirectory => VfsError::IsADirectory(path.display().to_string()), + std::io::ErrorKind::UnexpectedEof => VfsError::UnexpectedEof, + other => VfsError::Io(format!("{}: {}", other, path.display())), + } +} + +/// 从 std::fs::Metadata 构建 VfsStat +pub fn stat_from_metadata(meta: &std::fs::Metadata, is_symlink: bool) -> VfsStat { + #[cfg(unix)] + use std::os::unix::fs::MetadataExt; + + let mut stat = VfsStat::new(); + stat.size = meta.len(); + stat.is_dir = meta.is_dir(); + stat.is_symlink = is_symlink; + + #[cfg(unix)] + { + stat.mode = meta.permissions().mode(); + stat.uid = meta.uid(); + stat.gid = meta.gid(); + } + + #[cfg(not(unix))] + { + stat.mode = if meta.is_dir() { 0o40755 } else { 0o100644 }; + } + + if let Ok(t) = meta.accessed() { + stat.atime = t; + } + if let Ok(t) = meta.modified() { + stat.mtime = t; + } + + stat +} + +/// 构建目录条目的 long_name(类似 ls -l 格式) +pub fn build_long_name(stat: &VfsStat, name: &str) -> String { + let file_type = if stat.is_dir { 'd' } else { '-' }; + let perms = format_permissions(stat.mode & 0o777); + let link_count = if stat.is_dir { 3 } else { 1 }; + let size = stat.size; + let mtime = match stat.mtime.duration_since(std::time::UNIX_EPOCH) { + Ok(d) => { + let secs = d.as_secs(); + format_timestamp(secs) + } + Err(_) => "Jan 1 1970".to_string(), + }; + + format!( + "{}{} {} {} {} {} {} {}", + file_type, perms, + link_count, + stat.uid, + stat.gid, + size, + mtime, + name + ) +} + +fn format_permissions(mode: u32) -> String { + let rwx = |n: u32| -> String { + let r = if n & 4 != 0 { 'r' } else { '-' }; + let w = if n & 2 != 0 { 'w' } else { '-' }; + let x = if n & 1 != 0 { 'x' } else { '-' }; + format!("{}{}{}", r, w, x) + }; + + format!( + "{}{}{}", + rwx((mode >> 6) & 7), + rwx((mode >> 3) & 7), + rwx(mode & 7) + ) +} + +fn format_timestamp(secs: u64) -> String { + let datetime = match chrono::DateTime::from_timestamp(secs as i64, 0) { + Some(dt) => dt, + None => return "Jan 1 1970".to_string(), + }; + let now = chrono::Utc::now(); + if datetime.year() == now.year() { + datetime.format("%b %e %H:%M").to_string() + } else { + datetime.format("%b %e %Y").to_string() + } +}