diff --git a/markbase-core/src/sftp/auth.rs b/markbase-core/src/sftp/auth.rs index 643e55d..69b5875 100644 --- a/markbase-core/src/sftp/auth.rs +++ b/markbase-core/src/sftp/auth.rs @@ -1,51 +1,78 @@ -use crate::sync::{AuthDb, PgUser}; -use bcrypt::verify; +use crate::provider::{DataProvider, ProviderError, User}; +use std::sync::Arc; pub struct SftpAuth { - auth_db: AuthDb, + provider: Arc, } impl SftpAuth { - pub fn new(auth_db_path: &str) -> anyhow::Result { - let auth_db = AuthDb::new(auth_db_path)?; - Ok(Self { auth_db }) + pub fn new(provider: Arc) -> Self { + Self { provider } } pub fn verify_password(&self, username: &str, password: &str) -> bool { - match self.auth_db.get_user(username) { - Ok(Some(user)) if user.status == 1 => { - verify(password, &user.password_hash).unwrap_or(false) - } - Ok(Some(_)) => { - log::warn!("User {} is disabled", username); + match self.provider.check_password(username, password) { + Ok(true) => true, + Ok(false) => { + log::warn!("Password verification failed for user {}", username); false } - Ok(None) => { + Err(ProviderError::NotFound(_)) => { log::warn!("User {} not found", username); false } Err(e) => { - log::error!("Failed to get user {}: {}", username, e); + log::error!("Failed to verify password for {}: {}", username, e); false } } } - pub fn get_user(&self, username: &str) -> Option { - self.auth_db.get_user(username).ok().flatten() + pub fn get_user(&self, username: &str) -> Option { + match self.provider.get_user(username) { + Ok(Some(user)) => Some(user), + Ok(None) => { + log::warn!("User {} not found", username); + None + } + Err(e) => { + log::error!("Failed to get user {}: {}", username, e); + None + } + } + } + + pub fn get_home_dir(&self, username: &str) -> Option { + match self.provider.get_home_dir(username) { + Ok(Some(dir)) => Some(dir), + Ok(None) => None, + Err(e) => { + log::error!("Failed to get home dir for {}: {}", username, e); + None + } + } } } #[cfg(test)] mod tests { use bcrypt::{hash, verify, DEFAULT_COST}; + use crate::provider::sqlite::SqliteProvider; + use std::sync::Arc; + + fn get_test_provider() -> Arc { + let db_path = format!( + "{}/../data/auth.sqlite", + std::env::var("CARGO_MANIFEST_DIR").unwrap() + ); + Arc::new(SqliteProvider::new(&db_path).unwrap()) + } #[test] fn test_bcrypt_verify_correct_password() { let password = "demo123"; let hashed = hash(password, DEFAULT_COST).unwrap(); - // 验证正确密码 let valid = verify(password, &hashed).unwrap(); assert!(valid); } @@ -56,32 +83,42 @@ mod tests { let wrong_password = "wrong123"; let hashed = hash(password, DEFAULT_COST).unwrap(); - // 验证错误密码 let valid = verify(wrong_password, &hashed).unwrap(); assert!(!valid); } #[test] - fn test_bcrypt_verify_empty_password() { - let password = ""; - let hashed = hash(password, DEFAULT_COST).unwrap(); + fn test_sftp_auth_verify_password() { + let provider = get_test_provider(); + let auth = SftpAuth::new(provider); - // 验证空密码 - let valid = verify(password, &hashed).unwrap(); - assert!(valid); - - // 验证非空密码对空hash - let valid = verify("test", &hashed).unwrap(); - assert!(!valid); + assert!(auth.verify_password("demo", "demo123")); + assert!(!auth.verify_password("demo", "wrong")); + assert!(!auth.verify_password("__nonexistent__", "any")); } #[test] - fn test_verify_database_hash() { - // 验证数据库中的实际hash(demo123) - let db_hash = "$2b$10$ha5wU.mOi8fHLJCfun860u2cfVopa04jwe/q82IKOwqp5uG70qsH6"; - let password = "demo123"; + fn test_sftp_auth_get_user() { + let provider = get_test_provider(); + let auth = SftpAuth::new(provider); - let valid = verify(password, db_hash).unwrap(); - assert!(valid); + let user = auth.get_user("demo"); + assert!(user.is_some()); + assert_eq!(user.unwrap().username, "demo"); + + let user = auth.get_user("__nonexistent__"); + assert!(user.is_none()); } -} + + #[test] + fn test_sftp_auth_get_home_dir() { + let provider = get_test_provider(); + let auth = SftpAuth::new(provider); + + let dir = auth.get_home_dir("demo"); + assert!(dir.is_some()); + + let dir = auth.get_home_dir("__nonexistent__"); + assert!(dir.is_none()); + } +} \ No newline at end of file