use anyhow::Result; use std::process::Command; use std::time::Duration; use tokio::time::timeout; pub struct ShellHandler { user_id: String, config: std::sync::Arc, audit: crate::sftp::audit::AuditLog, } impl ShellHandler { pub fn new( user_id: &str, config: std::sync::Arc, audit: crate::sftp::audit::AuditLog, ) -> Self { Self { user_id: user_id.to_string(), config, audit, } } pub fn check_command_permission(&self, command: &str) -> bool { // 1. 检查是否启用shell if !self.config.shell.enabled { log::warn!("Shell disabled for user {}", self.user_id); return false; } // 2. 检查命令长度 if command.len() > self.config.shell.max_command_length { log::warn!( "Command too long for user {}: {}", self.user_id, command.len() ); return false; } // 3. 检查黑名单(优先级最高) let cmd_name = command.split_whitespace().next().unwrap_or(""); for forbidden in &self.config.shell.forbidden_commands { if cmd_name == forbidden || command.starts_with(&format!("{} ", forbidden)) { log::warn!("Forbidden command '{}' for user {}", cmd_name, self.user_id); self.audit .log_error(&self.user_id, "shell_check", command, "forbidden"); return false; } } // 4. 检查白名单(如果配置了白名单) if !self.config.shell.allowed_commands.is_empty() { if !self .config .shell .allowed_commands .contains(&cmd_name.to_string()) { log::warn!( "Command '{}' not in whitelist for user {}", cmd_name, self.user_id ); self.audit .log_error(&self.user_id, "shell_check", command, "not_in_whitelist"); return false; } } log::info!("Command '{}' allowed for user {}", cmd_name, self.user_id); true } pub async fn execute_command(&self, command: &str) -> Result { // 1. 检查权限 if !self.check_command_permission(command) { return Err(anyhow::anyhow!("Command not allowed")); } // 2. 执行命令(带timeout) let timeout_duration = Duration::from_secs(self.config.shell.timeout_seconds); let result = timeout(timeout_duration, async { // 使用系统shell执行命令 let output = Command::new(&self.config.shell.shell_path) .arg("-c") .arg(command) .output(); match output { Ok(output) => { let stdout = String::from_utf8_lossy(&output.stdout).to_string(); let stderr = String::from_utf8_lossy(&output.stderr).to_string(); if output.status.success() { Ok(stdout) } else { Err(anyhow::anyhow!("Command failed: {}", stderr)) } } Err(e) => Err(anyhow::anyhow!("Command execution error: {}", e)), } }) .await; // 3. 处理结果 match result { Ok(Ok(output)) => { self.audit.log_success(&self.user_id, "shell_exec", command); Ok(output) } Ok(Err(e)) => { self.audit .log_error(&self.user_id, "shell_exec", command, &e.to_string()); Err(e) } Err(_) => { self.audit .log_error(&self.user_id, "shell_exec", command, "timeout"); Err(anyhow::anyhow!( "Command timeout after {}s", self.config.shell.timeout_seconds )) } } } pub fn get_shell_path(&self) -> &str { &self.config.shell.shell_path } pub fn is_shell_enabled(&self) -> bool { self.config.shell.enabled } } #[cfg(test)] mod tests { use super::*; use crate::sftp::audit::AuditLog; use crate::sftp::config::SftpConfig; use tempfile::TempDir; fn create_test_shell_handler() -> ShellHandler { let mut config = SftpConfig::default(); config.shell.enabled = true; config.shell.allowed_commands = vec!["ls".to_string(), "pwd".to_string()]; config.shell.forbidden_commands = vec!["rm".to_string(), "sudo".to_string()]; let temp_dir = TempDir::new().unwrap(); let audit_log_path = temp_dir .path() .join("shell_audit.log") .to_string_lossy() .to_string(); let audit = AuditLog::new(&audit_log_path).unwrap(); ShellHandler::new("test_user", std::sync::Arc::new(config), audit) } #[test] fn test_check_command_permission_allowed() { let handler = create_test_shell_handler(); // 测试允许的命令 assert!(handler.check_command_permission("ls")); assert!(handler.check_command_permission("pwd")); assert!(handler.check_command_permission("ls -la")); } #[test] fn test_check_command_permission_forbidden() { let handler = create_test_shell_handler(); // 测试禁止的命令 assert!(!handler.check_command_permission("rm")); assert!(!handler.check_command_permission("rm -rf")); assert!(!handler.check_command_permission("sudo ls")); } #[test] fn test_check_command_permission_not_in_whitelist() { let handler = create_test_shell_handler(); // 测试不在白名单的命令 assert!(!handler.check_command_permission("cat")); assert!(!handler.check_command_permission("grep")); } #[test] fn test_check_command_permission_shell_disabled() { let mut config = SftpConfig::default(); config.shell.enabled = false; // 禁用shell let temp_dir = TempDir::new().unwrap(); let audit = AuditLog::new(&temp_dir.path().join("audit.log").to_string_lossy()).unwrap(); let handler = ShellHandler::new("test_user", std::sync::Arc::new(config), audit); // 任何命令都不允许 assert!(!handler.check_command_permission("ls")); } #[test] fn test_check_command_permission_too_long() { let mut config = SftpConfig::default(); config.shell.enabled = true; config.shell.max_command_length = 10; let temp_dir = TempDir::new().unwrap(); let audit = AuditLog::new(&temp_dir.path().join("audit.log").to_string_lossy()).unwrap(); let handler = ShellHandler::new("test_user", std::sync::Arc::new(config), audit); // 测试超长命令 let long_command = "ls -la /very/long/path/that/exceeds/max/length"; assert!(!handler.check_command_permission(long_command)); } }