use std::path::PathBuf; use rusqlite::{Connection, params}; use bcrypt::verify; use super::{DataProvider, ProviderError, User}; /// SQLite 数据提供者 pub struct SqliteProvider { db_path: PathBuf, } impl SqliteProvider { pub fn new(db_path: &str) -> Result { let path = PathBuf::from(db_path); if !path.exists() { return Err(ProviderError::NotFound(format!( "Database not found: {}", db_path ))); } Ok(Self { db_path: path }) } fn open_conn(&self) -> Result { Connection::open(&self.db_path) .map_err(|e| ProviderError::Internal(format!("Failed to open database: {}", e))) } } impl DataProvider for SqliteProvider { fn get_user(&self, username: &str) -> Result, ProviderError> { let conn = self.open_conn()?; let result = conn.query_row( "SELECT username, password_hash, home_dir, permissions, uid, gid, status FROM sftpgo_users WHERE username = ?1 AND status = 1", params![username], |row| { Ok(User { username: row.get(0)?, password_hash: row.get(1)?, home_dir: PathBuf::from(row.get::<_, String>(2)?), permissions: row.get(3)?, uid: row.get::<_, i64>(4)? as u32, gid: row.get::<_, i64>(5)? as u32, status: row.get(6)?, }) }, ); match result { Ok(user) => Ok(Some(user)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(ProviderError::Internal(format!( "Database query error: {}", e ))), } } fn check_password(&self, username: &str, password: &str) -> Result { let hash = match self.get_user(username)? { Some(user) => user.password_hash, None => return Ok(false), }; verify(password, &hash) .map_err(|e| ProviderError::Internal(format!("bcrypt verify error: {}", e))) } fn get_home_dir(&self, username: &str) -> Result, ProviderError> { Ok(self.get_user(username)?.map(|u| u.home_dir.to_string_lossy().to_string())) } fn get_public_keys(&self, username: &str) -> Result, ProviderError> { let _ = username; Ok(Vec::new()) } fn get_user_groups(&self, username: &str) -> Result, ProviderError> { let conn = self.open_conn()?; let groups: Vec = conn .prepare("SELECT group_name FROM users_groups_mapping WHERE username = ?1") .map_err(|e| ProviderError::Internal(format!("Query prepare error: {}", e)))? .query_map(params![username], |row| row.get(0)) .map_err(|e| ProviderError::Internal(format!("Query map error: {}", e)))? .filter_map(|r| r.ok()) .collect(); Ok(groups) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_provider_not_found() { let provider = SqliteProvider::new("/tmp/nonexistent.db"); assert!(provider.is_err()); } #[test] fn test_provider_get_user() { let provider = SqliteProvider::new("data/auth.sqlite").unwrap(); let user = provider.get_user("demo").unwrap(); assert!(user.is_some()); assert_eq!(user.unwrap().username, "demo"); } #[test] fn test_provider_nonexistent_user() { let provider = SqliteProvider::new("data/auth.sqlite").unwrap(); let user = provider.get_user("__nonexistent__").unwrap(); assert!(user.is_none()); } #[test] fn test_check_password_valid() { let provider = SqliteProvider::new("data/auth.sqlite").unwrap(); let valid = provider.check_password("demo", "demo123").unwrap(); assert!(valid); } #[test] fn test_check_password_invalid() { let provider = SqliteProvider::new("data/auth.sqlite").unwrap(); let valid = provider.check_password("demo", "wrong").unwrap(); assert!(!valid); } #[test] fn test_get_home_dir() { let provider = SqliteProvider::new("data/auth.sqlite").unwrap(); let dir = provider.get_home_dir("demo").unwrap(); assert!(dir.is_some()); } }