diff --git a/markbase-core/src/cli/interface/webdav.rs b/markbase-core/src/cli/interface/webdav.rs index 484a7f0..2e22a7f 100644 --- a/markbase-core/src/cli/interface/webdav.rs +++ b/markbase-core/src/cli/interface/webdav.rs @@ -1,6 +1,14 @@ -use axum::{extract::Request, response::IntoResponse, Extension}; +use axum::{ + body::Body, + extract::Request, + http::{HeaderValue, StatusCode}, + middleware, + response::IntoResponse, + Extension, +}; +use base64::Engine as _; use clap::Subcommand; -use dav_server::{fakels::FakeLs, DavHandler}; +use dav_server::DavHandler; use std::path::PathBuf; #[derive(Subcommand)] @@ -17,7 +25,14 @@ pub enum WebdavCommand { pub async fn handle_webdav_command(cmd: WebdavCommand) -> anyhow::Result<()> { match cmd { WebdavCommand::Start { port, user } => { - let home_dir = PathBuf::from("/Users/accusys/momentry/var/sftpgo/data").join(&user); + // Parse username and optional password (format: "name:password") + let username = user.split(':').next().unwrap_or(&user).to_string(); + let password = user.split(':').nth(1).map(|s| s.to_string()); + + let default_root = format!("/Users/accusys/momentry/var/sftpgo/data/{}", username); + let home_dir = PathBuf::from( + std::env::var("MB_WEBDAV_ROOT").unwrap_or(default_root), + ); if !home_dir.exists() { return Err(anyhow::anyhow!( @@ -27,12 +42,15 @@ pub async fn handle_webdav_command(cmd: WebdavCommand) -> anyhow::Result<()> { } println!("=== MarkBase WebDAV Server (VFS) ==="); - println!("User: {}", user); + println!("User: {}", username); + if password.is_some() { + println!("Auth: password protected"); + } println!("Port: {}", port); println!("Home: {}", home_dir.display()); println!(); - run_webdav_server(port, home_dir, user).await?; + run_webdav_server(port, home_dir, username, password).await?; } } Ok(()) @@ -42,6 +60,7 @@ async fn run_webdav_server( port: u16, home_dir: PathBuf, user: String, + password: Option, ) -> anyhow::Result<()> { use axum::{routing::any, Router}; use tokio::net::TcpListener; @@ -49,19 +68,56 @@ async fn run_webdav_server( let vfs = Box::new(crate::vfs::local_fs::LocalFs::new()); let upload_hook = None; - let dav_fs = crate::webdav::VfsDavFs::new(vfs, home_dir, upload_hook, user); + let dav_handler = crate::webdav::create_webdav_handler(vfs, home_dir, upload_hook, user.clone()); - let dav_handler = DavHandler::builder() - .filesystem(dav_fs) - .locksystem(FakeLs::new()) - .strip_prefix("/webdav") - .build_handler(); + async fn webdav_auth_middleware( + Extension(expected): Extension, + req: Request, + next: middleware::Next, + ) -> impl IntoResponse { + let auth = req + .headers() + .get("Authorization") + .and_then(|v| v.to_str().ok()) + .filter(|v| v.starts_with("Basic ")) + .and_then(|v| { + let encoded = &v[6..]; + let decoded = base64::engine::general_purpose::STANDARD + .decode(encoded) + .ok()?; + let creds = String::from_utf8(decoded).ok()?; + let colon = creds.find(':')?; + Some((creds[..colon].to_string(), creds[colon + 1..].to_string())) + }); + + let valid = auth.is_some_and(|(u, p)| { + u == expected.username && expected.password.as_ref().map_or(true, |exp| p == *exp) + }); + + if !valid { + return ( + StatusCode::UNAUTHORIZED, + [( + "WWW-Authenticate", + HeaderValue::from_static("Basic realm=\"MarkBase WebDAV\""), + )], + Body::from("Unauthorized"), + ).into_response(); + } + + next.run(req).await + } let app = Router::new() .route("/webdav", any(handle_dav)) .route("/webdav/", any(handle_dav)) .route("/webdav/*path", any(handle_dav)) - .layer(Extension(dav_handler)); + .layer(Extension(dav_handler)) + .layer(Extension(crate::webdav::WebdavCredentials { + username: user, + password, + })) + .layer(middleware::from_fn(webdav_auth_middleware)); let addr = format!("127.0.0.1:{}", port); let listener = TcpListener::bind(&addr).await?; diff --git a/markbase-core/src/config/mod.rs b/markbase-core/src/config/mod.rs index 8399ac7..0ce4566 100644 --- a/markbase-core/src/config/mod.rs +++ b/markbase-core/src/config/mod.rs @@ -29,6 +29,7 @@ pub struct WebSection { pub log_level: String, pub auth_db_path: String, pub users_db_dir: String, + pub webdav_root: String, } impl Default for WebSection { @@ -39,6 +40,7 @@ impl Default for WebSection { log_level: "info".to_string(), auth_db_path: "data/auth.sqlite".to_string(), users_db_dir: "data/users".to_string(), + webdav_root: "/Users/accusys/momentry/var/sftpgo/data/demo".to_string(), } } } diff --git a/markbase-core/src/config/web.rs b/markbase-core/src/config/web.rs index 89860f2..0d05107 100644 --- a/markbase-core/src/config/web.rs +++ b/markbase-core/src/config/web.rs @@ -18,6 +18,7 @@ pub struct ServerConfig { pub log_level: String, pub auth_db_path: String, pub users_db_dir: String, + pub webdav_root: String, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -87,6 +88,7 @@ impl MarkBaseConfig { log_level: "info".to_string(), auth_db_path: "data/auth.sqlite".to_string(), users_db_dir: "data/users".to_string(), + webdav_root: "/Users/accusys/momentry/var/sftpgo/data/demo".to_string(), }, postgresql: PostgreSQLConfig { host: "127.0.0.1".to_string(), @@ -138,6 +140,9 @@ impl MarkBaseConfig { if let Ok(log_level) = std::env::var("MB_LOG_LEVEL") { self.server.log_level = log_level; } + if let Ok(webdav_root) = std::env::var("MB_WEBDAV_ROOT") { + self.server.webdav_root = webdav_root; + } if let Ok(pg_host) = std::env::var("PG_HOST") { self.postgresql.host = pg_host; @@ -176,6 +181,7 @@ impl MarkBaseConfig { "server.log_level" => Some(self.server.log_level.clone()), "server.auth_db_path" => Some(self.server.auth_db_path.clone()), "server.users_db_dir" => Some(self.server.users_db_dir.clone()), + "server.webdav_root" => Some(self.server.webdav_root.clone()), "postgresql.host" => Some(self.postgresql.host.clone()), "postgresql.port" => Some(self.postgresql.port.to_string()), @@ -221,6 +227,7 @@ impl MarkBaseConfig { "server.log_level" => self.server.log_level = value.to_string(), "server.auth_db_path" => self.server.auth_db_path = value.to_string(), "server.users_db_dir" => self.server.users_db_dir = value.to_string(), + "server.webdav_root" => self.server.webdav_root = value.to_string(), "postgresql.host" => self.postgresql.host = value.to_string(), "postgresql.port" => self.postgresql.port = value.parse()?, diff --git a/markbase-core/src/s3_auth.rs b/markbase-core/src/s3_auth.rs index 17d7c8c..291744f 100644 --- a/markbase-core/src/s3_auth.rs +++ b/markbase-core/src/s3_auth.rs @@ -6,19 +6,14 @@ use std::fs; type HmacSha256 = Hmac; pub fn verify_signature(headers: HeaderMap, method: &str, path: &str) -> bool { - // Load S3 config and check require_auth flag let config = crate::s3_config::S3Config::load_default().unwrap_or_default(); - - // Merge environment variables (allows override via MB_S3_REQUIRE_AUTH) let mut config = config; config.merge_env(); if !config.s3.require_auth { - // Development mode: allow access without authentication return true; } - // 生产模式:必须提供Authorization header let auth_header = headers .get("Authorization") .and_then(|v| v.to_str().ok()) @@ -28,41 +23,55 @@ pub fn verify_signature(headers: HeaderMap, method: &str, path: &str) -> bool { return false; } - // 2. Parse Credential let credential = extract_credential(auth_header); if credential.is_none() { return false; } - let credential = credential.unwrap(); - // 3. Get secret_key from S3AccessKey database let secret_key = get_secret_key(&credential.access_key); if secret_key.is_none() { return false; } - let secret_key = secret_key.unwrap(); - // 4. Calculate Signature - let calculated_signature = calculate_signature( - headers.clone(), + let x_amz_date = headers + .get("X-Amz-Date") + .and_then(|v| v.to_str().ok()) + .unwrap_or(&credential.date); + + let signed_headers = extract_signed_headers(auth_header); + if signed_headers.is_none() { + return false; + } + let signed_headers = signed_headers.unwrap(); + + let payload_hash = get_payload_hash(&headers); + + let canonical_request = create_canonical_request( + &headers, method, path, - &credential.access_key, - &secret_key, - &credential.region, - &credential.service, - &credential.date, + &signed_headers, + &payload_hash, ); - // 5. Extract Signature from header + let string_to_sign = create_string_to_sign( + x_amz_date, + &credential.region, + &credential.service, + &canonical_request, + ); + + let signing_key = calculate_signing_key(&secret_key, &credential.date, &credential.region, &credential.service); + + let calculated_signature = hmac_sha256_hex(&signing_key, &string_to_sign); + let provided_signature = extract_signature(auth_header); if provided_signature.is_none() { return false; } - // 6. Compare signatures calculated_signature == provided_signature.unwrap() } @@ -74,14 +83,11 @@ struct Credential { } fn extract_credential(auth_header: &str) -> Option { - let parts: Vec<&str> = auth_header.split_whitespace().collect(); - if parts.len() < 2 { - return None; - } + let credential_part = auth_header + .split(',') + .find(|p| p.trim().starts_with("Credential="))?; - let credential_part = parts.iter().find(|p| p.starts_with("Credential="))?; - - let credential_str = credential_part.strip_prefix("Credential=")?; + let credential_str = credential_part.trim().strip_prefix("Credential=")?; let credential_parts: Vec<&str> = credential_str.split('/').collect(); if credential_parts.len() < 5 { @@ -96,16 +102,24 @@ fn extract_credential(auth_header: &str) -> Option { }) } +fn extract_signed_headers(auth_header: &str) -> Option> { + let signed_headers_part = auth_header + .split(',') + .find(|p| p.trim().starts_with("SignedHeaders="))?; + + let signed_headers_str = signed_headers_part.trim().strip_prefix("SignedHeaders=")?; + Some(signed_headers_str.split(';').map(|s| s.to_lowercase()).collect()) +} + fn extract_signature(auth_header: &str) -> Option { - let parts: Vec<&str> = auth_header.split_whitespace().collect(); + let signature_part = auth_header + .split(',') + .find(|p| p.trim().starts_with("Signature="))?; - let signature_part = parts.iter().find(|p| p.starts_with("Signature="))?; - - Some(signature_part.strip_prefix("Signature=")?.to_string()) + Some(signature_part.trim().strip_prefix("Signature=")?.to_string()) } fn get_secret_key(access_key: &str) -> Option { - // Load S3AccessKey database from data/s3_keys.json let s3_keys_path = "data/s3_keys.json"; let s3_keys_json = fs::read_to_string(s3_keys_path).ok()?; @@ -116,62 +130,97 @@ fn get_secret_key(access_key: &str) -> Option { } let s3_keys: Vec = serde_json::from_str(&s3_keys_json).ok()?; - s3_keys .iter() .find(|k| k.access_key == access_key) .map(|k| k.secret_key.clone()) } -fn calculate_signature( - headers: HeaderMap, - method: &str, - path: &str, - _access_key: &str, - secret_key: &str, - region: &str, - service: &str, - date: &str, -) -> String { - // 1. Create Canonical Request - let canonical_request = create_canonical_request(headers, method, path); - - // 2. Create String to Sign - let string_to_sign = create_string_to_sign(date, region, service, &canonical_request); - - // 3. Calculate Signing Key - let signing_key = calculate_signing_key(secret_key, date, region, service); - - // 4. Calculate Signature - - - hmac_sha256_hex(&signing_key, &string_to_sign) +fn get_payload_hash(headers: &HeaderMap) -> String { + headers + .get("X-Amz-Content-Sha256") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()) + .unwrap_or_else(|| sha256_hex("")) } -fn create_canonical_request(headers: HeaderMap, method: &str, path: &str) -> String { - // Simplified implementation for POC - let host = headers - .get("Host") - .and_then(|v| v.to_str().ok()) - .unwrap_or("localhost:11438"); +fn create_canonical_request( + headers: &HeaderMap, + method: &str, + path: &str, + signed_headers: &[String], + payload_hash: &str, +) -> String { + let canonical_uri = uri_encode(path, false); + + let canonical_query_string = build_canonical_query_string(headers); + + let canonical_headers = build_canonical_headers(headers, signed_headers); + + let signed_headers_str = signed_headers.join(";"); format!( - "{}\n{}\n\nhost:{}\n\nhost\nUNSIGNED-PAYLOAD", - method, path, host + "{}\n{}\n{}\n{}\n{}\n{}", + method, + canonical_uri, + canonical_query_string, + canonical_headers, + signed_headers_str, + payload_hash ) } +fn uri_encode(input: &str, encode_slash: bool) -> String { + input + .chars() + .map(|c| { + if c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.' || c == '~' { + c.to_string() + } else if c == '/' && !encode_slash { + c.to_string() + } else { + format!("%{:02X}", c as u8) + } + }) + .collect() +} + +fn build_canonical_query_string(headers: &HeaderMap) -> String { + // For S3, query string is typically empty for basic operations + // This can be extended for presigned URLs + String::new() +} + +fn build_canonical_headers(headers: &HeaderMap, signed_headers: &[String]) -> String { + signed_headers + .iter() + .map(|h| { + let value = headers + .get(h) + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + format!("{}:{}\n", h, value.trim()) + }) + .collect() +} + fn create_string_to_sign( - date: &str, + amz_date: &str, region: &str, service: &str, canonical_request: &str, ) -> String { let canonical_request_hash = sha256_hex(canonical_request); + let date_stamp = &amz_date[..8]; + format!( - "AWS4-HMAC-SHA256\n{}T000000Z\n{}/{}/{}/aws4_request\n{}", - date, date, region, service, canonical_request_hash + "AWS4-HMAC-SHA256\n{}\n{}/{}/{}/aws4_request\n{}", + amz_date, + date_stamp, + region, + service, + canonical_request_hash ) } @@ -203,7 +252,50 @@ fn sha256_hex(data: &str) -> String { } fn hex_encode(data: &[u8]) -> String { - data.iter() - .map(|b| format!("{:02x}", b)) - .collect::() + data.iter().map(|b| format!("{:02x}", b)).collect() } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_uri_encode() { + assert_eq!(uri_encode("/bucket/key", false), "/bucket/key"); + assert_eq!(uri_encode("/bucket/key", true), "%2Fbucket%2Fkey"); + assert_eq!(uri_encode("test file.txt", false), "test%20file.txt"); + } + + #[test] + fn test_sha256_hex() { + let empty_hash = sha256_hex(""); + assert_eq!( + empty_hash, + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + ); + } + + #[test] + fn test_calculate_signing_key() { + let key = calculate_signing_key("secret", "20260621", "us-east-1", "s3"); + assert_eq!(key.len(), 32); + } + + #[test] + fn test_create_canonical_request() { + let mut headers = HeaderMap::new(); + headers.insert("Host", "localhost:11438".parse().unwrap()); + let signed_headers = vec!["host".to_string()]; + + let canonical = create_canonical_request( + &headers, + "GET", + "/bucket/key", + &signed_headers, + "UNSIGNED-PAYLOAD", + ); + + assert!(canonical.contains("GET")); + assert!(canonical.contains("host:localhost:11438")); + } +} \ No newline at end of file