Refactor sftp/auth.rs to use DataProvider trait
Some checks failed
Test / test (push) Has been cancelled
Test / build (push) Has been cancelled

- SftpAuth now uses Arc<dyn DataProvider> instead of AuthDb
- Add verify_password(), get_user(), get_home_dir() methods
- Add unit tests for SftpAuth with SqliteProvider
- Maintain backward compatibility with existing tests
This commit is contained in:
Warren
2026-06-19 01:06:02 +08:00
parent 22fcc83535
commit 667d7209e2

View File

@@ -1,51 +1,78 @@
use crate::sync::{AuthDb, PgUser}; use crate::provider::{DataProvider, ProviderError, User};
use bcrypt::verify; use std::sync::Arc;
pub struct SftpAuth { pub struct SftpAuth {
auth_db: AuthDb, provider: Arc<dyn DataProvider>,
} }
impl SftpAuth { impl SftpAuth {
pub fn new(auth_db_path: &str) -> anyhow::Result<Self> { pub fn new(provider: Arc<dyn DataProvider>) -> Self {
let auth_db = AuthDb::new(auth_db_path)?; Self { provider }
Ok(Self { auth_db })
} }
pub fn verify_password(&self, username: &str, password: &str) -> bool { pub fn verify_password(&self, username: &str, password: &str) -> bool {
match self.auth_db.get_user(username) { match self.provider.check_password(username, password) {
Ok(Some(user)) if user.status == 1 => { Ok(true) => true,
verify(password, &user.password_hash).unwrap_or(false) Ok(false) => {
} log::warn!("Password verification failed for user {}", username);
Ok(Some(_)) => {
log::warn!("User {} is disabled", username);
false false
} }
Ok(None) => { Err(ProviderError::NotFound(_)) => {
log::warn!("User {} not found", username); log::warn!("User {} not found", username);
false false
} }
Err(e) => { Err(e) => {
log::error!("Failed to get user {}: {}", username, e); log::error!("Failed to verify password for {}: {}", username, e);
false false
} }
} }
} }
pub fn get_user(&self, username: &str) -> Option<PgUser> { pub fn get_user(&self, username: &str) -> Option<User> {
self.auth_db.get_user(username).ok().flatten() match self.provider.get_user(username) {
Ok(Some(user)) => Some(user),
Ok(None) => {
log::warn!("User {} not found", username);
None
}
Err(e) => {
log::error!("Failed to get user {}: {}", username, e);
None
}
}
}
pub fn get_home_dir(&self, username: &str) -> Option<String> {
match self.provider.get_home_dir(username) {
Ok(Some(dir)) => Some(dir),
Ok(None) => None,
Err(e) => {
log::error!("Failed to get home dir for {}: {}", username, e);
None
}
}
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use bcrypt::{hash, verify, DEFAULT_COST}; use bcrypt::{hash, verify, DEFAULT_COST};
use crate::provider::sqlite::SqliteProvider;
use std::sync::Arc;
fn get_test_provider() -> Arc<dyn DataProvider> {
let db_path = format!(
"{}/../data/auth.sqlite",
std::env::var("CARGO_MANIFEST_DIR").unwrap()
);
Arc::new(SqliteProvider::new(&db_path).unwrap())
}
#[test] #[test]
fn test_bcrypt_verify_correct_password() { fn test_bcrypt_verify_correct_password() {
let password = "demo123"; let password = "demo123";
let hashed = hash(password, DEFAULT_COST).unwrap(); let hashed = hash(password, DEFAULT_COST).unwrap();
// 验证正确密码
let valid = verify(password, &hashed).unwrap(); let valid = verify(password, &hashed).unwrap();
assert!(valid); assert!(valid);
} }
@@ -56,32 +83,42 @@ mod tests {
let wrong_password = "wrong123"; let wrong_password = "wrong123";
let hashed = hash(password, DEFAULT_COST).unwrap(); let hashed = hash(password, DEFAULT_COST).unwrap();
// 验证错误密码
let valid = verify(wrong_password, &hashed).unwrap(); let valid = verify(wrong_password, &hashed).unwrap();
assert!(!valid); assert!(!valid);
} }
#[test] #[test]
fn test_bcrypt_verify_empty_password() { fn test_sftp_auth_verify_password() {
let password = ""; let provider = get_test_provider();
let hashed = hash(password, DEFAULT_COST).unwrap(); let auth = SftpAuth::new(provider);
// 验证空密码 assert!(auth.verify_password("demo", "demo123"));
let valid = verify(password, &hashed).unwrap(); assert!(!auth.verify_password("demo", "wrong"));
assert!(valid); assert!(!auth.verify_password("__nonexistent__", "any"));
// 验证非空密码对空hash
let valid = verify("test", &hashed).unwrap();
assert!(!valid);
} }
#[test] #[test]
fn test_verify_database_hash() { fn test_sftp_auth_get_user() {
// 验证数据库中的实际hashdemo123 let provider = get_test_provider();
let db_hash = "$2b$10$ha5wU.mOi8fHLJCfun860u2cfVopa04jwe/q82IKOwqp5uG70qsH6"; let auth = SftpAuth::new(provider);
let password = "demo123";
let valid = verify(password, db_hash).unwrap(); let user = auth.get_user("demo");
assert!(valid); assert!(user.is_some());
assert_eq!(user.unwrap().username, "demo");
let user = auth.get_user("__nonexistent__");
assert!(user.is_none());
} }
}
#[test]
fn test_sftp_auth_get_home_dir() {
let provider = get_test_provider();
let auth = SftpAuth::new(provider);
let dir = auth.get_home_dir("demo");
assert!(dir.is_some());
let dir = auth.get_home_dir("__nonexistent__");
assert!(dir.is_none());
}
}