diff --git a/batch_2.bin b/batch_2.bin new file mode 100644 index 0000000..2ebb21a Binary files /dev/null and b/batch_2.bin differ diff --git a/batch_3.bin b/batch_3.bin new file mode 100644 index 0000000..caadc77 Binary files /dev/null and b/batch_3.bin differ diff --git a/batch_4.bin b/batch_4.bin new file mode 100644 index 0000000..0b412f0 Binary files /dev/null and b/batch_4.bin differ diff --git a/batch_5.bin b/batch_5.bin new file mode 100644 index 0000000..8853bb1 Binary files /dev/null and b/batch_5.bin differ diff --git a/data/auth.sqlite.backup b/data/auth.sqlite.backup new file mode 100644 index 0000000..5cac574 Binary files /dev/null and b/data/auth.sqlite.backup differ diff --git a/data/phase16_2_performance_analysis.md b/data/phase16_2_performance_analysis.md new file mode 100644 index 0000000..f71ea3a --- /dev/null +++ b/data/phase16_2_performance_analysis.md @@ -0,0 +1,83 @@ +# Phase 16.2:性能优化分析 + +**测试时间**:2026-06-17 22:30 +**目标**:将传输速度从780 KB/s提升到21-36 MB/s + +## 性能瓶颈分析 ⭐⭐⭐⭐⭐ + +**当前配置**: +- Window size: 2MB (local_window = 2097152) +- poll timeout: 10ms (每iteration) +- max_poll_iterations: 5000 (50s总timeout) +- stdin timeout: 3000 iterations (30s) + +**瓶颈1:poll iteration overhead ⭐⭐⭐⭐⭐** +- 每iteration: 10ms poll timeout +- 总iteration: 5000次 +- 每iteration开销: log输出 + try_wait() check +- **估算开销**: 5000 iterations * 10ms = 50秒(理论最大) +- **实际开销**: 20MB传输用了24秒,说明poll overhead占用了大量时间 + +**瓶颈2:Window size太小 ⭐⭐⭐⭐** +- OpenSSH默认: 2MB +- 实际测试: 20MB传输用了24秒 +- **问题**: Window size限制了单次传输的数据量 +- **解决方案**: 增加到16MB或32MB + +**瓶颈3:AES-CTR encryption overhead ⭐⭐⭐** +- AES-256-CTR加密/解密: 每packet需要计算 +- MAC计算: HMAC-SHA256 (每packet) +- **估算**: 每packet约100-200us开销 +- **影响**: 780 KB/s可能受encryption限制 + +**瓶颈4:sshbuf zero-copy性能 ⭐⭐** +- sshbuf实现: 339行 +- **问题**: 未进行性能测试 +- **可能**: zero-copy优化不足 + +## 性能优化方案 ⭐⭐⭐⭐⭐ + +**方案1:减少poll iteration overhead(优先 ⭐⭐⭐⭐⭐)** +- 增加poll timeout: 从10ms改到100ms +- 减少iteration次数: 从5000改到500 +- 减少log频率: 从每10次改到每50次 +- **预期效果**: 减少50-80% poll overhead + +**方案2:增加Window size ⭐⭐⭐⭐** +- 从2MB增加到16MB或32MB +- 动态调整Window size(根据传输速度) +- **预期效果**: 提升单次传输数据量 + +**方案3:优化encryption ⭐⭐⭐** +- 使用AES-NI硬件加速(检查是否已启用) +- 减少MAC计算频率(批量计算) +- **预期效果**: 减少encryption overhead + +**方案4:sshbuf性能测试 ⭐⭐** +- 编写benchmark测试sshbuf性能 +- 对比临时buffer vs zero-copy +- **预期效果**: 验证zero-copy优势 + +## 实施计划 ⭐⭐⭐⭐⭐ + +**Phase 16.2.1:减少poll overhead(立即实施)** +- 修改poll timeout: 10ms → 100ms +- 修改iteration次数: 5000 → 500 +- 修改log频率: 每10次 → 每50次 +- **预期传输速度**: 从780 KB/s提升到10-20 MB/s + +**Phase 16.2.2:增加Window size** +- 从2MB增加到16MB +- 测试传输速度变化 + +**Phase 16.2.3:encryption优化** +- 检查AES-NI是否启用 +- 如果未启用,添加AES-NI支持 + +--- + +**立即实施Phase 16.2.1**(减少poll overhead) + +--- + +**最后更新**:2026-06-17 22:30 diff --git a/data/sftp_client_test_recommendations.md b/data/sftp_client_test_recommendations.md new file mode 100644 index 0000000..87c29fd --- /dev/null +++ b/data/sftp_client_test_recommendations.md @@ -0,0 +1,441 @@ +# SFTP Client 测试建议与分析 + +**测试目标**:验证 MarkBaseSSH SFTP 实现(Phase 7)的兼容性和稳定性 +**测试环境**:MarkBaseSSH server (port 2024) + macOS client +**测试用户**:demo (password: demo123) + +--- + +## 推荐测试方案 ⭐⭐⭐⭐⭐ + +### 方案 1: OpenSSH sftp client(必须测试) + +**推荐等级**:⭐⭐⭐⭐⭐ **最高优先级** + +**理由**: +- ✅ OpenSSH 是标准实现,MarkBaseSSH 必须完全兼容 +- ✅ macOS 内置,无需额外安装 +- ✅ 命令行工具,适合自动化测试 +- ✅ 完整的 SFTP 协议支持(SSH_FXP_* 所有 packet) +- ✅ 错误信息清晰,易于调试 + +**测试命令**: +```bash +# 基本连接测试 +sftp -P 2024 demo@127.0.0.1 + +# 批量测试脚本 +sftp -P 2024 -b /tmp/sftp_test_batch.txt demo@127.0.0.1 + +# 大文件传输测试 +sftp -P 2024 demo@127.0.0.1 <&1 + +# ✅ 預期: "SSH OK" + debug1: Authentication succeeded (password) +``` + +--- + +## 2. SFTP 操作 (Phase 7 + batch fix) + +### 2.1 基礎功能 + +```bash +timeout 30 sftp -o StrictHostKeyChecking=no \ + -o UserKnownHostsFile=/dev/null -P 2024 demo@127.0.0.1 << 'EOF' +pwd +ls +mkdir sftp_test_dir +cd sftp_test_dir +put /etc/hostname test_upload.txt +get test_upload.txt /tmp/test_download.txt +rm test_upload.txt +cd .. +rmdir sftp_test_dir +!md5 /tmp/test_download.txt +bye +EOF +``` + +### 2.2 企業級錯誤處理 + +```bash +timeout 10 sftp -P 2024 demo@127.0.0.1 << 'EOF' +mkdir /etc/forbidden +get /root/.ssh/id_rsa +stat /nonexistent +rm /nonexistent +bye +EOF +# ✅ 預期: Permission denied / No such file (非 generic Failure) +``` + +### 2.3 大檔案傳輸 (MD5驗證, 每次 2MB / 5MB / 10MB) + +```bash +# 建立測試檔 +dd if=/dev/urandom of=/tmp/test_2m.bin bs=1M count=2 2>/dev/null +dd if=/dev/urandom of=/tmp/test_5m.bin bs=1M count=5 2>/dev/null +dd if=/dev/urandom of=/tmp/test_10m.bin bs=1M count=10 2>/dev/null + +# 測試每個檔案上傳+下載 +for f in test_2m test_5m test_10m; do + echo "=== Testing $f ===" + md5sum /tmp/${f}.bin | awk '{print $1}' > /tmp/${f}.md5 + + timeout 120 sftp -P 2024 demo@127.0.0.1 << EOFSFTP 2>&1 | grep -v debug +put /tmp/${f}.bin ${f}.bin +get ${f}.bin /tmp/${f}_dl.bin +rm ${f}.bin +bye +EOFSFTP + + md5sum -c /tmp/${f}.md5 <<< "$(md5sum /tmp/${f}_dl.bin | awk '{print $1}')" 2>/dev/null \ + && echo "✅ $f PASS" || echo "❌ $f FAIL" +done +``` + +### 2.4 多檔案批次傳輸 + +```bash +# 建立10個小檔 +for i in $(seq 1 10); do + dd if=/dev/urandom of=/tmp/batch_${i}.bin bs=1K count=64 2>/dev/null +done + +# 批次上傳 + 下載 + 驗證 +timeout 60 sftp -P 2024 demo@127.0.0.1 << 'EOF' +lcd /tmp +cd /tmp +put batch_1.bin batch_2.bin batch_3.bin batch_4.bin batch_5.bin +put batch_6.bin batch_7.bin batch_8.bin batch_9.bin batch_10.bin +get batch_1.bin batch_2.bin batch_3.bin batch_4.bin batch_5.bin +get batch_6.bin batch_7.bin batch_8.bin batch_9.bin batch_10.bin +rm batch_1.bin batch_2.bin batch_3.bin batch_4.bin batch_5.bin +rm batch_6.bin batch_7.bin batch_8.bin batch_9.bin batch_10.bin +bye +EOF + +# MD5比對 +for i in $(seq 1 10); do + [ "$(md5sum /tmp/batch_${i}.bin | awk '{print $1}')" = \ + "$(md5sum /tmp/batch_${i}.bin 2>/dev/null | awk '{print $1}')" ] \ + && echo "✅ batch_${i}" || echo "❌ batch_${i}" +done +``` + +--- + +## 3. SCP 測試 (Phase 8) + +```bash +# 3.1 上傳 +timeout 30 scp -P 2024 -o StrictHostKeyChecking=no \ + /tmp/test_5m.bin demo@127.0.0.1:scp_test.bin + +# 3.2 下載 +timeout 30 scp -P 2024 -o StrictHostKeyChecking=no \ + demo@127.0.0.1:scp_test.bin /tmp/scp_dl.bin + +# 3.3 目錄傳輸 +mkdir -p /tmp/scp_dir && for i in 1 2 3; do + dd if=/dev/urandom of=/tmp/scp_dir/file_${i}.bin bs=1M count=1 2>/dev/null +done +timeout 30 scp -P 2024 -r -o StrictHostKeyChecking=no \ + /tmp/scp_dir demo@127.0.0.1:scp_dir_remote + +# 3.4 完整驗證 +md5sum /tmp/test_5m.bin /tmp/scp_dl.bin +md5sum /tmp/scp_dir/* +rm -rf /tmp/scp_dir +``` + +--- + +## 4. rsync 測試 (Phase 16 Final: subprocess) + +```bash +export RSYNC_RSH="ssh -p 2024 -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null" + +# 4.1 上傳 +timeout 30 rsync -avz --rsh="$RSYNC_RSH" /tmp/test_5m.bin demo@127.0.0.1:rsync_test.bin + +# 4.2 下載 +timeout 30 rsync -avz --rsh="$RSYNC_RSH" demo@127.0.0.1:rsync_test.bin /tmp/rsync_dl.bin + +# 4.3 差異傳輸 (delta transfer) +echo "extra data" >> /tmp/test_5m.bin +timeout 30 rsync -avz --rsh="$RSYNC_RSH" /tmp/test_5m.bin demo@127.0.0.1:rsync_test.bin + +# 4.4 大檔案 (50MB-100MB) +dd if=/dev/urandom of=/tmp/test_50m.bin bs=1M count=50 2>/dev/null +md5sum /tmp/test_50m.bin > /tmp/test_50m.md5 + +timeout 120 rsync -avz --rsh="$RSYNC_RSH" /tmp/test_50m.bin demo@127.0.0.1:rsync_large.bin +timeout 120 rsync -avz --rsh="$RSYNC_RSH" demo@127.0.0.1:rsync_large.bin /tmp/rsync_50m_dl.bin +md5sum -c /tmp/test_50m.md5 <<< "$(md5sum /tmp/rsync_50m_dl.bin | awk '{print $1}')" + +# 4.5 目錄同步 +mkdir -p /tmp/rsync_dir && for i in $(seq 1 5); do + dd if=/dev/urandom of=/tmp/rsync_dir/f_${i}.bin bs=1M count=2 2>/dev/null +done +timeout 60 rsync -avz --rsh="$RSYNC_RSH" /tmp/rsync_dir/ demo@127.0.0.1:rsync_dir/ +timeout 60 rsync -avz --rsh="$RSYNC_RSH" demo@127.0.0.1:rsync_dir/ /tmp/rsync_dir_dl/ +``` + +--- + +## 5. SSH 通道指令執行 (Phase 6) + +```bash +# 5.1 基本指令 +ssh -p 2024 demo@127.0.0.1 'echo "hello"; whoami; pwd; uname -a' + +# 5.2 多指令管線 +ssh -p 2024 demo@127.0.0.1 'ls -la /tmp | head -5; echo "---"; df -h | head -3' + +# 5.3 環境變數 +ssh -p 2024 demo@127.0.0.1 'export TEST_VAR=hello; echo $TEST_VAR' +``` + +--- + +## 6. 壓力測試 + +### 6.1 連續連線 + +```bash +for i in $(seq 1 20); do + echo "=== Iteration $i ===" + timeout 10 ssh -p 2024 demo@127.0.0.1 'echo OK' 2>&1 | grep "^OK$" \ + && echo "✅" || echo "❌" +done +``` + +### 6.2 並行連線 + +```bash +for i in $(seq 1 5); do + (timeout 30 sftp -P 2024 demo@127.0.0.1 << EOF & +put /tmp/test_5m.bin parallel_${i}.bin +get parallel_${i}.bin /tmp/parallel_${i}_dl.bin +rm parallel_${i}.bin +bye +EOF +) & +done +wait +echo "All parallel transfers done" +``` + +--- + +## 7. 清理 + +```bash +# 清理伺服器端檔案 (需在 SFTP session 中執行) +rm scp_test.bin rsync_test.bin rsync_large.bin +rm -rf rsync_dir scp_dir_remote + +# 清理本機暫存 +rm -f /tmp/test_upload.txt /tmp/test_download.txt +rm -f /tmp/test_2m.bin /tmp/test_5m.bin /tmp/test_10m.bin +rm -f /tmp/test_2m_dl.bin /tmp/test_5m_dl.bin /tmp/test_10m_dl.bin +rm -f /tmp/rsync_test.bin /tmp/rsync_dl.bin /tmp/rsync_large.bin +rm -f /tmp/test_50m.bin /tmp/rsync_50m_dl.bin +rm -f /tmp/scp_test.bin /tmp/scp_dl.bin +rm -rf /tmp/rsync_dir /tmp/rsync_dir_dl /tmp/scp_dir +rm -f /tmp/batch_*.bin /tmp/parallel_*.bin +``` + +--- + +## 驗證矩陣 + +| 編號 | 測試項目 | 預期結果 | 檢查方法 | +|------|----------|----------|----------| +| 1.1 | SSH連線+認證 | `SSH OK` 輸出 | stdout | +| 2.1 | SFTP基礎功能 | 所有操作成功 | exit code=0 | +| 2.2 | SFTP錯誤處理 | 非 generic 錯誤 | 日誌比對 | +| 2.3 | SFTP大檔案 | MD5吻合 | md5sum | +| 2.4 | SFTP批次檔案 | 所有MD5吻合 | md5sum | +| 3.1 | SCP上傳 | 檔案存在 | md5sum | +| 3.2 | SCP下載 | MD5吻合 | md5sum | +| 3.3 | SCP目錄 | 結構一致 | ls -la | +| 4.1 | rsync上傳 | MD5吻合 | md5sum | +| 4.2 | rsync下載 | MD5吻合 | md5sum | +| 4.3 | rsync增量 | 僅傳差異 | speedup > 1 | +| 4.4 | rsync 50MB | MD5吻合 | md5sum | +| 5.1 | Shell指令 | 正確輸出 | stdout | +| 6.1 | 連續連線20次 | 100%成功 | 計數 | +| 6.2 | 並行xfer | 所有MD5吻合 | md5sum | diff --git a/filetree-rocksdb/src/lib.rs b/filetree-rocksdb/src/lib.rs index 4911628..72d4d69 100644 --- a/filetree-rocksdb/src/lib.rs +++ b/filetree-rocksdb/src/lib.rs @@ -312,9 +312,9 @@ impl FileTreeRocksDB { label: &str, file_uuid: &str, sha256: Option<&str>, - original_name: &str, + _original_name: &str, file_size: Option, - mime_type: Option<&str>, + _mime_type: Option<&str>, parent_id: Option<&str>, ) -> FileNode { FileNode { diff --git a/filetree-sled/src/lib.rs b/filetree-sled/src/lib.rs index f804f7d..4b38dd7 100644 --- a/filetree-sled/src/lib.rs +++ b/filetree-sled/src/lib.rs @@ -286,9 +286,9 @@ impl FileTreeSled { label: &str, file_uuid: &str, sha256: Option<&str>, - original_name: &str, + _original_name: &str, file_size: Option, - mime_type: Option<&str>, + _mime_type: Option<&str>, parent_id: Option<&str>, ) -> FileNode { FileNode { @@ -314,7 +314,7 @@ impl FileTreeSled { pub fn build_tree(nodes: &[FileNode]) -> Vec { let mut roots = Vec::new(); - let node_map: HashMap = + let _node_map: HashMap = nodes.iter().map(|n| (n.node_id.clone(), n)).collect(); for node in nodes { diff --git a/filetree/src/lib.rs b/filetree/src/lib.rs index 8b2d1e8..4de91fc 100644 --- a/filetree/src/lib.rs +++ b/filetree/src/lib.rs @@ -630,28 +630,28 @@ mod tests { } } - // 新增:创建虚拟树类型 - pub fn create_tree_type( - conn: &Connection, - tree_type: &str, - tree_name: &str, - description: &str, - is_system_defined: bool, - ) -> Result<()> { - conn.execute( - "INSERT INTO tree_registry (tree_type, tree_name, description, is_system_defined) +// 新增:创建虚拟树类型 +pub fn create_tree_type( + conn: &Connection, + tree_type: &str, + tree_name: &str, + description: &str, + is_system_defined: bool, +) -> Result<()> { + conn.execute( + "INSERT INTO tree_registry (tree_type, tree_name, description, is_system_defined) VALUES (?1, ?2, ?3, ?4)", - rusqlite::params![tree_type, tree_name, description, is_system_defined as i64], - )?; - Ok(()) - } + rusqlite::params![tree_type, tree_name, description, is_system_defined as i64], + )?; + Ok(()) +} - // 新增:获取所有虚拟树类型 - // 新增:删除虚拟树类型(仅限用户自定义) - pub fn delete_tree_type(conn: &Connection, tree_type: &str) -> Result<()> { - conn.execute( - "DELETE FROM tree_registry WHERE tree_type = ?1 AND is_system_defined = 0", - [tree_type], - )?; - Ok(()) - } +// 新增:获取所有虚拟树类型 +// 新增:删除虚拟树类型(仅限用户自定义) +pub fn delete_tree_type(conn: &Connection, tree_type: &str) -> Result<()> { + conn.execute( + "DELETE FROM tree_registry WHERE tree_type = ?1 AND is_system_defined = 0", + [tree_type], + )?; + Ok(()) +} diff --git a/large_dl_test.bin b/large_dl_test.bin new file mode 100644 index 0000000..5f43365 Binary files /dev/null and b/large_dl_test.bin differ diff --git a/large_test.bin b/large_test.bin new file mode 100644 index 0000000..5f43365 Binary files /dev/null and b/large_test.bin differ diff --git a/markbase-core/src/api/handlers/mod.rs b/markbase-core/src/api/handlers/mod.rs index 753d89d..98f594e 100644 --- a/markbase-core/src/api/handlers/mod.rs +++ b/markbase-core/src/api/handlers/mod.rs @@ -1,5 +1,5 @@ // API Handlers Module -// +// // This module provides space for future modular API handlers. // Current handlers are implemented in server.rs for stability. // @@ -13,4 +13,4 @@ // - view.rs: Category/Series view handlers // - static.rs: Static page handlers -pub use crate::server::AppState; \ No newline at end of file +pub use crate::server::AppState; diff --git a/markbase-core/src/api/mod.rs b/markbase-core/src/api/mod.rs index 20a4340..158db53 100644 --- a/markbase-core/src/api/mod.rs +++ b/markbase-core/src/api/mod.rs @@ -9,4 +9,4 @@ pub mod handlers; // - Clear separation of concerns // - Easier maintenance for new features // - Gradual migration path from server.rs -// - Independent testing per handler module \ No newline at end of file +// - Independent testing per handler module diff --git a/markbase-core/src/archive/config.rs b/markbase-core/src/archive/config.rs index 62fae26..1462a6a 100644 --- a/markbase-core/src/archive/config.rs +++ b/markbase-core/src/archive/config.rs @@ -1,22 +1,21 @@ // Archive Configuration - User Configurable Options use anyhow::Result; -use serde::{Deserialize, Serialize}; -use std::path::Path; use log::warn; +use serde::{Deserialize, Serialize}; /// Archive Configuration #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ArchiveConfig { // Optional formats (controversial) - pub enable_rar: bool, // ⚠️ Legal risk (RARLAB patent) - pub enable_xz: bool, // ⚠️ External dependency (liblzma) - pub enable_7z: bool, // ⚠️ Unstable library - + pub enable_rar: bool, // ⚠️ Legal risk (RARLAB patent) + pub enable_xz: bool, // ⚠️ External dependency (liblzma) + pub enable_7z: bool, // ⚠️ Unstable library + // Performance settings pub cache_size_mb: u64, pub max_concurrent_extractions: usize, - + // Security settings pub max_decompression_ratio: u64, pub max_file_size_mb: u64, @@ -29,11 +28,11 @@ impl Default for ArchiveConfig { enable_rar: false, enable_xz: false, enable_7z: false, - + // Performance cache_size_mb: 100, max_concurrent_extractions: 4, - + // Security max_decompression_ratio: 1000, max_file_size_mb: 1024, @@ -46,45 +45,46 @@ impl ArchiveConfig { pub fn load(path: &str) -> Result { let content = std::fs::read_to_string(path)?; let config: ArchiveConfig = toml::from_str(&content)?; - + // Validate configuration config.validate()?; - + Ok(config) } - + /// Save configuration to TOML file pub fn save(&self, path: &str) -> Result<()> { let content = toml::to_string_pretty(self)?; std::fs::write(path, content)?; Ok(()) } - + /// Validate configuration pub fn validate(&self) -> Result<()> { if self.cache_size_mb > 1000 { warn!("Cache size > 1GB may cause memory pressure"); } - + if self.max_concurrent_extractions > 10 { warn!("Concurrent extractions > 10 may cause resource exhaustion"); } - + if self.max_decompression_ratio < 10 { return Err(anyhow::anyhow!("Max decompression ratio too low (min 10)")); } - - if self.max_file_size_mb > 10_000 { // 10GB + + if self.max_file_size_mb > 10_000 { + // 10GB warn!("Max file size > 10GB may cause disk space issues"); } - + Ok(()) } - + /// Generate default config file template pub fn generate_template() -> String { let config = Self::default(); - + format!( "# === Archive Configuration === # MarkBase Universal Compression Format Support @@ -138,33 +138,33 @@ max_file_size_mb = {} # File size limit (MB) #[cfg(test)] mod tests { use super::*; - + #[test] fn test_default_config() { let config = ArchiveConfig::default(); - + assert_eq!(config.enable_rar, false); assert_eq!(config.enable_xz, false); assert_eq!(config.enable_7z, false); assert_eq!(config.cache_size_mb, 100); assert_eq!(config.max_decompression_ratio, 1000); } - + #[test] fn test_config_validation() { let config = ArchiveConfig { max_decompression_ratio: 5, ..Default::default() }; - + assert!(config.validate().is_err()); } - + #[test] fn test_config_template() { let template = ArchiveConfig::generate_template(); - + assert!(template.contains("enable_rar = false")); assert!(template.contains("⚠️ RAR Format Legal Risk Warning")); } -} \ No newline at end of file +} diff --git a/markbase-core/src/archive/detector.rs b/markbase-core/src/archive/detector.rs index 99f96de..2623155 100644 --- a/markbase-core/src/archive/detector.rs +++ b/markbase-core/src/archive/detector.rs @@ -1,9 +1,9 @@ // Format Detector - Automatic Detection Based on Magic Numbers +use anyhow::Result; use std::fs::File; use std::io::Read; use std::path::Path; -use anyhow::Result; use crate::archive::processor::ArchiveFormat; @@ -18,64 +18,61 @@ impl FormatDetector { // ZIP: 50 4B 03 04 or 50 4B 05 06 (empty) or 50 4B 07 08 (spanned) (vec![0x50, 0x4B, 0x03, 0x04], ArchiveFormat::Zip, 4), (vec![0x50, 0x4B, 0x05, 0x06], ArchiveFormat::Zip, 4), - // GZIP: 1F 8B (vec![0x1F, 0x8B], ArchiveFormat::Gzip, 2), ]; - + Self { magic_table } } - + /// Detect file format based on Magic Number pub fn detect(&self, path: &Path) -> Result { let mut file = File::open(path)?; let mut buffer = vec![0u8; 512]; - + let bytes_read = file.read(&mut buffer)?; if bytes_read < 2 { return Ok(ArchiveFormat::Unknown); } - + // Match Magic Numbers for (magic, format, offset) in &self.magic_table { if buffer.len() >= *offset && buffer[0..magic.len()] == *magic { return Ok(*format); } } - + // Special detection: TAR format (check ustar magic at offset 257) - if buffer.len() >= 262 { - if &buffer[257..262] == b"ustar" { + if buffer.len() >= 262 + && &buffer[257..262] == b"ustar" { return Ok(ArchiveFormat::Tar); } - } - + Ok(ArchiveFormat::Unknown) } - + /// Detect composite format (e.g., TAR.GZ) pub fn detect_composite(&self, path: &Path) -> Result { let format = self.detect(path)?; - + // If GZIP, check if it's TAR.GZ (by extension for now) if format == ArchiveFormat::Gzip { - let ext = path.extension() + let ext = path + .extension() .and_then(|e| e.to_str()) .unwrap_or("") .to_lowercase(); - + if ext == "tgz" || ext == "gz" { // Check if filename contains .tar - let filename = path.file_name() - .and_then(|n| n.to_str()) - .unwrap_or(""); - + let filename = path.file_name().and_then(|n| n.to_str()).unwrap_or(""); + if filename.contains(".tar") { return Ok(ArchiveFormat::TarGzip); } } } - + Ok(format) } } @@ -89,51 +86,51 @@ impl Default for FormatDetector { #[cfg(test)] mod tests { use super::*; - use tempfile::TempDir; use std::io::Write; - + use tempfile::TempDir; + #[test] fn test_detect_zip() { let temp_dir = TempDir::new().unwrap(); let zip_path = temp_dir.path().join("test.zip"); - + // Create minimal ZIP file header let mut file = File::create(&zip_path).unwrap(); file.write_all(&[0x50, 0x4B, 0x03, 0x04]).unwrap(); - + let detector = FormatDetector::new(); let format = detector.detect(&zip_path).unwrap(); - + assert_eq!(format, ArchiveFormat::Zip); } - + #[test] fn test_detect_gzip() { let temp_dir = TempDir::new().unwrap(); let gz_path = temp_dir.path().join("test.gz"); - + // Create minimal GZIP file header let mut file = File::create(&gz_path).unwrap(); file.write_all(&[0x1F, 0x8B]).unwrap(); - + let detector = FormatDetector::new(); let format = detector.detect(&gz_path).unwrap(); - + assert_eq!(format, ArchiveFormat::Gzip); } - + #[test] fn test_detect_unknown() { let temp_dir = TempDir::new().unwrap(); let unknown_path = temp_dir.path().join("test.bin"); - + // Create unknown file let mut file = File::create(&unknown_path).unwrap(); file.write_all(b"unknown data").unwrap(); - + let detector = FormatDetector::new(); let format = detector.detect(&unknown_path).unwrap(); - + assert_eq!(format, ArchiveFormat::Unknown); } -} \ No newline at end of file +} diff --git a/markbase-core/src/archive/metadata.rs b/markbase-core/src/archive/metadata.rs index 653bc60..d6bca21 100644 --- a/markbase-core/src/archive/metadata.rs +++ b/markbase-core/src/archive/metadata.rs @@ -1,8 +1,8 @@ // Metadata Module - Archive Entry Metadata Management +use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::time::SystemTime; -use serde::{Deserialize, Serialize}; use crate::archive::processor::ArchiveFormat; @@ -29,7 +29,7 @@ impl ArchiveMetadata { self.total_size as f64 / self.compressed_size as f64 } } - + /// Check if compression ratio exceeds limit (Zip Bomb detection) pub fn check_zip_bomb(&self, max_ratio: u64) -> bool { self.actual_ratio() > max_ratio as f64 @@ -65,7 +65,7 @@ impl ArchiveEntry { checksum: None, } } - + /// Create file entry pub fn file(path: PathBuf, size: u64, compressed_size: u64) -> Self { Self { @@ -104,7 +104,7 @@ impl ExtractResult { warnings: Vec::new(), } } - + pub fn success_rate(&self) -> f64 { if self.total_files == 0 { 100.0 @@ -113,11 +113,11 @@ impl ExtractResult { (success_count as f64 / self.total_files as f64) * 100.0 } } - + pub fn has_failures(&self) -> bool { !self.failed_files.is_empty() } - + pub fn has_warnings(&self) -> bool { !self.warnings.is_empty() } @@ -126,7 +126,7 @@ impl ExtractResult { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_archive_metadata() { let metadata = ArchiveMetadata { @@ -140,37 +140,37 @@ mod tests { created_time: None, modified_time: None, }; - + assert_eq!(metadata.actual_ratio(), 2.0); assert!(!metadata.check_zip_bomb(1000)); - assert!(metadata.check_zip_bomb(1)); // Should detect as bomb + assert!(metadata.check_zip_bomb(1)); // Should detect as bomb } - + #[test] fn test_archive_entry() { let dir_entry = ArchiveEntry::directory(PathBuf::from("test_dir")); assert!(dir_entry.is_dir); assert!(!dir_entry.is_file); - + let file_entry = ArchiveEntry::file(PathBuf::from("test.txt"), 100, 50); assert!(!file_entry.is_dir); assert!(file_entry.is_file); assert_eq!(file_entry.size, 100); } - + #[test] fn test_extract_result() { let result = ExtractResult::new(); assert_eq!(result.success_rate(), 100.0); - + let result_with_failure = ExtractResult { total_files: 10, success_files: 8, failed_files: vec![PathBuf::from("failed.txt")], ..Default::default() }; - + assert_eq!(result_with_failure.success_rate(), 80.0); assert!(result_with_failure.has_failures()); } -} \ No newline at end of file +} diff --git a/markbase-core/src/archive/mod.rs b/markbase-core/src/archive/mod.rs index 73ebdd4..2992d79 100644 --- a/markbase-core/src/archive/mod.rs +++ b/markbase-core/src/archive/mod.rs @@ -25,9 +25,9 @@ pub use metadata::{ArchiveEntry, ArchiveMetadata, ExtractResult}; pub use processor::{ArchiveFormat, ArchiveProcessor}; use anyhow::Result; +use log::info; use std::collections::HashMap; use std::path::Path; -use log::{info, warn}; /// Processor Registry - Plugin Architecture pub struct ProcessorRegistry { @@ -43,93 +43,108 @@ impl ProcessorRegistry { config, } } - + /// Initialize all processors (based on config) pub fn initialize(&mut self) -> Result<()> { // Core formats (always registered) self.register_core_processors()?; - + // Optional formats (based on config) self.register_optional_processors()?; - + Ok(()) } - + /// Register core format processors (9 formats) fn register_core_processors(&mut self) -> Result<()> { use crate::archive::processors::core::*; - - self.processors.insert(ArchiveFormat::Zip, Box::new(ZipProcessor::new())); - self.processors.insert(ArchiveFormat::Tar, Box::new(TarProcessor::new())); - self.processors.insert(ArchiveFormat::Gzip, Box::new(GzipProcessor::new())); - self.processors.insert(ArchiveFormat::Zstd, Box::new(ZstdProcessor::new())); - self.processors.insert(ArchiveFormat::Bzip2, Box::new(Bzip2Processor::new())); - self.processors.insert(ArchiveFormat::Lz4, Box::new(Lz4Processor::new())); - self.processors.insert(ArchiveFormat::TarGzip, Box::new(TarGzipProcessor::new())); - self.processors.insert(ArchiveFormat::TarBzip2, Box::new(TarBzip2Processor::new())); - self.processors.insert(ArchiveFormat::TarZstd, Box::new(TarZstdProcessor::new())); - + + self.processors + .insert(ArchiveFormat::Zip, Box::new(ZipProcessor::new())); + self.processors + .insert(ArchiveFormat::Tar, Box::new(TarProcessor::new())); + self.processors + .insert(ArchiveFormat::Gzip, Box::new(GzipProcessor::new())); + self.processors + .insert(ArchiveFormat::Zstd, Box::new(ZstdProcessor::new())); + self.processors + .insert(ArchiveFormat::Bzip2, Box::new(Bzip2Processor::new())); + self.processors + .insert(ArchiveFormat::Lz4, Box::new(Lz4Processor::new())); + self.processors + .insert(ArchiveFormat::TarGzip, Box::new(TarGzipProcessor::new())); + self.processors + .insert(ArchiveFormat::TarBzip2, Box::new(TarBzip2Processor::new())); + self.processors + .insert(ArchiveFormat::TarZstd, Box::new(TarZstdProcessor::new())); + info!("✅ Core formats registered: 9 formats"); Ok(()) } - + /// Register optional format processors (3 formats, based on config) fn register_optional_processors(&mut self) -> Result<()> { #[cfg(feature = "optional-formats")] { use crate::archive::processors::optional::*; - + // RAR format (legal risk) if self.config.enable_rar { crate::archive::warning::show_rar_legal_warning(); - self.processors.insert(ArchiveFormat::Rar, Box::new(RarProcessor::new())); + self.processors + .insert(ArchiveFormat::Rar, Box::new(RarProcessor::new())); warn!("⚠️ RAR format enabled (legal risk)"); } - + // XZ format (external dependency) if self.config.enable_xz { if check_liblzma_available() { - self.processors.insert(ArchiveFormat::Xz, Box::new(XzProcessor::new())); + self.processors + .insert(ArchiveFormat::Xz, Box::new(XzProcessor::new())); info!("✅ XZ format enabled"); } else { crate::archive::warning::show_xz_dependency_warning(); warn!("⚠️ XZ format disabled (liblzma not found)"); } } - + // 7z format (unstable library) if self.config.enable_7z { crate::archive::warning::show_7z_stability_warning(); - self.processors.insert(ArchiveFormat::SevenZ, Box::new(SevenZProcessor::new())); + self.processors + .insert(ArchiveFormat::SevenZ, Box::new(SevenZProcessor::new())); warn!("⚠️ 7z format enabled (stability warning)"); } } - + Ok(()) } - + /// Get processor for detected format (mutable version for open/extraction) pub fn get_processor_mut(&mut self, path: &Path) -> Result<&mut (dyn ArchiveProcessor + '_)> { let detector = FormatDetector::new(); let format = detector.detect(path)?; - + match self.processors.get_mut(&format) { Some(p) => Ok(p.as_mut()), - None => Err(anyhow::anyhow!("Format {} not supported or not enabled", format)), + None => Err(anyhow::anyhow!( + "Format {} not supported or not enabled", + format + )), } } - + /// Get processor for detected format (immutable version for listing) pub fn get_processor(&self, path: &Path) -> Result<&dyn ArchiveProcessor> { let detector = FormatDetector::new(); let format = detector.detect(path)?; - + self.processors .get(&format) .map(|p| p.as_ref()) .ok_or_else(|| anyhow::anyhow!("Format {} not supported or not enabled", format)) } - + /// List all enabled formats pub fn enabled_formats(&self) -> Vec { self.processors.keys().cloned().collect() @@ -141,7 +156,7 @@ impl ProcessorRegistry { fn check_liblzma_available() -> bool { // Try to load xz2 library // Simplified check: try to create XzProcessor - true // Simplified for now, actual implementation needs better detection + true // Simplified for now, actual implementation needs better detection } #[cfg(not(feature = "optional-formats"))] @@ -156,13 +171,16 @@ pub fn init_archive_system(config_path: Option<&str>) -> Result ArchiveFormat; - + /// Open archive file and read metadata fn open(&mut self, path: &Path) -> Result; - + /// List all file entries in archive fn list_entries(&mut self) -> Result>; - + /// Extract single file (on-demand decompression) fn extract_file(&mut self, entry_path: &Path, output: &mut Vec) -> Result; - + /// Extract all files to directory (batch extraction) fn extract_all(&mut self, output_dir: &Path) -> Result; - + /// Check if this processor can handle the format - fn can_process(format: ArchiveFormat) -> bool where Self: Sized; - + fn can_process(format: ArchiveFormat) -> bool + where + Self: Sized; + /// Create new processor instance - fn new() -> Self where Self: Sized; + fn new() -> Self + where + Self: Sized; } /// Security Validation - Zip Slip Protection pub fn validate_extraction_path(entry_path: &Path, base_dir: &Path) -> Result { use std::path::Component; - + // 1. Check path components for component in entry_path.components() { match component { @@ -92,51 +96,62 @@ pub fn validate_extraction_path(entry_path: &Path, base_dir: &Path) -> Result {} } } - + // 2. Build full path let full_path = base_dir.join(entry_path); - + // 3. Canonicalize and validate (ensure within base_dir) - let canonical_base = base_dir.canonicalize() + let canonical_base = base_dir + .canonicalize() .map_err(|e| anyhow::anyhow!("Cannot canonicalize base dir: {}", e))?; - + // Create parent directories first if let Some(parent) = full_path.parent() { std::fs::create_dir_all(parent)?; } - + // 4. Verify extraction path is within base_dir // Note: full_path may not exist yet, so we check parent directory if full_path.exists() { - let canonical_full = full_path.canonicalize() + let canonical_full = full_path + .canonicalize() .map_err(|e| anyhow::anyhow!("Cannot canonicalize full path: {}", e))?; - + if !canonical_full.starts_with(&canonical_base) { - return Err(anyhow::anyhow!("Zip Slip detected: path escapes base directory")); + return Err(anyhow::anyhow!( + "Zip Slip detected: path escapes base directory" + )); } } else { // Check parent directory instead if let Some(parent) = full_path.parent() { - let canonical_parent = parent.canonicalize() + let canonical_parent = parent + .canonicalize() .map_err(|e| anyhow::anyhow!("Cannot canonicalize parent: {}", e))?; - + if !canonical_parent.starts_with(&canonical_base) { - return Err(anyhow::anyhow!("Zip Slip detected: path escapes base directory")); + return Err(anyhow::anyhow!( + "Zip Slip detected: path escapes base directory" + )); } } } - + Ok(full_path) } /// Security Validation - Zip Bomb Protection -pub fn check_decompression_ratio(compressed_size: u64, decompressed_size: u64, max_ratio: u64) -> Result<()> { +pub fn check_decompression_ratio( + compressed_size: u64, + decompressed_size: u64, + max_ratio: u64, +) -> Result<()> { if compressed_size == 0 { - return Ok(()); // Empty file, allow + return Ok(()); // Empty file, allow } - + let ratio = decompressed_size / compressed_size; - + if ratio > max_ratio { return Err(anyhow::anyhow!( "Zip Bomb detected: compression ratio {} exceeds limit {}", @@ -144,7 +159,7 @@ pub fn check_decompression_ratio(compressed_size: u64, decompressed_size: u64, m max_ratio )); } - + Ok(()) } @@ -157,7 +172,7 @@ pub fn check_file_size_limit(file_size: u64, max_size: u64) -> Result<()> { max_size / 1024 / 1024 )); } - + Ok(()) } @@ -165,34 +180,34 @@ pub fn check_file_size_limit(file_size: u64, max_size: u64) -> Result<()> { mod tests { use super::*; use tempfile::TempDir; - + #[test] fn test_zip_slip_protection() { let temp_dir = TempDir::new().unwrap(); let base = temp_dir.path(); - + // Safe path: should pass let safe_path = Path::new("safe/file.txt"); assert!(validate_extraction_path(safe_path, base).is_ok()); - + // Evil path: should be rejected let evil_path = Path::new("../../etc/passwd"); assert!(validate_extraction_path(evil_path, base).is_err()); - + // Absolute path: should be rejected let abs_path = Path::new("/etc/passwd"); assert!(validate_extraction_path(abs_path, base).is_err()); } - + #[test] fn test_zip_bomb_detection() { // Normal ratio: should pass assert!(check_decompression_ratio(1000, 5000, 1000).is_ok()); - + // Zip Bomb ratio: should be rejected assert!(check_decompression_ratio(42_000, 5_000_000_000, 1000).is_err()); } - + #[test] fn test_compression_ratio_calculation() { let metadata = ArchiveMetadata { @@ -206,7 +221,7 @@ mod tests { created_time: None, modified_time: None, }; - + assert_eq!(metadata.actual_ratio(), 2.0); } -} \ No newline at end of file +} diff --git a/markbase-core/src/archive/processors/core/mod.rs b/markbase-core/src/archive/processors/core/mod.rs index 1a00cb4..4beddc7 100644 --- a/markbase-core/src/archive/processors/core/mod.rs +++ b/markbase-core/src/archive/processors/core/mod.rs @@ -1,16 +1,16 @@ // Core Format Processors - ZIP, TAR, GZIP, TAR.GZ Full Implementation -use crate::archive::{ - ArchiveProcessor, ArchiveFormat, ArchiveMetadata, ArchiveEntry, ExtractResult, - processor::{validate_extraction_path, check_decompression_ratio, check_file_size_limit}, -}; use crate::archive::config::ArchiveConfig; -use anyhow::{Result, anyhow}; +use crate::archive::{ + processor::{check_decompression_ratio, check_file_size_limit, validate_extraction_path}, + ArchiveEntry, ArchiveFormat, ArchiveMetadata, ArchiveProcessor, ExtractResult, +}; +use anyhow::{anyhow, Result}; +use log::{debug, info, warn}; +use std::fs::{create_dir_all, File}; +use std::io::{BufWriter, Read}; use std::path::{Path, PathBuf}; -use std::fs::{File, create_dir_all}; -use std::io::{Read, Write, BufReader, BufWriter}; use std::time::SystemTime; -use log::{info, warn, debug}; // ==================== ZIP Processor ==================== @@ -21,6 +21,12 @@ pub struct ZipProcessor { config: ArchiveConfig, } +impl Default for ZipProcessor { + fn default() -> Self { + Self::new() + } +} + impl ZipProcessor { pub fn new() -> Self { Self { @@ -29,7 +35,7 @@ impl ZipProcessor { config: ArchiveConfig::default(), } } - + pub fn with_config(config: ArchiveConfig) -> Self { Self { archive: None, @@ -43,7 +49,7 @@ impl ArchiveProcessor for ZipProcessor { fn format(&self) -> ArchiveFormat { ArchiveFormat::Zip } - + fn new() -> Self { Self { archive: None, @@ -51,64 +57,72 @@ impl ArchiveProcessor for ZipProcessor { config: ArchiveConfig::default(), } } - + fn open(&mut self, path: &Path) -> Result { info!("Opening ZIP archive: {}", path.display()); - + let file = File::open(path)?; let archive = zip::ZipArchive::new(file)?; - + self.archive = Some(archive); self.path = path.to_path_buf(); - + // Extract metadata (need mutable reference for by_index) let archive_ref = self.archive.as_mut().unwrap(); let total_files = archive_ref.len() as u64; - + let mut total_size = 0u64; let mut compressed_size = 0u64; - + for i in 0..archive_ref.len() { let file = archive_ref.by_index(i)?; total_size += file.size(); compressed_size += file.compressed_size(); } - + let compression_ratio = if compressed_size > 0 { total_size as f64 / compressed_size as f64 } else { 0.0 }; - + // Check for Zip Bomb if compression_ratio > self.config.max_decompression_ratio as f64 { - warn!("Potential Zip Bomb detected: ratio {:.1}:1", compression_ratio); - return Err(anyhow!("Zip Bomb detected: compression ratio {:.1} exceeds limit {}", - compression_ratio, self.config.max_decompression_ratio)); + warn!( + "Potential Zip Bomb detected: ratio {:.1}:1", + compression_ratio + ); + return Err(anyhow!( + "Zip Bomb detected: compression ratio {:.1} exceeds limit {}", + compression_ratio, + self.config.max_decompression_ratio + )); } - + Ok(ArchiveMetadata { format: ArchiveFormat::Zip, total_files, total_size, compressed_size, compression_ratio, - is_encrypted: false, // TODO: Check encryption + is_encrypted: false, // TODO: Check encryption is_multi_volume: false, created_time: Some(SystemTime::now()), modified_time: Some(SystemTime::now()), }) } - + fn list_entries(&mut self) -> Result> { - let archive = self.archive.as_mut() + let archive = self + .archive + .as_mut() .ok_or_else(|| anyhow!("Archive not opened"))?; - + let mut entries = Vec::new(); - + for i in 0..archive.len() { let file = archive.by_index(i)?; - + let entry = ArchiveEntry { path: PathBuf::from(file.name()), size: file.size(), @@ -116,61 +130,64 @@ impl ArchiveProcessor for ZipProcessor { is_dir: file.name().ends_with('/'), is_file: !file.name().ends_with('/'), is_encrypted: false, - modified: SystemTime::UNIX_EPOCH, // TODO: Get actual time + modified: SystemTime::UNIX_EPOCH, // TODO: Get actual time permissions: Some(0o644), checksum: None, }; - + entries.push(entry); } - + info!("Listed {} entries in ZIP archive", entries.len()); Ok(entries) } - + fn extract_file(&mut self, entry_path: &Path, output: &mut Vec) -> Result { - let archive = self.archive.as_mut() + let archive = self + .archive + .as_mut() .ok_or_else(|| anyhow!("Archive not opened"))?; - - let entry_name = entry_path.to_str() + + let entry_name = entry_path + .to_str() .ok_or_else(|| anyhow!("Invalid entry path"))?; - + let mut file = archive.by_name(entry_name)?; - + // Check file size limit check_file_size_limit(file.size(), self.config.max_file_size_mb * 1024 * 1024)?; - + output.clear(); output.reserve(file.size() as usize); - + file.read_to_end(output)?; - + info!("Extracted file: {} ({} bytes)", entry_name, output.len()); Ok(output.len() as u64) } - + fn extract_all(&mut self, output_dir: &Path) -> Result { create_dir_all(output_dir)?; - + let mut result = ExtractResult::new(); - + // Open archive if not already open if self.archive.is_none() { let file = File::open(&self.path)?; let archive = zip::ZipArchive::new(file)?; self.archive = Some(archive); } - + let archive = self.archive.as_mut().unwrap(); result.total_files = archive.len() as u64; - + // Use archive iteration to extract files for i in 0..archive.len() { let mut file = archive.by_index(i)?; let entry_name = file.name().to_string(); let file_size = file.size(); let is_dir = entry_name.ends_with('/'); - + // Zip Slip protection match validate_extraction_path(&PathBuf::from(&entry_name), output_dir) { Ok(safe_path) => { @@ -181,21 +198,24 @@ impl ArchiveProcessor for ZipProcessor { result.success_files += 1; } else { // File - check_file_size_limit(file_size, self.config.max_file_size_mb * 1024 * 1024)?; - + check_file_size_limit( + file_size, + self.config.max_file_size_mb * 1024 * 1024, + )?; + if let Some(parent) = safe_path.parent() { create_dir_all(parent)?; } - + // Extract file content let mut outfile = BufWriter::new(File::create(&safe_path)?); std::io::copy(&mut file, &mut outfile)?; - + result.success_files += 1; result.total_bytes += file_size; debug!("Extracted: {} ({} bytes)", entry_name, file_size); } - }, + } Err(e) => { warn!("Zip Slip detected: {} - {}", entry_name, e); result.failed_files.push(PathBuf::from(&entry_name)); @@ -203,13 +223,17 @@ impl ArchiveProcessor for ZipProcessor { } } } - - info!("Extracted {} files ({} bytes) to {}", - result.success_files, result.total_bytes, output_dir.display()); - + + info!( + "Extracted {} files ({} bytes) to {}", + result.success_files, + result.total_bytes, + output_dir.display() + ); + Ok(result) } - + fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::Zip } @@ -224,6 +248,12 @@ pub struct TarProcessor { config: ArchiveConfig, } +impl Default for TarProcessor { + fn default() -> Self { + Self::new() + } +} + impl TarProcessor { pub fn new() -> Self { Self { @@ -232,7 +262,7 @@ impl TarProcessor { config: ArchiveConfig::default(), } } - + pub fn with_config(config: ArchiveConfig) -> Self { Self { path: PathBuf::new(), @@ -246,7 +276,7 @@ impl ArchiveProcessor for TarProcessor { fn format(&self) -> ArchiveFormat { ArchiveFormat::Tar } - + fn new() -> Self { Self { path: PathBuf::new(), @@ -254,30 +284,30 @@ impl ArchiveProcessor for TarProcessor { config: ArchiveConfig::default(), } } - + fn open(&mut self, path: &Path) -> Result { info!("Opening TAR archive: {}", path.display()); - + self.path = path.to_path_buf(); self.entries.clear(); - + let file = File::open(path)?; let mut archive = tar::Archive::new(file); - + let mut total_size = 0u64; - + // Iterate entries to collect metadata for entry in archive.entries()? { let entry = entry?; let path = entry.path()?.to_path_buf(); let size = entry.size(); - + total_size += size; - + self.entries.push(ArchiveEntry { path, size, - compressed_size: size, // TAR has no compression + compressed_size: size, // TAR has no compression is_dir: entry.header().entry_type().is_dir(), is_file: entry.header().entry_type().is_file(), is_encrypted: false, @@ -286,78 +316,87 @@ impl ArchiveProcessor for TarProcessor { checksum: None, }); } - + let total_files = self.entries.len() as u64; - + Ok(ArchiveMetadata { format: ArchiveFormat::Tar, total_files, total_size, - compressed_size: total_size, // TAR has no compression - compression_ratio: 1.0, // No compression + compressed_size: total_size, // TAR has no compression + compression_ratio: 1.0, // No compression is_encrypted: false, is_multi_volume: false, created_time: Some(SystemTime::now()), modified_time: Some(SystemTime::now()), }) } - + fn list_entries(&mut self) -> Result> { Ok(self.entries.clone()) } - + fn extract_file(&mut self, entry_path: &Path, output: &mut Vec) -> Result { // TAR doesn't support random access, need to unpack entire archive // This is a limitation - for single file extraction, we unpack everything warn!("TAR format doesn't support random access - extracting entire archive"); - + let temp_dir = tempfile::tempdir()?; self.extract_all(temp_dir.path())?; - + let file_path = temp_dir.path().join(entry_path); let mut file = File::open(&file_path)?; output.clear(); file.read_to_end(output)?; - + Ok(output.len() as u64) } - + fn extract_all(&mut self, output_dir: &Path) -> Result { create_dir_all(output_dir)?; - + let file = File::open(&self.path)?; let mut archive = tar::Archive::new(file); - + let mut result = ExtractResult::new(); result.total_files = self.entries.len() as u64; - + for entry in archive.entries()? { let mut entry = entry?; let entry_path = entry.path()?.to_path_buf(); - let entry_path_str = entry_path.display().to_string(); // Save for warning - + let entry_path_str = entry_path.display().to_string(); // Save for warning + // Zip Slip protection match validate_extraction_path(&entry_path, output_dir) { Ok(safe_path) => { - check_file_size_limit(entry.size(), self.config.max_file_size_mb * 1024 * 1024)?; - + check_file_size_limit( + entry.size(), + self.config.max_file_size_mb * 1024 * 1024, + )?; + entry.unpack(&safe_path)?; - + result.success_files += 1; result.total_bytes += entry.size(); - }, + } Err(e) => { warn!("Zip Slip detected: {} - {}", entry_path_str, e); result.failed_files.push(entry_path); - result.warnings.push(format!("Zip Slip: {}", entry_path_str)); + result + .warnings + .push(format!("Zip Slip: {}", entry_path_str)); } } } - - info!("Extracted {} TAR entries to {}", result.success_files, output_dir.display()); + + info!( + "Extracted {} TAR entries to {}", + result.success_files, + output_dir.display() + ); Ok(result) } - + fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::Tar } @@ -372,6 +411,12 @@ pub struct GzipProcessor { config: ArchiveConfig, } +impl Default for GzipProcessor { + fn default() -> Self { + Self::new() + } +} + impl GzipProcessor { pub fn new() -> Self { Self { @@ -380,7 +425,7 @@ impl GzipProcessor { config: ArchiveConfig::default(), } } - + pub fn with_config(config: ArchiveConfig) -> Self { Self { path: PathBuf::new(), @@ -394,7 +439,7 @@ impl ArchiveProcessor for GzipProcessor { fn format(&self) -> ArchiveFormat { ArchiveFormat::Gzip } - + fn new() -> Self { Self { path: PathBuf::new(), @@ -402,27 +447,31 @@ impl ArchiveProcessor for GzipProcessor { config: ArchiveConfig::default(), } } - + fn open(&mut self, path: &Path) -> Result { info!("Opening GZIP archive: {}", path.display()); - + self.path = path.to_path_buf(); - + let file = File::open(path)?; let compressed_size = file.metadata()?.len(); - + let mut decoder = flate2::read::GzDecoder::new(file); let mut buffer = Vec::new(); decoder.read_to_end(&mut buffer)?; - + self.decompressed_size = buffer.len() as u64; - + // Check Zip Bomb - check_decompression_ratio(compressed_size, self.decompressed_size, self.config.max_decompression_ratio)?; - + check_decompression_ratio( + compressed_size, + self.decompressed_size, + self.config.max_decompression_ratio, + )?; + Ok(ArchiveMetadata { format: ArchiveFormat::Gzip, - total_files: 1, // GZIP is single file + total_files: 1, // GZIP is single file total_size: self.decompressed_size, compressed_size, compression_ratio: if compressed_size > 0 { @@ -436,58 +485,64 @@ impl ArchiveProcessor for GzipProcessor { modified_time: Some(SystemTime::now()), }) } - + fn list_entries(&mut self) -> Result> { // GZIP is single file - infer name from archive name - let name = self.path.file_name() + let name = self + .path + .file_name() .and_then(|n| n.to_str()) .unwrap_or("unknown") .replace(".gz", "") .replace(".gzip", ""); - + Ok(vec![ArchiveEntry::file( PathBuf::from(name), self.decompressed_size, - 0, // GZIP doesn't preserve compressed size per file + 0, // GZIP doesn't preserve compressed size per file )]) } - - fn extract_file(&mut self, entry_path: &Path, output: &mut Vec) -> Result { + + fn extract_file(&mut self, _entry_path: &Path, output: &mut Vec) -> Result { // GZIP is single file - just decompress it let file = File::open(&self.path)?; let mut decoder = flate2::read::GzDecoder::new(file); - + output.clear(); decoder.read_to_end(output)?; - - check_file_size_limit(output.len() as u64, self.config.max_file_size_mb * 1024 * 1024)?; - + + check_file_size_limit( + output.len() as u64, + self.config.max_file_size_mb * 1024 * 1024, + )?; + info!("Decompressed GZIP file: {} bytes", output.len()); Ok(output.len() as u64) } - + fn extract_all(&mut self, output_dir: &Path) -> Result { create_dir_all(output_dir)?; - + let entries = self.list_entries()?; - let entry = entries.first() + let entry = entries + .first() .ok_or_else(|| anyhow!("No entry in GZIP archive"))?; - + let outpath = output_dir.join(&entry.path); - + // Zip Slip protection validate_extraction_path(&entry.path, output_dir)?; - + if let Some(parent) = outpath.parent() { create_dir_all(parent)?; } - + let file = File::open(&self.path)?; let mut decoder = flate2::read::GzDecoder::new(file); let mut outfile = BufWriter::new(File::create(&outpath)?); - + std::io::copy(&mut decoder, &mut outfile)?; - + let result = ExtractResult { total_files: 1, total_bytes: self.decompressed_size, @@ -496,11 +551,11 @@ impl ArchiveProcessor for GzipProcessor { skipped_files: Vec::new(), warnings: Vec::new(), }; - + info!("Decompressed GZIP to: {}", outpath.display()); Ok(result) } - + fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::Gzip } @@ -514,6 +569,12 @@ pub struct TarGzipProcessor { config: ArchiveConfig, } +impl Default for TarGzipProcessor { + fn default() -> Self { + Self::new() + } +} + impl TarGzipProcessor { pub fn new() -> Self { Self { @@ -521,7 +582,7 @@ impl TarGzipProcessor { config: ArchiveConfig::default(), } } - + pub fn with_config(config: ArchiveConfig) -> Self { Self { gzip_processor: GzipProcessor::with_config(config.clone()), @@ -534,32 +595,33 @@ impl ArchiveProcessor for TarGzipProcessor { fn format(&self) -> ArchiveFormat { ArchiveFormat::TarGzip } - + fn new() -> Self { Self { gzip_processor: GzipProcessor::new(), config: ArchiveConfig::default(), } } - + fn open(&mut self, path: &Path) -> Result { info!("Opening TAR.GZ archive: {}", path.display()); - + // Step 1: Decompress GZIP let temp_dir = tempfile::tempdir()?; self.gzip_processor.open(path)?; self.gzip_processor.extract_all(temp_dir.path())?; - + // Step 2: Open TAR let tar_entries = self.gzip_processor.list_entries()?; - let tar_file = tar_entries.first() + let tar_file = tar_entries + .first() .ok_or_else(|| anyhow!("No TAR file in GZIP"))?; - + let tar_path = temp_dir.path().join(&tar_file.path); - + let mut tar_processor = TarProcessor::with_config(self.config.clone()); let tar_metadata = tar_processor.open(&tar_path)?; - + Ok(ArchiveMetadata { format: ArchiveFormat::TarGzip, total_files: tar_metadata.total_files, @@ -576,46 +638,47 @@ impl ArchiveProcessor for TarGzipProcessor { modified_time: Some(SystemTime::now()), }) } - + fn list_entries(&mut self) -> Result> { // Need to implement properly - this requires decompressing first warn!("TAR.GZ list_entries requires full decompression - consider extract_all instead"); Ok(Vec::new()) } - + fn extract_file(&mut self, entry_path: &Path, output: &mut Vec) -> Result { warn!("TAR.GZ extract_file requires full unpacking - inefficient for single file"); - + let temp_dir = tempfile::tempdir()?; self.extract_all(temp_dir.path())?; - + let file_path = temp_dir.path().join(entry_path); let mut file = File::open(&file_path)?; output.clear(); file.read_to_end(output)?; - + Ok(output.len() as u64) } - + fn extract_all(&mut self, output_dir: &Path) -> Result { info!("Extracting TAR.GZ to: {}", output_dir.display()); - + // Step 1: Decompress GZIP to temp let temp_dir = tempfile::tempdir()?; self.gzip_processor.extract_all(temp_dir.path())?; - + // Step 2: Extract TAR let tar_entries = self.gzip_processor.list_entries()?; - let tar_file = tar_entries.first() + let tar_file = tar_entries + .first() .ok_or_else(|| anyhow!("No TAR file found"))?; - + let tar_path = temp_dir.path().join(&tar_file.path); - + let mut tar_processor = TarProcessor::with_config(self.config.clone()); tar_processor.open(&tar_path)?; tar_processor.extract_all(output_dir) } - + fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::TarGzip } @@ -627,73 +690,133 @@ impl ArchiveProcessor for TarGzipProcessor { pub struct ZstdProcessor; impl ArchiveProcessor for ZstdProcessor { - fn format(&self) -> ArchiveFormat { ArchiveFormat::Zstd } + fn format(&self) -> ArchiveFormat { + ArchiveFormat::Zstd + } fn open(&mut self, _path: &Path) -> Result { Err(anyhow!("ZSTD processor not yet implemented")) } - fn list_entries(&mut self) -> Result> { Ok(Vec::new()) } - fn extract_file(&mut self, _entry: &Path, _output: &mut Vec) -> Result { Ok(0) } - fn extract_all(&mut self, _dir: &Path) -> Result { Ok(ExtractResult::new()) } - fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::Zstd } - fn new() -> Self { Self } + fn list_entries(&mut self) -> Result> { + Ok(Vec::new()) + } + fn extract_file(&mut self, _entry: &Path, _output: &mut Vec) -> Result { + Ok(0) + } + fn extract_all(&mut self, _dir: &Path) -> Result { + Ok(ExtractResult::new()) + } + fn can_process(format: ArchiveFormat) -> bool { + format == ArchiveFormat::Zstd + } + fn new() -> Self { + Self + } } /// BZIP2 Processor Stub (Phase 2/3) pub struct Bzip2Processor; impl ArchiveProcessor for Bzip2Processor { - fn format(&self) -> ArchiveFormat { ArchiveFormat::Bzip2 } + fn format(&self) -> ArchiveFormat { + ArchiveFormat::Bzip2 + } fn open(&mut self, _path: &Path) -> Result { Err(anyhow!("BZIP2 processor not yet implemented")) } - fn list_entries(&mut self) -> Result> { Ok(Vec::new()) } - fn extract_file(&mut self, _entry: &Path, _output: &mut Vec) -> Result { Ok(0) } - fn extract_all(&mut self, _dir: &Path) -> Result { Ok(ExtractResult::new()) } - fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::Bzip2 } - fn new() -> Self { Self } + fn list_entries(&mut self) -> Result> { + Ok(Vec::new()) + } + fn extract_file(&mut self, _entry: &Path, _output: &mut Vec) -> Result { + Ok(0) + } + fn extract_all(&mut self, _dir: &Path) -> Result { + Ok(ExtractResult::new()) + } + fn can_process(format: ArchiveFormat) -> bool { + format == ArchiveFormat::Bzip2 + } + fn new() -> Self { + Self + } } /// LZ4 Processor Stub (Phase 2/3) pub struct Lz4Processor; impl ArchiveProcessor for Lz4Processor { - fn format(&self) -> ArchiveFormat { ArchiveFormat::Lz4 } + fn format(&self) -> ArchiveFormat { + ArchiveFormat::Lz4 + } fn open(&mut self, _path: &Path) -> Result { Err(anyhow!("LZ4 processor not yet implemented")) } - fn list_entries(&mut self) -> Result> { Ok(Vec::new()) } - fn extract_file(&mut self, _entry: &Path, _output: &mut Vec) -> Result { Ok(0) } - fn extract_all(&mut self, _dir: &Path) -> Result { Ok(ExtractResult::new()) } - fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::Lz4 } - fn new() -> Self { Self } + fn list_entries(&mut self) -> Result> { + Ok(Vec::new()) + } + fn extract_file(&mut self, _entry: &Path, _output: &mut Vec) -> Result { + Ok(0) + } + fn extract_all(&mut self, _dir: &Path) -> Result { + Ok(ExtractResult::new()) + } + fn can_process(format: ArchiveFormat) -> bool { + format == ArchiveFormat::Lz4 + } + fn new() -> Self { + Self + } } /// TAR.BZ2 Composite Processor Stub (Phase 2/3) pub struct TarBzip2Processor; impl ArchiveProcessor for TarBzip2Processor { - fn format(&self) -> ArchiveFormat { ArchiveFormat::TarBzip2 } + fn format(&self) -> ArchiveFormat { + ArchiveFormat::TarBzip2 + } fn open(&mut self, _path: &Path) -> Result { Err(anyhow!("TAR.BZ2 processor not yet implemented")) } - fn list_entries(&mut self) -> Result> { Ok(Vec::new()) } - fn extract_file(&mut self, _entry: &Path, _output: &mut Vec) -> Result { Ok(0) } - fn extract_all(&mut self, _dir: &Path) -> Result { Ok(ExtractResult::new()) } - fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::TarBzip2 } - fn new() -> Self { Self } + fn list_entries(&mut self) -> Result> { + Ok(Vec::new()) + } + fn extract_file(&mut self, _entry: &Path, _output: &mut Vec) -> Result { + Ok(0) + } + fn extract_all(&mut self, _dir: &Path) -> Result { + Ok(ExtractResult::new()) + } + fn can_process(format: ArchiveFormat) -> bool { + format == ArchiveFormat::TarBzip2 + } + fn new() -> Self { + Self + } } /// TAR.ZST Composite Processor Stub (Phase 2/3) pub struct TarZstdProcessor; impl ArchiveProcessor for TarZstdProcessor { - fn format(&self) -> ArchiveFormat { ArchiveFormat::TarZstd } + fn format(&self) -> ArchiveFormat { + ArchiveFormat::TarZstd + } fn open(&mut self, _path: &Path) -> Result { Err(anyhow!("TAR.ZST processor not yet implemented")) } - fn list_entries(&mut self) -> Result> { Ok(Vec::new()) } - fn extract_file(&mut self, _entry: &Path, _output: &mut Vec) -> Result { Ok(0) } - fn extract_all(&mut self, _dir: &Path) -> Result { Ok(ExtractResult::new()) } - fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::TarZstd } - fn new() -> Self { Self } -} \ No newline at end of file + fn list_entries(&mut self) -> Result> { + Ok(Vec::new()) + } + fn extract_file(&mut self, _entry: &Path, _output: &mut Vec) -> Result { + Ok(0) + } + fn extract_all(&mut self, _dir: &Path) -> Result { + Ok(ExtractResult::new()) + } + fn can_process(format: ArchiveFormat) -> bool { + format == ArchiveFormat::TarZstd + } + fn new() -> Self { + Self + } +} diff --git a/markbase-core/src/archive/processors/optional/mod.rs b/markbase-core/src/archive/processors/optional/mod.rs index cca2cac..a4d2466 100644 --- a/markbase-core/src/archive/processors/optional/mod.rs +++ b/markbase-core/src/archive/processors/optional/mod.rs @@ -1,13 +1,15 @@ // Optional Format Processors - RAR, XZ, 7z // All optional formats have warnings displayed when enabled -use crate::archive::{ArchiveFormat, ArchiveProcessor, ArchiveMetadata, ArchiveEntry, ExtractResult}; +use crate::archive::processor::{check_decompression_ratio, validate_extraction_path}; use crate::archive::warning; -use crate::archive::processor::{validate_extraction_path, check_decompression_ratio}; -use anyhow::{Result, anyhow}; -use std::path::Path; +use crate::archive::{ + ArchiveEntry, ArchiveFormat, ArchiveMetadata, ArchiveProcessor, ExtractResult, +}; +use anyhow::{anyhow, Result}; +use log::{info, warn}; use std::fs; -use log::{warn, info}; +use std::path::Path; /// RAR Processor - Only Decompression /// ⚠️ Legal Warning: RARLAB patent, commercial use requires license @@ -28,54 +30,65 @@ impl ArchiveProcessor for RarProcessor { fn format(&self) -> ArchiveFormat { ArchiveFormat::Rar } - + fn open(&mut self, path: &Path) -> Result { // Show legal warning when RAR is used warning::show_rar_legal_warning(); - + self.archive_path = Some(path.to_path_buf()); - + // Use unrar library to open RAR // Note: unrar only supports decompression, no compression use unrar::Archive; - + let archive = Archive::new(path)?; - + let entries: Vec<_> = archive.list()?.collect(); let total_files = entries.len() as u64; - - let total_size = entries.iter() + + let total_size = entries + .iter() .filter_map(|e| e.ok()) .map(|e| e.uncompressed_size) .sum(); - + let compressed_size = fs::metadata(path)?.len(); - + Ok(ArchiveMetadata { format: ArchiveFormat::Rar, total_files, total_size, compressed_size, - compression_ratio: if compressed_size > 0 { total_size as f64 / compressed_size as f64 } else { 0.0 }, - is_encrypted: entries.iter().any(|e| e.ok().map_or(false, |e| e.is_encrypted())), - is_multi_volume: false, // unrar library limitation + compression_ratio: if compressed_size > 0 { + total_size as f64 / compressed_size as f64 + } else { + 0.0 + }, + is_encrypted: entries + .iter() + .any(|e| e.ok().map_or(false, |e| e.is_encrypted())), + is_multi_volume: false, // unrar library limitation created_time: None, modified_time: None, }) } - + fn list_entries(&mut self) -> Result> { use unrar::Archive; - - let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?; + + let path = self + .archive_path + .as_ref() + .ok_or_else(|| anyhow!("Archive not opened"))?; let archive = Archive::new(path)?; - - let entries: Vec = archive.list()? + + let entries: Vec = archive + .list()? .filter_map(|e| e.ok()) .map(|e| ArchiveEntry { path: PathBuf::from(e.filename), size: e.uncompressed_size, - compressed_size: 0, // unrar doesn't provide this + compressed_size: 0, // unrar doesn't provide this is_dir: e.is_directory(), is_file: !e.is_directory(), is_encrypted: e.is_encrypted(), @@ -83,45 +96,49 @@ impl ArchiveProcessor for RarProcessor { permissions: None, }) .collect(); - + Ok(entries) } - + fn extract_file(&self, entry_path: &Path, output: &mut Vec) -> Result { // RAR doesn't support random access efficiently // Need to extract entire archive warn!("RAR extract_file requires full extraction (no random access)"); - + let entries = self.list_entries()?; - let entry = entries.iter() + let entry = entries + .iter() .find(|e| e.path == entry_path) .ok_or_else(|| anyhow!("Entry not found: {}", entry_path.display()))?; - + // Extract to temp dir, then read let temp_dir = tempfile::tempdir()?; self.extract_all(temp_dir.path())?; - + let extracted_file = temp_dir.path().join(entry_path); let content = fs::read(&extracted_file)?; output.extend_from_slice(&content); - + Ok(content.len() as u64) } - + fn extract_all(&self, output_dir: &Path) -> Result { use unrar::Archive; use unrar::ExtractOption; - - let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?; - + + let path = self + .archive_path + .as_ref() + .ok_or_else(|| anyhow!("Archive not opened"))?; + // Validate output_dir path validate_extraction_path(output_dir, output_dir)?; - + let mut result = ExtractResult::new(); result.total_files = self.list_entries()?.len() as u64; - + let archive = Archive::new(path)?; - + for entry_result in archive.extract_all(output_dir, ExtractOption::Recurse)? { match entry_result { Ok(entry) => { @@ -135,10 +152,10 @@ impl ArchiveProcessor for RarProcessor { } } } - + Ok(result) } - + fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::Rar } @@ -163,57 +180,65 @@ impl ArchiveProcessor for XzProcessor { fn format(&self) -> ArchiveFormat { ArchiveFormat::Xz } - + fn open(&mut self, path: &Path) -> Result { // Check if liblzma is available if !check_liblzma_available() { warning::show_xz_dependency_warning(); return Err(anyhow!("liblzma library not found, XZ format disabled")); } - + self.archive_path = Some(path.to_path_buf()); - - use xz2::read::XzDecoder; + use std::io::Read; - + use xz2::read::XzDecoder; + let file = fs::File::open(path)?; let mut decoder = XzDecoder::new(file); - + // Read decompressed size (estimate) let mut buffer = Vec::new(); decoder.read_to_end(&mut buffer)?; - + let decompressed_size = buffer.len() as u64; let compressed_size = fs::metadata(path)?.len(); - + // Check decompression ratio check_decompression_ratio(compressed_size, decompressed_size, 1000)?; - + Ok(ArchiveMetadata { format: ArchiveFormat::Xz, - total_files: 1, // XZ is single-file format + total_files: 1, // XZ is single-file format total_size: decompressed_size, compressed_size, - compression_ratio: if compressed_size > 0 { decompressed_size as f64 / compressed_size as f64 } else { 0.0 }, + compression_ratio: if compressed_size > 0 { + decompressed_size as f64 / compressed_size as f64 + } else { + 0.0 + }, is_encrypted: false, is_multi_volume: false, created_time: None, modified_time: None, }) } - + fn list_entries(&mut self) -> Result> { // XZ is single-file, infer filename from archive name - let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?; - - let filename = path.file_name() + let path = self + .archive_path + .as_ref() + .ok_or_else(|| anyhow!("Archive not opened"))?; + + let filename = path + .file_name() .and_then(|n| n.to_str()) .map(|s| s.strip_suffix(".xz").unwrap_or(s)) .unwrap_or("output"); - + Ok(vec![ArchiveEntry { path: PathBuf::from(filename), - size: 0, // Will be determined during extraction + size: 0, // Will be determined during extraction compressed_size: 0, is_dir: false, is_file: true, @@ -222,48 +247,54 @@ impl ArchiveProcessor for XzProcessor { permissions: None, }]) } - + fn extract_file(&self, _entry_path: &Path, output: &mut Vec) -> Result { - use xz2::read::XzDecoder; use std::io::Read; - - let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?; - + use xz2::read::XzDecoder; + + let path = self + .archive_path + .as_ref() + .ok_or_else(|| anyhow!("Archive not opened"))?; + let file = fs::File::open(path)?; let mut decoder = XzDecoder::new(file); - + decoder.read_to_end(output)?; - + Ok(output.len() as u64) } - + fn extract_all(&self, output_dir: &Path) -> Result { - use xz2::read::XzDecoder; use std::io::Read; - - let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?; - + use xz2::read::XzDecoder; + + let path = self + .archive_path + .as_ref() + .ok_or_else(|| anyhow!("Archive not opened"))?; + // Infer output filename let entries = self.list_entries()?; let output_path = output_dir.join(&entries[0].path); - + // Validate path validate_extraction_path(&entries[0].path, output_dir)?; - + let file = fs::File::open(path)?; let mut decoder = XzDecoder::new(file); - + let mut output_file = fs::File::create(&output_path)?; std::io::copy(&mut decoder, &mut output_file)?; - + let mut result = ExtractResult::new(); result.success_files = 1; result.total_files = 1; result.total_bytes = fs::metadata(&output_path)?.len(); - + Ok(result) } - + fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::Xz && check_liblzma_available() } @@ -286,59 +317,61 @@ impl ArchiveProcessor for SevenZProcessor { fn format(&self) -> ArchiveFormat { ArchiveFormat::SevenZ } - + fn open(&mut self, path: &Path) -> Result { // Show stability warning warning::show_7z_stability_warning(); - + use sevenz_rust::SevenZReader; - + let reader = SevenZReader::new(path)?; - + let entries = reader.entries()?; let total_files = entries.len() as u64; - - let total_size = entries.iter() - .map(|e| e.uncompressed_size as u64) - .sum(); - + + let total_size = entries.iter().map(|e| e.uncompressed_size as u64).sum(); + let compressed_size = fs::metadata(path)?.len(); - + Ok(ArchiveMetadata { format: ArchiveFormat::SevenZ, total_files, total_size, compressed_size, - compression_ratio: if compressed_size > 0 { total_size as f64 / compressed_size as f64 } else { 0.0 }, + compression_ratio: if compressed_size > 0 { + total_size as f64 / compressed_size as f64 + } else { + 0.0 + }, is_encrypted: entries.iter().any(|e| e.is_encrypted), is_multi_volume: false, created_time: None, modified_time: None, }) } - + fn list_entries(&mut self) -> Result> { // Note: sevenz-rust doesn't have full entry listing yet // This is a stub returning empty list warn!("7z list_entries not fully implemented (library limitation)"); Ok(Vec::new()) } - + fn extract_file(&self, _entry_path: &Path, _output: &mut Vec) -> Result { warn!("7z extract_file not implemented (library limitation)"); Err(anyhow!("7z library doesn't support random access")) } - + fn extract_all(&self, output_dir: &Path) -> Result { use sevenz_rust::SevenZReader; - + // Note: sevenz-rust doesn't have full extraction yet // This is a stub warn!("7z extract_all limited (library under development)"); - + Ok(ExtractResult::new()) } - + fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::SevenZ } @@ -369,15 +402,21 @@ pub struct SevenZProcessor; #[cfg(not(feature = "optional-formats"))] impl RarProcessor { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } #[cfg(not(feature = "optional-formats"))] impl XzProcessor { - pub fn new() -> Self { Self } + pub fn new() -> Self { + Self + } } #[cfg(not(feature = "optional-formats"))] impl SevenZProcessor { - pub fn new() -> Self { Self } -} \ No newline at end of file + pub fn new() -> Self { + Self + } +} diff --git a/markbase-core/src/archive/tests/core_formats_test.rs b/markbase-core/src/archive/tests/core_formats_test.rs index 9ee9ac6..265c260 100644 --- a/markbase-core/src/archive/tests/core_formats_test.rs +++ b/markbase-core/src/archive/tests/core_formats_test.rs @@ -1,31 +1,31 @@ use crate::archive::{ - ArchiveProcessor, ArchiveFormat, ArchiveMetadata, ArchiveEntry, ExtractResult, - processors::core::{ZipProcessor, TarProcessor, GzipProcessor, TarGzipProcessor}, - processor::{validate_extraction_path, check_decompression_ratio}, config::ArchiveConfig, + processor::{check_decompression_ratio, validate_extraction_path}, + processors::core::{GzipProcessor, TarGzipProcessor, TarProcessor, ZipProcessor}, + ArchiveEntry, ArchiveFormat, ArchiveMetadata, ArchiveProcessor, ExtractResult, }; -use tempfile::TempDir; -use std::fs::{File, create_dir_all}; +use anyhow::Result; +use std::fs::{create_dir_all, File}; use std::io::Write; use std::path::PathBuf; -use anyhow::Result; +use tempfile::TempDir; #[cfg(test)] mod helpers { use std::fs::File; use std::io::Write; use std::path::PathBuf; - + pub fn create_test_zip(path: &PathBuf, files: Vec<(&str, &[u8])>) { use std::io::Cursor; - + let mut buffer = Cursor::new(Vec::new()); { let mut zip = zip::ZipWriter::new(&mut buffer); - + let options = zip::write::FileOptions::default() .compression_method(zip::CompressionMethod::Stored); - + for (name, content) in files { if name.ends_with('/') { zip.add_directory(name, options).unwrap(); @@ -34,31 +34,31 @@ mod helpers { zip.write_all(content).unwrap(); } } - + zip.finish().unwrap(); } - + let zip_data = buffer.into_inner(); File::create(path).unwrap().write_all(&zip_data).unwrap(); } - + pub fn create_test_tar(path: &PathBuf, files: Vec<(&str, &[u8])>) { let file = File::create(path).unwrap(); let mut builder = tar::Builder::new(file); - + for (name, content) in files { let mut header = tar::Header::new_gnu(); header.set_size(content.len() as u64); header.set_path(name); header.set_mode(0o644); header.set_cksum(); - + builder.append_data(&mut header, name, content).unwrap(); } - + builder.finish().unwrap(); } - + pub fn create_test_gzip(path: &PathBuf, content: &[u8]) { let file = File::create(path).unwrap(); let mut encoder = flate2::write::GzEncoder::new(file, flate2::Compression::default()); @@ -69,74 +69,74 @@ mod helpers { #[cfg(test)] mod core_format_tests { - use super::*; use super::helpers::*; - + use super::*; + #[test] fn test_zip_processor_basic() { let temp_dir = TempDir::new().unwrap(); let zip_path = temp_dir.path().join("test.zip"); create_test_zip(&zip_path, vec![("file1.txt", b"hello")]); - + let mut processor = ZipProcessor::new(); let metadata = processor.open(&zip_path).unwrap(); - + assert_eq!(metadata.format, ArchiveFormat::Zip); assert_eq!(metadata.total_files, 1); } - + #[test] fn test_tar_processor_basic() { let temp_dir = TempDir::new().unwrap(); let tar_path = temp_dir.path().join("test.tar"); create_test_tar(&tar_path, vec![("file1.txt", b"tar content")]); - + let mut processor = TarProcessor::new(); let metadata = processor.open(&tar_path).unwrap(); - + assert_eq!(metadata.format, ArchiveFormat::Tar); } - + #[test] fn test_gzip_processor_basic() { let temp_dir = TempDir::new().unwrap(); let gz_path = temp_dir.path().join("test.gz"); create_test_gzip(&gz_path, b"gzip content here"); - + let mut processor = GzipProcessor::new(); let metadata = processor.open(&gz_path).unwrap(); - + assert_eq!(metadata.format, ArchiveFormat::Gzip); assert_eq!(metadata.total_files, 1); } - + #[test] fn test_validate_extraction_path_safe() { let temp_dir = TempDir::new().unwrap(); let base = temp_dir.path(); let safe_path = PathBuf::from("safe/file.txt"); - + let result = validate_extraction_path(&safe_path, base); assert!(result.is_ok()); - + let resolved = result.unwrap(); assert!(resolved.starts_with(base)); } - + #[test] fn test_validate_extraction_path_zip_slip() { let base = PathBuf::from("/tmp/extract"); let evil_path = PathBuf::from("../../etc/passwd"); - + let result = validate_extraction_path(&evil_path, &base); assert!(result.is_err()); } - + #[test] fn test_check_decompression_ratio_ok() { assert!(check_decompression_ratio(1000, 5000, 1000).is_ok()); } - + #[test] fn test_check_decompression_ratio_zip_bomb() { assert!(check_decompression_ratio(42_000, 5_000_000_000, 1000).is_err()); @@ -145,39 +145,39 @@ mod core_format_tests { #[cfg(test)] mod integration_tests { - use super::*; use super::helpers::*; + use super::*; use crate::archive::detector::FormatDetector; use crate::archive::ProcessorRegistry; - + #[test] fn test_format_detection_automation() { let temp_dir = TempDir::new().unwrap(); let detector = FormatDetector::new(); - + let zip_path = temp_dir.path().join("test.zip"); create_test_zip(&zip_path, vec![("f.txt", b"z")]); assert_eq!(detector.detect(&zip_path).unwrap(), ArchiveFormat::Zip); - + let tar_path = temp_dir.path().join("test.tar"); create_test_tar(&tar_path, vec![("f.txt", b"t")]); assert_eq!(detector.detect(&tar_path).unwrap(), ArchiveFormat::Tar); - + let gz_path = temp_dir.path().join("test.gz"); create_test_gzip(&gz_path, b"g"); assert_eq!(detector.detect(&gz_path).unwrap(), ArchiveFormat::Gzip); } - + #[test] fn test_processor_registry_integration() { let config = ArchiveConfig::default(); let mut registry = ProcessorRegistry::new(config); registry.initialize().unwrap(); - + let formats = registry.enabled_formats(); assert!(formats.contains(&ArchiveFormat::Zip)); assert!(formats.contains(&ArchiveFormat::Tar)); assert!(formats.contains(&ArchiveFormat::Gzip)); assert!(formats.contains(&ArchiveFormat::TarGzip)); } -} \ No newline at end of file +} diff --git a/markbase-core/src/archive/tests/integration_test.rs b/markbase-core/src/archive/tests/integration_test.rs index db0dc57..04d6b8a 100644 --- a/markbase-core/src/archive/tests/integration_test.rs +++ b/markbase-core/src/archive/tests/integration_test.rs @@ -4,48 +4,46 @@ use std::fs; use std::io::Read; use tempfile::TempDir; -use crate::archive::*; use crate::archive::processor::check_decompression_ratio; use crate::archive::tests::test_helpers::*; +use crate::archive::*; #[test] fn test_zip_processor_full_workflow() { let temp_dir = TempDir::new().unwrap(); let zip_path = create_test_zip(&temp_dir); - + // Initialize processor let mut processor = processors::core::ZipProcessor::new(); - + // Test open let metadata = processor.open(&zip_path).unwrap(); assert_eq!(metadata.format, ArchiveFormat::Zip); assert_eq!(metadata.total_files, 3); - + // Test list_entries let entries = processor.list_entries().unwrap(); assert_eq!(entries.len(), 3); - + // Verify entry names - let names: Vec<&str> = entries.iter() - .map(|e| e.path.to_str().unwrap()) - .collect(); + let names: Vec<&str> = entries.iter().map(|e| e.path.to_str().unwrap()).collect(); assert!(names.contains(&"file1.txt")); assert!(names.contains(&"file2.txt")); assert!(names.contains(&"subdir/file3.txt")); - + // Test extract_all let extract_dir = temp_dir.path().join("extracted"); fs::create_dir_all(&extract_dir).unwrap(); - + let result = processor.extract_all(&extract_dir).unwrap(); assert_eq!(result.success_files, 3); assert_eq!(result.failed_files.len(), 0); - + // Verify extracted files assert!(extract_dir.join("file1.txt").exists()); assert!(extract_dir.join("file2.txt").exists()); assert!(extract_dir.join("subdir/file3.txt").exists()); - + // Verify content let content1 = fs::read_to_string(extract_dir.join("file1.txt")).unwrap(); assert_eq!(content1, "content of file 1"); @@ -55,24 +53,24 @@ fn test_zip_processor_full_workflow() { fn test_tar_processor_full_workflow() { let temp_dir = TempDir::new().unwrap(); let tar_path = create_test_tar(&temp_dir); - + let mut processor = processors::core::TarProcessor::new(); - + // Test open let metadata = processor.open(&tar_path).unwrap(); assert_eq!(metadata.format, ArchiveFormat::Tar); - + // Test list_entries let entries = processor.list_entries().unwrap(); - assert!(entries.len() >= 3); // TAR may include directory entries - + assert!(entries.len() >= 3); // TAR may include directory entries + // Test extract_all let extract_dir = temp_dir.path().join("extracted_tar"); fs::create_dir_all(&extract_dir).unwrap(); - + let result = processor.extract_all(&extract_dir).unwrap(); assert!(result.success_files >= 3); - + // Verify extracted files exist assert!(extract_dir.join("file1.txt").exists()); assert!(extract_dir.join("file2.txt").exists()); @@ -82,25 +80,25 @@ fn test_tar_processor_full_workflow() { fn test_gzip_processor_full_workflow() { let temp_dir = TempDir::new().unwrap(); let gz_path = create_test_gzip(&temp_dir); - + let mut processor = processors::core::GzipProcessor::new(); - + // Test open let metadata = processor.open(&gz_path).unwrap(); assert_eq!(metadata.format, ArchiveFormat::Gzip); - assert_eq!(metadata.total_files, 1); // GZIP is single file - + assert_eq!(metadata.total_files, 1); // GZIP is single file + // Test extract_all let extract_dir = temp_dir.path().join("extracted_gz"); fs::create_dir_all(&extract_dir).unwrap(); - + let result = processor.extract_all(&extract_dir).unwrap(); assert_eq!(result.success_files, 1); - + // Verify extracted file (should strip .gz extension) let extracted_file = extract_dir.join("test.txt"); assert!(extracted_file.exists()); - + // Verify content let content = fs::read_to_string(&extracted_file).unwrap(); assert_eq!(content, "test gzip content for validation"); @@ -110,20 +108,20 @@ fn test_gzip_processor_full_workflow() { fn test_tar_gz_processor_workflow() { let temp_dir = TempDir::new().unwrap(); let tar_gz_path = create_test_tar_gz(&temp_dir); - + let mut processor = processors::core::TarGzipProcessor::new(); - + // Test open let metadata = processor.open(&tar_gz_path).unwrap(); assert_eq!(metadata.format, ArchiveFormat::TarGzip); - + // Test extract_all let extract_dir = temp_dir.path().join("extracted_tar_gz"); fs::create_dir_all(&extract_dir).unwrap(); - + let result = processor.extract_all(&extract_dir).unwrap(); assert!(result.success_files >= 2); - + // Verify extracted TAR files assert!(extract_dir.join("file1.txt").exists()); assert!(extract_dir.join("file2.txt").exists()); @@ -132,18 +130,18 @@ fn test_tar_gz_processor_workflow() { #[test] fn test_format_detection_auto() { let temp_dir = TempDir::new().unwrap(); - + // Test ZIP detection let zip_path = create_test_zip(&temp_dir); let detector = FormatDetector::new(); let format = detector.detect(&zip_path).unwrap(); assert_eq!(format, ArchiveFormat::Zip); - + // Test TAR detection let tar_path = create_test_tar(&temp_dir); let format = detector.detect(&tar_path).unwrap(); assert_eq!(format, ArchiveFormat::Tar); - + // Test GZIP detection let gz_path = create_test_gzip(&temp_dir); let format = detector.detect(&gz_path).unwrap(); @@ -155,12 +153,12 @@ fn test_processor_registry_core_formats() { let config = ArchiveConfig::default(); let mut registry = ProcessorRegistry::new(config); registry.initialize().unwrap(); - + let formats = registry.enabled_formats(); - + // Should have 9 core formats - assert!(formats.len() >= 4); // At least the ones we implemented - + assert!(formats.len() >= 4); // At least the ones we implemented + // Verify format support assert!(formats.contains(&ArchiveFormat::Zip)); assert!(formats.contains(&ArchiveFormat::Tar)); @@ -172,20 +170,20 @@ fn test_processor_registry_core_formats() { fn test_zip_slip_protection() { let temp_dir = TempDir::new().unwrap(); let zip_bomb_data = create_zip_slip_test(); - + // Write malicious ZIP to file let evil_zip_path = temp_dir.path().join("evil.zip"); fs::write(&evil_zip_path, &zip_bomb_data).unwrap(); - + let mut processor = processors::core::ZipProcessor::new(); processor.open(&evil_zip_path).unwrap(); - + // Attempt extraction should fail due to Zip Slip protection let extract_dir = temp_dir.path().join("should_fail"); fs::create_dir_all(&extract_dir).unwrap(); - + let result = processor.extract_all(&extract_dir); - + // Should either fail or have empty extracted files // (validate_extraction_path prevents malicious paths) if result.is_ok() { @@ -199,11 +197,11 @@ fn test_zip_slip_protection() { fn test_zip_bomb_detection() { // Test decompression ratio check let result = check_decompression_ratio(42_000, 5_000_000_000, 1000); - assert!(result.is_err()); // Should detect as Zip Bomb - + assert!(result.is_err()); // Should detect as Zip Bomb + // Test normal ratio let result = check_decompression_ratio(1000, 5000, 1000); - assert!(result.is_ok()); // Normal ratio should pass + assert!(result.is_ok()); // Normal ratio should pass } #[test] @@ -219,21 +217,21 @@ fn test_metadata_compression_ratio() { created_time: None, modified_time: None, }; - - assert_eq!(metadata.actual_ratio(), 5.0); // 5000/1000 = 5.0 - assert!(!metadata.check_zip_bomb(10)); // ratio 5.0 < 10, not a bomb - assert!(metadata.check_zip_bomb(4)); // ratio 5.0 > 4, detected as bomb + + assert_eq!(metadata.actual_ratio(), 5.0); // 5000/1000 = 5.0 + assert!(!metadata.check_zip_bomb(10)); // ratio 5.0 < 10, not a bomb + assert!(metadata.check_zip_bomb(4)); // ratio 5.0 > 4, detected as bomb } #[test] fn test_config_validation() { let config = ArchiveConfig { - max_decompression_ratio: 5, // Too low + max_decompression_ratio: 5, // Too low ..Default::default() }; - + assert!(config.validate().is_err()); - + let valid_config = ArchiveConfig::default(); assert!(valid_config.validate().is_ok()); -} \ No newline at end of file +} diff --git a/markbase-core/src/archive/tests/mod.rs b/markbase-core/src/archive/tests/mod.rs index 30283df..aca64a2 100644 --- a/markbase-core/src/archive/tests/mod.rs +++ b/markbase-core/src/archive/tests/mod.rs @@ -7,10 +7,10 @@ pub mod test_helpers; #[cfg(test)] mod tests { use super::*; - + #[test] fn test_module_structure() { // Test that all test modules exist assert!(true); } -} \ No newline at end of file +} diff --git a/markbase-core/src/archive/tests/test_helpers.rs b/markbase-core/src/archive/tests/test_helpers.rs index 9e60fc1..31aee35 100644 --- a/markbase-core/src/archive/tests/test_helpers.rs +++ b/markbase-core/src/archive/tests/test_helpers.rs @@ -1,28 +1,27 @@ +use flate2::write::GzEncoder; +use flate2::Compression; use std::fs::{self, File}; use std::io::Write; use std::path::PathBuf; -use tempfile::TempDir; -use zip::{ZipWriter, write::FileOptions, CompressionMethod}; -use flate2::write::GzEncoder; -use flate2::Compression; use tar::Builder; +use tempfile::TempDir; +use zip::{write::FileOptions, CompressionMethod, ZipWriter}; pub fn create_test_zip(temp_dir: &TempDir) -> PathBuf { let zip_path = temp_dir.path().join("test.zip"); let file = File::create(&zip_path).unwrap(); let mut zip = ZipWriter::new(file); - let options = FileOptions::default() - .compression_method(CompressionMethod::Stored); - + let options = FileOptions::default().compression_method(CompressionMethod::Stored); + zip.start_file("file1.txt", options).unwrap(); zip.write_all(b"content of file 1").unwrap(); - + zip.start_file("file2.txt", options).unwrap(); zip.write_all(b"content of file 2").unwrap(); - + zip.start_file("subdir/file3.txt", options).unwrap(); zip.write_all(b"content of file 3 in subdir").unwrap(); - + zip.finish().unwrap(); zip_path } @@ -31,28 +30,38 @@ pub fn create_test_tar(temp_dir: &TempDir) -> PathBuf { let tar_path = temp_dir.path().join("test.tar"); let file = File::create(&tar_path).unwrap(); let mut builder = Builder::new(file); - + let mut header1 = tar::Header::new_gnu(); header1.set_path("file1.txt").unwrap(); header1.set_size(17); header1.set_mode(0o644); header1.set_cksum(); - builder.append_data(&mut header1, "file1.txt", &b"content of file 1"[..]).unwrap(); - + builder + .append_data(&mut header1, "file1.txt", &b"content of file 1"[..]) + .unwrap(); + let mut header2 = tar::Header::new_gnu(); header2.set_path("file2.txt").unwrap(); header2.set_size(17); header2.set_mode(0o644); header2.set_cksum(); - builder.append_data(&mut header2, "file2.txt", &b"content of file 2"[..]).unwrap(); - + builder + .append_data(&mut header2, "file2.txt", &b"content of file 2"[..]) + .unwrap(); + let mut header3 = tar::Header::new_gnu(); header3.set_path("subdir/file3.txt").unwrap(); header3.set_size(27); header3.set_mode(0o644); header3.set_cksum(); - builder.append_data(&mut header3, "subdir/file3.txt", &b"content of file 3 in subdir"[..]).unwrap(); - + builder + .append_data( + &mut header3, + "subdir/file3.txt", + &b"content of file 3 in subdir"[..], + ) + .unwrap(); + builder.finish().unwrap(); tar_path } @@ -61,7 +70,9 @@ pub fn create_test_gzip(temp_dir: &TempDir) -> PathBuf { let gz_path = temp_dir.path().join("test.txt.gz"); let file = File::create(&gz_path).unwrap(); let mut encoder = GzEncoder::new(file, Compression::default()); - encoder.write_all(b"test gzip content for validation").unwrap(); + encoder + .write_all(b"test gzip content for validation") + .unwrap(); encoder.finish().unwrap(); gz_path } @@ -70,33 +81,37 @@ pub fn create_test_tar_gz(temp_dir: &TempDir) -> PathBuf { let tar_path = temp_dir.path().join("test.tar"); let tar_file = File::create(&tar_path).unwrap(); let mut builder = Builder::new(tar_file); - + let mut header1 = tar::Header::new_gnu(); header1.set_path("file1.txt").unwrap(); header1.set_size(10); header1.set_mode(0o644); header1.set_cksum(); - builder.append_data(&mut header1, "file1.txt", &b"file1 data"[..]).unwrap(); - + builder + .append_data(&mut header1, "file1.txt", &b"file1 data"[..]) + .unwrap(); + let mut header2 = tar::Header::new_gnu(); header2.set_path("file2.txt").unwrap(); header2.set_size(10); header2.set_mode(0o644); header2.set_cksum(); - builder.append_data(&mut header2, "file2.txt", &b"file2 data"[..]).unwrap(); - + builder + .append_data(&mut header2, "file2.txt", &b"file2 data"[..]) + .unwrap(); + builder.finish().unwrap(); - + let tar_gz_path = temp_dir.path().join("test.tar.gz"); let gz_file = File::create(&tar_gz_path).unwrap(); let mut encoder = GzEncoder::new(gz_file, Compression::default()); - + let tar_content = std::fs::read(&tar_path).unwrap(); encoder.write_all(&tar_content).unwrap(); encoder.finish().unwrap(); - + std::fs::remove_file(&tar_path).unwrap(); - + tar_gz_path } @@ -105,13 +120,12 @@ pub fn create_zip_bomb_test() -> Vec { { let writer = std::io::Cursor::new(&mut buffer); let mut zip = ZipWriter::new(writer); - - let options = FileOptions::default() - .compression_method(CompressionMethod::Stored); - + + let options = FileOptions::default().compression_method(CompressionMethod::Stored); + zip.start_file("bomb.txt", options).unwrap(); zip.write_all(&[0u8; 100]).unwrap(); - + zip.finish().unwrap(); } buffer @@ -123,11 +137,11 @@ pub fn create_zip_slip_test() -> Vec { let writer = std::io::Cursor::new(&mut buffer); let mut zip = ZipWriter::new(writer); let options = FileOptions::default(); - + zip.start_file("../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../etc/passwd", options).unwrap(); zip.write_all(b"malicious content").unwrap(); - + zip.finish().unwrap(); } buffer -} \ No newline at end of file +} diff --git a/markbase-core/src/archive/warning.rs b/markbase-core/src/archive/warning.rs index 4b32633..9418187 100644 --- a/markbase-core/src/archive/warning.rs +++ b/markbase-core/src/archive/warning.rs @@ -1,6 +1,6 @@ // Warning System - Legal and Technical Warnings for Optional Formats -use log::{warn, info}; +use log::{info, warn}; use crate::archive::config::ArchiveConfig; @@ -63,25 +63,27 @@ pub fn show_startup_warnings(config: &ArchiveConfig) { if config.enable_rar { show_rar_legal_warning(); } - + if config.enable_xz { // Dependency check happens in ProcessorRegistry } - + if config.enable_7z { show_7z_stability_warning(); } - + // Show summary of enabled formats - let enabled_optional = [ - config.enable_rar, - config.enable_xz, - config.enable_7z, - ].iter().filter(|&x| *x).count(); - + let enabled_optional = [config.enable_rar, config.enable_xz, config.enable_7z] + .iter() + .filter(|&x| *x) + .count(); + if enabled_optional > 0 { info!(""); - info!("⚠️ {} optional format(s) enabled with warnings shown above", enabled_optional); + info!( + "⚠️ {} optional format(s) enabled with warnings shown above", + enabled_optional + ); info!("Core formats (9): ZIP, TAR, GZIP, ZSTD, BZIP2, LZ4, TAR.GZ, TAR.BZ2, TAR.ZST"); info!(""); } @@ -89,8 +91,7 @@ pub fn show_startup_warnings(config: &ArchiveConfig) { /// Generate user-facing legal disclaimer text pub fn generate_rar_legal_disclaimer() -> String { - format!( - "RAR FORMAT LEGAL DISCLAIMER + "RAR FORMAT LEGAL DISCLAIMER IMPORTANT WARNING: @@ -136,6 +137,5 @@ CONTACT: Last Updated: 2026-06-10 Version: 1.0 Legal Consultation: [Please consult professional lawyer for commercial use] -" - ) -} \ No newline at end of file +".to_string() +} diff --git a/markbase-core/src/audit.rs b/markbase-core/src/audit.rs index 37f75ae..30796ef 100644 --- a/markbase-core/src/audit.rs +++ b/markbase-core/src/audit.rs @@ -52,7 +52,7 @@ impl AuditLogger { }; self.write_entry(&entry)?; - + log::info!( "Audit: {} config {} changed from '{}' to '{}' by {}", config_type, @@ -61,7 +61,7 @@ impl AuditLogger { new_value, user ); - + Ok(()) } @@ -126,7 +126,7 @@ impl AuditLogger { } else { 0 }; - + Ok(entries[start..].to_vec()) } -} \ No newline at end of file +} diff --git a/markbase-core/src/auth.rs b/markbase-core/src/auth.rs index ba21cc6..439677d 100644 --- a/markbase-core/src/auth.rs +++ b/markbase-core/src/auth.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; use uuid::Uuid; -use crate::provider::{DataProvider, ProviderError}; +use crate::provider::DataProvider; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct User { @@ -71,6 +71,12 @@ pub struct AuthState { pub provider: Option>, } +impl Default for AuthState { + fn default() -> Self { + Self::new() + } +} + impl AuthState { pub fn new() -> Self { let mut users = HashMap::new(); @@ -284,7 +290,12 @@ impl AuthState { } } - fn login_with_provider(&self, provider: &dyn DataProvider, username: &str, password: &str) -> Option { + fn login_with_provider( + &self, + provider: &dyn DataProvider, + username: &str, + password: &str, + ) -> Option { match provider.get_user(username) { Ok(Some(user)) => { if user.status != 1 { diff --git a/markbase-core/src/category_view.rs b/markbase-core/src/category_view.rs index 34a4d95..67654e4 100644 --- a/markbase-core/src/category_view.rs +++ b/markbase-core/src/category_view.rs @@ -118,14 +118,20 @@ fn get_series_display_name(name: &str) -> String { pub fn get_all_categories() -> Result { let conn = FileTree::open_user_db("accusys")?; let tree = FileTree::load(&conn, "accusys", "categories")?; - - let categories: Vec = tree.nodes.iter() + + let categories: Vec = tree + .nodes + .iter() .filter(|n| n.parent_id.is_none() && n.node_type.as_str() == "folder") .map(|n| { - let file_count = tree.nodes.iter() - .filter(|f| f.parent_id == Some(n.node_id.clone()) && f.node_type.as_str() == "file") + let file_count = tree + .nodes + .iter() + .filter(|f| { + f.parent_id == Some(n.node_id.clone()) && f.node_type.as_str() == "file" + }) .count(); - + Category { name: n.label.clone(), display_name: get_category_display_name(&n.label), @@ -135,11 +141,13 @@ pub fn get_all_categories() -> Result { } }) .collect(); - - let total_files = tree.nodes.iter() + + let total_files = tree + .nodes + .iter() .filter(|n| n.node_type.as_str() == "file") .count(); - + Ok(CategoriesResponse { total_categories: categories.len(), total_files, @@ -150,42 +158,65 @@ pub fn get_all_categories() -> Result { pub fn get_category_detail(category_name: &str) -> Result { let conn = FileTree::open_user_db("accusys")?; let tree = FileTree::load(&conn, "accusys", "categories")?; - - let category_node = tree.nodes.iter() - .find(|n| n.label == category_name && n.parent_id.is_none() && n.node_type.as_str() == "folder") + + let category_node = tree + .nodes + .iter() + .find(|n| { + n.label == category_name && n.parent_id.is_none() && n.node_type.as_str() == "folder" + }) .ok_or_else(|| anyhow::anyhow!("Category not found: {}", category_name))?; - - let series_groups: Vec = tree.nodes.iter() - .filter(|n| n.parent_id == Some(category_node.node_id.clone()) && n.node_type.as_str() == "folder") + + let series_groups: Vec = tree + .nodes + .iter() + .filter(|n| { + n.parent_id == Some(category_node.node_id.clone()) && n.node_type.as_str() == "folder" + }) .map(|series_node| { - let files: Vec = tree.nodes.iter() - .filter(|f| f.parent_id == Some(series_node.node_id.clone()) && f.node_type.as_str() == "file") - .map(|file_node| { - CategoryFile { - filename: file_node.label.clone(), - size: file_node.aliases.get("file_size_display").cloned().unwrap_or_default(), - download_url: file_node.aliases.get("download_url").cloned().unwrap_or_default(), - sha256: file_node.sha256.clone(), - } + let files: Vec = tree + .nodes + .iter() + .filter(|f| { + f.parent_id == Some(series_node.node_id.clone()) + && f.node_type.as_str() == "file" + }) + .map(|file_node| CategoryFile { + filename: file_node.label.clone(), + size: file_node + .aliases + .get("file_size_display") + .cloned() + .unwrap_or_default(), + download_url: file_node + .aliases + .get("download_url") + .cloned() + .unwrap_or_default(), + sha256: file_node.sha256.clone(), }) .collect(); - + SeriesGroup { series_name: series_node.label.clone(), files, } }) .collect(); - + let file_count = series_groups.iter().map(|g| g.files.len()).sum(); - + Ok(CategoryDetail { category: Category { name: category_name.to_string(), display_name: get_category_display_name(category_name), file_count, last_updated: category_node.updated_at.clone(), - description: category_node.aliases.get("description").cloned().unwrap_or_default(), + description: category_node + .aliases + .get("description") + .cloned() + .unwrap_or_default(), }, series_groups, }) @@ -194,25 +225,31 @@ pub fn get_category_detail(category_name: &str) -> Result { pub fn get_all_series() -> Result { let conn = FileTree::open_user_db("accusys")?; let tree = FileTree::load(&conn, "accusys", "series")?; - - let series: Vec = tree.nodes.iter() + + let series: Vec = tree + .nodes + .iter() .filter(|n| n.parent_id.is_none() && n.node_type.as_str() == "folder") .map(|n| { - let file_count = tree.nodes.iter() + let file_count = tree + .nodes + .iter() .filter(|f| { let mut current = f.parent_id.clone(); while let Some(pid) = current { if pid == n.node_id { return f.node_type.as_str() == "file"; } - current = tree.nodes.iter() + current = tree + .nodes + .iter() .find(|p| p.node_id == pid) - .map(|p| p.parent_id.clone()).flatten(); + .and_then(|p| p.parent_id.clone()); } false }) .count(); - + Series { name: n.label.clone(), display_name: get_series_display_name(&n.label), @@ -223,11 +260,13 @@ pub fn get_all_series() -> Result { } }) .collect(); - - let total_files = tree.nodes.iter() + + let total_files = tree + .nodes + .iter() .filter(|n| n.node_type.as_str() == "file") .count(); - + Ok(SeriesResponse { total_series: series.len(), total_files, @@ -238,45 +277,63 @@ pub fn get_all_series() -> Result { pub fn get_series_detail(series_name: &str) -> Result { let conn = FileTree::open_user_db("accusys")?; let tree = FileTree::load(&conn, "accusys", "series")?; - - let series_node = tree.nodes.iter() - .find(|n| n.label == series_name && n.parent_id.is_none() && n.node_type.as_str() == "folder") + + let series_node = tree + .nodes + .iter() + .find(|n| { + n.label == series_name && n.parent_id.is_none() && n.node_type.as_str() == "folder" + }) .ok_or_else(|| anyhow::anyhow!("Series not found: {}", series_name))?; - - let categories: Vec = tree.nodes.iter() - .filter(|n| n.parent_id == Some(series_node.node_id.clone()) && n.node_type.as_str() == "folder") + + let categories: Vec = tree + .nodes + .iter() + .filter(|n| { + n.parent_id == Some(series_node.node_id.clone()) && n.node_type.as_str() == "folder" + }) .map(|category_node| { - let files: Vec = tree.nodes.iter() + let files: Vec = tree + .nodes + .iter() .filter(|f| { let mut current = f.parent_id.clone(); while let Some(pid) = current { if pid == category_node.node_id && f.node_type.as_str() == "file" { return true; } - current = tree.nodes.iter() + current = tree + .nodes + .iter() .find(|p| p.node_id == pid) - .map(|p| p.parent_id.clone()).flatten(); + .and_then(|p| p.parent_id.clone()); } false }) - .map(|file_node| { - SeriesFile { - filename: file_node.label.clone(), - size: file_node.aliases.get("file_size_display").unwrap_or(&"N/A".to_string()).clone(), - download_url: file_node.aliases.get("download_url").unwrap_or(&"".to_string()).clone(), - } + .map(|file_node| SeriesFile { + filename: file_node.label.clone(), + size: file_node + .aliases + .get("file_size_display") + .unwrap_or(&"N/A".to_string()) + .clone(), + download_url: file_node + .aliases + .get("download_url") + .unwrap_or(&"".to_string()) + .clone(), }) .collect(); - + SeriesCategory { category_name: category_node.label.clone(), files, } }) .collect(); - + let file_count = categories.iter().map(|c| c.files.len()).sum(); - + Ok(SeriesDetail { series: Series { name: series_name.to_string(), @@ -296,29 +353,39 @@ pub fn search_files(query: &str, view: &str) -> Result { "series" => "series", _ => "untitled folder", }; - + let conn = FileTree::open_user_db("accusys")?; let tree = FileTree::load(&conn, "accusys", tree_type)?; - - let results: Vec = tree.nodes.iter() - .filter(|n| n.node_type.as_str() == "file" && n.label.to_lowercase().contains(&query.to_lowercase())) + + let results: Vec = tree + .nodes + .iter() + .filter(|n| { + n.node_type.as_str() == "file" && n.label.to_lowercase().contains(&query.to_lowercase()) + }) .map(|file_node| { - let parent_node = tree.nodes.iter() + let parent_node = tree + .nodes + .iter() .find(|n| n.node_id == file_node.parent_id.clone().unwrap_or_default()); - + SearchResult { category: parent_node.map(|n| n.label.clone()), series: parent_node.map(|n| n.label.clone()), filename: file_node.label.clone(), - download_url: file_node.aliases.get("download_url").unwrap_or(&"".to_string()).clone(), + download_url: file_node + .aliases + .get("download_url") + .unwrap_or(&"".to_string()) + .clone(), } }) .collect(); - + Ok(SearchResponse { query: query.to_string(), view: view.to_string(), total_results: results.len(), results, }) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/interface/iscsi.rs b/markbase-core/src/cli/interface/iscsi.rs index 7f1a87c..e633748 100644 --- a/markbase-core/src/cli/interface/iscsi.rs +++ b/markbase-core/src/cli/interface/iscsi.rs @@ -25,7 +25,7 @@ pub async fn handle_iscsi_command(cmd: IscsiCommand) -> anyhow::Result<()> { let binary = find_binary("markbase-iscsi"); let mut cmd_process = std::process::Command::new(&binary); cmd_process.arg("iscsi"); - + match cmd { IscsiCommand::Start { user, @@ -34,7 +34,8 @@ pub async fn handle_iscsi_command(cmd: IscsiCommand) -> anyhow::Result<()> { force, device, } => { - cmd_process.arg("start") + cmd_process + .arg("start") .args(["--user", &user]) .args(["--port", &port.to_string()]) .args(["--lun-size", &lun_size]); @@ -52,7 +53,7 @@ pub async fn handle_iscsi_command(cmd: IscsiCommand) -> anyhow::Result<()> { cmd_process.arg("status"); } } - + let status = cmd_process.status()?; std::process::exit(status.code().unwrap_or(1)); } @@ -61,4 +62,4 @@ fn find_binary(name: &str) -> std::path::PathBuf { let exe = std::env::current_exe().unwrap(); let dir = exe.parent().unwrap(); dir.join(name) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/interface/mod.rs b/markbase-core/src/cli/interface/mod.rs index 835a441..2424f64 100644 --- a/markbase-core/src/cli/interface/mod.rs +++ b/markbase-core/src/cli/interface/mod.rs @@ -1,8 +1,8 @@ -pub mod web; -pub mod ssh; -pub mod webdav; pub mod iscsi; +pub mod ssh; pub mod tree; +pub mod web; +pub mod webdav; use clap::Subcommand; @@ -29,4 +29,4 @@ pub async fn handle_interface_command(cmd: InterfaceCommands) -> anyhow::Result< InterfaceCommands::Tree(c) => tree::handle_tree_command(c).await?, } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/interface/ssh.rs b/markbase-core/src/cli/interface/ssh.rs index 9e543cd..5dc54ae 100644 --- a/markbase-core/src/cli/interface/ssh.rs +++ b/markbase-core/src/cli/interface/ssh.rs @@ -32,4 +32,4 @@ pub async fn handle_ssh_command(cmd: SshCommand) -> anyhow::Result<()> { } } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/interface/tree.rs b/markbase-core/src/cli/interface/tree.rs index c9c7e2e..8ec264c 100644 --- a/markbase-core/src/cli/interface/tree.rs +++ b/markbase-core/src/cli/interface/tree.rs @@ -1,6 +1,6 @@ +use anyhow::Context; use clap::Subcommand; use rusqlite::Connection; -use anyhow::Context; use uuid::Uuid; #[derive(Subcommand)] @@ -33,12 +33,12 @@ pub enum TreeCommand { #[arg(short, long)] name: String, }, - + Folder { #[command(subcommand)] action: FolderCommand, }, - + Ls { #[arg(short, long)] user: String, @@ -47,7 +47,7 @@ pub enum TreeCommand { #[arg(short, long)] tree_type: String, }, - + Cp { #[arg(short, long)] user: String, @@ -58,7 +58,7 @@ pub enum TreeCommand { #[arg(short, long)] tree_type: String, }, - + Mv { #[arg(short, long)] user: String, @@ -113,44 +113,54 @@ pub enum FolderCommand { pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> { match cmd { - TreeCommand::Create { name, user, tree_type } => { + TreeCommand::Create { + name, + user, + tree_type, + } => { let db_path = format!("data/users/{}.sqlite", user); let conn = Connection::open(&db_path) .with_context(|| format!("Failed to open database: {}", db_path))?; - + let node_id = Uuid::new_v4().to_string(); let created_at = chrono::Utc::now().to_rfc3339(); - + conn.execute( "INSERT INTO file_nodes (node_id, label, node_type, tree_type, created_at, updated_at) VALUES (?1, ?2, 'folder', ?3, ?4, ?4)", rusqlite::params![node_id, name, tree_type, created_at] ).context("Failed to create tree")?; - - println!("✓ Tree created: {} (type: {}) for user: {}", name, tree_type, user); + + println!( + "✓ Tree created: {} (type: {}) for user: {}", + name, tree_type, user + ); println!("✓ Node ID: {}", node_id); } TreeCommand::List { user } => { let db_path = format!("data/users/{}.sqlite", user); let conn = Connection::open(&db_path) .with_context(|| format!("Failed to open database: {}", db_path))?; - - let mut stmt = conn.prepare( - "SELECT DISTINCT tree_type FROM file_nodes ORDER BY tree_type" - ).context("Failed to prepare query")?; - - let tree_types = stmt.query_map([], |row| row.get::<_, String>(0)) + + let mut stmt = conn + .prepare("SELECT DISTINCT tree_type FROM file_nodes ORDER BY tree_type") + .context("Failed to prepare query")?; + + let tree_types = stmt + .query_map([], |row| row.get::<_, String>(0)) .context("Failed to query tree types")?; - + println!("=== Trees for user: {} ===", user); for tree_type in tree_types { let tt = tree_type?; - let count: i64 = conn.query_row( - "SELECT COUNT(*) FROM file_nodes WHERE tree_type = ?1", - [&tt], - |row| row.get(0) - ).unwrap_or(0); - + let count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM file_nodes WHERE tree_type = ?1", + [&tt], + |row| row.get(0), + ) + .unwrap_or(0); + println!(" {} ({} nodes)", tt, count); } } @@ -158,9 +168,9 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> { let db_path = format!("data/users/{}.sqlite", user); let conn = Connection::open(&db_path) .with_context(|| format!("Failed to open database: {}", db_path))?; - + println!("Importing Markdown files to {} virtual tree...", tree_type); - + if tree_type == "categories" { crate::import_markdown::import_categories_to_db(&conn, &user, &tree_type)?; println!("✓ Categories imported successfully!"); @@ -168,53 +178,66 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> { crate::import_markdown::import_series_to_db(&conn, &user, &tree_type)?; println!("✓ Series imported successfully!"); } else { - eprintln!("Invalid tree_type: {}. Use 'categories' or 'series'", tree_type); + eprintln!( + "Invalid tree_type: {}. Use 'categories' or 'series'", + tree_type + ); } } TreeCommand::Delete { user, name } => { let db_path = format!("data/users/{}.sqlite", user); let conn = Connection::open(&db_path) .with_context(|| format!("Failed to open database: {}", db_path))?; - + conn.execute( "DELETE FROM file_nodes WHERE label = ?1 AND node_type = 'folder'", - [&name] - ).context("Failed to delete tree")?; - + [&name], + ) + .context("Failed to delete tree")?; + println!("✓ Tree deleted: {} for user: {}", name, user); } - + TreeCommand::Folder { action } => { handle_folder_command(action)?; } - - TreeCommand::Ls { user, path, tree_type } => { + + TreeCommand::Ls { + user, + path, + tree_type, + } => { let db_path = format!("data/users/{}.sqlite", user); let conn = Connection::open(&db_path) .with_context(|| format!("Failed to open database: {}", db_path))?; - + let parent_id = find_node_id(&conn, &path, &tree_type)?; - - let mut stmt = conn.prepare( - "SELECT label, node_type, file_size FROM file_nodes + + let mut stmt = conn + .prepare( + "SELECT label, node_type, file_size FROM file_nodes WHERE parent_id = ?1 AND tree_type = ?2 - ORDER BY node_type DESC, label ASC" - ).context("Failed to prepare ls query")?; - - let entries = stmt.query_map( - rusqlite::params![parent_id, tree_type], - |row| Ok(( - row.get::<_, String>(0)?, - row.get::<_, String>(1)?, - row.get::<_, Option>(2)? - )) - ).context("Failed to query entries")?; - + ORDER BY node_type DESC, label ASC", + ) + .context("Failed to prepare ls query")?; + + let entries = stmt + .query_map(rusqlite::params![parent_id, tree_type], |row| { + Ok(( + row.get::<_, String>(0)?, + row.get::<_, String>(1)?, + row.get::<_, Option>(2)?, + )) + }) + .context("Failed to query entries")?; + println!("=== Contents of {} (tree_type: {}) ===", path, tree_type); for entry in entries { let (name, node_type, size) = entry?; - let size_str = size.map(|s| format!("{} bytes", s)).unwrap_or_else(|| "-".to_string()); - + let size_str = size + .map(|s| format!("{} bytes", s)) + .unwrap_or_else(|| "-".to_string()); + if node_type == "folder" { println!(" 📁 {} ({})", name, size_str); } else { @@ -222,57 +245,72 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> { } } } - - TreeCommand::Cp { user, source, target, tree_type } => { + + TreeCommand::Cp { + user, + source, + target, + tree_type, + } => { let db_path = format!("data/users/{}.sqlite", user); let conn = Connection::open(&db_path) .with_context(|| format!("Failed to open database: {}", db_path))?; - + let source_id = find_node_id(&conn, &source, &tree_type)?; let target_parent_id = find_node_id(&conn, &target, &tree_type)?; - - let (label, node_type, aliases_json, file_uuid, sha256, file_size) = conn.query_row( - "SELECT label, node_type, aliases_json, file_uuid, sha256, file_size + + let (label, node_type, aliases_json, file_uuid, sha256, file_size) = conn + .query_row( + "SELECT label, node_type, aliases_json, file_uuid, sha256, file_size FROM file_nodes WHERE node_id = ?1", - [&source_id], - |row| Ok(( - row.get::<_, String>(0)?, - row.get::<_, String>(1)?, - row.get::<_, String>(2)?, - row.get::<_, Option>(3)?, - row.get::<_, Option>(4)?, - row.get::<_, Option>(5)? - )) - ).context("Failed to get source node")?; - + [&source_id], + |row| { + Ok(( + row.get::<_, String>(0)?, + row.get::<_, String>(1)?, + row.get::<_, String>(2)?, + row.get::<_, Option>(3)?, + row.get::<_, Option>(4)?, + row.get::<_, Option>(5)?, + )) + }, + ) + .context("Failed to get source node")?; + let new_id = Uuid::new_v4().to_string(); let created_at = chrono::Utc::now().to_rfc3339(); - + conn.execute( "INSERT INTO file_nodes (node_id, label, aliases_json, file_uuid, sha256, parent_id, node_type, file_size, tree_type, created_at, updated_at) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?10)", rusqlite::params![new_id, label, aliases_json, file_uuid, sha256, target_parent_id, node_type, file_size, tree_type, created_at] ).context("Failed to copy node")?; - + println!("✓ Copied {} to {} (new ID: {})", source, target, new_id); } - - TreeCommand::Mv { user, source, target, tree_type } => { + + TreeCommand::Mv { + user, + source, + target, + tree_type, + } => { let db_path = format!("data/users/{}.sqlite", user); let conn = Connection::open(&db_path) .with_context(|| format!("Failed to open database: {}", db_path))?; - + let source_id = find_node_id(&conn, &source, &tree_type)?; let target_parent_id = find_node_id(&conn, &target, &tree_type)?; - + let updated_at = chrono::Utc::now().to_rfc3339(); - + conn.execute( "UPDATE file_nodes SET parent_id = ?1, updated_at = ?2 WHERE node_id = ?3", - rusqlite::params![target_parent_id, updated_at, source_id] - ).context("Failed to move node")?; - + rusqlite::params![target_parent_id, updated_at, source_id], + ) + .context("Failed to move node")?; + println!("✓ Moved {} to {}", source, target); } } @@ -281,104 +319,136 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> { fn handle_folder_command(cmd: FolderCommand) -> anyhow::Result<()> { match cmd { - FolderCommand::Create { user, path, name, tree_type } => { + FolderCommand::Create { + user, + path, + name, + tree_type, + } => { let db_path = format!("data/users/{}.sqlite", user); let conn = Connection::open(&db_path) .with_context(|| format!("Failed to open database: {}", db_path))?; - - let parent_id = if path == "/" || path == "" { + + let parent_id = if path == "/" || path.is_empty() { None } else { Some(find_node_id(&conn, &path, &tree_type)?) }; - + let node_id = Uuid::new_v4().to_string(); let created_at = chrono::Utc::now().to_rfc3339(); - + conn.execute( "INSERT INTO file_nodes (node_id, label, parent_id, node_type, tree_type, created_at, updated_at) VALUES (?1, ?2, ?3, 'folder', ?4, ?5, ?5)", - rusqlite::params![node_id, name, parent_id, tree_type, created_at] - ).context("Failed to create folder")?; - - println!("✓ Folder created: {} in {} (tree_type: {})", name, path, tree_type); + rusqlite::params![node_id, name, parent_id, tree_type, created_at], + ) + .context("Failed to create folder")?; + + println!( + "✓ Folder created: {} in {} (tree_type: {})", + name, path, tree_type + ); println!("✓ Node ID: {}", node_id); } - FolderCommand::Delete { user, path, name, tree_type } => { + FolderCommand::Delete { + user, + path, + name, + tree_type, + } => { let db_path = format!("data/users/{}.sqlite", user); let conn = Connection::open(&db_path) .with_context(|| format!("Failed to open database: {}", db_path))?; - - let folder_path = if path == "/" || path == "" { + + let folder_path = if path == "/" || path.is_empty() { name.clone() } else { format!("{}/{}", path, name) }; - + let folder_id = find_node_id(&conn, &folder_path, &tree_type)?; - + conn.execute( "DELETE FROM file_nodes WHERE node_id = ?1 OR parent_id = ?1", - [&folder_id] - ).context("Failed to delete folder and children")?; - - println!("✓ Folder deleted: {} in {} (tree_type: {})", name, path, tree_type); + [&folder_id], + ) + .context("Failed to delete folder and children")?; + + println!( + "✓ Folder deleted: {} in {} (tree_type: {})", + name, path, tree_type + ); } - FolderCommand::Rename { user, path, old_name, new_name, tree_type } => { + FolderCommand::Rename { + user, + path, + old_name, + new_name, + tree_type, + } => { let db_path = format!("data/users/{}.sqlite", user); let conn = Connection::open(&db_path) .with_context(|| format!("Failed to open database: {}", db_path))?; - - let folder_path = if path == "/" || path == "" { + + let folder_path = if path == "/" || path.is_empty() { old_name.clone() } else { format!("{}/{}", path, old_name) }; - + let folder_id = find_node_id(&conn, &folder_path, &tree_type)?; - + let updated_at = chrono::Utc::now().to_rfc3339(); - + conn.execute( "UPDATE file_nodes SET label = ?1, updated_at = ?2 WHERE node_id = ?3", - rusqlite::params![new_name, updated_at, folder_id] - ).context("Failed to rename folder")?; - - println!("✓ Folder renamed: {} → {} in {} (tree_type: {})", old_name, new_name, path, tree_type); + rusqlite::params![new_name, updated_at, folder_id], + ) + .context("Failed to rename folder")?; + + println!( + "✓ Folder renamed: {} → {} in {} (tree_type: {})", + old_name, new_name, path, tree_type + ); } } Ok(()) } fn find_node_id(conn: &Connection, path: &str, tree_type: &str) -> anyhow::Result { - if path == "/" || path == "" { - let node_id: String = conn.query_row( - "SELECT node_id FROM file_nodes + if path == "/" || path.is_empty() { + let node_id: String = conn + .query_row( + "SELECT node_id FROM file_nodes WHERE parent_id IS NULL AND node_type = 'folder' AND tree_type = ?1 LIMIT 1", - [tree_type], - |row| row.get(0) - ).context("Failed to find root folder")?; - + [tree_type], + |row| row.get(0), + ) + .context("Failed to find root folder")?; + return Ok(node_id); } - + let parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect(); - + let mut current_parent: Option = None; - + for part in parts { - let node_id: String = conn.query_row( - "SELECT node_id FROM file_nodes + let node_id: String = conn + .query_row( + "SELECT node_id FROM file_nodes WHERE label = ?1 AND tree_type = ?2 AND (parent_id = ?3 OR (?3 IS NULL AND parent_id IS NULL))", - rusqlite::params![part, tree_type, current_parent], - |row| row.get(0) - ).context(format!("Failed to find node: {}", part))?; - + rusqlite::params![part, tree_type, current_parent], + |row| row.get(0), + ) + .context(format!("Failed to find node: {}", part))?; + current_parent = Some(node_id); } - + current_parent.context("Failed to find node ID for path") -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/interface/web.rs b/markbase-core/src/cli/interface/web.rs index aee3fb4..5a2bac5 100644 --- a/markbase-core/src/cli/interface/web.rs +++ b/markbase-core/src/cli/interface/web.rs @@ -18,4 +18,4 @@ pub async fn handle_web_command(cmd: WebCommand) -> anyhow::Result<()> { } } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/interface/webdav.rs b/markbase-core/src/cli/interface/webdav.rs index ec6a3ff..24b2f10 100644 --- a/markbase-core/src/cli/interface/webdav.rs +++ b/markbase-core/src/cli/interface/webdav.rs @@ -1,5 +1,5 @@ -use clap::Subcommand; use axum::{extract::Request, response::IntoResponse, Extension}; +use clap::Subcommand; #[derive(Subcommand)] pub enum WebdavCommand { @@ -28,7 +28,7 @@ pub async fn handle_webdav_command(cmd: WebdavCommand) -> anyhow::Result<()> { println!("User: {}", user); println!("Port: {}", port); println!("Database: {}", db_path.display()); - println!(""); + println!(); run_webdav_server(port, user, db_path).await?; } @@ -41,7 +41,7 @@ async fn run_webdav_server( user: String, db_path: std::path::PathBuf, ) -> anyhow::Result<()> { - use axum::{extract::Request, response::IntoResponse, routing::any, Extension, Router}; + use axum::{routing::any, Extension, Router}; use tokio::net::TcpListener; let webdav = markbase_webdav::webdav::MarkBaseWebDAV::new(user, db_path); @@ -58,7 +58,7 @@ async fn run_webdav_server( println!("WebDAV server listening on http://{}", addr); println!("Mount point: /webdav"); - println!(""); + println!(); println!("Press Ctrl+C to stop"); axum::serve(listener, app).await?; @@ -71,4 +71,4 @@ async fn handle_dav( req: Request, ) -> impl IntoResponse { dav.handle(req).await -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/metadata/auth.rs b/markbase-core/src/cli/metadata/auth.rs index f22160e..9aa1cc7 100644 --- a/markbase-core/src/cli/metadata/auth.rs +++ b/markbase-core/src/cli/metadata/auth.rs @@ -1,6 +1,6 @@ +use anyhow::Context; use clap::Subcommand; use rusqlite::Connection; -use anyhow::Context; #[derive(Subcommand)] pub enum AuthCommand { @@ -24,29 +24,30 @@ pub fn handle_auth_command(cmd: AuthCommand) -> anyhow::Result<()> { match cmd { AuthCommand::Login { user, password } => { let db_path = "data/auth.sqlite"; - + if !std::path::Path::new(db_path).exists() { return Err(anyhow::anyhow!("Auth database not found: {}", db_path)); } - - let conn = Connection::open(db_path) - .context("Failed to open auth database")?; - - let password_hash: String = conn.query_row( - "SELECT password_hash FROM sftpgo_users WHERE username = ?", - [&user], - |row| row.get(0) - ).context("Failed to query password hash")?; - - let valid = bcrypt::verify(&password, &password_hash) - .context("Failed to verify password")?; - + + let conn = Connection::open(db_path).context("Failed to open auth database")?; + + let password_hash: String = conn + .query_row( + "SELECT password_hash FROM sftpgo_users WHERE username = ?", + [&user], + |row| row.get(0), + ) + .context("Failed to query password hash")?; + + let valid = + bcrypt::verify(&password, &password_hash).context("Failed to verify password")?; + if !valid { return Err(anyhow::anyhow!("Invalid password for user: {}", user)); } - + let token = generate_simple_token(&user); - + println!("✓ Login successful for user: {}", user); println!("✓ Token: {}", token); println!("Note: This is a simple token for demonstration. Use JWT in production."); @@ -57,7 +58,7 @@ pub fn handle_auth_command(cmd: AuthCommand) -> anyhow::Result<()> { } AuthCommand::Verify { token } => { let user = verify_simple_token(&token)?; - + println!("✓ Token valid for user: {}", user); println!("Note: This is simple token verification. Use JWT in production."); } @@ -67,37 +68,38 @@ pub fn handle_auth_command(cmd: AuthCommand) -> anyhow::Result<()> { fn generate_simple_token(user: &str) -> String { use std::time::{SystemTime, UNIX_EPOCH}; - + let timestamp = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); - + format!("{}_{}", user, timestamp) } fn verify_simple_token(token: &str) -> anyhow::Result { let parts: Vec<&str> = token.split('_').collect(); - + if parts.len() < 2 { return Err(anyhow::anyhow!("Invalid token format")); } - + let user = parts[0]; let timestamp_str = parts[1]; - - let timestamp: u64 = timestamp_str.parse() + + let timestamp: u64 = timestamp_str + .parse() .context("Failed to parse token timestamp")?; - + use std::time::{SystemTime, UNIX_EPOCH}; let now = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); - + if now - timestamp > 86400 { return Err(anyhow::anyhow!("Token expired (valid for 24 hours)")); } - + Ok(user.to_string()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/metadata/config.rs b/markbase-core/src/cli/metadata/config.rs index 249b970..781948d 100644 --- a/markbase-core/src/cli/metadata/config.rs +++ b/markbase-core/src/cli/metadata/config.rs @@ -45,7 +45,9 @@ pub fn handle_config_command(cmd: ConfigCommand) -> anyhow::Result<()> { let config_path = Path::new("config/markbase.toml"); if !config_path.exists() { - println!("Configuration file not found. Run 'markbase metadata config init' first."); + println!( + "Configuration file not found. Run 'markbase metadata config init' first." + ); return Ok(()); } @@ -61,7 +63,9 @@ pub fn handle_config_command(cmd: ConfigCommand) -> anyhow::Result<()> { let config_path = Path::new("config/markbase.toml"); if !config_path.exists() { - println!("Configuration file not found. Run 'markbase metadata config init' first."); + println!( + "Configuration file not found. Run 'markbase metadata config init' first." + ); return Ok(()); } @@ -86,7 +90,9 @@ pub fn handle_config_command(cmd: ConfigCommand) -> anyhow::Result<()> { let config_path = Path::new("config/markbase.toml"); if !config_path.exists() { - println!("Configuration file not found. Run 'markbase metadata config init' first."); + println!( + "Configuration file not found. Run 'markbase metadata config init' first." + ); return Ok(()); } @@ -115,4 +121,4 @@ fn show_section(config: &crate::config::MarkBaseConfig, section: &str) { "logging" => println!("{}", toml::to_string_pretty(&config.logging).unwrap()), _ => println!("Invalid section: {}. Valid sections: server, postgresql, authentication, test, logging", section), } -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/metadata/db.rs b/markbase-core/src/cli/metadata/db.rs index 9f0aeb0..7c36032 100644 --- a/markbase-core/src/cli/metadata/db.rs +++ b/markbase-core/src/cli/metadata/db.rs @@ -1,6 +1,6 @@ +use anyhow::Context; use clap::Subcommand; use rusqlite::Connection; -use anyhow::Context; #[derive(Subcommand)] pub enum DbCommand { @@ -34,54 +34,52 @@ pub fn handle_db_command(cmd: DbCommand) -> anyhow::Result<()> { match cmd { DbCommand::Create { user } => { let db_path = filetree::FileTree::user_db_path(&user); - + if std::path::Path::new(&db_path).exists() { println!("Database already exists: {}", db_path); println!("Use 'db status' to check database info"); return Ok(()); } - + println!("Creating database for user: {}", user); - - let conn = filetree::FileTree::init_user_db(&user) - .context("Failed to initialize database")?; - + + let conn = + filetree::FileTree::init_user_db(&user).context("Failed to initialize database")?; + println!("✓ Database created: {}", db_path); - println!("✓ Tables initialized: file_nodes, file_registry, file_locations, tree_registry"); - - conn.close().map_err(|e| anyhow::anyhow!("Failed to close database: {:?}", e))?; + println!( + "✓ Tables initialized: file_nodes, file_registry, file_locations, tree_registry" + ); + + conn.close() + .map_err(|e| anyhow::anyhow!("Failed to close database: {:?}", e))?; } DbCommand::Status { user } => { let db_path = filetree::FileTree::user_db_path(&user); - + if !std::path::Path::new(&db_path).exists() { return Err(anyhow::anyhow!("Database not found: {}", db_path)); } - - let conn = Connection::open(&db_path) - .context("Failed to open database")?; - + + let conn = Connection::open(&db_path).context("Failed to open database")?; + let file_size = std::fs::metadata(&db_path)?.len(); let file_size_mb = file_size as f64 / 1024.0 / 1024.0; - - let node_count: i64 = conn.query_row( - "SELECT COUNT(*) FROM file_nodes", - [], - |row| row.get(0) - ).context("Failed to count nodes")?; - - let file_count: i64 = conn.query_row( - "SELECT COUNT(*) FROM file_registry", - [], - |row| row.get(0) - ).context("Failed to count files")?; - + + let node_count: i64 = conn + .query_row("SELECT COUNT(*) FROM file_nodes", [], |row| row.get(0)) + .context("Failed to count nodes")?; + + let file_count: i64 = conn + .query_row("SELECT COUNT(*) FROM file_registry", [], |row| row.get(0)) + .context("Failed to count files")?; + let tree_types: Vec = { let mut stmt = conn.prepare("SELECT tree_type FROM tree_registry")?; let rows = stmt.query_map([], |row| row.get(0))?; rows.collect::, _>>()? }; - + println!("=== Database Status ==="); println!("User: {}", user); println!("Path: {}", db_path); @@ -89,21 +87,21 @@ pub fn handle_db_command(cmd: DbCommand) -> anyhow::Result<()> { println!("Nodes: {}", node_count); println!("Files: {}", file_count); println!("Tree Types: {:?}", tree_types); - - conn.close().map_err(|e| anyhow::anyhow!("Failed to close database: {:?}", e))?; + + conn.close() + .map_err(|e| anyhow::anyhow!("Failed to close database: {:?}", e))?; } DbCommand::Backup { user, output } => { let db_path = filetree::FileTree::user_db_path(&user); - + if !std::path::Path::new(&db_path).exists() { return Err(anyhow::anyhow!("Database not found: {}", db_path)); } - + println!("Backing up database for user: {} to {}", user, output); - - std::fs::copy(&db_path, &output) - .context("Failed to backup database")?; - + + std::fs::copy(&db_path, &output).context("Failed to backup database")?; + println!("✓ Database backed up to: {}", output); println!("✓ Backup size: {} bytes", std::fs::metadata(&output)?.len()); } @@ -111,24 +109,26 @@ pub fn handle_db_command(cmd: DbCommand) -> anyhow::Result<()> { if !std::path::Path::new(&input).exists() { return Err(anyhow::anyhow!("Backup file not found: {}", input)); } - + let db_path = filetree::FileTree::user_db_path(&user); - + if std::path::Path::new(&db_path).exists() { let backup_path = format!("{}.bak", db_path); println!("Warning: Database exists, creating backup: {}", backup_path); std::fs::copy(&db_path, &backup_path) .context("Failed to create backup before restore")?; } - + println!("Restoring database for user: {} from {}", user, input); - - std::fs::copy(&input, &db_path) - .context("Failed to restore database")?; - + + std::fs::copy(&input, &db_path).context("Failed to restore database")?; + println!("✓ Database restored from: {}", input); - println!("✓ Database size: {} bytes", std::fs::metadata(&db_path)?.len()); + println!( + "✓ Database size: {} bytes", + std::fs::metadata(&db_path)?.len() + ); } } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/metadata/mod.rs b/markbase-core/src/cli/metadata/mod.rs index 3a343fa..b8a289b 100644 --- a/markbase-core/src/cli/metadata/mod.rs +++ b/markbase-core/src/cli/metadata/mod.rs @@ -1,7 +1,7 @@ -pub mod config; -pub mod user; -pub mod db; pub mod auth; +pub mod config; +pub mod db; +pub mod user; use clap::Subcommand; @@ -25,4 +25,4 @@ pub async fn handle_metadata_command(cmd: MetadataCommands) -> anyhow::Result<() MetadataCommands::Auth(c) => auth::handle_auth_command(c)?, } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/metadata/user.rs b/markbase-core/src/cli/metadata/user.rs index 6065ac0..db12623 100644 --- a/markbase-core/src/cli/metadata/user.rs +++ b/markbase-core/src/cli/metadata/user.rs @@ -1,6 +1,6 @@ +use anyhow::Context; use clap::Subcommand; use rusqlite::Connection; -use anyhow::Context; #[derive(Subcommand)] pub enum UserCommand { @@ -18,7 +18,7 @@ pub enum UserCommand { #[arg(short, long)] name: String, }, -#[command(name = "user-delete")] + #[command(name = "user-delete")] Delete { #[arg(short, long)] name: String, @@ -29,58 +29,60 @@ pub fn handle_user_command(cmd: UserCommand) -> anyhow::Result<()> { match cmd { UserCommand::Create { name, password } => { let db_path = "data/auth.sqlite"; - + if !std::path::Path::new(db_path).exists() { return Err(anyhow::anyhow!("Auth database not found: {}", db_path)); } - - let conn = Connection::open(db_path) - .context("Failed to open auth database")?; - - let exists: i64 = conn.query_row( - "SELECT COUNT(*) FROM sftpgo_users WHERE username = ?", - [&name], - |row| row.get(0) - ).context("Failed to check user existence")?; - + + let conn = Connection::open(db_path).context("Failed to open auth database")?; + + let exists: i64 = conn + .query_row( + "SELECT COUNT(*) FROM sftpgo_users WHERE username = ?", + [&name], + |row| row.get(0), + ) + .context("Failed to check user existence")?; + if exists > 0 { return Err(anyhow::anyhow!("User already exists: {}", name)); } - - let password_hash = bcrypt::hash(&password, bcrypt::DEFAULT_COST) - .context("Failed to hash password")?; - + + let password_hash = + bcrypt::hash(&password, bcrypt::DEFAULT_COST).context("Failed to hash password")?; + conn.execute( "INSERT INTO sftpgo_users (username, password_hash, role, created_at) VALUES (?, ?, 'user', datetime('now'))", rusqlite::params![name, password_hash] ).context("Failed to create user")?; - + println!("✓ User created: {}", name); println!("✓ Role: user"); println!("✓ Password hashed with bcrypt"); } UserCommand::List => { let db_path = "data/auth.sqlite"; - + if !std::path::Path::new(db_path).exists() { return Err(anyhow::anyhow!("Auth database not found: {}", db_path)); } - - let conn = Connection::open(db_path) - .context("Failed to open auth database")?; - - let mut stmt = conn.prepare( - "SELECT username, role, created_at FROM sftpgo_users ORDER BY username" - ).context("Failed to prepare query")?; - - let users = stmt.query_map([], |row| { - Ok(( - row.get::<_, String>(0)?, - row.get::<_, String>(1)?, - row.get::<_, String>(2)?, - )) - }).context("Failed to query users")?; - + + let conn = Connection::open(db_path).context("Failed to open auth database")?; + + let mut stmt = conn + .prepare("SELECT username, role, created_at FROM sftpgo_users ORDER BY username") + .context("Failed to prepare query")?; + + let users = stmt + .query_map([], |row| { + Ok(( + row.get::<_, String>(0)?, + row.get::<_, String>(1)?, + row.get::<_, String>(2)?, + )) + }) + .context("Failed to query users")?; + println!("=== Users List ==="); let mut count = 0; for user in users { @@ -88,7 +90,7 @@ pub fn handle_user_command(cmd: UserCommand) -> anyhow::Result<()> { println!(" {} (role: {}, created: {})", name, role, created_at); count += 1; } - + if count == 0 { println!("No users found"); } else { @@ -97,24 +99,27 @@ pub fn handle_user_command(cmd: UserCommand) -> anyhow::Result<()> { } UserCommand::Show { name } => { let db_path = "data/auth.sqlite"; - + if !std::path::Path::new(db_path).exists() { return Err(anyhow::anyhow!("Auth database not found: {}", db_path)); } - - let conn = Connection::open(db_path) - .context("Failed to open auth database")?; - - let user = conn.query_row( - "SELECT username, role, created_at FROM sftpgo_users WHERE username = ?", - [&name], - |row| Ok(( - row.get::<_, String>(0)?, - row.get::<_, String>(1)?, - row.get::<_, String>(2)?, - )) - ).context("Failed to query user")?; - + + let conn = Connection::open(db_path).context("Failed to open auth database")?; + + let user = conn + .query_row( + "SELECT username, role, created_at FROM sftpgo_users WHERE username = ?", + [&name], + |row| { + Ok(( + row.get::<_, String>(0)?, + row.get::<_, String>(1)?, + row.get::<_, String>(2)?, + )) + }, + ) + .context("Failed to query user")?; + let (username, role, created_at) = user; println!("=== User Details ==="); println!("Username: {}", username); @@ -123,31 +128,30 @@ pub fn handle_user_command(cmd: UserCommand) -> anyhow::Result<()> { } UserCommand::Delete { name } => { let db_path = "data/auth.sqlite"; - + if !std::path::Path::new(db_path).exists() { return Err(anyhow::anyhow!("Auth database not found: {}", db_path)); } - - let conn = Connection::open(db_path) - .context("Failed to open auth database")?; - - let exists: i64 = conn.query_row( - "SELECT COUNT(*) FROM sftpgo_users WHERE username = ?", - [&name], - |row| row.get(0) - ).context("Failed to check user existence")?; - + + let conn = Connection::open(db_path).context("Failed to open auth database")?; + + let exists: i64 = conn + .query_row( + "SELECT COUNT(*) FROM sftpgo_users WHERE username = ?", + [&name], + |row| row.get(0), + ) + .context("Failed to check user existence")?; + if exists == 0 { return Err(anyhow::anyhow!("User not found: {}", name)); } - - conn.execute( - "DELETE FROM sftpgo_users WHERE username = ?", - [&name] - ).context("Failed to delete user")?; - + + conn.execute("DELETE FROM sftpgo_users WHERE username = ?", [&name]) + .context("Failed to delete user")?; + println!("✓ User deleted: {}", name); } } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/mod.rs b/markbase-core/src/cli/mod.rs index 6fad1f1..9b885bc 100644 --- a/markbase-core/src/cli/mod.rs +++ b/markbase-core/src/cli/mod.rs @@ -22,4 +22,4 @@ pub enum Commands { Storage(storage::StorageCommands), #[command(flatten)] Tools(tools::ToolsCommands), -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/storage/archive.rs b/markbase-core/src/cli/storage/archive.rs index 217d82a..434a5b7 100644 --- a/markbase-core/src/cli/storage/archive.rs +++ b/markbase-core/src/cli/storage/archive.rs @@ -21,56 +21,56 @@ pub fn handle_archive_command(cmd: ArchiveCommand) -> anyhow::Result<()> { match cmd { ArchiveCommand::Decompress { file, output } => { use crate::archive::{ArchiveConfig, ProcessorRegistry}; - + println!("Decompressing {} to {}", file, output); - + let archive_path = Path::new(&file); if !archive_path.exists() { return Err(anyhow::anyhow!("Archive file not found: {}", file)); } - + let config = ArchiveConfig::default(); let mut registry = ProcessorRegistry::new(config); registry.initialize()?; - + let output_path = Path::new(&output); std::fs::create_dir_all(output_path)?; - + let processor = registry.get_processor_mut(archive_path)?; let result = processor.extract_all(output_path)?; - + println!("✓ Archive decompressed to: {}", output); println!("✓ Files extracted: {}", result.success_files); println!("✓ Total size: {} bytes", result.total_bytes); } ArchiveCommand::List { file } => { use crate::archive::{ArchiveConfig, ProcessorRegistry}; - + println!("Listing contents of {}", file); - + let archive_path = Path::new(&file); if !archive_path.exists() { return Err(anyhow::anyhow!("Archive file not found: {}", file)); } - + let config = ArchiveConfig::default(); let mut registry = ProcessorRegistry::new(config); registry.initialize()?; - + let processor = registry.get_processor_mut(archive_path)?; let metadata = processor.open(archive_path)?; let entries = processor.list_entries()?; - + println!("=== Archive Contents ==="); println!("Format: {}", metadata.format); println!("Total files: {}", metadata.total_files); println!("Total size: {} bytes", metadata.total_size); - println!(""); - + println!(); + for entry in entries { println!(" {} ({} bytes)", entry.path.display(), entry.size); } } } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/storage/hash.rs b/markbase-core/src/cli/storage/hash.rs index 0673e0e..9eff405 100644 --- a/markbase-core/src/cli/storage/hash.rs +++ b/markbase-core/src/cli/storage/hash.rs @@ -17,4 +17,4 @@ pub fn handle_hash_command(cmd: HashCommand) -> anyhow::Result<()> { } } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/storage/mod.rs b/markbase-core/src/cli/storage/mod.rs index 20dd67d..fd66782 100644 --- a/markbase-core/src/cli/storage/mod.rs +++ b/markbase-core/src/cli/storage/mod.rs @@ -1,8 +1,8 @@ -pub mod scan; -pub mod hash; pub mod archive; -pub mod sync; +pub mod hash; pub mod mount; +pub mod scan; +pub mod sync; use clap::Subcommand; @@ -29,4 +29,4 @@ pub async fn handle_storage_command(cmd: StorageCommands) -> anyhow::Result<()> StorageCommands::Mount(c) => mount::handle_mount_command(c)?, } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/storage/mount.rs b/markbase-core/src/cli/storage/mount.rs index b29df20..ffc96e4 100644 --- a/markbase-core/src/cli/storage/mount.rs +++ b/markbase-core/src/cli/storage/mount.rs @@ -22,21 +22,25 @@ pub enum MountCommand { pub fn handle_mount_command(cmd: MountCommand) -> anyhow::Result<()> { match cmd { - MountCommand::Attach { type_, server, path } => { + MountCommand::Attach { + type_, + server, + path, + } => { use std::process::Command; - + println!("Mounting {} from {} to {}", type_, server, path); - + if type_ == "nfs" { let mount_point = std::path::Path::new(&path); std::fs::create_dir_all(mount_point)?; - + let nfs_path = format!("{}:{}", server, path); - + let status = Command::new("mount") .args(["-t", "nfs", &nfs_path, &path]) .status()?; - + if status.success() { println!("✓ NFS mounted: {} to {}", nfs_path, path); } else { @@ -45,31 +49,32 @@ pub fn handle_mount_command(cmd: MountCommand) -> anyhow::Result<()> { } else if type_ == "smb" { let mount_point = std::path::Path::new(&path); std::fs::create_dir_all(mount_point)?; - + let smb_path = format!("//{}", server); - + let status = Command::new("mount") .args(["-t", "smbfs", &smb_path, &path]) .status()?; - + if status.success() { println!("✓ SMB mounted: {} to {}", smb_path, path); } else { return Err(anyhow::anyhow!("SMB mount failed")); } } else { - return Err(anyhow::anyhow!("Unknown mount type: {}. Use 'nfs' or 'smb'", type_)); + return Err(anyhow::anyhow!( + "Unknown mount type: {}. Use 'nfs' or 'smb'", + type_ + )); } } MountCommand::Detach { path } => { use std::process::Command; - + println!("Unmounting {}", path); - - let status = Command::new("umount") - .arg(&path) - .status()?; - + + let status = Command::new("umount").arg(&path).status()?; + if status.success() { println!("✓ Unmounted: {}", path); } else { @@ -78,14 +83,13 @@ pub fn handle_mount_command(cmd: MountCommand) -> anyhow::Result<()> { } MountCommand::List => { use std::process::Command; - + println!("Listing mounted storage"); - - let output = Command::new("mount") - .output()?; - + + let output = Command::new("mount").output()?; + let mounts = String::from_utf8_lossy(&output.stdout); - + println!("=== Mounted Filesystems ==="); for line in mounts.lines() { if line.contains("nfs") || line.contains("smbfs") || line.contains("fuse") { @@ -95,4 +99,4 @@ pub fn handle_mount_command(cmd: MountCommand) -> anyhow::Result<()> { } } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/storage/scan.rs b/markbase-core/src/cli/storage/scan.rs index 06366e2..b13273c 100644 --- a/markbase-core/src/cli/storage/scan.rs +++ b/markbase-core/src/cli/storage/scan.rs @@ -31,4 +31,4 @@ pub fn handle_scan_command(cmd: ScanCommand) -> anyhow::Result<()> { } } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/storage/sync.rs b/markbase-core/src/cli/storage/sync.rs index fdf519f..5fbe743 100644 --- a/markbase-core/src/cli/storage/sync.rs +++ b/markbase-core/src/cli/storage/sync.rs @@ -17,27 +17,31 @@ pub enum SyncCommand { pub fn handle_sync_command(cmd: SyncCommand) -> anyhow::Result<()> { match cmd { - SyncCommand::Start { source, target, mode } => { + SyncCommand::Start { + source, + target, + mode, + } => { use std::path::Path; - + println!("Syncing {} to {} (mode: {})", source, target, mode); - + let source_path = Path::new(&source); let target_path = Path::new(&target); - + if !source_path.exists() { return Err(anyhow::anyhow!("Source path not found: {}", source)); } - + if mode == "mirror" { std::fs::create_dir_all(target_path)?; - + let entries = std::fs::read_dir(source_path)?; for entry in entries { let entry = entry?; let path = entry.path(); let target_file = target_path.join(entry.file_name()); - + if path.is_file() { std::fs::copy(&path, &target_file)?; println!(" Copied: {:?}", entry.file_name()); @@ -46,7 +50,7 @@ pub fn handle_sync_command(cmd: SyncCommand) -> anyhow::Result<()> { println!(" Created directory: {:?}", entry.file_name()); } } - + println!("✓ Sync completed (mirror mode)"); } else { return Err(anyhow::anyhow!("Unknown sync mode: {}. Use 'mirror'", mode)); @@ -59,4 +63,4 @@ pub fn handle_sync_command(cmd: SyncCommand) -> anyhow::Result<()> { } } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/tools/mod.rs b/markbase-core/src/cli/tools/mod.rs index 8b07a5e..b7c69c6 100644 --- a/markbase-core/src/cli/tools/mod.rs +++ b/markbase-core/src/cli/tools/mod.rs @@ -17,4 +17,4 @@ pub async fn handle_tools_command(cmd: ToolsCommands) -> anyhow::Result<()> { ToolsCommands::Test(c) => test::handle_test_command(c)?, } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/tools/render.rs b/markbase-core/src/cli/tools/render.rs index 8ac1965..8a52678 100644 --- a/markbase-core/src/cli/tools/render.rs +++ b/markbase-core/src/cli/tools/render.rs @@ -23,4 +23,4 @@ pub fn handle_render_command(cmd: RenderCommand) -> anyhow::Result<()> { } } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/cli/tools/test.rs b/markbase-core/src/cli/tools/test.rs index f2086fe..d2eb172 100644 --- a/markbase-core/src/cli/tools/test.rs +++ b/markbase-core/src/cli/tools/test.rs @@ -16,36 +16,39 @@ pub enum TestCommand { pub fn handle_test_command(cmd: TestCommand) -> anyhow::Result<()> { match cmd { - TestCommand::Bcrypt { password, verify_hash } => { + TestCommand::Bcrypt { + password, + verify_hash, + } => { use bcrypt::{hash, verify, DEFAULT_COST}; println!("=== bcrypt Hash Test ==="); println!("Password: {}", password); - println!(""); + println!(); let new_hash = hash(&password, DEFAULT_COST)?; println!("Generated hash:"); println!("{}", new_hash); - println!(""); + println!(); if let Some(hash_to_verify) = verify_hash { println!("Verifying hash: {}", hash_to_verify); let valid = verify(&password, &hash_to_verify)?; println!("Valid: {}", valid); - println!(""); + println!(); } let db_hash = "$2b$10$ha5wU.mOi8fHLJCfun860u2cfVopa04jwe/q82IKOwqp5uG70qsH6"; println!("Database hash: {}", db_hash); let valid = verify(&password, db_hash)?; println!("Database hash valid for '{}': {}", password, valid); - println!(""); + println!(); if !valid { println!("❌ Database hash is incorrect!"); println!("Update SQL:"); println!("UPDATE sftpgo_users SET password_hash = '{}' WHERE username IN ('testuser', 'demo', 'warren', 'momentry');", new_hash); - println!(""); + println!(); println!("Execute:"); println!("sqlite3 data/auth.sqlite \"UPDATE sftpgo_users SET password_hash = '{}' WHERE username IN ('testuser', 'demo', 'warren', 'momentry');\"", new_hash); } else { @@ -58,4 +61,4 @@ pub fn handle_test_command(cmd: TestCommand) -> anyhow::Result<()> { } } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/config/mod.rs b/markbase-core/src/config/mod.rs index 06a049b..86a2e4d 100644 --- a/markbase-core/src/config/mod.rs +++ b/markbase-core/src/config/mod.rs @@ -8,6 +8,7 @@ pub use web::*; /// Unified application configuration #[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Default)] pub struct AppConfig { #[serde(default)] pub web: WebSection, @@ -154,13 +155,19 @@ impl AppConfig { self.web.host = v; } if let Ok(v) = std::env::var("MB_WEB_PORT") { - if let Ok(p) = v.parse() { self.web.port = p; } + if let Ok(p) = v.parse() { + self.web.port = p; + } } if let Ok(v) = std::env::var("MB_SSH_PORT") { - if let Ok(p) = v.parse() { self.ssh.port = p; } + if let Ok(p) = v.parse() { + self.ssh.port = p; + } } if let Ok(v) = std::env::var("MB_SFTP_PORT") { - if let Ok(p) = v.parse() { self.sftp.port = p; } + if let Ok(p) = v.parse() { + self.sftp.port = p; + } } if let Ok(v) = std::env::var("MB_S3_ENABLED") { self.s3.enabled = v == "true" || v == "1"; @@ -172,16 +179,6 @@ impl AppConfig { } } -impl Default for AppConfig { - fn default() -> Self { - Self { - web: WebSection::default(), - s3: S3Section::default(), - sftp: SftpSection::default(), - ssh: SshSection::default(), - } - } -} #[cfg(test)] mod tests { diff --git a/markbase-core/src/config/web.rs b/markbase-core/src/config/web.rs index 4f41c4f..89860f2 100644 --- a/markbase-core/src/config/web.rs +++ b/markbase-core/src/config/web.rs @@ -323,11 +323,15 @@ impl MarkBaseConfig { } if self.authentication.default_user.is_empty() { - return Err(anyhow::anyhow!("authentication.default_user cannot be empty")); + return Err(anyhow::anyhow!( + "authentication.default_user cannot be empty" + )); } if self.authentication.default_password.is_empty() { - return Err(anyhow::anyhow!("authentication.default_password cannot be empty")); + return Err(anyhow::anyhow!( + "authentication.default_password cannot be empty" + )); } if self.authentication.max_sessions_per_user == 0 { diff --git a/markbase-core/src/download/db.rs b/markbase-core/src/download/db.rs index da528d3..531dc9a 100644 --- a/markbase-core/src/download/db.rs +++ b/markbase-core/src/download/db.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use rusqlite::{Connection, params}; +use rusqlite::{params, Connection}; use serde::{Deserialize, Serialize}; use std::path::Path; @@ -46,10 +46,10 @@ impl DownloadDb { Self::init_tables(&conn)?; conn }; - + Ok(DownloadDb { conn }) } - + fn init_tables(conn: &Connection) -> Result<()> { conn.execute_batch( "CREATE TABLE IF NOT EXISTS products ( @@ -74,63 +74,70 @@ impl DownloadDb { CREATE INDEX IF NOT EXISTS idx_product_files_product_id ON product_files(product_id); CREATE INDEX IF NOT EXISTS idx_products_series ON products(series); - " + ", )?; - + Ok(()) } - - pub fn create_product(&mut self, product_name: &str, series: &str, description: Option<&str>) -> Result { + + pub fn create_product( + &mut self, + product_name: &str, + series: &str, + description: Option<&str>, + ) -> Result { let now = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string(); - + self.conn.execute( "INSERT INTO products (product_name, series, description, created_at) VALUES (?1, ?2, ?3, ?4)", params![product_name, series, description, now], )?; - + Ok(self.conn.last_insert_rowid()) } - + pub fn get_all_products(&self) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, product_name, series, description, created_at FROM products ORDER BY series, product_name" )?; - - let products = stmt.query_map([], |row| { - Ok(Product { - id: row.get(0)?, - product_name: row.get(1)?, - series: row.get(2)?, - description: row.get(3)?, - created_at: row.get(4)?, - }) - })? - .collect::, _>>()?; - + + let products = stmt + .query_map([], |row| { + Ok(Product { + id: row.get(0)?, + product_name: row.get(1)?, + series: row.get(2)?, + description: row.get(3)?, + created_at: row.get(4)?, + }) + })? + .collect::, _>>()?; + Ok(products) } - + pub fn get_products_by_series(&self, series: &str) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, product_name, series, description, created_at FROM products - WHERE series = ?1 ORDER BY product_name" + WHERE series = ?1 ORDER BY product_name", )?; - - let products = stmt.query_map([series], |row| { - Ok(Product { - id: row.get(0)?, - product_name: row.get(1)?, - series: row.get(2)?, - description: row.get(3)?, - created_at: row.get(4)?, - }) - })? - .collect::, _>>()?; - + + let products = stmt + .query_map([series], |row| { + Ok(Product { + id: row.get(0)?, + product_name: row.get(1)?, + series: row.get(2)?, + description: row.get(3)?, + created_at: row.get(4)?, + }) + })? + .collect::, _>>()?; + Ok(products) } - + pub fn get_series_stats(&self) -> Result> { let mut stmt = self.conn.prepare( "SELECT @@ -141,106 +148,118 @@ impl DownloadDb { FROM products p LEFT JOIN product_files pf ON p.id = pf.product_id GROUP BY p.series - ORDER BY p.series" + ORDER BY p.series", )?; - - let stats = stmt.query_map([], |row| { - Ok(SeriesStats { - series: row.get(0)?, - product_count: row.get(1)?, - file_count: row.get(2)?, - total_size: row.get::<_, i64>(3)? as u64, - }) - })? - .collect::, _>>()?; - + + let stats = stmt + .query_map([], |row| { + Ok(SeriesStats { + series: row.get(0)?, + product_count: row.get(1)?, + file_count: row.get(2)?, + total_size: row.get::<_, i64>(3)? as u64, + }) + })? + .collect::, _>>()?; + Ok(stats) } - - pub fn add_file_to_product(&mut self, product_id: i64, file_path: &str, file_name: &str, file_size: u64, file_hash: Option<&str>) -> Result { + + pub fn add_file_to_product( + &mut self, + product_id: i64, + file_path: &str, + file_name: &str, + file_size: u64, + file_hash: Option<&str>, + ) -> Result { let now = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string(); - + self.conn.execute( "INSERT INTO product_files (product_id, file_path, file_name, file_size, file_hash, uploaded_at) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", params![product_id, file_path, file_name, file_size as i64, file_hash, now], )?; - + Ok(self.conn.last_insert_rowid()) } - + pub fn get_files_by_product(&self, product_id: i64) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, product_id, file_path, file_name, file_size, file_hash, download_count, uploaded_at FROM product_files WHERE product_id = ?1 ORDER BY file_name" )?; - - let files = stmt.query_map([product_id], |row| { - Ok(ProductFile { - id: row.get(0)?, - product_id: row.get(1)?, - file_path: row.get(2)?, - file_name: row.get(3)?, - file_size: row.get::<_, i64>(4)? as u64, - file_hash: row.get(5)?, - download_count: row.get(6)?, - uploaded_at: row.get(7)?, - }) - })? - .collect::, _>>()?; - + + let files = stmt + .query_map([product_id], |row| { + Ok(ProductFile { + id: row.get(0)?, + product_id: row.get(1)?, + file_path: row.get(2)?, + file_name: row.get(3)?, + file_size: row.get::<_, i64>(4)? as u64, + file_hash: row.get(5)?, + download_count: row.get(6)?, + uploaded_at: row.get(7)?, + }) + })? + .collect::, _>>()?; + Ok(files) } - + pub fn increment_download_count(&mut self, file_id: i64) -> Result<()> { self.conn.execute( "UPDATE product_files SET download_count = download_count + 1 WHERE id = ?1", params![file_id], )?; - + Ok(()) } - + pub fn get_all_files(&self) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, product_id, file_path, file_name, file_size, file_hash, download_count, uploaded_at FROM product_files ORDER BY uploaded_at DESC" )?; - - let files = stmt.query_map([], |row| { - Ok(ProductFile { - id: row.get(0)?, - product_id: row.get(1)?, - file_path: row.get(2)?, - file_name: row.get(3)?, - file_size: row.get::<_, i64>(4)? as u64, - file_hash: row.get(5)?, - download_count: row.get(6)?, - uploaded_at: row.get(7)?, - }) - })? - .collect::, _>>()?; - + + let files = stmt + .query_map([], |row| { + Ok(ProductFile { + id: row.get(0)?, + product_id: row.get(1)?, + file_path: row.get(2)?, + file_name: row.get(3)?, + file_size: row.get::<_, i64>(4)? as u64, + file_hash: row.get(5)?, + download_count: row.get(6)?, + uploaded_at: row.get(7)?, + }) + })? + .collect::, _>>()?; + Ok(files) } - + pub fn delete_product_with_files(&mut self, product_id: i64) -> Result<(i64, i64)> { // 先删除关联的文件映射 self.conn.execute( "DELETE FROM product_files WHERE product_id = ?1", params![product_id], )?; - + let deleted_files = self.conn.last_insert_rowid(); - + // 再删除产品记录 - self.conn.execute( - "DELETE FROM products WHERE id = ?1", - params![product_id], - )?; - - let deleted_product = if self.conn.last_insert_rowid() > 0 { 1 } else { 0 }; - + self.conn + .execute("DELETE FROM products WHERE id = ?1", params![product_id])?; + + let deleted_product = if self.conn.last_insert_rowid() > 0 { + 1 + } else { + 0 + }; + Ok((deleted_files, deleted_product)) } -} \ No newline at end of file +} diff --git a/markbase-core/src/download/download_handler.rs b/markbase-core/src/download/download_handler.rs index bd0e7af..3da6bf9 100644 --- a/markbase-core/src/download/download_handler.rs +++ b/markbase-core/src/download/download_handler.rs @@ -7,15 +7,19 @@ use axum::{ use std::fs::File; use std::io::Read; -use crate::server::AppState; use crate::download::db::DownloadDb; +use crate::server::AppState; pub async fn download_file( Path(file_id): Path, State(state): State, ) -> impl IntoResponse { - let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite"); - + let db_path = format!( + "{}{}", + state.db_dir.replace("users", "downloads"), + "/products.sqlite" + ); + match DownloadDb::new(&db_path) { Ok(mut db) => { // 获取文件信息 @@ -24,48 +28,65 @@ pub async fn download_file( if files.is_empty() { return (StatusCode::NOT_FOUND, "File not found").into_response(); } - + let file_info = &files[0]; - + // 更新下载统计 db.increment_download_count(file_info.id).ok(); - + // 构建文件路径(使用配置的db_dir) let base_path = state.db_dir.replace("users", "Downloads"); let file_path = std::path::Path::new(&base_path).join(&file_info.file_path); - + if !file_path.exists() { return (StatusCode::NOT_FOUND, "File not found on disk").into_response(); } - + // 读取文件内容 match File::open(&file_path) { Ok(mut file) => { let mut buffer = Vec::new(); match file.read_to_end(&mut buffer) { - Ok(_) => { - Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "application/octet-stream") - .header( - header::CONTENT_DISPOSITION, - format!("attachment; filename=\"{}\"", file_info.file_name) - ) - .header("X-File-Hash", file_info.file_hash.clone().unwrap_or_default()) - .header("X-File-Size", file_info.file_size) - .body(buffer.into()) - .unwrap() - } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error reading file: {}", e)).into_response() + Ok(_) => Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "application/octet-stream") + .header( + header::CONTENT_DISPOSITION, + format!("attachment; filename=\"{}\"", file_info.file_name), + ) + .header( + "X-File-Hash", + file_info.file_hash.clone().unwrap_or_default(), + ) + .header("X-File-Size", file_info.file_size) + .body(buffer.into()) + .unwrap(), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Error reading file: {}", e), + ) + .into_response(), } } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error opening file: {}", e)).into_response() + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Error opening file: {}", e), + ) + .into_response(), } } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {}", e)).into_response() + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Database error: {}", e), + ) + .into_response(), } } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {}", e)).into_response() + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Database error: {}", e), + ) + .into_response(), } } @@ -80,69 +101,84 @@ pub async fn download_file_by_path( // User files are in Downloads/user_id/ format!("/Users/accusys/Downloads/{}", user_id) }; - + let full_path = std::path::Path::new(&base_path).join(&file_path); - + if !full_path.exists() { return (StatusCode::NOT_FOUND, "File not found").into_response(); } - - let filename = file_path.split('/').last().unwrap_or("unknown"); - + + let filename = file_path.split('/').next_back().unwrap_or("unknown"); + match File::open(&full_path) { Ok(mut file) => { let mut buffer = Vec::new(); match file.read_to_end(&mut buffer) { - Ok(_) => { - Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "application/octet-stream") - .header( - header::CONTENT_DISPOSITION, - format!("attachment; filename=\"{}\"", filename) - ) - .body(buffer.into()) - .unwrap() - } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error reading file: {}", e)).into_response() + Ok(_) => Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "application/octet-stream") + .header( + header::CONTENT_DISPOSITION, + format!("attachment; filename=\"{}\"", filename), + ) + .body(buffer.into()) + .unwrap(), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Error reading file: {}", e), + ) + .into_response(), } } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error opening file: {}", e)).into_response() + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Error opening file: {}", e), + ) + .into_response(), } } -pub async fn get_download_stats( - State(state): State, -) -> impl IntoResponse { - let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite"); - +pub async fn get_download_stats(State(state): State) -> impl IntoResponse { + let db_path = format!( + "{}{}", + state.db_dir.replace("users", "downloads"), + "/products.sqlite" + ); + match DownloadDb::new(&db_path) { - Ok(db) => { - match db.get_all_files() { - Ok(files) => { - let total_downloads: i64 = files.iter().map(|f| f.download_count).sum(); - let top_files: Vec<_> = files.iter() - .filter(|f| f.download_count > 0) - .take(10) - .map(|f| serde_json::json!({ + Ok(db) => match db.get_all_files() { + Ok(files) => { + let total_downloads: i64 = files.iter().map(|f| f.download_count).sum(); + let top_files: Vec<_> = files + .iter() + .filter(|f| f.download_count > 0) + .take(10) + .map(|f| { + serde_json::json!({ "file_name": f.file_name, "download_count": f.download_count - })) - .collect(); - - ( - StatusCode::OK, - Json(serde_json::json!({ - "total_files": files.len(), - "total_downloads": total_downloads, - "top_files": top_files - })) - ) - } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()}))) + }) + }) + .collect(); + + ( + StatusCode::OK, + Json(serde_json::json!({ + "total_files": files.len(), + "total_downloads": total_downloads, + "top_files": top_files + })), + ) } - } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()}))) + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({"error": e.to_string()})), + ), + }, + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({"error": e.to_string()})), + ), } } @@ -151,35 +187,41 @@ pub async fn download_product_file( ) -> impl IntoResponse { let base_path = format!("/Users/accusys/markbase/data/downloads/{}/", product_series); let full_path = std::path::Path::new(&base_path).join(&file_path); - + if !full_path.exists() { return (StatusCode::NOT_FOUND, "File not found").into_response(); } - + if full_path.is_dir() { return (StatusCode::BAD_REQUEST, "Path is a directory, not a file").into_response(); } - - let filename = file_path.split('/').last().unwrap_or("unknown"); - + + let filename = file_path.split('/').next_back().unwrap_or("unknown"); + match File::open(&full_path) { Ok(mut file) => { let mut buffer = Vec::new(); match file.read_to_end(&mut buffer) { - Ok(_) => { - Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "application/octet-stream") - .header( - header::CONTENT_DISPOSITION, - format!("attachment; filename=\"{}\"", filename) - ) - .body(buffer.into()) - .unwrap() - } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error reading file: {}", e)).into_response() + Ok(_) => Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "application/octet-stream") + .header( + header::CONTENT_DISPOSITION, + format!("attachment; filename=\"{}\"", filename), + ) + .body(buffer.into()) + .unwrap(), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Error reading file: {}", e), + ) + .into_response(), } } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error opening file: {}", e)).into_response() + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Error opening file: {}", e), + ) + .into_response(), } -} \ No newline at end of file +} diff --git a/markbase-core/src/download/handlers.rs b/markbase-core/src/download/handlers.rs index 9fa77cc..c64398f 100644 --- a/markbase-core/src/download/handlers.rs +++ b/markbase-core/src/download/handlers.rs @@ -1,27 +1,22 @@ use axum::{ - extract::{Path, State}, - http::{HeaderMap, StatusCode}, - response::{Html, IntoResponse, Json}, + extract::Path, + http::StatusCode, + response::{IntoResponse, Json}, }; use serde_json::json; use std::path::PathBuf; -use crate::server::AppState; use crate::download::storage; -pub async fn list_uploaded_files( - Path(user_id): Path, -) -> impl IntoResponse { +pub async fn list_uploaded_files(Path(user_id): Path) -> impl IntoResponse { let file_list = storage::scan_uploaded_files(&user_id); (StatusCode::OK, Json(file_list)) } -pub async fn get_file_info( - Path((user_id, filename)): Path<(String, String)>, -) -> impl IntoResponse { +pub async fn get_file_info(Path((user_id, filename)): Path<(String, String)>) -> impl IntoResponse { let base_path = format!("/Users/accusys/Downloads/{}", user_id); let file_path = PathBuf::from(&base_path).join(&filename); - + if !file_path.exists() { return ( StatusCode::NOT_FOUND, @@ -29,7 +24,7 @@ pub async fn get_file_info( ) .into_response(); } - + let metadata = std::fs::metadata(&file_path).unwrap(); let file_size = metadata.len(); let file_hash = if file_size > 0 { @@ -37,7 +32,7 @@ pub async fn get_file_info( } else { Some("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string()) }; - + ( StatusCode::OK, Json(json!({ @@ -49,4 +44,4 @@ pub async fn get_file_info( })), ) .into_response() -} \ No newline at end of file +} diff --git a/markbase-core/src/download/mod.rs b/markbase-core/src/download/mod.rs index 42c19d8..4185e95 100644 --- a/markbase-core/src/download/mod.rs +++ b/markbase-core/src/download/mod.rs @@ -1,12 +1,12 @@ -pub mod models; pub mod db; -pub mod handlers; -pub mod storage; -pub mod product_handlers; pub mod download_handler; +pub mod handlers; +pub mod models; +pub mod product_handlers; +pub mod storage; -pub use models::*; pub use db::{DownloadDb, Product, ProductFile, SeriesStats}; +pub use download_handler::*; pub use handlers::*; +pub use models::*; pub use product_handlers::*; -pub use download_handler::*; \ No newline at end of file diff --git a/markbase-core/src/download/models.rs b/markbase-core/src/download/models.rs index 5a2a4fb..325d371 100644 --- a/markbase-core/src/download/models.rs +++ b/markbase-core/src/download/models.rs @@ -39,4 +39,4 @@ pub struct DownloadStats { pub total_files: i64, pub total_downloads: i64, pub series_stats: Vec, -} \ No newline at end of file +} diff --git a/markbase-core/src/download/product_handlers.rs b/markbase-core/src/download/product_handlers.rs index b1e814c..a0e412b 100644 --- a/markbase-core/src/download/product_handlers.rs +++ b/markbase-core/src/download/product_handlers.rs @@ -1,33 +1,42 @@ use axum::{ extract::{Path, State}, http::StatusCode, - response::{Json, IntoResponse}, + response::{IntoResponse, Json}, }; use serde_json::json; +use crate::download::db::DownloadDb; use crate::server::AppState; -use crate::download::db::{DownloadDb, Product, ProductFile, SeriesStats}; -pub async fn list_all_products( - State(state): State, -) -> impl IntoResponse { - let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite"); - +pub async fn list_all_products(State(state): State) -> impl IntoResponse { + let db_path = format!( + "{}{}", + state.db_dir.replace("users", "downloads"), + "/products.sqlite" + ); + match DownloadDb::new(&db_path) { - Ok(db) => { - match db.get_all_products() { - Ok(products) => (StatusCode::OK, Json(json!({ + Ok(db) => match db.get_all_products() { + Ok(products) => ( + StatusCode::OK, + Json(json!({ "products": products, "total": products.len() - }))), - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ + })), + ), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": e.to_string() - }))), - } - } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ - "error": e.to_string() - }))), + })), + ), + }, + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": e.to_string() + })), + ), } } @@ -35,47 +44,67 @@ pub async fn list_products_by_series( Path(series): Path, State(state): State, ) -> impl IntoResponse { - let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite"); - + let db_path = format!( + "{}{}", + state.db_dir.replace("users", "downloads"), + "/products.sqlite" + ); + match DownloadDb::new(&db_path) { - Ok(db) => { - match db.get_products_by_series(&series) { - Ok(products) => (StatusCode::OK, Json(json!({ + Ok(db) => match db.get_products_by_series(&series) { + Ok(products) => ( + StatusCode::OK, + Json(json!({ "series": series, "products": products, "total": products.len() - }))), - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ + })), + ), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": e.to_string() - }))), - } - } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ - "error": e.to_string() - }))), + })), + ), + }, + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": e.to_string() + })), + ), } } -pub async fn get_series_stats( - State(state): State, -) -> impl IntoResponse { - let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite"); - +pub async fn get_series_stats(State(state): State) -> impl IntoResponse { + let db_path = format!( + "{}{}", + state.db_dir.replace("users", "downloads"), + "/products.sqlite" + ); + match DownloadDb::new(&db_path) { - Ok(db) => { - match db.get_series_stats() { - Ok(stats) => (StatusCode::OK, Json(json!({ + Ok(db) => match db.get_series_stats() { + Ok(stats) => ( + StatusCode::OK, + Json(json!({ "series_stats": stats, "total_series": stats.len() - }))), - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ + })), + ), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": e.to_string() - }))), - } - } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ - "error": e.to_string() - }))), + })), + ), + }, + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": e.to_string() + })), + ), } } @@ -83,25 +112,36 @@ pub async fn get_product_files( Path(product_id): Path, State(state): State, ) -> impl IntoResponse { - let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite"); - + let db_path = format!( + "{}{}", + state.db_dir.replace("users", "downloads"), + "/products.sqlite" + ); + match DownloadDb::new(&db_path) { - Ok(db) => { - match db.get_files_by_product(product_id) { - Ok(files) => (StatusCode::OK, Json(json!({ + Ok(db) => match db.get_files_by_product(product_id) { + Ok(files) => ( + StatusCode::OK, + Json(json!({ "product_id": product_id, "files": files, "total_files": files.len(), "total_size": files.iter().map(|f| f.file_size).sum::() - }))), - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ + })), + ), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": e.to_string() - }))), - } - } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ - "error": e.to_string() - }))), + })), + ), + }, + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": e.to_string() + })), + ), } } @@ -109,29 +149,40 @@ pub async fn create_product_handler( State(state): State, Json(payload): Json, ) -> impl IntoResponse { - let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite"); - + let db_path = format!( + "{}{}", + state.db_dir.replace("users", "downloads"), + "/products.sqlite" + ); + let product_name = payload["product_name"].as_str().unwrap_or(""); let series = payload["series"].as_str().unwrap_or(""); let description = payload["description"].as_str(); - + match DownloadDb::new(&db_path) { - Ok(mut db) => { - match db.create_product(product_name, series, description) { - Ok(product_id) => (StatusCode::OK, Json(json!({ + Ok(mut db) => match db.create_product(product_name, series, description) { + Ok(product_id) => ( + StatusCode::OK, + Json(json!({ "ok": true, "product_id": product_id, "product_name": product_name, "series": series - }))), - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ + })), + ), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": e.to_string() - }))), - } - } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ - "error": e.to_string() - }))), + })), + ), + }, + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": e.to_string() + })), + ), } } @@ -140,48 +191,62 @@ pub async fn assign_files_to_product( State(state): State, Json(payload): Json, ) -> impl IntoResponse { - let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite"); - + let db_path = format!( + "{}{}", + state.db_dir.replace("users", "downloads"), + "/products.sqlite" + ); + let files_vec = payload["files"].as_array().cloned().unwrap_or_default(); let files = files_vec.as_slice(); - + match DownloadDb::new(&db_path) { Ok(mut db) => { let mut assigned_count = 0; let mut errors = vec![]; - + for file in files { let file_path = file["file_path"].as_str().unwrap_or(""); let file_name = file["file_name"].as_str().unwrap_or(""); let file_size = file["file_size"].as_u64().unwrap_or(0); let file_hash = file["file_hash"].as_str(); - - match db.add_file_to_product(product_id, file_path, file_name, file_size, file_hash) { + + match db.add_file_to_product(product_id, file_path, file_name, file_size, file_hash) + { Ok(_) => assigned_count += 1, Err(e) => { errors.push(format!("Failed to assign {}: {}", file_path, e)); } } } - + if errors.is_empty() { - (StatusCode::OK, Json(json!({ - "ok": true, - "product_id": product_id, - "assigned_count": assigned_count - }))) + ( + StatusCode::OK, + Json(json!({ + "ok": true, + "product_id": product_id, + "assigned_count": assigned_count + })), + ) } else { - (StatusCode::PARTIAL_CONTENT, Json(json!({ - "ok": true, - "product_id": product_id, - "assigned_count": assigned_count, - "errors": errors - }))) + ( + StatusCode::PARTIAL_CONTENT, + Json(json!({ + "ok": true, + "product_id": product_id, + "assigned_count": assigned_count, + "errors": errors + })), + ) } } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ - "error": e.to_string() - }))), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": e.to_string() + })), + ), } } @@ -189,24 +254,35 @@ pub async fn delete_product( Path(product_id): Path, State(state): State, ) -> impl IntoResponse { - let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite"); - + let db_path = format!( + "{}{}", + state.db_dir.replace("users", "downloads"), + "/products.sqlite" + ); + match DownloadDb::new(&db_path) { - Ok(mut db) => { - match db.delete_product_with_files(product_id) { - Ok((deleted_files, deleted_product)) => (StatusCode::OK, Json(json!({ + Ok(mut db) => match db.delete_product_with_files(product_id) { + Ok((deleted_files, deleted_product)) => ( + StatusCode::OK, + Json(json!({ "ok": true, "product_id": product_id, "deleted_files": deleted_files, "deleted_product": deleted_product - }))), - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ + })), + ), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": e.to_string() - }))), - } - } - Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ - "error": e.to_string() - }))), + })), + ), + }, + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": e.to_string() + })), + ), } -} \ No newline at end of file +} diff --git a/markbase-core/src/download/storage.rs b/markbase-core/src/download/storage.rs index 280eafe..d4e491a 100644 --- a/markbase-core/src/download/storage.rs +++ b/markbase-core/src/download/storage.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -use std::path::{Path, PathBuf}; +use std::path::Path; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FileInfo { @@ -23,17 +23,17 @@ pub struct FileListResponse { pub fn scan_uploaded_files(user_id: &str) -> FileListResponse { let base_path = format!("/Users/accusys/Downloads/{}", user_id); let path = Path::new(&base_path); - + let mut files = Vec::new(); let mut total_size = 0u64; - + if path.exists() { scan_directory_recursive(path, path, &mut files, &mut total_size); } - + FileListResponse { user_id: user_id.to_string(), - base_path: base_path, + base_path, total_files: files.len(), total_size, files, @@ -49,24 +49,25 @@ fn scan_directory_recursive( if let Ok(entries) = std::fs::read_dir(current) { for entry in entries.flatten() { let path = entry.path(); - + if path.is_file() { - let filename = path.file_name() + let filename = path + .file_name() .and_then(|n| n.to_str()) .unwrap_or("unknown") .to_string(); - - let file_size = entry.metadata() - .map(|m| m.len()) - .unwrap_or(0); - - let relative_path = path.strip_prefix(base) + + let file_size = entry.metadata().map(|m| m.len()).unwrap_or(0); + + let relative_path = path + .strip_prefix(base) .ok() .and_then(|p| p.to_str()) .map(|s| s.to_string()) .unwrap_or_else(|| filename.clone()); - - let upload_time = entry.metadata() + + let upload_time = entry + .metadata() .ok() .and_then(|m| m.modified().ok()) .and_then(|t| { @@ -75,13 +76,16 @@ fn scan_directory_recursive( .map(|dt| dt.format("%Y-%m-%dT%H:%M:%SZ").to_string()) }) .unwrap_or_else(|| chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string()); - + let file_hash = if file_size > 0 { compute_file_hash(&path).ok() } else { - Some("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string()) + Some( + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + .to_string(), + ) }; - + files.push(FileInfo { filename, file_size, @@ -90,7 +94,7 @@ fn scan_directory_recursive( relative_path, upload_time, }); - + *total_size += file_size; } else if path.is_dir() { scan_directory_recursive(base, &path, files, total_size); @@ -102,11 +106,11 @@ fn scan_directory_recursive( pub fn compute_file_hash(path: &Path) -> Result { use sha2::{Digest, Sha256}; use std::io::Read; - + let mut file = std::fs::File::open(path)?; let mut hasher = Sha256::new(); let mut buffer = [0u8; 8192]; - + loop { let bytes_read = file.read(&mut buffer)?; if bytes_read == 0 { @@ -114,6 +118,6 @@ pub fn compute_file_hash(path: &Path) -> Result { } hasher.update(&buffer[..bytes_read]); } - + Ok(format!("{:x}", hasher.finalize())) -} \ No newline at end of file +} diff --git a/markbase-core/src/import_markdown.rs b/markbase-core/src/import_markdown.rs index d4c1712..abc8266 100644 --- a/markbase-core/src/import_markdown.rs +++ b/markbase-core/src/import_markdown.rs @@ -1,9 +1,7 @@ use anyhow::Result; +use serde::{Deserialize, Serialize}; use std::fs; use std::path::Path; -use serde::{Deserialize, Serialize}; -use pulldown_cmark::{Parser, Event, Tag, HeadingLevel, TagEnd}; -use regex::Regex; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MarkdownFile { @@ -39,17 +37,21 @@ pub struct SeriesMarkdown { pub fn parse_category_markdown(content: &str) -> Result { let mut category = String::new(); let mut sections: Vec = Vec::new(); - + let lines: Vec<&str> = content.lines().collect(); let mut current_product = String::new(); let mut current_files: Vec = Vec::new(); let mut pending_file: Option<(String, String)> = None; - + for i in 0..lines.len() { let line = lines[i].trim(); - + if line.contains("**Category**:") { - category = line.replace("**Category**:", "").replace("**", "").trim().to_string(); + category = line + .replace("**Category**:", "") + .replace("**", "") + .trim() + .to_string(); } else if line.starts_with("## ") { if !current_product.is_empty() && !current_files.is_empty() { sections.push(CategorySection { @@ -72,13 +74,17 @@ pub fn parse_category_markdown(content: &str) -> Result { current_files.push(MarkdownFile { filename, size: Some(size), - download_url: line.trim_start_matches('`').trim_end_matches('`').trim().to_string(), + download_url: line + .trim_start_matches('`') + .trim_end_matches('`') + .trim() + .to_string(), }); pending_file = None; } } } - + if !current_product.is_empty() && !current_files.is_empty() { sections.push(CategorySection { product: current_product.clone(), @@ -92,17 +98,21 @@ pub fn parse_category_markdown(content: &str) -> Result { pub fn parse_series_markdown(content: &str) -> Result { let mut series = String::new(); let mut sections: Vec = Vec::new(); - + let lines: Vec<&str> = content.lines().collect(); let mut current_category = String::new(); let mut current_files: Vec = Vec::new(); let mut pending_file: Option<(String, String)> = None; - + for i in 0..lines.len() { let line = lines[i].trim(); - + if line.starts_with("# ") && line.contains("Download Links") { - series = line.replace("# ", "").replace(" Download Links", "").trim().to_string(); + series = line + .replace("# ", "") + .replace(" Download Links", "") + .trim() + .to_string(); } else if line.starts_with("## ") { if !current_category.is_empty() && !current_files.is_empty() { sections.push(SeriesSection { @@ -125,13 +135,17 @@ pub fn parse_series_markdown(content: &str) -> Result { current_files.push(MarkdownFile { filename, size: Some(size), - download_url: line.trim_start_matches('`').trim_end_matches('`').trim().to_string(), + download_url: line + .trim_start_matches('`') + .trim_end_matches('`') + .trim() + .to_string(), }); pending_file = None; } } } - + if !current_category.is_empty() && !current_files.is_empty() { sections.push(SeriesSection { category: current_category.clone(), @@ -144,65 +158,77 @@ pub fn parse_series_markdown(content: &str) -> Result { pub fn read_category_files(dir: &Path) -> Result> { let mut files = Vec::new(); - + for entry in fs::read_dir(dir)? { let entry = entry?; let path = entry.path(); - - if path.extension().map_or(false, |ext| ext == "md") && path.file_name() != Some(std::ffi::OsStr::new("README.md")) { + + if path.extension().is_some_and(|ext| ext == "md") + && path.file_name() != Some(std::ffi::OsStr::new("README.md")) + { let filename = path.file_name().unwrap().to_string_lossy().to_string(); let content = fs::read_to_string(&path)?; files.push((filename, content)); } } - + Ok(files) } pub fn read_series_files(dir: &Path) -> Result> { let mut files = Vec::new(); - + for entry in fs::read_dir(dir)? { let entry = entry?; let path = entry.path(); - - if path.extension().map_or(false, |ext| ext == "md") && path.file_name() != Some(std::ffi::OsStr::new("README.md")) { + + if path.extension().is_some_and(|ext| ext == "md") + && path.file_name() != Some(std::ffi::OsStr::new("README.md")) + { let filename = path.file_name().unwrap().to_string_lossy().to_string(); let content = fs::read_to_string(&path)?; files.push((filename, content)); } } - + Ok(files) } -pub fn import_categories_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type: &str) -> Result<()> { +pub fn import_categories_to_db( + conn: &rusqlite::Connection, + user_id: &str, + tree_type: &str, +) -> Result<()> { use crate::FileTree; - use filetree::node::{FileNode, Aliases, NodeType}; - use uuid::Uuid; + use filetree::node::{Aliases, FileNode, NodeType}; use std::collections::HashMap; - + use uuid::Uuid; + let category_dir = Path::new("/Users/accusys/markbase/data/downloads/by_category"); let files = read_category_files(category_dir)?; - + println!("Found {} Markdown files", files.len()); - + let mut tree = FileTree::load(conn, user_id, tree_type)?; - + for (_filename, content) in files { let parsed = parse_category_markdown(&content)?; - - println!("Parsed category: '{}', sections: {}", parsed.category, parsed.sections.len()); - + + println!( + "Parsed category: '{}', sections: {}", + parsed.category, + parsed.sections.len() + ); + if parsed.category.is_empty() { println!("Warning: category is empty, skipping"); continue; } - + let category_node_id = Uuid::new_v4().to_string(); let mut aliases_map = HashMap::new(); aliases_map.insert("category_type".to_string(), "category".to_string()); - + let category_node = FileNode { node_id: category_node_id.clone(), label: parsed.category.clone(), @@ -221,20 +247,27 @@ pub fn import_categories_to_db(conn: &rusqlite::Connection, user_id: &str, tree_ updated_at: chrono::Utc::now().to_rfc3339(), sort_order: 0, }; - - println!("Inserting category node: {} (id: {})", category_node.label, category_node_id); - + + println!( + "Inserting category node: {} (id: {})", + category_node.label, category_node_id + ); + tree.insert_node(conn, &category_node)?; - + println!("Category node inserted successfully"); - + for section in parsed.sections { - println!("Processing section: {} with {} files", section.product, section.files.len()); - + println!( + "Processing section: {} with {} files", + section.product, + section.files.len() + ); + let product_node_id = Uuid::new_v4().to_string(); let mut aliases_map = HashMap::new(); aliases_map.insert("product".to_string(), section.product.clone()); - + let product_node = FileNode { node_id: product_node_id.clone(), label: section.product.clone(), @@ -253,15 +286,18 @@ pub fn import_categories_to_db(conn: &rusqlite::Connection, user_id: &str, tree_ updated_at: chrono::Utc::now().to_rfc3339(), sort_order: 0, }; - + tree.insert_node(conn, &product_node)?; - + for file in section.files { let file_node_id = Uuid::new_v4().to_string(); let mut aliases_map = HashMap::new(); aliases_map.insert("download_url".to_string(), file.download_url.clone()); - aliases_map.insert("file_size_display".to_string(), file.size.clone().unwrap_or_else(|| "Unknown".to_string())); - + aliases_map.insert( + "file_size_display".to_string(), + file.size.clone().unwrap_or_else(|| "Unknown".to_string()), + ); + let file_node = FileNode { node_id: file_node_id.clone(), label: file.filename.clone(), @@ -280,42 +316,50 @@ pub fn import_categories_to_db(conn: &rusqlite::Connection, user_id: &str, tree_ updated_at: chrono::Utc::now().to_rfc3339(), sort_order: 0, }; - + tree.insert_node(conn, &file_node)?; } } } - + Ok(()) } -pub fn import_series_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type: &str) -> Result<()> { +pub fn import_series_to_db( + conn: &rusqlite::Connection, + user_id: &str, + tree_type: &str, +) -> Result<()> { use crate::FileTree; - use filetree::node::{FileNode, Aliases, NodeType}; - use uuid::Uuid; + use filetree::node::{Aliases, FileNode, NodeType}; use std::collections::HashMap; - + use uuid::Uuid; + let series_dir = Path::new("/Users/accusys/markbase/data/downloads/by_series"); let files = read_series_files(series_dir)?; - + println!("Found {} Markdown files for series", files.len()); - + let mut tree = FileTree::load(conn, user_id, tree_type)?; - + for (_filename, content) in files { let parsed = parse_series_markdown(&content)?; - - println!("Parsed series: '{}', sections: {}", parsed.series, parsed.sections.len()); - + + println!( + "Parsed series: '{}', sections: {}", + parsed.series, + parsed.sections.len() + ); + if parsed.series.is_empty() { println!("Warning: series is empty, skipping"); continue; } - + let series_node_id = Uuid::new_v4().to_string(); let mut aliases_map = HashMap::new(); aliases_map.insert("series_type".to_string(), "series".to_string()); - + let series_node = FileNode { node_id: series_node_id.clone(), label: parsed.series.clone(), @@ -334,18 +378,22 @@ pub fn import_series_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type updated_at: chrono::Utc::now().to_rfc3339(), sort_order: 0, }; - + tree.insert_node(conn, &series_node)?; - + println!("Series node inserted successfully"); - + for section in parsed.sections { - println!("Processing section: {} with {} files", section.category, section.files.len()); - + println!( + "Processing section: {} with {} files", + section.category, + section.files.len() + ); + let category_node_id = Uuid::new_v4().to_string(); let mut aliases_map = HashMap::new(); aliases_map.insert("category".to_string(), section.category.clone()); - + let category_node = FileNode { node_id: category_node_id.clone(), label: section.category.clone(), @@ -364,15 +412,18 @@ pub fn import_series_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type updated_at: chrono::Utc::now().to_rfc3339(), sort_order: 0, }; - + tree.insert_node(conn, &category_node)?; - + for file in section.files { let file_node_id = Uuid::new_v4().to_string(); let mut aliases_map = HashMap::new(); aliases_map.insert("download_url".to_string(), file.download_url.clone()); - aliases_map.insert("file_size_display".to_string(), file.size.clone().unwrap_or_else(|| "Unknown".to_string())); - + aliases_map.insert( + "file_size_display".to_string(), + file.size.clone().unwrap_or_else(|| "Unknown".to_string()), + ); + let file_node = FileNode { node_id: file_node_id.clone(), label: file.filename.clone(), @@ -391,12 +442,12 @@ pub fn import_series_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type updated_at: chrono::Utc::now().to_rfc3339(), sort_order: 0, }; - + tree.insert_node(conn, &file_node)?; } } } - + Ok(()) } @@ -418,8 +469,8 @@ mod tests { ```https://download.accusys.ddns.net/api/v2/download/products/ExaSAN-DAS/C1M_C2M/User%20Guide/C2M-QIG20170906.zip ``` "#; - + let result = parse_category_markdown(content).unwrap(); assert_eq!(result.category, "GUI"); } -} \ No newline at end of file +} diff --git a/markbase-core/src/lib.rs b/markbase-core/src/lib.rs index 97d60f2..0fdce86 100644 --- a/markbase-core/src/lib.rs +++ b/markbase-core/src/lib.rs @@ -1,11 +1,14 @@ -pub mod audio; -pub mod auth; -pub mod audit; -pub mod cli; pub mod api; +pub mod archive; // Archive Module - Universal Compression Format Support (Phase 1-3完成) +pub mod audio; +pub mod audit; +pub mod auth; +pub mod category_view; +pub mod cli; pub mod command; pub mod config; pub mod download; +pub mod import_markdown; pub mod pg_client; pub mod render; pub mod rsync; @@ -14,20 +17,17 @@ pub mod s3_auth; pub mod s3_config; pub mod s3_xml; pub mod scan; -pub mod server; -pub mod archive; // Archive Module - Universal Compression Format Support (Phase 1-3完成) -pub mod category_view; -pub mod import_markdown; // Category View Module - 双视图管理(Phase 1) -// pub mod sftp; // ⚠️ russh版本(已禁用) -// pub mod ssh2_server; // ssh2服务器(已禁用) -// pub mod ssh2_mod; // ssh2辅助模块(已禁用) -pub mod ssh_server; // SSH服务器(Phase 1-9完成,正在修复编译错误)⭐⭐⭐⭐⭐ +pub mod server; // Category View Module - 双视图管理(Phase 1) + // pub mod sftp; // ⚠️ russh版本(已禁用) + // pub mod ssh2_server; // ssh2服务器(已禁用) + // pub mod ssh2_mod; // ssh2辅助模块(已禁用) +pub mod provider; // DataProvider抽象层(Phase 5) +pub mod ssh_server; // SSH服务器(Phase 1-9完成,正在修复编译错误)⭐⭐⭐⭐⭐ pub mod sync; -pub mod provider; // DataProvider抽象层(Phase 5) -pub mod vfs; // VFS抽象层(Phase 1-6重构计划) +pub mod vfs; // VFS抽象层(Phase 1-6重构计划) #[cfg(test)] -mod security_audit; // Security Audit Module - Phase 9 +mod security_audit; // Security Audit Module - Phase 9 // Re-export from external filetree crate pub use filetree::node::FileNode; diff --git a/markbase-core/src/main.rs b/markbase-core/src/main.rs index ad5e1fa..b8ca9c2 100644 --- a/markbase-core/src/main.rs +++ b/markbase-core/src/main.rs @@ -1,12 +1,12 @@ -use markbase_core::cli::Cli; use clap::Parser; +use markbase_core::cli::Cli; #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::Builder::from_default_env() .filter_level(log::LevelFilter::Info) .init(); - + let cli = Cli::parse(); match cli.command { @@ -25,4 +25,4 @@ async fn main() -> anyhow::Result<()> { } Ok(()) -} \ No newline at end of file +} diff --git a/markbase-core/src/pg_client.rs b/markbase-core/src/pg_client.rs index afb69a4..60a3d3d 100644 --- a/markbase-core/src/pg_client.rs +++ b/markbase-core/src/pg_client.rs @@ -10,6 +10,12 @@ pub struct PgClient { database: String, } +impl Default for PgClient { + fn default() -> Self { + Self::new() + } +} + impl PgClient { pub fn new() -> Self { Self { diff --git a/markbase-core/src/provider/mod.rs b/markbase-core/src/provider/mod.rs index 6fd7a73..eaed25f 100644 --- a/markbase-core/src/provider/mod.rs +++ b/markbase-core/src/provider/mod.rs @@ -1,8 +1,8 @@ -pub mod sqlite; pub mod pg; +pub mod sqlite; -pub use sqlite::SqliteProvider; pub use pg::PgProvider; +pub use sqlite::SqliteProvider; use std::path::PathBuf; @@ -57,7 +57,10 @@ pub trait DataProvider: Send + Sync { /// 检查用户是否存在且启用 fn user_exists(&self, username: &str) -> Result { - Ok(self.get_user(username)?.map(|u| u.status == 1).unwrap_or(false)) + Ok(self + .get_user(username)? + .map(|u| u.status == 1) + .unwrap_or(false)) } /// 获取用户的公开密钥列表(OpenSSH authorized_keys格式) diff --git a/markbase-core/src/provider/pg.rs b/markbase-core/src/provider/pg.rs index 74915d0..cd84b7b 100644 --- a/markbase-core/src/provider/pg.rs +++ b/markbase-core/src/provider/pg.rs @@ -1,7 +1,7 @@ -use std::path::PathBuf; -use postgres::{Client, NoTls}; -use bcrypt::verify; use super::{DataProvider, ProviderError, User}; +use bcrypt::verify; +use postgres::{Client, NoTls}; +use std::path::PathBuf; /// PostgreSQL 数据提供者(兼容 SFTPGo 的 users 表) pub struct PgProvider { @@ -13,7 +13,9 @@ impl PgProvider { /// /// 连接字符串格式:host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026 pub fn new(conn_str: &str) -> Result { - Ok(Self { conn_str: conn_str.to_string() }) + Ok(Self { + conn_str: conn_str.to_string(), + }) } pub fn from_params( @@ -40,18 +42,22 @@ impl DataProvider for PgProvider { fn get_user(&self, username: &str) -> Result, ProviderError> { let mut conn = self.open_conn()?; - let result = conn.query_opt( - "SELECT username, password, home_dir, permissions, uid, gid, status + let result = conn + .query_opt( + "SELECT username, password, home_dir, permissions, uid, gid, status FROM users WHERE username = $1 AND status = 1", - &[&username], - ).map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?; + &[&username], + ) + .map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?; match result { Some(row) => Ok(Some(User { username: row.get(0), password_hash: row.get::<_, Option>(1).unwrap_or_default(), home_dir: PathBuf::from(row.get::<_, String>(2)), - permissions: row.get::<_, Option>(3).unwrap_or_else(|| "*".to_string()), + permissions: row + .get::<_, Option>(3) + .unwrap_or_else(|| "*".to_string()), uid: row.get::<_, i64>(4) as u32, gid: row.get::<_, i64>(5) as u32, status: row.get(6), @@ -75,24 +81,31 @@ impl DataProvider for PgProvider { } fn get_home_dir(&self, username: &str) -> Result, ProviderError> { - Ok(self.get_user(username)?.map(|u| u.home_dir.to_string_lossy().to_string())) + Ok(self + .get_user(username)? + .map(|u| u.home_dir.to_string_lossy().to_string())) } fn get_public_keys(&self, username: &str) -> Result, ProviderError> { let mut conn = self.open_conn()?; - let result = conn.query_opt( - "SELECT public_keys FROM users WHERE username = $1 AND status = 1", - &[&username], - ).map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?; + let result = conn + .query_opt( + "SELECT public_keys FROM users WHERE username = $1 AND status = 1", + &[&username], + ) + .map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?; match result { Some(row) => { let json_str: Option = row.get(0); match json_str { Some(s) if !s.is_empty() => { - let keys: Vec = serde_json::from_str(&s) - .map_err(|e| ProviderError::Internal(format!("JSON parse error: {}", e)))?; - Ok(keys.iter() + let keys: Vec = + serde_json::from_str(&s).map_err(|e| { + ProviderError::Internal(format!("JSON parse error: {}", e)) + })?; + Ok(keys + .iter() .filter_map(|v| v.get("public_key")?.as_str().map(|s| s.to_string())) .collect()) } @@ -112,7 +125,7 @@ mod tests { fn test_pg_provider_connection() { // 仅当 SFTPGo PostgreSQL 可用时运行 let provider = PgProvider::new( - "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026", ); assert!(provider.is_ok(), "Should connect to SFTPGo PostgreSQL"); } @@ -120,8 +133,9 @@ mod tests { #[test] fn test_pg_get_user_demo() { let provider = PgProvider::new( - "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" - ).unwrap(); + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026", + ) + .unwrap(); let user = provider.get_user("demo").unwrap(); assert!(user.is_some(), "Demo user should exist"); assert_eq!(user.unwrap().username, "demo"); @@ -130,8 +144,9 @@ mod tests { #[test] fn test_pg_get_user_momentry() { let provider = PgProvider::new( - "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" - ).unwrap(); + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026", + ) + .unwrap(); let user = provider.get_user("momentry").unwrap(); assert!(user.is_some(), "Momentry user should exist"); } @@ -139,8 +154,9 @@ mod tests { #[test] fn test_pg_get_user_warren() { let provider = PgProvider::new( - "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" - ).unwrap(); + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026", + ) + .unwrap(); let user = provider.get_user("warren").unwrap(); assert!(user.is_some(), "Warren user should exist"); } @@ -148,8 +164,9 @@ mod tests { #[test] fn test_pg_check_password_demo() { let provider = PgProvider::new( - "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" - ).unwrap(); + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026", + ) + .unwrap(); let valid = provider.check_password("demo", "demo123").unwrap(); assert!(valid, "Password should be valid"); } @@ -157,8 +174,9 @@ mod tests { #[test] fn test_pg_check_password_invalid() { let provider = PgProvider::new( - "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" - ).unwrap(); + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026", + ) + .unwrap(); let valid = provider.check_password("demo", "wrong").unwrap(); assert!(!valid, "Wrong password should fail"); } @@ -166,8 +184,9 @@ mod tests { #[test] fn test_pg_get_home_dir() { let provider = PgProvider::new( - "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" - ).unwrap(); + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026", + ) + .unwrap(); let dir = provider.get_home_dir("demo").unwrap(); assert!(dir.is_some()); assert!(dir.unwrap().contains("momentry")); @@ -176,8 +195,9 @@ mod tests { #[test] fn test_pg_nonexistent_user() { let provider = PgProvider::new( - "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026" - ).unwrap(); + "host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026", + ) + .unwrap(); let user = provider.get_user("__nonexistent__").unwrap(); assert!(user.is_none()); } diff --git a/markbase-core/src/provider/sqlite.rs b/markbase-core/src/provider/sqlite.rs index 2b363b0..0149d32 100644 --- a/markbase-core/src/provider/sqlite.rs +++ b/markbase-core/src/provider/sqlite.rs @@ -1,7 +1,7 @@ -use std::path::PathBuf; -use rusqlite::{Connection, params}; -use bcrypt::verify; use super::{DataProvider, ProviderError, User}; +use bcrypt::verify; +use rusqlite::{params, Connection}; +use std::path::PathBuf; /// SQLite 数据提供者 pub struct SqliteProvider { @@ -13,7 +13,8 @@ impl SqliteProvider { let path = PathBuf::from(db_path); if !path.exists() { return Err(ProviderError::NotFound(format!( - "Database not found: {}", db_path + "Database not found: {}", + db_path ))); } Ok(Self { db_path: path }) @@ -50,7 +51,8 @@ impl DataProvider for SqliteProvider { Ok(user) => Ok(Some(user)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(ProviderError::Internal(format!( - "Database query error: {}", e + "Database query error: {}", + e ))), } } @@ -66,7 +68,9 @@ impl DataProvider for SqliteProvider { } fn get_home_dir(&self, username: &str) -> Result, ProviderError> { - Ok(self.get_user(username)?.map(|u| u.home_dir.to_string_lossy().to_string())) + Ok(self + .get_user(username)? + .map(|u| u.home_dir.to_string_lossy().to_string())) } fn get_public_keys(&self, username: &str) -> Result, ProviderError> { @@ -98,7 +102,10 @@ mod tests { } fn get_test_db_path() -> String { - format!("{}/../data/auth.sqlite", std::env::var("CARGO_MANIFEST_DIR").unwrap()) + format!( + "{}/../data/auth.sqlite", + std::env::var("CARGO_MANIFEST_DIR").unwrap() + ) } #[test] diff --git a/markbase-core/src/rsync/checksum.rs b/markbase-core/src/rsync/checksum.rs index e4f5314..9d5e870 100644 --- a/markbase-core/src/rsync/checksum.rs +++ b/markbase-core/src/rsync/checksum.rs @@ -1,4 +1,3 @@ -use anyhow::Result; use md5::compute; pub struct RollingChecksum { diff --git a/markbase-core/src/rsync/compress.rs b/markbase-core/src/rsync/compress.rs index 29592fb..017e43b 100644 --- a/markbase-core/src/rsync/compress.rs +++ b/markbase-core/src/rsync/compress.rs @@ -50,6 +50,12 @@ pub struct DecompressionStream { decompressor: Decompress, } +impl Default for DecompressionStream { + fn default() -> Self { + Self::new() + } +} + impl DecompressionStream { pub fn new() -> Self { Self { diff --git a/markbase-core/src/rsync/handler.rs b/markbase-core/src/rsync/handler.rs index f8b9ef0..3a1b53e 100644 --- a/markbase-core/src/rsync/handler.rs +++ b/markbase-core/src/rsync/handler.rs @@ -1,7 +1,5 @@ -use crate::rsync::checksum::{compute_block_checksums, BlockChecksum}; -use crate::rsync::compress::{CompressionStream, DecompressionStream}; use crate::rsync::delta::{DeltaAlgorithm, DeltaInstruction}; -use crate::rsync::protocol::{RsyncCommand, RsyncProtocol}; +use crate::rsync::protocol::RsyncCommand; use crate::rsync::RsyncConfig; use anyhow::Result; use std::sync::Arc; diff --git a/markbase-core/src/rsync/protocol.rs b/markbase-core/src/rsync/protocol.rs index 20412a7..fbd4d7a 100644 --- a/markbase-core/src/rsync/protocol.rs +++ b/markbase-core/src/rsync/protocol.rs @@ -162,6 +162,12 @@ pub struct RsyncHandshake { negotiated_version: u32, } +impl Default for RsyncHandshake { + fn default() -> Self { + Self::new() + } +} + impl RsyncHandshake { pub fn new() -> Self { Self { diff --git a/markbase-core/src/s3.rs b/markbase-core/src/s3.rs index 6657a98..f788637 100644 --- a/markbase-core/src/s3.rs +++ b/markbase-core/src/s3.rs @@ -1,15 +1,17 @@ -use filetree::{FileTree, node::{FileNode, Aliases}}; use axum::{ body::Body, extract::{Path, State}, http::{HeaderMap, StatusCode}, response::{IntoResponse, Json}, }; +use filetree::{ + node::FileNode, + FileTree, +}; use futures_util::StreamExt; use serde::{Deserialize, Serialize}; use serde_json::Value; use sha2::{Digest, Sha256}; -use std::sync::{Arc, Mutex}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio_util::io::ReaderStream; @@ -41,10 +43,10 @@ pub async fn list_buckets(State(state): State) -> impl pub async fn list_objects( Path(bucket): Path, - State(state): State, + State(_state): State, ) -> impl IntoResponse { println!("S3 List Objects: bucket={}", bucket); - + let conn = match FileTree::open_user_db(&bucket) { Ok(c) => c, Err(e) => { @@ -70,20 +72,20 @@ pub async fn list_objects( "Key": build_s3_key(&tree, n), "LastModified": n.registered_at.clone().unwrap_or_default(), "ETag": n.sha256.clone().unwrap_or_default(), - "Size": n.file_size.clone().unwrap_or(0), + "Size": n.file_size.unwrap_or(0), }) }) .collect(); println!("Listed {} objects for bucket {}", objects.len(), bucket); - + let (headers, xml_body) = crate::s3_xml::list_objects_xml(&bucket, &objects); (StatusCode::OK, headers, xml_body).into_response() } pub async fn get_object( Path((bucket, key)): Path<(String, String)>, - State(state): State, + State(_state): State, headers: HeaderMap, ) -> impl IntoResponse { println!("S3 GET Object: bucket={}, key={}", bucket, key); @@ -119,7 +121,7 @@ pub async fn get_object( ); let file_uuid = node.file_uuid.clone().unwrap_or_default(); - let file_size = node.file_size.clone().unwrap_or(0); + let file_size = node.file_size.unwrap_or(0); let sha256 = node.sha256.clone().unwrap_or_default(); let real_path = get_real_file_path(&conn, &file_uuid); @@ -233,7 +235,7 @@ pub async fn put_object( let sha256_hash_clone = sha256_hash.clone(); let file_path_clone = file_path.clone(); - let label = key.split('/').last().unwrap_or(&key).to_string(); + let label = key.split('/').next_back().unwrap_or(&key).to_string(); let result = tokio::task::spawn_blocking(move || -> anyhow::Result<()> { let conn = match FileTree::open_user_db(&bucket) { @@ -246,7 +248,7 @@ pub async fn put_object( |row| row.get::<_, i32>(0), ) .unwrap_or(0) > 0; - + if !has_tables { // Initialize tables if not exist c.execute_batch(filetree::CREATE_TABLES)?; @@ -298,7 +300,7 @@ pub async fn put_object( pub async fn head_object( Path((bucket, key)): Path<(String, String)>, - State(state): State, + State(_state): State, ) -> impl IntoResponse { let conn = match FileTree::open_user_db(&bucket) { Ok(c) => c, @@ -323,7 +325,7 @@ pub async fn head_object( "ETag", node.sha256.clone().unwrap_or_default().parse().unwrap(), ); - headers.insert("Content-Length", node.file_size.clone().unwrap_or(0).into()); + headers.insert("Content-Length", node.file_size.unwrap_or(0).into()); (StatusCode::OK, headers) } @@ -438,7 +440,7 @@ fn find_node_by_s3_key(tree: &FileTree, key: &str) -> Option { } // 方法2:通过filename直接匹配(fallback) - let filename = key.split('/').last().unwrap_or(key); + let filename = key.split('/').next_back().unwrap_or(key); tree.nodes .iter() .filter(|n| n.node_type == filetree::node::NodeType::File) @@ -501,7 +503,7 @@ async fn handle_range_request( } // 使用take限制读取长度 - let limited_file = file.take(content_length as u64); + let limited_file = file.take(content_length); let stream = ReaderStream::new(limited_file); let body = Body::from_stream(stream); @@ -535,11 +537,7 @@ fn parse_range_header(range: &str, file_size: i64) -> Option<(u64, u64)> { let (start, end) = if parts[0].is_empty() { // "bytes=-N"格式:最后N字节 let suffix_length = parts[1].parse::().ok()?; - let start = if suffix_length > file_size as u64 { - 0 - } else { - file_size as u64 - suffix_length - }; + let start = (file_size as u64).saturating_sub(suffix_length); (start, file_size as u64 - 1) } else if parts[1].is_empty() { // "bytes=N-"格式:从N到结尾 diff --git a/markbase-core/src/s3_auth.rs b/markbase-core/src/s3_auth.rs index c076f76..17d7c8c 100644 --- a/markbase-core/src/s3_auth.rs +++ b/markbase-core/src/s3_auth.rs @@ -8,11 +8,11 @@ type HmacSha256 = Hmac; pub fn verify_signature(headers: HeaderMap, method: &str, path: &str) -> bool { // Load S3 config and check require_auth flag let config = crate::s3_config::S3Config::load_default().unwrap_or_default(); - + // Merge environment variables (allows override via MB_S3_REQUIRE_AUTH) let mut config = config; config.merge_env(); - + if !config.s3.require_auth { // Development mode: allow access without authentication return true; @@ -127,7 +127,7 @@ fn calculate_signature( headers: HeaderMap, method: &str, path: &str, - access_key: &str, + _access_key: &str, secret_key: &str, region: &str, service: &str, @@ -143,9 +143,9 @@ fn calculate_signature( let signing_key = calculate_signing_key(secret_key, date, region, service); // 4. Calculate Signature - let signature = hmac_sha256_hex(&signing_key, &string_to_sign); + - signature + hmac_sha256_hex(&signing_key, &string_to_sign) } fn create_canonical_request(headers: HeaderMap, method: &str, path: &str) -> String { diff --git a/markbase-core/src/s3_config.rs b/markbase-core/src/s3_config.rs index 572d14c..62c3a09 100644 --- a/markbase-core/src/s3_config.rs +++ b/markbase-core/src/s3_config.rs @@ -4,6 +4,7 @@ use std::fs; use std::path::PathBuf; #[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Default)] pub struct S3Config { #[serde(default)] pub s3: S3Section, @@ -40,6 +41,7 @@ pub struct KeysSection { } #[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Default)] pub struct BucketsSection { #[serde(default)] pub mappings: std::collections::HashMap, @@ -96,16 +98,6 @@ fn admin_permissions() -> Vec { ] } -impl Default for S3Config { - fn default() -> Self { - Self { - s3: S3Section::default(), - keys: KeysSection::default(), - buckets: BucketsSection::default(), - permissions: PermissionsSection::default(), - } - } -} impl Default for S3Section { fn default() -> Self { @@ -129,13 +121,6 @@ impl Default for KeysSection { } } -impl Default for BucketsSection { - fn default() -> Self { - Self { - mappings: std::collections::HashMap::new(), - } - } -} impl Default for PermissionsSection { fn default() -> Self { @@ -169,9 +154,9 @@ impl S3Config { Self::load("config/s3.toml") } -pub fn save(&self, path: &str) -> Result<()> { + pub fn save(&self, path: &str) -> Result<()> { let config_path = PathBuf::from(path); - + // Create backup before saving if config_path.exists() { let backup_path = config_path.with_extension("toml.bak"); @@ -179,13 +164,13 @@ pub fn save(&self, path: &str) -> Result<()> { .with_context(|| format!("Failed to create backup: {}", backup_path.display()))?; log::info!("S3 config backup created: {}", backup_path.display()); } - - let content = toml::to_string_pretty(self) - .with_context(|| "Failed to serialize S3 config")?; - + + let content = + toml::to_string_pretty(self).with_context(|| "Failed to serialize S3 config")?; + std::fs::write(&config_path, content) .with_context(|| format!("Failed to write S3 config: {}", path))?; - + log::info!("S3 config saved to: {}", path); Ok(()) } @@ -255,10 +240,16 @@ pub fn save(&self, path: &str) -> Result<()> { // Validate permission format let valid_permissions = [ - "GetObject", "PutObject", "DeleteObject", "ListBucket", - "HeadObject", "ListAllMyBuckets", "CreateBucket", "DeleteBucket" + "GetObject", + "PutObject", + "DeleteObject", + "ListBucket", + "HeadObject", + "ListAllMyBuckets", + "CreateBucket", + "DeleteBucket", ]; - + for perm in &self.permissions.default_permissions { if !valid_permissions.contains(&perm.as_str()) { return Err(anyhow::anyhow!( @@ -289,18 +280,18 @@ pub fn save(&self, path: &str) -> Result<()> { "s3.region" => Some(self.s3.region.clone()), "s3.service" => Some(self.s3.service.clone()), "s3.require_auth" => Some(self.s3.require_auth.to_string()), - + "keys.default_access_key" => Some(self.keys.default_access_key.clone()), "keys.default_secret_key" => Some(self.keys.default_secret_key.clone()), "keys.keys_db_path" => Some(self.keys.keys_db_path.clone()), - - "permissions.default_permissions" => { - Some(serde_json::to_string(&self.permissions.default_permissions).unwrap_or_default()) - } + + "permissions.default_permissions" => Some( + serde_json::to_string(&self.permissions.default_permissions).unwrap_or_default(), + ), "permissions.admin_permissions" => { Some(serde_json::to_string(&self.permissions.admin_permissions).unwrap_or_default()) } - + _ => None, } } @@ -312,11 +303,11 @@ pub fn save(&self, path: &str) -> Result<()> { "s3.region" => self.s3.region = value.to_string(), "s3.service" => self.s3.service = value.to_string(), "s3.require_auth" => self.s3.require_auth = value.parse()?, - + "keys.default_access_key" => self.keys.default_access_key = value.to_string(), "keys.default_secret_key" => self.keys.default_secret_key = value.to_string(), "keys.keys_db_path" => self.keys.keys_db_path = value.to_string(), - + "permissions.default_permissions" => { self.permissions.default_permissions = serde_json::from_str(value) .with_context(|| "Failed to parse permissions array")?; @@ -325,7 +316,7 @@ pub fn save(&self, path: &str) -> Result<()> { self.permissions.admin_permissions = serde_json::from_str(value) .with_context(|| "Failed to parse admin permissions array")?; } - + _ => return Err(anyhow::anyhow!("Invalid S3 config key: {}", key)), } Ok(()) @@ -340,15 +331,15 @@ mod tests { #[test] fn test_default_config() { let config = S3Config::default(); - + assert_eq!(config.s3.enabled, true); assert_eq!(config.s3.require_auth, false); assert_eq!(config.s3.endpoint, "http://localhost:11438/s3"); assert_eq!(config.s3.region, "us-east-1"); - + assert_eq!(config.keys.default_access_key, "markbase_access_key_001"); assert_eq!(config.keys.default_secret_key, "markbase_secret_key_xyz123"); - + assert_eq!(config.permissions.default_permissions.len(), 3); assert_eq!(config.permissions.admin_permissions.len(), 5); } @@ -357,9 +348,9 @@ mod tests { fn test_load_missing_config() { let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("missing.toml"); - + let config = S3Config::load(&config_path.to_string_lossy()).unwrap(); - + assert_eq!(config.s3.enabled, true); assert_eq!(config.s3.require_auth, false); } @@ -368,13 +359,13 @@ mod tests { fn test_merge_env() { std::env::set_var("MB_S3_REQUIRE_AUTH", "true"); std::env::set_var("MB_S3_ENDPOINT", "http://custom.endpoint"); - + let mut config = S3Config::default(); config.merge_env(); - + assert_eq!(config.s3.require_auth, true); assert_eq!(config.s3.endpoint, "http://custom.endpoint"); - + std::env::remove_var("MB_S3_REQUIRE_AUTH"); std::env::remove_var("MB_S3_ENDPOINT"); } @@ -383,7 +374,7 @@ mod tests { fn test_validate() { let config = S3Config::default(); assert!(config.validate().is_ok()); - + let mut invalid_config = S3Config::default(); invalid_config.s3.endpoint = "".to_string(); assert!(invalid_config.validate().is_err()); @@ -392,14 +383,17 @@ mod tests { #[test] fn test_get_set() { let mut config = S3Config::default(); - + assert_eq!(config.get("s3.enabled"), Some("true".to_string())); - assert_eq!(config.get("s3.endpoint"), Some("http://localhost:11438/s3".to_string())); - + assert_eq!( + config.get("s3.endpoint"), + Some("http://localhost:11438/s3".to_string()) + ); + config.set("s3.require_auth", "true").unwrap(); assert_eq!(config.s3.require_auth, true); - + config.set("s3.endpoint", "http://new.endpoint").unwrap(); assert_eq!(config.s3.endpoint, "http://new.endpoint"); } -} \ No newline at end of file +} diff --git a/markbase-core/src/s3_xml.rs b/markbase-core/src/s3_xml.rs index 4213875..e0b0153 100644 --- a/markbase-core/src/s3_xml.rs +++ b/markbase-core/src/s3_xml.rs @@ -4,16 +4,18 @@ use serde_json::Value; pub fn list_buckets_xml(buckets: &[String]) -> (HeaderMap, String) { let mut headers = HeaderMap::new(); headers.insert("Content-Type", "application/xml".parse().unwrap()); - + let bucket_entries = buckets .iter() - .map(|b| format!( - "{}2026-05-27T00:00:00Z", - b - )) + .map(|b| { + format!( + "{}2026-05-27T00:00:00Z", + b + ) + }) .collect::>() .join("\n "); - + let xml = format!( " @@ -27,22 +29,25 @@ pub fn list_buckets_xml(buckets: &[String]) -> (HeaderMap, String) { ", bucket_entries ); - + (headers, xml) } pub fn list_objects_xml(bucket_name: &str, objects: &[Value]) -> (HeaderMap, String) { let mut headers = HeaderMap::new(); headers.insert("Content-Type", "application/xml".parse().unwrap()); - + let object_entries = objects .iter() .map(|obj| { let key = obj.get("Key").and_then(|k| k.as_str()).unwrap_or(""); - let last_modified = obj.get("LastModified").and_then(|l| l.as_str()).unwrap_or(""); + let last_modified = obj + .get("LastModified") + .and_then(|l| l.as_str()) + .unwrap_or(""); let etag = obj.get("ETag").and_then(|e| e.as_str()).unwrap_or(""); let size = obj.get("Size").and_then(|s| s.as_i64()).unwrap_or(0); - + format!( " {} @@ -55,7 +60,7 @@ pub fn list_objects_xml(bucket_name: &str, objects: &[Value]) -> (HeaderMap, Str }) .collect::>() .join("\n "); - + let xml = format!( " @@ -68,6 +73,6 @@ pub fn list_objects_xml(bucket_name: &str, objects: &[Value]) -> (HeaderMap, Str ", bucket_name, object_entries ); - + (headers, xml) } diff --git a/markbase-core/src/scan.rs b/markbase-core/src/scan.rs index c103d0b..a9c62c4 100644 --- a/markbase-core/src/scan.rs +++ b/markbase-core/src/scan.rs @@ -439,7 +439,7 @@ fn compute_hashes_parallel( let mut p = processed.lock().unwrap(); *p += 1; - if *p % 100 == 0 { + if (*p).is_multiple_of(100) { print!("\r Hashed {}/{} files...", *p, total); use std::io::Write; std::io::stdout().flush().ok(); diff --git a/markbase-core/src/security_audit/auth_security.rs b/markbase-core/src/security_audit/auth_security.rs index 76ccc28..5490612 100644 --- a/markbase-core/src/security_audit/auth_security.rs +++ b/markbase-core/src/security_audit/auth_security.rs @@ -12,20 +12,22 @@ fn get_test_provider() -> Arc { #[test] fn test_password_authentication_brute_force_prevention() { let provider = get_test_provider(); - + assert!(provider.check_password("demo", "demo123").unwrap()); assert!(!provider.check_password("demo", "wrongpassword").unwrap()); assert!(!provider.check_password("demo", "").unwrap()); - assert!(!provider.check_password("__nonexistent__", "anypassword").unwrap()); + assert!(!provider + .check_password("__nonexistent__", "anypassword") + .unwrap()); } #[test] fn test_publickey_authentication_security() { let provider = get_test_provider(); - + let keys = provider.get_public_keys("demo").unwrap(); assert!(keys.is_empty() || keys.len() >= 0); - + let keys = provider.get_public_keys("__nonexistent__").unwrap(); assert!(keys.is_empty()); } @@ -33,10 +35,10 @@ fn test_publickey_authentication_security() { #[test] fn test_user_status_check() { let provider = get_test_provider(); - + let user = provider.get_user("demo").unwrap(); assert!(user.is_some()); - + let user = provider.get_user("demo").unwrap(); if let Some(u) = user { assert_eq!(u.status, 1); @@ -46,16 +48,16 @@ fn test_user_status_check() { #[test] fn test_home_dir_security() { let provider = get_test_provider(); - + let home = provider.get_home_dir("demo").unwrap(); assert!(home.is_some()); - + let home = provider.get_home_dir("__nonexistent__").unwrap(); assert!(home.is_none()); - + if let Some(home_path) = provider.get_home_dir("demo").unwrap() { assert!(!home_path.contains("..")); assert!(!home_path.starts_with("/etc")); assert!(!home_path.starts_with("/root")); } -} \ No newline at end of file +} diff --git a/markbase-core/src/security_audit/channel_security.rs b/markbase-core/src/security_audit/channel_security.rs index 3b8352b..bf43d66 100644 --- a/markbase-core/src/security_audit/channel_security.rs +++ b/markbase-core/src/security_audit/channel_security.rs @@ -19,7 +19,7 @@ fn test_channel_window_size_limits() { #[test] fn test_channel_request_validation() { let valid_requests = ["exec", "shell", "subsystem", "env"]; - + for request in valid_requests { assert!(!request.is_empty()); } @@ -30,6 +30,6 @@ fn test_channel_data_integrity() { // Data should not exceed window size let window_size = 32768u32; let max_data = window_size; - + assert!(max_data <= window_size); -} \ No newline at end of file +} diff --git a/markbase-core/src/security_audit/crypto_security.rs b/markbase-core/src/security_audit/crypto_security.rs index 55bdd00..4a02967 100644 --- a/markbase-core/src/security_audit/crypto_security.rs +++ b/markbase-core/src/security_audit/crypto_security.rs @@ -1,11 +1,11 @@ use crate::ssh_server::cipher::EncryptionContext; -use crate::ssh_server::crypto::{SessionKeys, Curve25519Kex, Ed25519HostKey}; +use crate::ssh_server::crypto::{Curve25519Kex, Ed25519HostKey, SessionKeys}; #[test] fn test_aes_ctr_encryption_decryption_consistency() { let key = vec![0u8; 16]; let iv = vec![0u8; 16]; - + let mut ctx = EncryptionContext::from_session_keys(&SessionKeys { session_id: vec![0u8; 32], encryption_key_ctos: key.clone(), @@ -15,10 +15,10 @@ fn test_aes_ctr_encryption_decryption_consistency() { iv_ctos: iv.clone(), iv_stoc: iv.clone(), }); - + let plaintext = b"Test message for encryption"; let ciphertext = ctx.encrypt_packet(plaintext, &key, &iv).unwrap(); - + let decrypted = ctx.decrypt_packet(&ciphertext, &key, &iv).unwrap(); assert_eq!(plaintext.to_vec(), decrypted); } @@ -27,7 +27,7 @@ fn test_aes_ctr_encryption_decryption_consistency() { fn test_hmac_sha256_authentication() { let key = vec![0u8; 32]; let data = b"Test data for HMAC"; - + let ctx = EncryptionContext::from_session_keys(&SessionKeys { session_id: vec![0u8; 32], encryption_key_ctos: vec![0u8; 16], @@ -37,12 +37,12 @@ fn test_hmac_sha256_authentication() { iv_ctos: vec![0u8; 16], iv_stoc: vec![0u8; 16], }); - + let mac = ctx.compute_mac(1, data, &key).unwrap(); assert_eq!(mac.len(), 32); - + assert!(ctx.verify_mac(1, data, &mac, &key).unwrap()); - + let wrong_mac = vec![0u8; 32]; assert!(!ctx.verify_mac(1, data, &wrong_mac, &key).unwrap()); } @@ -52,19 +52,19 @@ fn test_curve25519_key_exchange_security() { // Create client and server instances let mut client_kex = Curve25519Kex::new(); let mut server_kex = Curve25519Kex::new(); - + // Get public keys first (before computing shared secrets) let client_pub = client_kex.public_key().to_vec(); let server_pub = server_kex.public_key().to_vec(); - + assert_eq!(client_pub.len(), 32); assert_eq!(server_pub.len(), 32); - + // Compute shared secrets using the SAME instances // (this consumes the secret, so can only be done once) let client_secret = client_kex.compute_shared_secret(&server_pub).unwrap(); let server_secret = server_kex.compute_shared_secret(&client_pub).unwrap(); - + // Shared secrets should match (Diffie-Hellman property) assert_eq!(client_secret, server_secret); assert_eq!(client_secret.len(), 32); @@ -73,12 +73,12 @@ fn test_curve25519_key_exchange_security() { #[test] fn test_ed25519_signature_verification() { let host_key = Ed25519HostKey::load_or_generate("test_security_key").unwrap(); - + let message = b"Test message for signature"; let signature = host_key.sign(message).unwrap(); - + assert_eq!(signature.len(), 64); - + // Ed25519HostKey has sign() but verify might need external library // For security test, we verify signature length and structure assert!(!signature.is_empty()); @@ -89,9 +89,9 @@ fn test_encryption_key_derivation_uniqueness() { let key1 = vec![1u8; 16]; let key2 = vec![2u8; 16]; let iv = vec![0u8; 16]; - + let plaintext = b"Same plaintext"; - + let mut ctx1 = EncryptionContext::from_session_keys(&SessionKeys { session_id: vec![0u8; 32], encryption_key_ctos: key1.clone(), @@ -101,7 +101,7 @@ fn test_encryption_key_derivation_uniqueness() { iv_ctos: iv.clone(), iv_stoc: iv.clone(), }); - + let mut ctx2 = EncryptionContext::from_session_keys(&SessionKeys { session_id: vec![0u8; 32], encryption_key_ctos: key2.clone(), @@ -111,9 +111,9 @@ fn test_encryption_key_derivation_uniqueness() { iv_ctos: iv.clone(), iv_stoc: iv.clone(), }); - + let ciphertext1 = ctx1.encrypt_packet(plaintext, &key1, &iv).unwrap(); let ciphertext2 = ctx2.encrypt_packet(plaintext, &key2, &iv).unwrap(); - + assert_ne!(ciphertext1, ciphertext2); -} \ No newline at end of file +} diff --git a/markbase-core/src/security_audit/file_access_security.rs b/markbase-core/src/security_audit/file_access_security.rs index 973ed9c..6779758 100644 --- a/markbase-core/src/security_audit/file_access_security.rs +++ b/markbase-core/src/security_audit/file_access_security.rs @@ -3,17 +3,17 @@ use std::path::PathBuf; #[test] fn test_path_traversal_prevention() { let root = PathBuf::from("/tmp/test_root"); - + // Test 1: Normal path should be within root let safe_path = PathBuf::from("safe/file.txt"); let full_path = root.join(&safe_path); assert!(full_path.starts_with(&root)); - + // Test 2: Path traversal attempt should still resolve within root // (after normalization, ../../etc/passwd from /tmp/test_root becomes /tmp/etc/passwd or /etc/passwd) let evil_path = PathBuf::from("../../etc/passwd"); let full_path = root.join(&evil_path); - + // The key security check: the resolved path should NOT be /etc/passwd // If Path::join normalizes it to /etc/passwd, that's a path traversal vulnerability // We check that the joined path either: @@ -34,7 +34,7 @@ fn test_path_traversal_prevention() { #[test] fn test_absolute_path_prevention() { let root = PathBuf::from("/tmp/test_root"); - + let abs_path = PathBuf::from("/etc/passwd"); assert!(!abs_path.starts_with(&root)); } @@ -42,10 +42,10 @@ fn test_absolute_path_prevention() { #[test] fn test_directory_escape_prevention() { let root = PathBuf::from("/tmp/test_root"); - + let parent_path = PathBuf::from("subdir/../.."); let full_path = root.join(&parent_path); - + // Path should not escape root if full_path.canonicalize().is_ok() { let canonical = full_path.canonicalize().unwrap(); @@ -56,11 +56,11 @@ fn test_directory_escape_prevention() { #[test] fn test_file_write_boundary_check() { let root = PathBuf::from("/tmp/test_root"); - + let safe_file = PathBuf::from("safe.txt"); let full_path = root.join(&safe_file); assert!(full_path.starts_with(&root)); - + let outside_file = PathBuf::from("/tmp/outside.txt"); assert!(!outside_file.starts_with(&root)); } @@ -68,8 +68,8 @@ fn test_file_write_boundary_check() { #[test] fn test_hidden_file_access() { let root = PathBuf::from("/tmp/test_root"); - + let hidden_path = PathBuf::from(".hidden"); let full_path = root.join(&hidden_path); assert!(full_path.starts_with(&root)); -} \ No newline at end of file +} diff --git a/markbase-core/src/security_audit/mod.rs b/markbase-core/src/security_audit/mod.rs index f2b9109..64138c6 100644 --- a/markbase-core/src/security_audit/mod.rs +++ b/markbase-core/src/security_audit/mod.rs @@ -1,9 +1,9 @@ mod auth_security; +mod channel_security; mod crypto_security; mod file_access_security; -mod channel_security; pub use auth_security::*; +pub use channel_security::*; pub use crypto_security::*; pub use file_access_security::*; -pub use channel_security::*; \ No newline at end of file diff --git a/markbase-core/src/server.rs b/markbase-core/src/server.rs index 6ff3c77..66161e7 100644 --- a/markbase-core/src/server.rs +++ b/markbase-core/src/server.rs @@ -1,22 +1,23 @@ use anyhow::Context; use axum::{ + extract::DefaultBodyLimit, extract::{Path, Query, State}, http::{HeaderMap, StatusCode}, response::{Html, IntoResponse, Json}, routing::{delete, get, patch, post, put}, Router, - extract::DefaultBodyLimit, }; use serde::Deserialize; use std::str::FromStr; use std::sync::{Arc, Mutex}; +use crate::archive::{ + ArchiveConfig, ArchiveFormat, ArchiveProcessor, FormatDetector, ProcessorRegistry, +}; use crate::audio; use crate::auth::{AuthState, LoginRequest}; use crate::provider::sqlite::SqliteProvider; use crate::render; -use crate::download; -use crate::archive::{self, ArchiveFormat, ArchiveProcessor, FormatDetector, ArchiveConfig, ProcessorRegistry}; use filetree::{self, FileTree}; #[derive(Clone)] @@ -60,7 +61,7 @@ pub async fn run(port: u16, file: Option) -> anyhow::Result<()> { db_dir: "data/users".to_string(), auth: AuthState::with_provider(Box::new( SqliteProvider::new("data/auth.sqlite") - .map_err(|e| anyhow::anyhow!("Failed to init SqliteProvider: {}", e))? + .map_err(|e| anyhow::anyhow!("Failed to init SqliteProvider: {}", e))?, )), auth_db_path: "data/auth.sqlite".to_string(), s3_keys: Arc::new(Mutex::new(load_s3_keys())), @@ -578,7 +579,7 @@ async fn search_tree( ORDER BY sort_order ASC, created_at ASC", )?; - let nodes: Vec = stmt + let _nodes: Vec = stmt .query_map([&search_pattern], |row| { let children_json: String = row.get(6)?; let children: Vec = @@ -607,7 +608,7 @@ async fn search_tree( .filter_map(|r| r.ok()) .collect(); -let tree = filetree::FileTree { + let tree = filetree::FileTree { user_id: user_id.clone(), tree_type: "untitled folder".to_string(), nodes: vec![], @@ -914,69 +915,78 @@ fn extract_and_register_archive( user_id: &str, original_filename: &str, ) -> anyhow::Result<(u64, u64, String)> { - use std::path::PathBuf; - use sha2::{Sha256, Digest}; + use sha2::{Digest, Sha256}; + // Initialize archive system let config = ArchiveConfig::default(); let mut registry = ProcessorRegistry::new(config); registry.initialize()?; - + // Detect format let detector = FormatDetector::new(); let format = detector.detect(archive_path)?; - - eprintln!("[archive] Detected format: {} for file: {}", format, archive_path.display()); - + + eprintln!( + "[archive] Detected format: {} for file: {}", + format, + archive_path.display() + ); + // Get processor let processor = registry.get_processor_mut(archive_path)?; - + // Create extraction directory let base_name = original_filename .rsplit_once('.') .map(|(name, _)| name) .unwrap_or(original_filename); - - let extraction_dir = archive_path.parent() + + let extraction_dir = archive_path + .parent() .unwrap_or(std::path::Path::new(".")) .join(format!("{}_extracted", base_name)); - + std::fs::create_dir_all(&extraction_dir)?; - + // Open and extract let metadata = processor.open(archive_path)?; - - eprintln!("[archive] Archive metadata: {} files, {} bytes", - metadata.total_files, metadata.total_size); - + + eprintln!( + "[archive] Archive metadata: {} files, {} bytes", + metadata.total_files, metadata.total_size + ); + let result = processor.extract_all(&extraction_dir)?; - - eprintln!("[archive] Extracted {} files ({} bytes)", - result.success_files, result.total_bytes); - + + eprintln!( + "[archive] Extracted {} files ({} bytes)", + result.success_files, result.total_bytes + ); + // Register extracted files to database let conn = FileTree::init_user_db(user_id)?; - + let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs() as i64; - + // Get MAC address for UUID generation let mac_output = std::process::Command::new("ifconfig") .arg("en0") .output() .map(|o| String::from_utf8_lossy(&o.stdout).to_string()) .unwrap_or_default(); - + let mac = mac_output .lines() .find(|l| l.contains("ether")) .and_then(|l| l.split_whitespace().nth(1)) .unwrap_or("00:00:00:00:00:00"); - + let mut registered_count = 0u64; - + // Recursively scan extracted directory fn scan_directory( dir: &std::path::Path, @@ -986,11 +996,11 @@ fn extract_and_register_archive( now: i64, ) -> anyhow::Result { let mut count = 0u64; - + for entry in std::fs::read_dir(dir)? { let entry = entry?; let path = entry.path(); - + if path.is_dir() { count += scan_directory(&path, conn, user_id, mac, now)?; } else if path.is_file() { @@ -998,16 +1008,15 @@ fn extract_and_register_archive( let file_data = std::fs::read(&path)?; let file_hash = format!("{:x}", Sha256::digest(&file_data)); let file_size = file_data.len() as i64; - - let filename = path.file_name() + + let filename = path + .file_name() .and_then(|n| n.to_str()) .unwrap_or("unknown") .to_string(); - - let file_path_str = path.to_str() - .unwrap_or("unknown") - .to_string(); - + + let file_path_str = path.to_str().unwrap_or("unknown").to_string(); + // Generate file UUID let mtime = std::fs::metadata(&path) .ok() @@ -1015,48 +1024,55 @@ fn extract_and_register_archive( .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok()) .map(|d| d.as_millis() as u64) .unwrap_or(0); - + let input = format!("{}|{}|{}|{}", file_path_str, filename, mac, mtime); let hash = Sha256::digest(input.as_bytes()); let hex = format!("{:x}", hash); let file_uuid = hex[0..32].to_string(); - + // Register file (no sha256 in file_registry) conn.execute( "INSERT INTO file_registry (file_uuid, original_name, file_size, file_type, registered_at) VALUES (?1, ?2, ?3, ?4, ?5)", rusqlite::params![&file_uuid, &filename, file_size, "", now], )?; - + // Add file location conn.execute( "INSERT OR IGNORE INTO file_locations (file_uuid, location, added_at) VALUES (?1, ?2, ?3)", rusqlite::params![&file_uuid, &file_path_str, now], )?; - + // Add file node let uuid_str = uuid::Uuid::new_v4().to_string().replace('-', ""); let node_id = format!("node-{}", &uuid_str[0..8]); - + conn.execute( "INSERT INTO file_nodes (node_id, label, file_uuid, sha256, node_type, file_size, created_at, updated_at) VALUES (?1, ?2, ?3, ?4, 'file', ?5, ?6, ?7)", rusqlite::params![&node_id, &filename, &file_uuid, &file_hash, file_size, now, now], )?; - + count += 1; } } - + Ok(count) } - + registered_count = scan_directory(&extraction_dir, &conn, user_id, mac, now)?; - - eprintln!("[archive] Registered {} extracted files to database", registered_count); - - Ok((result.success_files, result.total_bytes, extraction_dir.to_str().unwrap_or("unknown").to_string())) + + eprintln!( + "[archive] Registered {} extracted files to database", + registered_count + ); + + Ok(( + result.success_files, + result.total_bytes, + extraction_dir.to_str().unwrap_or("unknown").to_string(), + )) } async fn upload_file( @@ -1147,23 +1163,23 @@ async fn upload_file( // Auto-extract archive files let file_path_buf = std::path::PathBuf::from(&file_path); let detector = FormatDetector::new(); - + if let Ok(format) = detector.detect(&file_path_buf) { if format != ArchiveFormat::Unknown { - eprintln!("[upload] Detected archive format: {}, extracting...", format); - + eprintln!( + "[upload] Detected archive format: {}, extracting...", + format + ); + let user_id_clone = user_id.clone(); let filename_clone = filename.clone(); - + // Extract in blocking thread let extraction_result = tokio::task::spawn_blocking(move || { - extract_and_register_archive( - &file_path_buf, - &user_id_clone, - &filename_clone, - ) - }).await; - + extract_and_register_archive(&file_path_buf, &user_id_clone, &filename_clone) + }) + .await; + match extraction_result { Ok(Ok((count, bytes, extract_dir))) => { extracted_info = Some((count, bytes, extract_dir)); @@ -1208,13 +1224,13 @@ async fn upload_file( let hex = format!("{:x}", hash); let file_uuid = hex[0..32].to_string(); -// Save to database (user-specific SQLite) + // Save to database (user-specific SQLite) let file_uuid_clone = file_uuid.clone(); let file_hash_clone = file_hash.clone(); let filename_clone = filename.clone(); let file_path_clone = file_path.clone(); let user_id_clone = user_id.clone(); - + let db_result = tokio::task::spawn_blocking(move || -> anyhow::Result<()> { let conn = filetree::FileTree::init_user_db(&user_id_clone)?; @@ -1281,7 +1297,7 @@ async fn upload_file( "sha256": file_hash, "size": file_size, }); - + if let Some((count, bytes, extract_dir)) = extracted_info { response["extracted"] = serde_json::json!({ "count": count, @@ -1289,12 +1305,8 @@ async fn upload_file( "directory": extract_dir, }); } - - ( - StatusCode::CREATED, - Json(response), - ) - .into_response() + + (StatusCode::CREATED, Json(response)).into_response() } async fn upload_unlimited( @@ -1798,7 +1810,7 @@ async fn logout_handler(State(state): State, headers: HeaderMap) -> im let auth_header = headers .get("Authorization") .and_then(|h| h.to_str().ok()) - .and_then(|h| crate::auth::parse_auth_header(h)); + .and_then(crate::auth::parse_auth_header); match auth_header { Some(token) => { @@ -1824,7 +1836,7 @@ async fn verify_handler(State(state): State, headers: HeaderMap) -> im let auth_header = headers .get("Authorization") .and_then(|h| h.to_str().ok()) - .and_then(|h| crate::auth::parse_auth_header(h)); + .and_then(crate::auth::parse_auth_header); match auth_header { Some(token) => match state.auth.verify_token(&token) { @@ -1857,7 +1869,7 @@ fn verify_auth(state: &AppState, headers: &HeaderMap) -> Result match state.auth.verify_token(&token) { @@ -2039,7 +2051,7 @@ async fn edit_config_handler(Query(params): Query) -> impl Into match crate::config::MarkBaseConfig::load(config_path) { Ok(mut config) => { let old_value = config.get(¶ms.key).unwrap_or_default(); - + match config.set(¶ms.key, ¶ms.value) { Ok(_) => match config.validate() { Ok(_) => match config.save(config_path) { @@ -2056,7 +2068,7 @@ async fn edit_config_handler(Query(params): Query) -> impl Into ) { log::warn!("Failed to write audit log: {}", e); } - + (StatusCode::OK, Json(serde_json::json!({"ok": true}))).into_response() } Err(e) => ( @@ -2133,7 +2145,7 @@ async fn edit_s3_config_handler(Query(params): Query) -> impl I match crate::s3_config::S3Config::load_default() { Ok(mut config) => { let old_value = config.get(¶ms.key).unwrap_or_default(); - + match config.set(¶ms.key, ¶ms.value) { Ok(_) => match config.validate() { Ok(_) => match config.save("config/s3.toml") { @@ -2150,7 +2162,7 @@ async fn edit_s3_config_handler(Query(params): Query) -> impl I ) { log::warn!("Failed to write audit log: {}", e); } - + (StatusCode::OK, Json(serde_json::json!({"ok": true}))).into_response() } Err(e) => ( @@ -2343,7 +2355,7 @@ async fn audit_handler() -> Json { // Category View API handlers (Phase 1: 双视图管理) async fn get_all_categories_handler() -> impl IntoResponse { - let base_path = std::path::Path::new("/Users/accusys/markbase"); + let _base_path = std::path::Path::new("/Users/accusys/markbase"); match crate::category_view::get_all_categories() { Ok(response) => (StatusCode::OK, Json(response)).into_response(), Err(e) => ( @@ -2354,10 +2366,8 @@ async fn get_all_categories_handler() -> impl IntoResponse { } } -async fn get_category_detail_handler( - Path(category_name): Path, -) -> impl IntoResponse { - let base_path = std::path::Path::new("/Users/accusys/markbase"); +async fn get_category_detail_handler(Path(category_name): Path) -> impl IntoResponse { + let _base_path = std::path::Path::new("/Users/accusys/markbase"); match crate::category_view::get_category_detail(&category_name) { Ok(response) => (StatusCode::OK, Json(response)).into_response(), Err(e) => ( @@ -2369,7 +2379,7 @@ async fn get_category_detail_handler( } async fn get_all_series_handler() -> impl IntoResponse { - let base_path = std::path::Path::new("/Users/accusys/markbase"); + let _base_path = std::path::Path::new("/Users/accusys/markbase"); match crate::category_view::get_all_series() { Ok(response) => (StatusCode::OK, Json(response)).into_response(), Err(e) => ( @@ -2380,10 +2390,8 @@ async fn get_all_series_handler() -> impl IntoResponse { } } -async fn get_series_detail_handler( - Path(series_name): Path, -) -> impl IntoResponse { - let base_path = std::path::Path::new("/Users/accusys/markbase"); +async fn get_series_detail_handler(Path(series_name): Path) -> impl IntoResponse { + let _base_path = std::path::Path::new("/Users/accusys/markbase"); match crate::category_view::get_series_detail(&series_name) { Ok(response) => (StatusCode::OK, Json(response)).into_response(), Err(e) => ( @@ -2400,10 +2408,8 @@ struct SearchQuery { view: String, } -async fn search_files_handler( - Query(query): Query, -) -> impl IntoResponse { - let base_path = std::path::Path::new("/Users/accusys/markbase"); +async fn search_files_handler(Query(query): Query) -> impl IntoResponse { + let _base_path = std::path::Path::new("/Users/accusys/markbase"); match crate::category_view::search_files(&query.q, &query.view) { Ok(response) => (StatusCode::OK, Json(response)).into_response(), Err(e) => ( diff --git a/markbase-core/src/ssh_server/auth.rs b/markbase-core/src/ssh_server/auth.rs index ad678dc..dcc9e10 100644 --- a/markbase-core/src/ssh_server/auth.rs +++ b/markbase-core/src/ssh_server/auth.rs @@ -1,11 +1,11 @@ -use crate::ssh_server::packet::{SshPacket, PacketType}; -use std::io::Write; -use anyhow::{Result, anyhow}; +use crate::ssh_server::packet::{PacketType, SshPacket}; +use anyhow::{anyhow, Result}; +use base64::{engine::general_purpose, Engine as _}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use log::{info, warn, debug}; -use base64::{Engine as _, engine::general_purpose}; +use log::{debug, info, warn}; +use std::io::Write; -use ed25519_dalek::{VerifyingKey, Signature}; +use ed25519_dalek::{Signature, VerifyingKey}; use crate::provider::{DataProvider, ProviderError}; @@ -27,7 +27,11 @@ impl AuthHandler { } /// 处理SSH_MSG_USERAUTH_REQUEST(参考OpenSSH auth2.c: userauth_request()) - pub fn handle_userauth_request(&mut self, packet: &SshPacket, session_id: &[u8]) -> Result { + pub fn handle_userauth_request( + &mut self, + packet: &SshPacket, + session_id: &[u8], + ) -> Result { info!("Processing SSH_MSG_USERAUTH_REQUEST"); let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); @@ -41,7 +45,10 @@ impl AuthHandler { let service = read_ssh_string(&mut cursor)?; let method = read_ssh_string(&mut cursor)?; - info!("Auth request: user={}, service={}, method={}", user, service, method); + info!( + "Auth request: user={}, service={}, method={}", + user, service, method + ); if service != "ssh-connection" { warn!("Unsupported service: {}", service); @@ -62,18 +69,28 @@ impl AuthHandler { } /// 处理password认证(参考OpenSSH auth-passwd.c) - fn handle_password_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result { + fn handle_password_auth( + &mut self, + cursor: &mut std::io::Cursor<&[u8]>, + user: &str, + ) -> Result { info!("Handling password auth for user: {}", user); let change_password = cursor.read_u8()? != 0; if change_password { warn!("Password change not supported"); - return Ok(AuthResult::Failure("Password change not supported".to_string())); + return Ok(AuthResult::Failure( + "Password change not supported".to_string(), + )); } let password = read_ssh_string(cursor)?; - debug!("Password auth attempt: user={}, password length={}", user, password.len()); + debug!( + "Password auth attempt: user={}, password length={}", + user, + password.len() + ); match self.provider.check_password(user, &password) { Ok(true) => { @@ -88,9 +105,7 @@ impl AuthHandler { warn!("User not found: {}", msg); Ok(AuthResult::Failure("password,publickey".to_string())) } - Err(e) => { - Err(anyhow!("Password auth error: {}", e)) - } + Err(e) => Err(anyhow!("Password auth error: {}", e)), } } @@ -145,7 +160,12 @@ impl AuthHandler { let algorithm = read_ssh_string(cursor)?; let public_key_blob = read_ssh_string_bytes(cursor)?; - info!("Publickey auth: algorithm={}, blob_len={}, is_signed={}", algorithm, public_key_blob.len(), is_signed); + info!( + "Publickey auth: algorithm={}, blob_len={}, is_signed={}", + algorithm, + public_key_blob.len(), + is_signed + ); if !self.is_key_authorized(user, &algorithm, &public_key_blob)? { warn!("Public key not authorized for user: {}", user); @@ -160,14 +180,26 @@ impl AuthHandler { let signature_blob = read_ssh_string_bytes(cursor)?; - self.verify_signature(&algorithm, &public_key_blob, &signature_blob, user, service, session_id)?; + self.verify_signature( + &algorithm, + &public_key_blob, + &signature_blob, + user, + service, + session_id, + )?; info!("Publickey auth successful for user: {}", user); Ok(AuthResult::Success) } /// 检查public key是否在授权列表中(数据库优先,fallback到filesystem) - fn is_key_authorized(&self, user: &str, algorithm: &str, public_key_blob: &[u8]) -> Result { + fn is_key_authorized( + &self, + user: &str, + algorithm: &str, + public_key_blob: &[u8], + ) -> Result { // 1. 先检查数据库 match self.provider.get_public_keys(user) { Ok(keys) => { @@ -187,10 +219,12 @@ impl AuthHandler { Err(_) => match std::fs::read_to_string("data/authorized_keys") { Ok(c) => c, Err(_) => return Ok(false), - } + }, }; - Ok(content.lines().any(|line| public_key_matches_line(line, algorithm, public_key_blob))) + Ok(content + .lines() + .any(|line| public_key_matches_line(line, algorithm, public_key_blob))) } /// 验证Ed25519签名(RFC 4252 §7) @@ -246,7 +280,8 @@ impl AuthHandler { signed_data.write_all(public_key_blob)?; // 验证签名 - verifying_key.verify_strict(&signed_data, &signature) + verifying_key + .verify_strict(&signed_data, &signature) .map_err(|e| anyhow!("Ed25519 signature verification failed: {}", e)) } } @@ -270,10 +305,10 @@ fn parse_ed25519_verifying_key(public_key_blob: &[u8]) -> Result { if key_bytes.len() != 32 { return Err(anyhow!("Invalid Ed25519 key length: {}", key_bytes.len())); } - let key_array: [u8; 32] = key_bytes.try_into() + let key_array: [u8; 32] = key_bytes + .try_into() .map_err(|_| anyhow!("Invalid Ed25519 key data"))?; - VerifyingKey::from_bytes(&key_array) - .map_err(|e| anyhow!("Invalid Ed25519 key: {}", e)) + VerifyingKey::from_bytes(&key_array).map_err(|e| anyhow!("Invalid Ed25519 key: {}", e)) } /// 解析Ed25519签名blob(SSH格式 -> Signature) @@ -285,9 +320,13 @@ fn parse_ed25519_signature(signature_blob: &[u8]) -> Result { } let sig_bytes = read_ssh_string_bytes(&mut cursor)?; if sig_bytes.len() != 64 { - return Err(anyhow!("Invalid Ed25519 signature length: {}", sig_bytes.len())); + return Err(anyhow!( + "Invalid Ed25519 signature length: {}", + sig_bytes.len() + )); } - let sig_array: [u8; 64] = sig_bytes.try_into() + let sig_array: [u8; 64] = sig_bytes + .try_into() .map_err(|_| anyhow!("Invalid Ed25519 signature data"))?; Ok(Signature::from_bytes(&sig_array)) } @@ -305,7 +344,9 @@ fn public_key_matches_line(line: &str, algorithm: &str, public_key_blob: &[u8]) if parts[0] != algorithm { return false; } - base64_decode(parts[1]).map(|decoded| decoded == public_key_blob).unwrap_or(false) + base64_decode(parts[1]) + .map(|decoded| decoded == public_key_blob) + .unwrap_or(false) } fn read_ssh_string(reader: &mut R) -> Result { @@ -323,7 +364,8 @@ fn read_ssh_string_bytes(reader: &mut R) -> Result> { } fn base64_decode(input: &str) -> Result> { - general_purpose::STANDARD.decode(input) + general_purpose::STANDARD + .decode(input) .map_err(|e| anyhow!("Base64 decode error: {}", e)) } @@ -335,7 +377,10 @@ mod tests { #[test] fn test_userauth_success_packet() { let packet = AuthHandler::build_userauth_success().unwrap(); - assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_SUCCESS as u8); + assert_eq!( + packet.payload[0], + PacketType::SSH_MSG_USERAUTH_SUCCESS as u8 + ); } #[test] @@ -343,6 +388,9 @@ mod tests { let methods = vec!["password".to_string(), "publickey".to_string()]; let packet = AuthHandler::build_userauth_failure(&methods, false).unwrap(); - assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_FAILURE as u8); + assert_eq!( + packet.payload[0], + PacketType::SSH_MSG_USERAUTH_FAILURE as u8 + ); } } diff --git a/markbase-core/src/ssh_server/channel.rs b/markbase-core/src/ssh_server/channel.rs index e1ccd18..01ec8b0 100644 --- a/markbase-core/src/ssh_server/channel.rs +++ b/markbase-core/src/ssh_server/channel.rs @@ -1,23 +1,25 @@ // SSH Channel协议实现(Phase 6 + Phase 13端口转发) // 参考OpenSSH channel.c -use crate::ssh_server::packet::{SshPacket, PacketType}; -use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.3: 安全配置 -use crate::ssh_server::port_forward::{PortForwardManager, DirectTcpipChannel, ForwardedTcpipChannel}; // Phase 13.3 -use std::io::{Read, Write}; // 导入Write trait(OpenSSH标准) -use anyhow::{Result, anyhow}; +use crate::ssh_server::packet::{PacketType, SshPacket}; +use crate::ssh_server::port_forward::{ + DirectTcpipChannel, ForwardedTcpipChannel, PortForwardManager, +}; // Phase 13.3 +use crate::ssh_server::rsync_handler::RsyncHandler; // Phase 8: rsync handler +use crate::ssh_server::scp_handler::ScpHandler; // Phase 8: SCP handler +use crate::ssh_server::sftp_handler::SftpHandler; // Phase 7: SFTP handler +use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.3: 安全配置 +use anyhow::{anyhow, Result}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use log::{info, warn, debug, error}; +use log::{info, warn}; +use nix::fcntl::{fcntl, FcntlArg, OFlag}; // Phase 14: 非阻塞I/O(OpenSSH风格) +use nix::poll::{poll, PollFd, PollFlags}; use std::collections::{HashMap, VecDeque}; -use std::sync::{Arc, Mutex}; -use crate::ssh_server::sftp_handler::SftpHandler; // Phase 7: SFTP handler -use crate::ssh_server::scp_handler::ScpHandler; // Phase 8: SCP handler -use crate::ssh_server::rsync_handler::RsyncHandler; // Phase 8: rsync handler -use std::path::PathBuf; // Phase 7-8: Path for SFTP/SCP/rsync root directory -use std::process::{Child, ChildStdin, ChildStdout, ChildStderr}; // Phase 14: 交互式exec -use std::os::unix::io::{AsRawFd, RawFd}; // Phase 14: OpenSSH风格poll机制(需要RawFd) -use nix::fcntl::{fcntl, FcntlArg, OFlag}; // Phase 14: 非阻塞I/O(OpenSSH风格) -use nix::poll::{poll, PollFd, PollFlags}; // Phase 14: poll机制(OpenSSH风格) +use std::io::{Read, Write}; // 导入Write trait(OpenSSH标准) +use std::os::unix::io::{AsRawFd, RawFd}; // Phase 14: OpenSSH风格poll机制(需要RawFd) +use std::path::PathBuf; // Phase 7-8: Path for SFTP/SCP/rsync root directory +use std::process::{Child, ChildStderr, ChildStdin, ChildStdout}; // Phase 14: 交互式exec + // Phase 14: poll机制(OpenSSH风格) /// SSH Channel管理器(参考OpenSSH channel.c: struct channel) pub struct ChannelManager { @@ -32,13 +34,13 @@ pub struct ChannelManager { /// Phase 14: 交互式Exec进程管理(参考OpenSSH session.c: do_exec_no_pty) /// ⭐⭐⭐⭐⭐ OpenSSH风格:使用poll()替代thread::spawn(非阻塞I/O) pub struct ExecProcess { - pub child: Child, // 子进程(rsync/scp等) - pub stdin: Option, // stdin管道(SSH client → 子进程) - pub stdout: Option, // ⭐⭐⭐⭐⭐ stdout管道(直接poll,不使用thread) - pub stderr: Option, // ⭐⭐⭐⭐⭐ stderr管道(直接poll,不使用thread) - pub stdout_fd: RawFd, // ⭐⭐⭐⭐⭐ stdout RawFd(用于poll) - pub stderr_fd: RawFd, // ⭐⭐⭐⭐⭐ stderr RawFd(用于poll) - pub command: String, // ⭐⭐⭐⭐⭐ Phase 16.2: 存储exec命令(用于SCP检测) + pub child: Child, // 子进程(rsync/scp等) + pub stdin: Option, // stdin管道(SSH client → 子进程) + pub stdout: Option, // ⭐⭐⭐⭐⭐ stdout管道(直接poll,不使用thread) + pub stderr: Option, // ⭐⭐⭐⭐⭐ stderr管道(直接poll,不使用thread) + pub stdout_fd: RawFd, // ⭐⭐⭐⭐⭐ stdout RawFd(用于poll) + pub stderr_fd: RawFd, // ⭐⭐⭐⭐⭐ stderr RawFd(用于poll) + pub command: String, // ⭐⭐⭐⭐⭐ Phase 16.2: 存储exec命令(用于SCP检测) } impl ChannelManager { @@ -50,87 +52,112 @@ impl ChannelManager { home_dir, } } - + /// 处理SSH_MSG_CHANNEL_OPEN(参考OpenSSH channel.c: channel_open()) /// Phase 13.3: 支持direct-tcpip和forwarded-tcpip channel - pub fn handle_channel_open(&mut self, packet: &SshPacket, security_config: Option<&SshSecurityConfig>) -> Result { + pub fn handle_channel_open( + &mut self, + packet: &SshPacket, + security_config: Option<&SshSecurityConfig>, + ) -> Result { info!("Processing SSH_MSG_CHANNEL_OPEN"); - + let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); - + // Packet type let packet_type = cursor.read_u8()?; if packet_type != PacketType::SSH_MSG_CHANNEL_OPEN as u8 { return Err(anyhow!("Invalid packet type for CHANNEL_OPEN")); } - + // 读取channel类型(SSH string) let channel_type = read_ssh_string(&mut cursor)?; - + // 读取sender channel ID(u32) let sender_channel = cursor.read_u32::()?; - + // 读取初始窗口大小(u32) let initial_window_size = cursor.read_u32::()?; - + // 读取最大packet大小(u32) let maximum_packet_size = cursor.read_u32::()?; - - info!("Channel open: type={}, sender_channel={}, window={}, max_packet={}", - channel_type, sender_channel, initial_window_size, maximum_packet_size); - + + info!( + "Channel open: type={}, sender_channel={}, window={}, max_packet={}", + channel_type, sender_channel, initial_window_size, maximum_packet_size + ); + // Phase 13.3: 检查channel类型(支持session、direct-tcpip、forwarded-tcpip) match channel_type.as_str() { "session" => { // 传统的session channel(Phase 6) - self.handle_session_channel_open(sender_channel, initial_window_size, maximum_packet_size) + self.handle_session_channel_open( + sender_channel, + initial_window_size, + maximum_packet_size, + ) } - + "direct-tcpip" => { // Phase 13.3: Remote port forwarding channel info!("Received direct-tcpip channel open (Remote port forwarding)"); - self.handle_direct_tcpip_channel_open(packet, sender_channel, initial_window_size, maximum_packet_size, security_config) + self.handle_direct_tcpip_channel_open( + packet, + sender_channel, + initial_window_size, + maximum_packet_size, + security_config, + ) } - + "forwarded-tcpip" => { // Phase 13.3: Local port forwarding channel info!("Received forwarded-tcpip channel open (Local port forwarding)"); - self.handle_forwarded_tcpip_channel_open(packet, sender_channel, initial_window_size, maximum_packet_size) + self.handle_forwarded_tcpip_channel_open( + packet, + sender_channel, + initial_window_size, + maximum_packet_size, + ) } - + _ => { warn!("Unsupported channel type: {}", channel_type); self.build_channel_open_failure( sender_channel, - 3, // SSH_OPEN_UNKNOWN_CHANNEL_TYPE + 3, // SSH_OPEN_UNKNOWN_CHANNEL_TYPE "Unsupported channel type", - "en" + "en", ) } } } - /// 处理session channel open(Phase 6) - fn handle_session_channel_open(&mut self, sender_channel: u32, initial_window_size: u32, maximum_packet_size: u32) -> Result { + fn handle_session_channel_open( + &mut self, + sender_channel: u32, + initial_window_size: u32, + maximum_packet_size: u32, + ) -> Result { info!("Processing session channel open"); - + let server_channel = self.next_channel_id; self.next_channel_id += 1; - + let channel = Channel { server_channel, sender_channel, channel_type: "session".to_string(), - + // ⭐⭐⭐⭐⭐ Phase 15: Window Control(参考OpenSSH channels.h) - remote_window: initial_window_size, // 远端窗口(从 CHANNEL_OPEN packet 中读取) - remote_maxpacket: maximum_packet_size, // 远端最大 packet - local_window: 2097152, // 本地窗口(OpenSSH 默认 2MB) - local_window_max: 2097152, // 本地窗口最大值(同上) - local_consumed: 0, // 本地已消费的数据(初始为 0)⭐⭐⭐⭐⭐ - local_maxpacket: 32768, // 本地最大 packet(OpenSSH 默认 32KB) - + remote_window: initial_window_size, // 远端窗口(从 CHANNEL_OPEN packet 中读取) + remote_maxpacket: maximum_packet_size, // 远端最大 packet + local_window: 2097152, // 本地窗口(OpenSSH 默认 2MB) + local_window_max: 2097152, // 本地窗口最大值(同上) + local_consumed: 0, // 本地已消费的数据(初始为 0)⭐⭐⭐⭐⭐ + local_maxpacket: 32768, // 本地最大 packet(OpenSSH 默认 32KB) + // 旧字段(保留兼容) window_size: initial_window_size, maximum_packet_size, @@ -139,23 +166,31 @@ impl ChannelManager { sftp_handler: None, scp_handler: None, rsync_handler: None, - exec_process: None, // Phase 14: 交互式exec - sftp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.2修复:SFTP packet累积 - scp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.4修复:SCP packet累积 + exec_process: None, // Phase 14: 交互式exec + sftp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.2修复:SFTP packet累积 + scp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.4修复:SCP packet累积 direct_tcpip: None, forwarded_tcpip: None, }; - + self.channels.insert(server_channel, channel); - - info!("Session channel created: server_channel={}, sender_channel={}", server_channel, sender_channel); - - self.build_channel_open_confirmation(server_channel, sender_channel, initial_window_size, maximum_packet_size) + + info!( + "Session channel created: server_channel={}, sender_channel={}", + server_channel, sender_channel + ); + + self.build_channel_open_confirmation( + server_channel, + sender_channel, + initial_window_size, + maximum_packet_size, + ) } - + /// 处理direct-tcpip channel open(Phase 13.3: Remote port forwarding) fn handle_direct_tcpip_channel_open( - &mut self, + &mut self, packet: &SshPacket, sender_channel: u32, initial_window_size: u32, @@ -163,33 +198,36 @@ impl ChannelManager { security_config: Option<&SshSecurityConfig>, ) -> Result { info!("Processing direct-tcpip channel open"); - + // 解析direct-tcpip参数 let mut port_forward_manager = PortForwardManager::new(); let direct_tcpip = port_forward_manager.handle_direct_tcpip_channel(&packet.payload)?; - + // Phase 13.3: 安全配置验证 if let Some(security) = security_config { - if let Err(e) = security.validate_direct_tcpip_channel(&direct_tcpip.host_to_connect, direct_tcpip.port_to_connect) { + if let Err(e) = security.validate_direct_tcpip_channel( + &direct_tcpip.host_to_connect, + direct_tcpip.port_to_connect, + ) { warn!("direct-tcpip security validation failed: {}", e); return self.build_channel_open_failure( sender_channel, - 2, // SSH_OPEN_CONNECT_FAILED + 2, // SSH_OPEN_CONNECT_FAILED "Security validation failed", - "en" + "en", ); } info!("direct-tcpip security validation passed"); } - + let server_channel = self.next_channel_id; self.next_channel_id += 1; - + let channel = Channel { server_channel, sender_channel, channel_type: "direct-tcpip".to_string(), - + // ⭐⭐⭐⭐⭐ Phase 15: Window Control remote_window: initial_window_size, remote_maxpacket: maximum_packet_size, @@ -197,7 +235,7 @@ impl ChannelManager { local_window_max: 2097152, local_consumed: 0, local_maxpacket: 32768, - + window_size: initial_window_size, maximum_packet_size, state: ChannelState::Open, @@ -211,17 +249,36 @@ impl ChannelManager { direct_tcpip: Some(direct_tcpip), forwarded_tcpip: None, }; - + self.channels.insert(server_channel, channel); - - info!("direct-tcpip channel created: server_channel={}, host={}, port={}", - server_channel, - self.channels.get(&server_channel).unwrap().direct_tcpip.as_ref().unwrap().host_to_connect, - self.channels.get(&server_channel).unwrap().direct_tcpip.as_ref().unwrap().port_to_connect); - - self.build_channel_open_confirmation(server_channel, sender_channel, initial_window_size, maximum_packet_size) + + info!( + "direct-tcpip channel created: server_channel={}, host={}, port={}", + server_channel, + self.channels + .get(&server_channel) + .unwrap() + .direct_tcpip + .as_ref() + .unwrap() + .host_to_connect, + self.channels + .get(&server_channel) + .unwrap() + .direct_tcpip + .as_ref() + .unwrap() + .port_to_connect + ); + + self.build_channel_open_confirmation( + server_channel, + sender_channel, + initial_window_size, + maximum_packet_size, + ) } - + /// 处理forwarded-tcpip channel open(Phase 13.3: Local port forwarding) fn handle_forwarded_tcpip_channel_open( &mut self, @@ -231,19 +288,20 @@ impl ChannelManager { maximum_packet_size: u32, ) -> Result { info!("Processing forwarded-tcpip channel open"); - + // 解析forwarded-tcpip参数 let mut port_forward_manager = PortForwardManager::new(); - let forwarded_tcpip = port_forward_manager.handle_forwarded_tcpip_channel(&packet.payload)?; - + let forwarded_tcpip = + port_forward_manager.handle_forwarded_tcpip_channel(&packet.payload)?; + let server_channel = self.next_channel_id; self.next_channel_id += 1; - + let channel = Channel { server_channel, sender_channel, channel_type: "forwarded-tcpip".to_string(), - + // ⭐⭐⭐⭐⭐ Phase 15: Window Control remote_window: initial_window_size, remote_maxpacket: maximum_packet_size, @@ -251,7 +309,7 @@ impl ChannelManager { local_window_max: 2097152, local_consumed: 0, local_maxpacket: 32768, - + window_size: initial_window_size, maximum_packet_size, state: ChannelState::Open, @@ -259,57 +317,80 @@ impl ChannelManager { sftp_handler: None, scp_handler: None, rsync_handler: None, - exec_process: None, // Phase 14: 交互式exec - sftp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.2修复 - scp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.4修复 + exec_process: None, // Phase 14: 交互式exec + sftp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.2修复 + scp_input_buffer: Vec::new(), // ⭐⭐⭐⭐⭐ Phase 14.4修复 direct_tcpip: None, forwarded_tcpip: Some(forwarded_tcpip), }; - + self.channels.insert(server_channel, channel); - - info!("forwarded-tcpip channel created: server_channel={}, bind={}, originator={}", - server_channel, - self.channels.get(&server_channel).unwrap().forwarded_tcpip.as_ref().unwrap().bind_port, - self.channels.get(&server_channel).unwrap().forwarded_tcpip.as_ref().unwrap().originator_address); - - self.build_channel_open_confirmation(server_channel, sender_channel, initial_window_size, maximum_packet_size) + + info!( + "forwarded-tcpip channel created: server_channel={}, bind={}, originator={}", + server_channel, + self.channels + .get(&server_channel) + .unwrap() + .forwarded_tcpip + .as_ref() + .unwrap() + .bind_port, + self.channels + .get(&server_channel) + .unwrap() + .forwarded_tcpip + .as_ref() + .unwrap() + .originator_address + ); + + self.build_channel_open_confirmation( + server_channel, + sender_channel, + initial_window_size, + maximum_packet_size, + ) } /// 处理SSH_MSG_CHANNEL_REQUEST(参考OpenSSH channel.c: channel_request()) pub fn handle_channel_request(&mut self, packet: &SshPacket) -> Result> { info!("Processing SSH_MSG_CHANNEL_REQUEST"); - - let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准) - + + let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准) + // Packet type let packet_type = cursor.read_u8()?; if packet_type != PacketType::SSH_MSG_CHANNEL_REQUEST as u8 { return Err(anyhow!("Invalid packet type for CHANNEL_REQUEST")); } - + // 读取recipient channel(u32) let recipient_channel = cursor.read_u32::()?; - + // 读取请求类型(SSH string) let request_type = read_ssh_string(&mut cursor)?; - + // 读取want reply标志(boolean) let want_reply = cursor.read_u8()? != 0; - - info!("Channel request: channel={}, type={}, want_reply={}", - recipient_channel, request_type, want_reply); - + + info!( + "Channel request: channel={}, type={}, want_reply={}", + recipient_channel, request_type, want_reply + ); + // 处理不同请求类型(参考OpenSSH channel.c) if request_type == "exec" { - self.handle_exec_request(&mut cursor, recipient_channel, want_reply) // 移除?操作符(返回Option不是Result) + self.handle_exec_request(&mut cursor, recipient_channel, want_reply) + // 移除?操作符(返回Option不是Result) } else if request_type == "subsystem" { - self.handle_subsystem_request(&mut cursor, recipient_channel, want_reply) // 移除?操作符 + self.handle_subsystem_request(&mut cursor, recipient_channel, want_reply) + // 移除?操作符 } else if request_type == "shell" { - self.handle_shell_request(recipient_channel, want_reply) // 移除?操作符 + self.handle_shell_request(recipient_channel, want_reply) // 移除?操作符 } else if request_type == "env" { - self.handle_env_request(&mut cursor, recipient_channel, want_reply) // 移除?操作符 + self.handle_env_request(&mut cursor, recipient_channel, want_reply) // 移除?操作符 } else if request_type == "pty-req" { - self.handle_pty_request(&mut cursor, recipient_channel, want_reply) // 移除?操作符 + self.handle_pty_request(&mut cursor, recipient_channel, want_reply) // 移除?操作符 } else { warn!("Unsupported channel request: {}", request_type); if want_reply { @@ -319,60 +400,79 @@ impl ChannelManager { } } } - + /// 处理exec请求(参考OpenSSH channel.c: channel_request_exec() + session.c: do_exec_no_pty) - fn handle_exec_request(&mut self, cursor: &mut std::io::Cursor<&[u8]>, channel: u32, want_reply: bool) -> Result> { + fn handle_exec_request( + &mut self, + cursor: &mut std::io::Cursor<&[u8]>, + channel: u32, + want_reply: bool, + ) -> Result> { info!("Handling exec request for channel {}", channel); - + // 读取命令(SSH string) let command = read_ssh_string(cursor)?; - + info!("Exec command: {}", command); - + // Phase 14: 检测rsync/SCP命令,启动交互式进程 if command.starts_with("rsync --server") || command.contains("rsync") { - info!("⭐⭐⭐⭐⭐ [EXEC_REQUEST] Detected rsync command: {}", command); + info!( + "⭐⭐⭐⭐⭐ [EXEC_REQUEST] Detected rsync command: {}", + command + ); self.handle_rsync_exec(&command, channel)?; } else if command.starts_with("scp") || command.contains("scp -") { // ⭐⭐⭐⭐⭐ Phase 14.5: SCP命令处理(scp -t destination 或 scp -f source) - info!("⭐⭐⭐⭐⭐ [EXEC_REQUEST] Detected SCP command: {}", command); + info!( + "⭐⭐⭐⭐⭐ [EXEC_REQUEST] Detected SCP command: {}", + command + ); self.handle_scp_exec(&command, channel)?; } else { // Phase 6: 普通命令使用非交互式执行 let output = self.execute_command(&command)?; - + // 存储输出,等待后续发送CHANNEL_DATA if let Some(ch) = self.channels.get_mut(&channel) { ch.output_buffer = Some(output); } } - + if want_reply { Ok(Some(self.build_channel_success(channel)?)) } else { Ok(None) } } - + /// ⭐⭐⭐⭐⭐ Phase 16.5: rsync exec(使用真实rsync子进程,替代in-process handler) fn handle_rsync_exec(&mut self, command: &str, channel_id: u32) -> Result<()> { self.handle_interactive_exec(command, channel_id, "rsync") } - + /// Phase 14.5: 处理SCP交互式exec(scp -t destination 或 scp -f source) /// ⭐⭐⭐⭐⭐ OpenSSH风格:使用poll()替代thread::spawn(非阻塞I/O) fn handle_scp_exec(&mut self, command: &str, channel_id: u32) -> Result<()> { // ⭐⭐⭐⭐⭐ SCP和rsync共用相同的交互式exec逻辑 self.handle_interactive_exec(command, channel_id, "scp") } - + /// ⭐⭐⭐⭐⭐ Phase 14.6: 交互式exec通用处理(rsync/SCP共用) - fn handle_interactive_exec(&mut self, command: &str, channel_id: u32, process_type: &str) -> Result<()> { - use std::process::{Command, Stdio}; + fn handle_interactive_exec( + &mut self, + command: &str, + channel_id: u32, + process_type: &str, + ) -> Result<()> { use std::os::unix::io::AsRawFd; - - info!("⭐⭐⭐⭐⭐ [{}_EXEC_START] Starting interactive process: {}", process_type, command); - + use std::process::{Command, Stdio}; + + info!( + "⭐⭐⭐⭐⭐ [{}_EXEC_START] Starting interactive process: {}", + process_type, command + ); + // 启动子进程(相当于OpenSSH fork) // ⭐⭐⭐⭐⭐ Phase 17: 设置工作目录为用户home_dir(SFTPGo兼容) let home_dir = self.home_dir.clone(); @@ -384,93 +484,104 @@ impl ChannelManager { .stdout(Stdio::piped()) // ← 创建stdout管道(相当于pipe(pout)) .stderr(Stdio::piped()) // ← 创建stderr管道(相当于pipe(perr)) .spawn()?; - - info!("⭐⭐⭐⭐⭐ [CHILD_SPAWNED] Child process spawned, PID: {}", child.id()); - + + info!( + "⭐⭐⭐⭐⭐ [CHILD_SPAWNED] Child process spawned, PID: {}", + child.id() + ); + // 提取管道(相当于OpenSSH dup2) let stdin = child.stdin.take().ok_or(anyhow!("stdin take failed"))?; let stdout = child.stdout.take().ok_or(anyhow!("stdout take failed"))?; let stderr = child.stderr.take().ok_or(anyhow!("stderr take failed"))?; - + // ⭐⭐⭐⭐⭐ OpenSSH关键:设置非阻塞模式(fcntl O_NONBLOCK) let stdout_fd = stdout.as_raw_fd(); let stderr_fd = stderr.as_raw_fd(); - + info!("Setting stdout/stderr to non-blocking mode (OpenSSH style)"); fcntl(stdout_fd, FcntlArg::F_SETFL(OFlag::O_NONBLOCK))?; fcntl(stderr_fd, FcntlArg::F_SETFL(OFlag::O_NONBLOCK))?; - info!("Non-blocking I/O enabled for stdout (fd {}) and stderr (fd {})", stdout_fd, stderr_fd); - + info!( + "Non-blocking I/O enabled for stdout (fd {}) and stderr (fd {})", + stdout_fd, stderr_fd + ); + // ⭐⭐⭐⭐⭐ OpenSSH风格:不再使用thread::spawn,直接保留File对象用于poll // 存储到channel(相当于OpenSSH session_set_fds) if let Some(ch) = self.channels.get_mut(&channel_id) { ch.exec_process = Some(ExecProcess { child, stdin: Some(stdin), - stdout: Some(stdout), // ⭐⭐⭐⭐⭐ 直接保留File对象 - stderr: Some(stderr), // ⭐⭐⭐⭐⭐ 直接保留File对象 - stdout_fd, // ⭐⭐⭐⭐⭐ RawFd用于poll - stderr_fd, // ⭐⭐⭐⭐⭐ RawFd用于poll - command: command.to_string(), // ⭐⭐⭐⭐⭐ Phase 16.2: 存储exec命令(用于SCP检测) + stdout: Some(stdout), // ⭐⭐⭐⭐⭐ 直接保留File对象 + stderr: Some(stderr), // ⭐⭐⭐⭐⭐ 直接保留File对象 + stdout_fd, // ⭐⭐⭐⭐⭐ RawFd用于poll + stderr_fd, // ⭐⭐⭐⭐⭐ RawFd用于poll + command: command.to_string(), // ⭐⭐⭐⭐⭐ Phase 16.2: 存储exec命令(用于SCP检测) }); - info!("Interactive process stored for channel {} (poll-ready)", channel_id); + info!( + "Interactive process stored for channel {} (poll-ready)", + channel_id + ); } - + Ok(()) } - + /// 执行命令并捕获输出(Phase 6基础实现) fn execute_command(&self, command: &str) -> Result> { - use std::process::{Command, Stdio}; - + use std::process::Command; + info!("Executing command: {}", command); - + // 使用shell执行命令(参考OpenSSH session.c) - let output = Command::new("sh") - .arg("-c") - .arg(command) - .output()?; - + let output = Command::new("sh").arg("-c").arg(command).output()?; + // 返回stdout + stderr let mut result = output.stdout; result.extend_from_slice(&output.stderr); - + info!("Command output: {} bytes", result.len()); Ok(result) } - + /// 处理subsystem请求(参考OpenSSH channel.c: channel_request_subsystem()) - fn handle_subsystem_request(&mut self, cursor: &mut std::io::Cursor<&[u8]>, channel: u32, want_reply: bool) -> Result> { + fn handle_subsystem_request( + &mut self, + cursor: &mut std::io::Cursor<&[u8]>, + channel: u32, + want_reply: bool, + ) -> Result> { info!("Handling subsystem request for channel {}", channel); - + // 读取subsystem名称(SSH string) let subsystem = read_ssh_string(cursor)?; - + info!("Subsystem: {}", subsystem); - + // 检查subsystem支持(OpenSSH支持:sftp) if subsystem == "sftp" { info!("SFTP subsystem requested"); - + // Phase 7: 初始化SFTP handler(使用用户home目录,SFTPGo兼容) let root_dir = self.home_dir.clone(); - + // ⭐⭐⭐⭐⭐ Phase 4: 获取 client maxpack 限制(从 Channel 中获取) let maxpacket = if let Some(ch) = self.channels.get(&channel) { - ch.remote_maxpacket // 来自 SSH_MSG_CHANNEL_OPEN 的 maximum_packet_size + ch.remote_maxpacket // 来自 SSH_MSG_CHANNEL_OPEN 的 maximum_packet_size } else { - 32768 // OpenSSH 默认值(32KB) + 32768 // OpenSSH 默认值(32KB) }; - + let vfs = Box::new(crate::vfs::local_fs::LocalFs::new()); - let sftp_handler = SftpHandler::new(root_dir, vfs, maxpacket); // ⭐⭐⭐⭐⭐ Phase 4: 传入 maxpack - + let sftp_handler = SftpHandler::new(root_dir, vfs, maxpacket); // ⭐⭐⭐⭐⭐ Phase 4: 传入 maxpack + // 存储到channel if let Some(ch) = self.channels.get_mut(&channel) { ch.sftp_handler = Some(sftp_handler); info!("SFTP handler initialized for channel {}", channel); } - + if want_reply { Ok(Some(self.build_channel_success(channel)?)) } else { @@ -485,95 +596,119 @@ impl ChannelManager { } } } - + /// 处理shell请求(参考OpenSSH channel.c) - fn handle_shell_request(&mut self, channel: u32, want_reply: bool) -> Result> { + fn handle_shell_request( + &mut self, + channel: u32, + want_reply: bool, + ) -> Result> { info!("Handling shell request for channel {}", channel); - + // Phase 9将实现shell warn!("Shell not implemented in Phase 6"); - + if want_reply { Ok(Some(self.build_channel_failure(channel)?)) } else { Ok(None) } } - + /// 处理env请求(参考OpenSSH channel.c) - fn handle_env_request(&mut self, cursor: &mut std::io::Cursor<&[u8]>, channel: u32, want_reply: bool) -> Result> { + fn handle_env_request( + &mut self, + cursor: &mut std::io::Cursor<&[u8]>, + channel: u32, + want_reply: bool, + ) -> Result> { info!("Handling env request for channel {}", channel); - + // 读取环境变量名和值 let name = read_ssh_string(cursor)?; let value = read_ssh_string(cursor)?; - + info!("Env: {}={}", name, value); - + if want_reply { Ok(Some(self.build_channel_success(channel)?)) } else { Ok(None) } } - + /// 处理pty请求(参考OpenSSH channel.c) - fn handle_pty_request(&mut self, cursor: &mut std::io::Cursor<&[u8]>, channel: u32, want_reply: bool) -> Result> { + fn handle_pty_request( + &mut self, + cursor: &mut std::io::Cursor<&[u8]>, + channel: u32, + want_reply: bool, + ) -> Result> { info!("Handling pty request for channel {}", channel); - + // 读取terminal类型(SSH string) let term = read_ssh_string(cursor)?; - -// 读取窗口大小(4个uint32) + + // 读取窗口大小(4个uint32) let width = cursor.read_u32::()?; let height = cursor.read_u32::()?; let _pixel_width = cursor.read_u32::()?; let _pixel_height = cursor.read_u32::()?; - + // 读取terminal modes(SSH string格式) let modes_len = cursor.read_u32::()?; let mut modes = vec![0u8; modes_len as usize]; cursor.read_exact(&mut modes)?; - - info!("PTY: term={}, width={}, height={}, modes_len={}", term, width, height, modes_len); - + + info!( + "PTY: term={}, width={}, height={}, modes_len={}", + term, width, height, modes_len + ); + if want_reply { Ok(Some(self.build_channel_success(channel)?)) } else { Ok(None) } } - + /// 处理SSH_MSG_CHANNEL_DATA(参考OpenSSH channel.c: channel_input_data()) pub fn handle_channel_data(&mut self, packet: &SshPacket) -> Result> { info!("Processing SSH_MSG_CHANNEL_DATA"); - + let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); - + // Packet type let packet_type = cursor.read_u8()?; if packet_type != PacketType::SSH_MSG_CHANNEL_DATA as u8 { return Err(anyhow!("Invalid packet type for CHANNEL_DATA")); } - + // 读取recipient channel let recipient_channel = cursor.read_u32::()?; - + // 读取数据(SSH string) let data_length = cursor.read_u32::()?; let mut data = vec![0u8; data_length as usize]; cursor.read_exact(&mut data)?; - - info!("Channel data: channel={}, length={}", recipient_channel, data.len()); - info!("Channel data content (first 20 bytes): {:?}", &data[..std::cmp::min(20, data.len())]); - + + info!( + "Channel data: channel={}, length={}", + recipient_channel, + data.len() + ); + info!( + "Channel data content (first 20 bytes): {:?}", + &data[..std::cmp::min(20, data.len())] + ); + // Phase 14: 检查是否是交互式exec进程 if let Some(channel) = self.channels.get_mut(&recipient_channel) { if let Some(exec_process) = &mut channel.exec_process { info!("Interactive exec process detected, forwarding data to stdin"); info!("Channel data content: {:?}", &data); info!("Child PID: {:?}", exec_process.child.id()); - + // 检查子进程状态 match exec_process.child.try_wait() { Ok(Some(status)) => { @@ -586,7 +721,7 @@ impl ChannelManager { warn!("Failed to check child status: {}", e); } } - + // 转发数据到子进程stdin(相当于OpenSSH写fdin) if let Some(stdin) = &mut exec_process.stdin { use std::io::Write; @@ -595,7 +730,7 @@ impl ChannelManager { stdin.flush()?; info!("⭐⭐⭐⭐⭐ [AFTER write_all + flush] Successfully forwarded {} bytes to stdin", data.len()); } - + // ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐⭐ Critical修复:Window Control - 减少 local_window // OpenSSH channel.c: channel_input_data() 中 c->local_window -= data_len if let Some(channel) = self.channels.get_mut(&recipient_channel) { @@ -603,47 +738,64 @@ impl ChannelManager { info!("⭐⭐⭐⭐⭐ [WINDOW_DECREASED] channel {} local_window decreased by {} bytes (new window: {})", recipient_channel, data.len(), channel.local_window); } - + // ⭐⭐⭐⭐⭐ OpenSSH风格:不等待,直接返回None(主循环会通过poll处理stdout) info!("stdin forwarded, returning None (main loop will poll stdout/stderr)"); - + // ⭐⭐⭐⭐⭐ Phase 15: 更新 local_consumed(跟踪已消费的数据) if let Some(channel) = self.channels.get_mut(&recipient_channel) { channel.local_consumed += data.len() as u32; - info!("⭐⭐⭐⭐⭐ [LOCAL_CONSUMED] channel {} consumed {} bytes (total: {})", - recipient_channel, data.len(), channel.local_consumed); - + info!( + "⭐⭐⭐⭐⭐ [LOCAL_CONSUMED] channel {} consumed {} bytes (total: {})", + recipient_channel, + data.len(), + channel.local_consumed + ); + // ⭐⭐⭐⭐⭐ Phase 15: 检查窗口并发送 Window adjust - if let Some(window_adjust_packet) = channel_check_window(recipient_channel, &mut self.channels) { + if let Some(window_adjust_packet) = + channel_check_window(recipient_channel, &mut self.channels) + { // 返回 window adjust packet(主循环会发送) return Ok(Some(window_adjust_packet)); } } - + return Ok(None); } - + // ⭐⭐⭐⭐⭐ Phase 16.5: rsync in-process handler (no child process) if let Some(rsync_handler) = &mut channel.rsync_handler { - info!("⭐⭐⭐⭐⭐ [RSYNC_DATA] Feeding {} bytes to RsyncHandler", data.len()); + info!( + "⭐⭐⭐⭐⭐ [RSYNC_DATA] Feeding {} bytes to RsyncHandler", + data.len() + ); let data_clone = data.clone(); rsync_handler.feed(&data_clone)?; let output = rsync_handler.drain_output(); - info!("⭐⭐⭐⭐⭐ [RSYNC_DATA] RsyncHandler produced {} bytes output, done={}", - output.len(), rsync_handler.is_done()); + info!( + "⭐⭐⭐⭐⭐ [RSYNC_DATA] RsyncHandler produced {} bytes output, done={}", + output.len(), + rsync_handler.is_done() + ); // ⭐⭐⭐⭐⭐ Phase 15: Window Control - decrease local_window channel.local_window -= data.len() as u32; channel.local_consumed += data.len() as u32; // Check for window adjust - if let Some(window_adjust_packet) = channel_check_window(recipient_channel, &mut self.channels) { + if let Some(window_adjust_packet) = + channel_check_window(recipient_channel, &mut self.channels) + { return Ok(Some(window_adjust_packet)); } if !output.is_empty() { - info!("⭐⭐⭐⭐⭐ [RSYNC_DATA] Returning {} bytes as CHANNEL_DATA", output.len()); + info!( + "⭐⭐⭐⭐⭐ [RSYNC_DATA] Returning {} bytes as CHANNEL_DATA", + output.len() + ); return Ok(Some(self.build_channel_data(recipient_channel, &output)?)); } @@ -654,15 +806,18 @@ impl ChannelManager { // Extract SFTP result from channel borrow, then send outside let sftp_result = if let Some(sftp_handler) = &mut channel.sftp_handler { info!("Processing SFTP request ({} bytes)", data.len()); - + // ⭐⭐⭐⭐⭐ Window Control: decrease local_window channel.local_window -= data.len() as u32; channel.local_consumed += data.len() as u32; - + // ⭐⭐⭐⭐⭐ Critical修复:累积SFTP packet数据 channel.sftp_input_buffer.extend_from_slice(&data); - info!("SFTP buffer accumulated: {} bytes total", channel.sftp_input_buffer.len()); - + info!( + "SFTP buffer accumulated: {} bytes total", + channel.sftp_input_buffer.len() + ); + // ⭐⭐⭐⭐⭐ Process ALL complete SFTP packets from buffer (not just one) let mut all_responses: Vec> = Vec::new(); loop { @@ -670,68 +825,77 @@ impl ChannelManager { info!("SFTP buffer too short for length field, waiting for more data"); break; } - + let sftp_length = u32::from_be_bytes([ channel.sftp_input_buffer[0], channel.sftp_input_buffer[1], channel.sftp_input_buffer[2], - channel.sftp_input_buffer[3] + channel.sftp_input_buffer[3], ]) as usize; info!("SFTP packet length field: {}", sftp_length); - + let expected_total = 4 + sftp_length; if channel.sftp_input_buffer.len() < expected_total { info!("SFTP packet incomplete: expected {} bytes, have {} bytes in buffer, waiting for more", expected_total, channel.sftp_input_buffer.len()); break; } - + let sftp_packet = channel.sftp_input_buffer[4..expected_total].to_vec(); - info!("SFTP packet complete: {} bytes, processing", sftp_packet.len()); - + info!( + "SFTP packet complete: {} bytes, processing", + sftp_packet.len() + ); + let response = sftp_handler.handle_request(&sftp_packet)?; info!("SFTP response: {} bytes", response.len()); - + if channel.sftp_input_buffer.len() > expected_total { let remaining = channel.sftp_input_buffer[expected_total..].to_vec(); channel.sftp_input_buffer = remaining; - info!("SFTP buffer has remaining {} bytes after processing", channel.sftp_input_buffer.len()); + info!( + "SFTP buffer has remaining {} bytes after processing", + channel.sftp_input_buffer.len() + ); } else { channel.sftp_input_buffer.clear(); info!("SFTP buffer cleared after processing"); } - + all_responses.push(response); } - + Some(all_responses) } else { None }; - + if let Some(responses) = sftp_result { // ⭐⭐⭐⭐⭐ Channel borrow is dropped; now we can use self freely - + // All responses except the last go to pending_packets for i in 0..responses.len().saturating_sub(1) { let pending = self.build_channel_data(recipient_channel, &responses[i])?; self.pending_packets.push_back(pending); } - + // Last response is returned (possibly with WINDOW_ADJUST) if let Some(last_response) = responses.into_iter().last() { // ⭐⭐⭐⭐⭐ Check window adjust (re-borrow channel briefly) - let (needs_window, consumed) = if let Some(ch) = self.channels.get_mut(&recipient_channel) { - let window_used = ch.local_window_max - ch.local_window; - let need = (window_used > ch.local_maxpacket * 3) || - (ch.local_window < ch.local_window_max / 2); - (need, ch.local_consumed) - } else { - (false, 0) - }; - + let (needs_window, consumed) = + if let Some(ch) = self.channels.get_mut(&recipient_channel) { + let window_used = ch.local_window_max - ch.local_window; + let need = (window_used > ch.local_maxpacket * 3) + || (ch.local_window < ch.local_window_max / 2); + (need, ch.local_consumed) + } else { + (false, 0) + }; + if needs_window && consumed > 0 { - info!("⭐⭐⭐⭐⭐ [SFTP_WINDOW] Sending WINDOW_ADJUST before SFTP response"); + info!( + "⭐⭐⭐⭐⭐ [SFTP_WINDOW] Sending WINDOW_ADJUST before SFTP response" + ); let window_adjust = build_window_adjust(recipient_channel, consumed); // Update window state if let Some(ch) = self.channels.get_mut(&recipient_channel) { @@ -750,13 +914,19 @@ impl ChannelManager { self.pending_packets.push_back(sftp_packet); return Ok(None); } - return Ok(Some(self.build_channel_data(recipient_channel, &last_response)?)); + return Ok(Some( + self.build_channel_data(recipient_channel, &last_response)?, + )); } else { // No SFTP packets were complete, but maybe we need window adjust if let Some(ch) = self.channels.get_mut(&recipient_channel) { let window_used = ch.local_window_max - ch.local_window; - if (window_used > ch.local_maxpacket * 3 || ch.local_window < ch.local_window_max / 2) && ch.local_consumed > 0 { - let window_adjust = build_window_adjust(recipient_channel, ch.local_consumed); + if (window_used > ch.local_maxpacket * 3 + || ch.local_window < ch.local_window_max / 2) + && ch.local_consumed > 0 + { + let window_adjust = + build_window_adjust(recipient_channel, ch.local_consumed); ch.local_window += ch.local_consumed; ch.local_consumed = 0; self.pending_packets.push_back(window_adjust); @@ -766,11 +936,11 @@ impl ChannelManager { } } } - + // 如果不是SFTP或exec_process,返回None Ok(None) } - + /// ⭐⭐⭐⭐⭐ Phase 13.5: 处理 client 发送的 SSH_MSG_CHANNEL_WINDOW_ADJUST pub fn adjust_remote_window(&mut self, recipient_channel: u32, bytes_to_add: u32) { if let Some(channel) = self.channels.get_mut(&recipient_channel) { @@ -779,17 +949,22 @@ impl ChannelManager { recipient_channel, bytes_to_add, channel.remote_window); } } - + /// Phase 14: 构建SSH_MSG_CHANNEL_EXTENDED_DATA(参考OpenSSH channel.c) - fn build_channel_extended_data(&self, channel: u32, data_type: u32, data: &[u8]) -> Result { + fn build_channel_extended_data( + &self, + channel: u32, + data_type: u32, + data: &[u8], + ) -> Result { let mut buffer = Vec::new(); - + buffer.write_u8(PacketType::SSH_MSG_CHANNEL_EXTENDED_DATA as u8)?; buffer.write_u32::(channel)?; - buffer.write_u32::(data_type)?; // 1 = stderr, 2 = exit status + buffer.write_u32::(data_type)?; // 1 = stderr, 2 = exit status buffer.write_u32::(data.len() as u32)?; buffer.write_all(data)?; - + Ok(SshPacket { packet_length: 0, padding_length: 0, @@ -797,28 +972,28 @@ impl ChannelManager { padding: Vec::new(), }) } - + /// 处理SSH_MSG_CHANNEL_CLOSE(参考OpenSSH channel.c: channel_input_close()) pub fn handle_channel_close(&mut self, packet: &SshPacket) -> Result> { info!("Processing SSH_MSG_CHANNEL_CLOSE"); - - let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准) - + + let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准) + // Packet type let packet_type = cursor.read_u8()?; if packet_type != PacketType::SSH_MSG_CHANNEL_CLOSE as u8 { return Err(anyhow!("Invalid packet type for CHANNEL_CLOSE")); } - + // 读取recipient channel let recipient_channel = cursor.read_u32::()?; - + info!("Channel close: channel={}", recipient_channel); - + // 移除channel(参考OpenSSH channel.c) if let Some(channel) = self.channels.remove(&recipient_channel) { info!("Channel {} removed", recipient_channel); - + // 发送SSH_MSG_CHANNEL_CLOSE回应 Ok(Some(self.build_channel_close(channel.sender_channel)?)) } else { @@ -826,7 +1001,7 @@ impl ChannelManager { Ok(None) } } - + /// 构建SSH_MSG_CHANNEL_OPEN_CONFIRMATION(参考OpenSSH channel.c) fn build_channel_open_confirmation( &self, @@ -836,25 +1011,25 @@ impl ChannelManager { packet_size: u32, ) -> Result { let mut payload = Vec::new(); - + // Packet type payload.write_u8(PacketType::SSH_MSG_CHANNEL_OPEN_CONFIRMATION as u8)?; - + // Server channel number payload.write_u32::(server_channel)?; - + // Sender channel number payload.write_u32::(sender_channel)?; - + // Initial window size payload.write_u32::(window_size)?; - + // Maximum packet size payload.write_u32::(packet_size)?; - + Ok(SshPacket::new(payload)) } - + /// 构建SSH_MSG_CHANNEL_OPEN_FAILURE(参考OpenSSH channel.c) fn build_channel_open_failure( &self, @@ -864,81 +1039,84 @@ impl ChannelManager { language: &str, ) -> Result { let mut payload = Vec::new(); - + // Packet type payload.write_u8(PacketType::SSH_MSG_CHANNEL_OPEN_FAILURE as u8)?; - + // Sender channel number payload.write_u32::(sender_channel)?; - + // Reason code payload.write_u32::(reason_code)?; - + // Description(SSH string) payload.write_u32::(description.len() as u32)?; payload.write_all(description.as_bytes())?; - + // Language(SSH string) payload.write_u32::(language.len() as u32)?; payload.write_all(language.as_bytes())?; - + Ok(SshPacket::new(payload)) } - + /// 构建SSH_MSG_CHANNEL_SUCCESS(参考OpenSSH channel.c) fn build_channel_success(&self, channel: u32) -> Result { let mut payload = Vec::new(); - + payload.write_u8(PacketType::SSH_MSG_CHANNEL_SUCCESS as u8)?; payload.write_u32::(channel)?; - + Ok(SshPacket::new(payload)) } - + /// 构建SSH_MSG_CHANNEL_FAILURE(参考OpenSSH channel.c) fn build_channel_failure(&self, channel: u32) -> Result { let mut payload = Vec::new(); - + payload.write_u8(PacketType::SSH_MSG_CHANNEL_FAILURE as u8)?; payload.write_u32::(channel)?; - + Ok(SshPacket::new(payload)) } - + /// 构建SSH_MSG_CHANNEL_CLOSE(参考OpenSSH channel.c) pub fn build_channel_close(&self, channel: u32) -> Result { let mut payload = Vec::new(); - + payload.write_u8(PacketType::SSH_MSG_CHANNEL_CLOSE as u8)?; payload.write_u32::(channel)?; - + Ok(SshPacket::new(payload)) } - + /// 构建SSH_MSG_CHANNEL_DATA(Phase 6新增) pub fn build_channel_data(&self, channel: u32, data: &[u8]) -> Result { info!("⭐⭐⭐⭐⭐ [build_channel_data] Building SSH_MSG_CHANNEL_DATA: channel={}, data_len={}", channel, data.len()); let mut payload = Vec::new(); - + payload.write_u8(PacketType::SSH_MSG_CHANNEL_DATA as u8)?; payload.write_u32::(channel)?; payload.write_u32::(data.len() as u32)?; payload.write_all(data)?; - - info!("⭐⭐⭐⭐⭐ [build_channel_data] Packet built successfully, payload_len={}", payload.len()); + + info!( + "⭐⭐⭐⭐⭐ [build_channel_data] Packet built successfully, payload_len={}", + payload.len() + ); Ok(SshPacket::new(payload)) } - + /// 构建SSH_MSG_CHANNEL_EOF(Phase 6新增) pub fn build_channel_eof(&self, channel: u32) -> Result { let mut payload = Vec::new(); - + payload.write_u8(PacketType::SSH_MSG_CHANNEL_EOF as u8)?; payload.write_u32::(channel)?; - + Ok(SshPacket::new(payload)) } - + /// 获取有输出待发送的channel ID(Phase 6新增) pub fn get_channel_with_output(&self) -> Option { for (&id, channel) in &self.channels { @@ -948,7 +1126,7 @@ impl ChannelManager { } None } - + /// ⭐⭐⭐⭐⭐ Phase 14.5新增:检查是否有 exec_process(交互式进程) pub fn has_exec_process(&self) -> bool { for channel in self.channels.values() { @@ -958,7 +1136,7 @@ impl ChannelManager { } false } - + /// Phase 17: 关闭所有子进程stdin(收到CHANNEL_EOF时调用) /// SCP upload需要:scp -t 等待EOF on stdin才知道数据传输完毕 pub fn close_child_stdin(&mut self) { @@ -968,13 +1146,16 @@ impl ChannelManager { if let Some(exec) = &mut channel.exec_process { if let Some(stdin) = exec.stdin.take() { drop(stdin); - info!("⭐⭐⭐⭐⭐ [CHANNEL_EOF] Closed child stdin (channel {})", id); + info!( + "⭐⭐⭐⭐⭐ [CHANNEL_EOF] Closed child stdin (channel {})", + id + ); } } } } } - + /// 获取channel输出(Phase 6新增) pub fn get_channel_output(&mut self, channel_id: u32) -> Option> { if let Some(channel) = self.channels.get_mut(&channel_id) { @@ -983,19 +1164,20 @@ impl ChannelManager { None } } - + /// 移除channel(Phase 6新增) pub fn remove_channel(&mut self, channel_id: u32) { self.channels.remove(&channel_id); } - + /// Phase 14: OpenSSH风格poll机制(使用nix::poll监听stdout/stderr fd) /// ⭐⭐⭐⭐⭐ 关键:非阻塞读取数据,不等待子进程完成 /// ⭐⭐⭐⭐⭐ Phase 14.2: 处理child exited(发送EOF + CLOSE) /// 参考:OpenSSH session.c: do_exec_no_pty() pub fn handle_child_exited(&mut self) -> Result> { // 1. 收集需要处理的channel IDs (exec_process OR rsync_handler) - let channel_ids: Vec = self.channels + let channel_ids: Vec = self + .channels .iter() .filter_map(|(id, channel)| { if channel.exec_process.is_some() || channel.rsync_handler.is_some() { @@ -1005,17 +1187,17 @@ impl ChannelManager { } }) .collect(); - + // 2. 构建packets(避免borrow冲突) let mut packets = Vec::new(); for channel_id in &channel_ids { let eof_packet = self.build_channel_eof(*channel_id)?; packets.push(eof_packet); - + let close_packet = self.build_channel_close(*channel_id)?; packets.push(close_packet); } - + // 3. 清除exec_process + rsync_handler(mutable borrow) for channel_id in &channel_ids { if let Some(channel) = self.channels.get_mut(channel_id) { @@ -1023,72 +1205,72 @@ impl ChannelManager { channel.rsync_handler = None; } } - + if !channel_ids.is_empty() { - info!("Child/rsync exited, sent EOF + CLOSE for {} channels", channel_ids.len()); + info!( + "Child/rsync exited, sent EOF + CLOSE for {} channels", + channel_ids.len() + ); } - + Ok(packets) } - + /// ⭐⭐⭐⭐⭐ Phase 14.2: OpenSSH统一poll + child进程状态检测 /// 参考:OpenSSH session.c: do_exec_no_pty() + channel.c: channel_handle_fd() - /// + /// /// 关键改进(Phase 14.2): /// - 单次poll()同时监听client socket和子进程输出 /// - timeout 10ms(非阻塞) /// - **添加child进程状态检测**(防止无限spinning)⭐⭐⭐⭐⭐ /// - **添加max_poll_iterations限制**(最多100次,1秒) /// - 返回(stdout_packets, client_has_data, child_exited) - pub fn poll_exec_stdout_and_client(&mut self, stream: &std::net::TcpStream) -> Result<(Option>, bool, bool)> { - use std::io::Read; - use std::os::unix::io::{BorrowedFd, AsRawFd}; + pub fn poll_exec_stdout_and_client( + &mut self, + stream: &std::net::TcpStream, + ) -> Result<(Option>, bool, bool)> { use nix::poll::{poll, PollFd, PollFlags}; - + use std::io::Read; + use std::os::unix::io::{AsRawFd, BorrowedFd}; + // 收集所有需要poll的fd let mut poll_fds_vec = Vec::new(); let mut client_has_data = false; let mut child_exited = false; - + // 1. 添加client socket fd(监听stdin数据) let client_fd = stream.as_raw_fd(); - let client_poll_fd = unsafe { - BorrowedFd::borrow_raw(client_fd) - }; + let client_poll_fd = unsafe { BorrowedFd::borrow_raw(client_fd) }; poll_fds_vec.push(PollFd::new(client_poll_fd, PollFlags::POLLIN)); - let client_fd_idx = 0; // client fd总是第一个 - + let client_fd_idx = 0; // client fd总是第一个 + // 2. 添加所有channel的stdout/stderr fd - let mut channel_fds_map: HashMap = HashMap::new(); // channel_id -> (stdout_idx, stderr_idx) - let mut channel_ids_vec = Vec::new(); // 用于后续child状态检查 - + let mut channel_fds_map: HashMap = HashMap::new(); // channel_id -> (stdout_idx, stderr_idx) + let mut channel_ids_vec = Vec::new(); // 用于后续child状态检查 + for (channel_id, channel) in &self.channels { if let Some(exec_process) = &channel.exec_process { channel_ids_vec.push(*channel_id); - + // stdout fd if let Some(_stdout) = &exec_process.stdout { - let stdout_poll_fd = unsafe { - BorrowedFd::borrow_raw(exec_process.stdout_fd) - }; + let stdout_poll_fd = unsafe { BorrowedFd::borrow_raw(exec_process.stdout_fd) }; poll_fds_vec.push(PollFd::new(stdout_poll_fd, PollFlags::POLLIN)); } - + // stderr fd if let Some(_stderr) = &exec_process.stderr { - let stderr_poll_fd = unsafe { - BorrowedFd::borrow_raw(exec_process.stderr_fd) - }; + let stderr_poll_fd = unsafe { BorrowedFd::borrow_raw(exec_process.stderr_fd) }; poll_fds_vec.push(PollFd::new(stderr_poll_fd, PollFlags::POLLIN)); } - + // 记录索引(相对于client_fd_idx) let stdout_idx = poll_fds_vec.len() - 2; let stderr_idx = poll_fds_vec.len() - 1; channel_fds_map.insert(*channel_id, (stdout_idx, stderr_idx)); } } - + if poll_fds_vec.len() == 1 { // 只有client fd,没有exec_process // ⭐⭐⭐⭐⭐ Phase 16.5: 检查rsync handler的pending output @@ -1103,16 +1285,20 @@ impl ChannelManager { let out = rsync.drain_output(); if !out.is_empty() { let sid = channel.server_channel; - info!("⭐⭐⭐⭐⭐ [RSYNC_POLL] {} bytes pending from rsync handler", out.len()); + info!( + "⭐⭐⭐⭐⭐ [RSYNC_POLL] {} bytes pending from rsync handler", + out.len() + ); rsync_items.push((sid, out)); } } } // Check rsync done (immutable borrow) - rsync_is_done = self.channels.values().any(|ch| { - ch.rsync_handler.as_ref().map_or(false, |r| r.is_done()) - }); + rsync_is_done = self + .channels + .values() + .any(|ch| ch.rsync_handler.as_ref().is_some_and(|r| r.is_done())); // Directly poll client match poll(&mut poll_fds_vec, 10u16) { @@ -1141,30 +1327,36 @@ impl ChannelManager { return Ok((None, client_has_data, rsync_is_done)); } - + // ⭐⭐⭐⭐⭐ Phase 16.4修复:增加poll轮询限制(支持大文件传输) // 最多轮询2000次(200秒),poll timeout从10ms改到100ms // 修复:从500改到2000,支持50MB+文件传输(预计可传输500MB+) let max_poll_iterations = 2000; let mut poll_iteration = 0; let mut found_data = false; - let mut stdin_closed = false; // ⭐⭐⭐⭐⭐ 新增:跟踪stdin是否已关闭 - + let mut stdin_closed = false; // ⭐⭐⭐⭐⭐ 新增:跟踪stdin是否已关闭 + for iteration in 0..max_poll_iterations { poll_iteration = iteration; - + // ⭐⭐⭐⭐⭐ Phase 16.2.1优化:增加poll timeout(减少iteration overhead) // 每50次轮询记录一次日志(从10改到50,减少噪音) if iteration % 50 == 0 { - info!("Polling {} fds (iteration {} of {}, stdin_closed={})", poll_fds_vec.len(), iteration, max_poll_iterations, stdin_closed); + info!( + "Polling {} fds (iteration {} of {}, stdin_closed={})", + poll_fds_vec.len(), + iteration, + max_poll_iterations, + stdin_closed + ); } - + // ⭐⭐⭐⭐⭐ Phase 16.2.1优化:增加poll timeout(减少iteration overhead) match poll(&mut poll_fds_vec, 100u16) { Ok(n) if n > 0 => { info!("{} fds have data available (iteration {})", n, iteration); found_data = true; - break; // 有数据,立即处理 + break; // 有数据,立即处理 } Ok(0) => { // timeout,无数据 @@ -1176,9 +1368,12 @@ impl ChannelManager { if let Some(exec_process) = &mut channel.exec_process { match exec_process.child.try_wait() { Ok(Some(status)) => { - info!("Child process exited (channel {}, status: {:?})", channel_id, status); + info!( + "Child process exited (channel {}, status: {:?})", + channel_id, status + ); child_exited = true; - + // ⭐⭐⭐⭐⭐ Child exited,读取剩余stdout(如果有) if let Some(stdout) = &mut exec_process.stdout { let mut buffer = vec![0u8; 32768]; @@ -1186,32 +1381,44 @@ impl ChannelManager { Ok(n) if n > 0 => { info!("Read {} final bytes from stdout (child exited)", n); // 构建packet并返回 - let packet = self.build_channel_data(*channel_id, &buffer[..n])?; - return Ok((Some(vec![packet]), false, true)); + let packet = self.build_channel_data( + *channel_id, + &buffer[..n], + )?; + return Ok(( + Some(vec![packet]), + false, + true, + )); } _ => {} } } - + // 没有剩余数据,返回child_exited标志 return Ok((None, false, true)); } Ok(None) => { // Child still running(正常) info!("Child still running (channel {}, iteration {}, stdin_closed={})", channel_id, iteration, stdin_closed); - + // ⭐⭐⭐⭐⭐ Phase 16.4修复:增加stdin超时机制(支持大文件传输) // 如果stdin未关闭,且超过1500次poll(150s)无数据 // 强制关闭stdin,发送EOF给SCP/rsync // ⭐⭐⭐⭐⭐ Phase 16.2修复:SCP完全禁用stdin timeout(让SCP自然完成) // 检测command是否包含"scp",如果是SCP则不强制关闭stdin - let is_scp_command = exec_process.command.contains("scp"); - - if !stdin_closed && !is_scp_command && iteration >= 1500 && exec_process.stdin.is_some() { + let is_scp_command = + exec_process.command.contains("scp"); + + if !stdin_closed + && !is_scp_command + && iteration >= 1500 + && exec_process.stdin.is_some() + { info!("⭐⭐⭐⭐⭐ Forcing stdin close after {} iterations ({} ms) - sending EOF to rsync (SCP excluded)", iteration, iteration * 100); - exec_process.stdin = None; // Drop stdin,发送EOF + exec_process.stdin = None; // Drop stdin,发送EOF stdin_closed = true; - + // ⭐⭐⭐⭐⭐ stdin关闭后,继续等待child处理完成 // 不要立即返回,给rsync时间处理数据并产生stdout info!("stdin closed, continuing to poll for stdout output..."); @@ -1225,7 +1432,7 @@ impl ChannelManager { } } } - + // 继续轮询(如果iteration < max_poll_iterations) } Err(e) => { @@ -1237,11 +1444,15 @@ impl ChannelManager { } } } - + // ⭐⭐⭐⭐⭐ 达到max_poll_iterations,检查最终child状态 if !found_data { - info!("No data after {} iterations ({} ms), checking child status", max_poll_iterations, max_poll_iterations * 10); - + info!( + "No data after {} iterations ({} ms), checking child status", + max_poll_iterations, + max_poll_iterations * 10 + ); + for channel_id in &channel_ids_vec { if let Some(channel) = self.channels.get_mut(channel_id) { if let Some(exec_process) = &mut channel.exec_process { @@ -1249,19 +1460,20 @@ impl ChannelManager { Ok(Some(status)) => { info!("Child exited after max iterations (status: {:?})", status); child_exited = true; - + // 读取剩余stdout if let Some(stdout) = &mut exec_process.stdout { let mut buffer = vec![0u8; 32768]; match stdout.read(&mut buffer) { Ok(n) if n > 0 => { - let packet = self.build_channel_data(*channel_id, &buffer[..n])?; + let packet = + self.build_channel_data(*channel_id, &buffer[..n])?; return Ok((Some(vec![packet]), false, true)); } _ => {} } } - + return Ok((None, false, true)); } Ok(None) => { @@ -1278,8 +1490,8 @@ impl ChannelManager { } } } - -// ⭐⭐⭐⭐⭐ 处理找到的数据(如果found_data) + + // ⭐⭐⭐⭐⭐ 处理找到的数据(如果found_data) // 3. 检查client fd状态(包括EOF/HUP) if let Some(revents) = poll_fds_vec[client_fd_idx].revents() { if revents.contains(PollFlags::POLLIN) { @@ -1289,11 +1501,11 @@ impl ChannelManager { info!("Client fd hangup (EOF received from client)"); // ⭐⭐⭐⭐⭐ Phase 14.2关键修复:关闭stdin pipe,发送EOF给child // 参考:OpenSSH session.c: do_exec_no_pty() stdin handling - for (_, channel) in &mut self.channels { + for channel in self.channels.values_mut() { if let Some(exec_process) = &mut channel.exec_process { if exec_process.stdin.is_some() { info!("Closing stdin pipe (sending EOF to child process)"); - exec_process.stdin = None; // Drop stdin,发送EOF给child + exec_process.stdin = None; // Drop stdin,发送EOF给child } } } @@ -1303,18 +1515,21 @@ impl ChannelManager { return Err(anyhow::anyhow!("Client socket error")); } } - + // 4. 检查stdout/stderr fd是否有数据 let mut packets_data: Vec<(u32, Vec)> = Vec::new(); - let mut stderr_packets: Vec<(u32, Vec)> = Vec::new(); // Phase 17: stderr → CHANNEL_EXTENDED_DATA - + let mut stderr_packets: Vec<(u32, Vec)> = Vec::new(); // Phase 17: stderr → CHANNEL_EXTENDED_DATA + for (channel_id, (stdout_idx, stderr_idx)) in channel_fds_map { if let Some(channel) = self.channels.get_mut(&channel_id) { if let Some(exec_process) = &mut channel.exec_process { // 检查stdout if let Some(revents) = poll_fds_vec[stdout_idx].revents() { if revents.contains(PollFlags::POLLIN) { - info!("⭐⭐⭐⭐⭐ [stdout POLLIN] stdout fd has data (channel {})", channel_id); + info!( + "⭐⭐⭐⭐⭐ [stdout POLLIN] stdout fd has data (channel {})", + channel_id + ); if let Some(stdout) = &mut exec_process.stdout { let mut buffer = vec![0u8; 32768]; info!("⭐⭐⭐⭐⭐ [BEFORE stdout.read] Attempting to read from stdout (buffer size 32KB)"); @@ -1324,20 +1539,23 @@ impl ChannelManager { packets_data.push((channel_id, buffer[..n].to_vec())); } Ok(0) => { - info!("stdout EOF (channel {}), closing stdout pipe", channel_id); + info!( + "stdout EOF (channel {}), closing stdout pipe", + channel_id + ); // ⭐⭐⭐⭐⭐ Critical修复:EOF时关闭pipe,避免无限循环 exec_process.stdout = None; } Err(e) if e.kind() != std::io::ErrorKind::WouldBlock => { warn!("stdout read error: {}", e); - exec_process.stdout = None; // 错误时也关闭 + exec_process.stdout = None; // 错误时也关闭 } _ => {} } } } } - + // 检查stderr if let Some(revents) = poll_fds_vec[stderr_idx].revents() { if revents.contains(PollFlags::POLLIN) { @@ -1348,18 +1566,24 @@ impl ChannelManager { match stderr.read(&mut buffer) { Ok(n) if n > 0 => { info!("⭐⭐⭐⭐⭐ [AFTER stderr.read] Read {} bytes from stderr (channel {})", n, channel_id); - info!("⭐⭐⭐⭐⭐ stderr content: {:?}", &buffer[..std::cmp::min(50, n)]); + info!( + "⭐⭐⭐⭐⭐ stderr content: {:?}", + &buffer[..std::cmp::min(50, n)] + ); // ⭐⭐⭐⭐⭐ Phase 17: stderr → SSH_MSG_CHANNEL_EXTENDED_DATA (data_type=1) stderr_packets.push((channel_id, buffer[..n].to_vec())); } Ok(0) => { - info!("stderr EOF (channel {}), closing stderr pipe", channel_id); + info!( + "stderr EOF (channel {}), closing stderr pipe", + channel_id + ); // ⭐⭐⭐⭐⭐ Critical修复:EOF时关闭pipe,避免无限循环 exec_process.stderr = None; } Err(e) if e.kind() != std::io::ErrorKind::WouldBlock => { warn!("stderr read error: {}", e); - exec_process.stderr = None; // 错误时也关闭 + exec_process.stderr = None; // 错误时也关闭 } _ => {} } @@ -1374,7 +1598,7 @@ impl ChannelManager { } } } - + // 构建packets if !packets_data.is_empty() || !stderr_packets.is_empty() { let mut packets = Vec::new(); @@ -1387,10 +1611,13 @@ impl ChannelManager { let packet = self.build_channel_extended_data(channel_id, 1, &data)?; packets.push(packet); } - info!("⭐⭐⭐⭐⭐ Returning {} packets (stdout/stderr data)", packets.len()); + info!( + "⭐⭐⭐⭐⭐ Returning {} packets (stdout/stderr data)", + packets.len() + ); return Ok((Some(packets), client_has_data, child_exited)); } - + // ⭐⭐⭐⭐⭐ Phase 14.2最终修复:stdout/stderr EOF后检查child exited // 当stdout和stderr都关闭后,强制检查child状态 for channel_id in &channel_ids_vec { @@ -1398,14 +1625,17 @@ impl ChannelManager { if let Some(exec_process) = &mut channel.exec_process { // 检查stdout和stderr是否都已关闭 if exec_process.stdout.is_none() && exec_process.stderr.is_none() { - info!("stdout/stderr both closed (channel {}), checking child status", channel_id); - + info!( + "stdout/stderr both closed (channel {}), checking child status", + channel_id + ); + // ⭐⭐⭐⭐⭐ 立即检查child是否exited match exec_process.child.try_wait() { Ok(Some(status)) => { info!("⭐⭐⭐⭐⭐ Child exited after stdout/stderr EOF (status: {:?})", status); child_exited = true; - + // ⭐⭐⭐⭐⭐ 关键:立即返回child_exited标志 // server.rs会发送SSH_MSG_CHANNEL_EOF + CLOSE info!("⭐⭐⭐⭐⭐ No packets to send, returning child_exited flag"); @@ -1424,71 +1654,75 @@ impl ChannelManager { } } } - + // 有数据但只有client数据 Ok((None, client_has_data, child_exited)) } - + // ⭐⭐⭐⭐⭐ Phase 14.0: 旧版poll(仅监听stdout/stderr,已废弃) /// 已废弃:使用poll_exec_stdout_and_client()替代 #[allow(dead_code)] pub fn poll_exec_stdout_with_fds(&mut self) -> Result>> { use std::io::Read; use std::os::unix::io::BorrowedFd; - + // 遍历所有channel,收集poll_fds let mut poll_fds_vec = Vec::new(); - let mut channel_fds_map: HashMap = HashMap::new(); // channel_id -> (stdout_idx, stderr_idx) in poll_fds_vec - + let mut channel_fds_map: HashMap = HashMap::new(); // channel_id -> (stdout_idx, stderr_idx) in poll_fds_vec + for (channel_id, channel) in &self.channels { if let Some(exec_process) = &channel.exec_process { // ⭐⭐⭐⭐⭐ OpenSSH风格:创建PollFd监听stdout/stderr // nix 0.29 API: PollFd::new()需要借用fd,不是RawFd - if let Some(stdout) = &exec_process.stdout { + if let Some(_stdout) = &exec_process.stdout { let stdout_poll_fd = unsafe { // ⭐⭐⭐⭐⭐ 使用BorrowedFd::borrow_raw()(正确API) BorrowedFd::borrow_raw(exec_process.stdout_fd) }; poll_fds_vec.push(PollFd::new(stdout_poll_fd, PollFlags::POLLIN)); } - - if let Some(stderr) = &exec_process.stderr { - let stderr_poll_fd = unsafe { - BorrowedFd::borrow_raw(exec_process.stderr_fd) - }; + + if let Some(_stderr) = &exec_process.stderr { + let stderr_poll_fd = unsafe { BorrowedFd::borrow_raw(exec_process.stderr_fd) }; poll_fds_vec.push(PollFd::new(stderr_poll_fd, PollFlags::POLLIN)); } - + // 记录poll_fds_vec中的索引 let stdout_idx = poll_fds_vec.len() - 2; let stderr_idx = poll_fds_vec.len() - 1; channel_fds_map.insert(*channel_id, (stdout_idx, stderr_idx)); } } - + if poll_fds_vec.is_empty() { - return Ok(None); // 没有exec_process + return Ok(None); // 没有exec_process } - + // ⭐⭐⭐⭐⭐ OpenSSH关键:使用poll监听所有fd // ⭐⭐⭐⭐⭐ 持续poll机制:最多轮询1000次(给大文件传输足够时间) // 大文件传输需要很长时间,增加轮询次数到1000次(总共10秒) let max_poll_attempts = 1000; let mut poll_attempt = 0; let mut found_data = false; - + for attempt in 0..max_poll_attempts { poll_attempt = attempt; // 每100次轮询记录一次日志(减少日志噪音) if attempt % 100 == 0 { - info!("Polling {} fds (OpenSSH style, timeout 10ms, attempt {} of {})", poll_fds_vec.len(), attempt, max_poll_attempts); + info!( + "Polling {} fds (OpenSSH style, timeout 10ms, attempt {} of {})", + poll_fds_vec.len(), + attempt, + max_poll_attempts + ); } - match poll(&mut poll_fds_vec, 10u16) { // timeout 10ms + match poll(&mut poll_fds_vec, 10u16) { + // timeout 10ms Ok(n) => { if n > 0 { info!("{} fds have data available (attempt {})", n, attempt); found_data = true; - break; // 有数据,立即处理 + break; // 有数据,立即处理 } // 没有数据,继续轮询(最多1000次) } @@ -1498,15 +1732,19 @@ impl ChannelManager { } } } - + if !found_data { - info!("No data available after {} poll attempts ({} ms), returning None", max_poll_attempts, max_poll_attempts * 10); - return Ok(None); // 轮询1000次后仍无数据,主循环继续处理client packet + info!( + "No data available after {} poll attempts ({} ms), returning None", + max_poll_attempts, + max_poll_attempts * 10 + ); + return Ok(None); // 轮询1000次后仍无数据,主循环继续处理client packet } - + // ⭐⭐⭐⭐⭐ OpenSSH风格:根据revents判断哪个fd有数据,立即读取 - let mut packets_data: Vec<(u32, Vec)> = Vec::new(); // (channel_id, data) - + let mut packets_data: Vec<(u32, Vec)> = Vec::new(); // (channel_id, data) + for (channel_id, (stdout_idx, stderr_idx)) in channel_fds_map { if let Some(channel) = self.channels.get_mut(&channel_id) { if let Some(exec_process) = &mut channel.exec_process { @@ -1520,7 +1758,10 @@ impl ChannelManager { match stdout.read(&mut buffer) { Ok(n) => { if n > 0 { - info!("Read {} bytes from stdout (channel {})", n, channel_id); + info!( + "Read {} bytes from stdout (channel {})", + n, channel_id + ); packets_data.push((channel_id, buffer[..n].to_vec())); } else { info!("stdout EOF (channel {})", channel_id); @@ -1536,7 +1777,7 @@ impl ChannelManager { } } } - + // 检查stderr是否有数据(类似处理) if let Some(revents) = poll_fds_vec[stderr_idx].revents() { if revents.contains(PollFlags::POLLIN) { @@ -1546,7 +1787,10 @@ impl ChannelManager { match stderr.read(&mut buffer) { Ok(n) => { if n > 0 { - info!("Read {} bytes from stderr (channel {})", n, channel_id); + info!( + "Read {} bytes from stderr (channel {})", + n, channel_id + ); packets_data.push((channel_id, buffer[..n].to_vec())); } else { info!("stderr EOF (channel {})", channel_id); @@ -1565,14 +1809,14 @@ impl ChannelManager { } } } - + // ⭐⭐⭐⭐⭐ 释放mutable borrow后,构建packets(避免borrow冲突) let mut packets = Vec::new(); for (channel_id, data) in packets_data { let packet = self.build_channel_data(channel_id, &data)?; packets.push(packet); } - + if packets.is_empty() { Ok(None) } else { @@ -1586,32 +1830,32 @@ struct Channel { server_channel: u32, sender_channel: u32, channel_type: String, - + // ⭐⭐⭐⭐⭐ Phase 15: Window Control(参考OpenSSH channels.h:176-182) - remote_window: u32, // 远端窗口大小(OpenSSH: c->remote_window) - remote_maxpacket: u32, // 远端最大 packet(OpenSSH: c->remote_maxpacket) - local_window: u32, // 本地窗口大小(OpenSSH: c->local_window) - local_window_max: u32, // 本地窗口最大值(OpenSSH: c->local_window_max) - local_consumed: u32, // 本地已消费的数据(OpenSSH: c->local_consumed)⭐⭐⭐⭐⭐ 关键! - local_maxpacket: u32, // 本地最大 packet(OpenSSH: c->local_maxpacket) - + remote_window: u32, // 远端窗口大小(OpenSSH: c->remote_window) + remote_maxpacket: u32, // 远端最大 packet(OpenSSH: c->remote_maxpacket) + local_window: u32, // 本地窗口大小(OpenSSH: c->local_window) + local_window_max: u32, // 本地窗口最大值(OpenSSH: c->local_window_max) + local_consumed: u32, // 本地已消费的数据(OpenSSH: c->local_consumed)⭐⭐⭐⭐⭐ 关键! + local_maxpacket: u32, // 本地最大 packet(OpenSSH: c->local_maxpacket) + // 旧字段(保留兼容) - window_size: u32, // 当前窗口大小(兼容旧代码) - maximum_packet_size: u32, // 最大 packet 大小(兼容旧代码) - + window_size: u32, // 当前窗口大小(兼容旧代码) + maximum_packet_size: u32, // 最大 packet 大小(兼容旧代码) + state: ChannelState, - output_buffer: Option>, // Phase 6: 命令输出缓冲 - sftp_handler: Option, // Phase 7: SFTP处理器 - scp_handler: Option, // Phase 8: SCP处理器 - rsync_handler: Option, // Phase 8: rsync处理器 - exec_process: Option, // Phase 14: 交互式exec进程 + output_buffer: Option>, // Phase 6: 命令输出缓冲 + sftp_handler: Option, // Phase 7: SFTP处理器 + scp_handler: Option, // Phase 8: SCP处理器 + rsync_handler: Option, // Phase 8: rsync处理器 + exec_process: Option, // Phase 14: 交互式exec进程 // ⭐⭐⭐⭐⭐ Critical修复:SFTP packet累积buffer - sftp_input_buffer: Vec, // Phase 14.2修复:累积不完整的SFTP packets + sftp_input_buffer: Vec, // Phase 14.2修复:累积不完整的SFTP packets // ⭐⭐⭐⭐⭐ Phase 14.4:SCP packet累积buffer - scp_input_buffer: Vec, // Phase 14.4修复:累积不完整的SCP packets + scp_input_buffer: Vec, // Phase 14.4修复:累积不完整的SCP packets // Phase 13.3: 端口转发相关字段 - direct_tcpip: Option, // direct-tcpip channel(Remote forwarding) - forwarded_tcpip: Option, // forwarded-tcpip channel(Local forwarding) + direct_tcpip: Option, // direct-tcpip channel(Remote forwarding) + forwarded_tcpip: Option, // forwarded-tcpip channel(Local forwarding) } /// SSH Channel状态(参考OpenSSH channel.c) @@ -1630,7 +1874,7 @@ fn read_ssh_string(reader: &mut R) -> Result { } /// ⭐⭐⭐⭐⭐ Phase 15: 检查并发送 Window Adjust(参考OpenSSH channels.c:2425-2450) -/// +/// /// OpenSSH 实现: /// ```c /// static int channel_check_window(struct ssh *ssh, Channel *c) { @@ -1651,39 +1895,41 @@ fn read_ssh_string(reader: &mut R) -> Result { /// } /// } /// ``` -pub fn channel_check_window(channel_id: u32, channels: &mut HashMap) -> Option { +pub fn channel_check_window( + channel_id: u32, + channels: &mut HashMap, +) -> Option { if let Some(channel) = channels.get_mut(&channel_id) { // 检查窗口调整条件 let window_used = channel.local_window_max - channel.local_window; - let need_adjust = (window_used > channel.local_maxpacket * 3) || - (channel.local_window < channel.local_window_max / 2); - + let need_adjust = (window_used > channel.local_maxpacket * 3) + || (channel.local_window < channel.local_window_max / 2); + if need_adjust && channel.local_consumed > 0 { info!("⭐⭐⭐⭐⭐ [WINDOW_ADJUST] channel {} needs adjust: window_used={}, local_consumed={}", channel_id, window_used, channel.local_consumed); - + // 发送 SSH_MSG_CHANNEL_WINDOW_ADJUST - let adjust_packet = build_window_adjust( - channel.server_channel, - channel.local_consumed - ); - + let adjust_packet = build_window_adjust(channel.server_channel, channel.local_consumed); + // 更新窗口大小 channel.local_window += channel.local_consumed; channel.local_consumed = 0; - - info!("⭐⭐⭐⭐⭐ [WINDOW_UPDATED] channel {} new window: {}", - channel_id, channel.local_window); - + + info!( + "⭐⭐⭐⭐⭐ [WINDOW_UPDATED] channel {} new window: {}", + channel_id, channel.local_window + ); + return Some(adjust_packet); } } - + None } /// ⭐⭐⭐⭐⭐ Phase 15: 构建 SSH_MSG_CHANNEL_WINDOW_ADJUST packet -/// +/// /// OpenSSH packet format: /// ```c /// SSH2_MSG_CHANNEL_WINDOW_ADJUST (93) @@ -1692,19 +1938,21 @@ pub fn channel_check_window(channel_id: u32, channels: &mut HashMap SshPacket { let mut payload = Vec::new(); - + // Packet type payload.push(PacketType::SSH_MSG_CHANNEL_WINDOW_ADJUST as u8); - + // recipient_channel (u32) payload.write_u32::(recipient_channel).unwrap(); - + // bytes_to_add (u32) payload.write_u32::(bytes_to_add).unwrap(); - - info!("⭐⭐⭐⭐⭐ [BUILD_WINDOW_ADJUST] recipient_channel={}, bytes_to_add={}", - recipient_channel, bytes_to_add); - + + info!( + "⭐⭐⭐⭐⭐ [BUILD_WINDOW_ADJUST] recipient_channel={}, bytes_to_add={}", + recipient_channel, bytes_to_add + ); + SshPacket { packet_length: 0, padding_length: 0, @@ -1716,26 +1964,31 @@ fn build_window_adjust(recipient_channel: u32, bytes_to_add: u32) -> SshPacket { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_channel_manager_creation() { let manager = ChannelManager::new(PathBuf::from("/tmp")); assert_eq!(manager.next_channel_id, 0); } - + #[test] fn test_channel_open_confirmation() { let manager = ChannelManager::new(PathBuf::from("/tmp")); - let packet = manager.build_channel_open_confirmation(0, 100, 2097152, 32768).unwrap(); - - assert_eq!(packet.payload[0], PacketType::SSH_MSG_CHANNEL_OPEN_CONFIRMATION as u8); + let packet = manager + .build_channel_open_confirmation(0, 100, 2097152, 32768) + .unwrap(); + + assert_eq!( + packet.payload[0], + PacketType::SSH_MSG_CHANNEL_OPEN_CONFIRMATION as u8 + ); } - + #[test] fn test_channel_success() { let manager = ChannelManager::new(PathBuf::from("/tmp")); let packet = manager.build_channel_success(0).unwrap(); - + assert_eq!(packet.payload[0], PacketType::SSH_MSG_CHANNEL_SUCCESS as u8); } } diff --git a/markbase-core/src/ssh_server/cipher.rs b/markbase-core/src/ssh_server/cipher.rs index 8d5e233..628c79c 100644 --- a/markbase-core/src/ssh_server/cipher.rs +++ b/markbase-core/src/ssh_server/cipher.rs @@ -1,33 +1,33 @@ // SSH加密通道实现(Phase 4) // 参考OpenSSH cipher.c, mac.c -use aes::Aes128; // 改为AES-128(协商算法是aes128-ctr) +use super::crypto::SessionKeys; +use aes::Aes128; // 改为AES-128(协商算法是aes128-ctr) +use anyhow::{anyhow, Result}; +use byteorder::{BigEndian, WriteBytesExt}; +use cipher::{KeyIvInit, StreamCipher}; use ctr::Ctr128BE; use hmac::{Hmac, Mac}; +use log::info; use sha2::Sha256; -use cipher::{KeyIvInit, StreamCipher}; use std::io::Write; -use anyhow::{Result, anyhow}; -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use log::{info, debug, warn}; -use super::crypto::SessionKeys; -type Aes128Ctr = Ctr128BE; // AES-128-CTR(16字节密钥) +type Aes128Ctr = Ctr128BE; // AES-128-CTR(16字节密钥) type HmacSha256 = Hmac; /// SSH加密通道管理器(参考OpenSSH struct sshcipher_ctx) pub struct EncryptionContext { pub session_id: Vec, // session identifier (exchange hash) - pub encryption_key_ctos: Vec, // 客户端→服务器加密密钥 - pub encryption_key_stoc: Vec, // 服务器→客户端加密密钥 - pub mac_key_ctos: Vec, // 客户端→服务器MAC密钥 - pub mac_key_stoc: Vec, // 服务器→客户端MAC密钥 - pub iv_ctos: Vec, // 客户端→服务器IV - pub iv_stoc: Vec, // 服务器→客户端IV - pub sequence_number_ctos: u32, // 客户端→服务器序列号 - pub sequence_number_stoc: u32, // 服务器→客户端序列号 - pub cipher_ctos: Option, // 客户端→服务器cipher实例(持久化) - pub cipher_stoc: Option, // 服务器→客户端cipher实例(持久化) + pub encryption_key_ctos: Vec, // 客户端→服务器加密密钥 + pub encryption_key_stoc: Vec, // 服务器→客户端加密密钥 + pub mac_key_ctos: Vec, // 客户端→服务器MAC密钥 + pub mac_key_stoc: Vec, // 服务器→客户端MAC密钥 + pub iv_ctos: Vec, // 客户端→服务器IV + pub iv_stoc: Vec, // 服务器→客户端IV + pub sequence_number_ctos: u32, // 客户端→服务器序列号 + pub sequence_number_stoc: u32, // 服务器→客户端序列号 + pub cipher_ctos: Option, // 客户端→服务器cipher实例(持久化) + pub cipher_stoc: Option, // 服务器→客户端cipher实例(持久化) } impl Default for EncryptionContext { @@ -53,27 +53,33 @@ impl EncryptionContext { /// OpenSSH cipher.c: cipher初始化后状态持久化,counter跨packet递增 pub fn from_session_keys(keys: &SessionKeys) -> Self { info!("Initializing ciphers with session keys:"); - info!(" encryption_key_ctos (16 bytes): {:?}", &keys.encryption_key_ctos[..16]); + info!( + " encryption_key_ctos (16 bytes): {:?}", + &keys.encryption_key_ctos[..16] + ); info!(" iv_ctos (16 bytes): {:?}", &keys.iv_ctos[..16]); - info!(" encryption_key_stoc (16 bytes): {:?}", &keys.encryption_key_stoc[..16]); + info!( + " encryption_key_stoc (16 bytes): {:?}", + &keys.encryption_key_stoc[..16] + ); info!(" iv_stoc (16 bytes): {:?}", &keys.iv_stoc[..16]); - + // 初始化客户端→服务器cipher(用于解密client packets) let key_ctos_array = <[u8; 16]>::try_from(&keys.encryption_key_ctos[..16]) .expect("encryption_key_ctos must be 16 bytes"); - let iv_ctos_array = <[u8; 16]>::try_from(&keys.iv_ctos[..16]) - .expect("iv_ctos must be 16 bytes"); + let iv_ctos_array = + <[u8; 16]>::try_from(&keys.iv_ctos[..16]).expect("iv_ctos must be 16 bytes"); let cipher_ctos = Aes128Ctr::new(&key_ctos_array.into(), &iv_ctos_array.into()); - + // 初始化服务器→客户端cipher(用于加密server packets) let key_stoc_array = <[u8; 16]>::try_from(&keys.encryption_key_stoc[..16]) .expect("encryption_key_stoc must be 16 bytes"); - let iv_stoc_array = <[u8; 16]>::try_from(&keys.iv_stoc[..16]) - .expect("iv_stoc must be 16 bytes"); + let iv_stoc_array = + <[u8; 16]>::try_from(&keys.iv_stoc[..16]).expect("iv_stoc must be 16 bytes"); let cipher_stoc = Aes128Ctr::new(&key_stoc_array.into(), &iv_stoc_array.into()); - + info!("Ciphers initialized successfully"); - + Self { session_id: keys.session_id.clone(), encryption_key_ctos: keys.encryption_key_ctos.clone(), @@ -84,26 +90,26 @@ impl EncryptionContext { iv_stoc: keys.iv_stoc.clone(), sequence_number_ctos: 0, sequence_number_stoc: 0, - cipher_ctos: Some(cipher_ctos), // 持久化cipher实例 - cipher_stoc: Some(cipher_stoc), // 持久化cipher实例 + cipher_ctos: Some(cipher_ctos), // 持久化cipher实例 + cipher_stoc: Some(cipher_stoc), // 持久化cipher实例 } } - + /// RFC 4344: Compute AES-CTR IV for a specific packet /// IV = nonce(8 bytes from derived IV) + sequence_number(8 bytes) fn compute_ctr_iv(nonce: &[u8], sequence_number: u32) -> Vec { let mut iv = Vec::with_capacity(16); - + // Nonce: first 8 bytes of derived IV (constant) iv.extend_from_slice(&nonce[..8]); - + // Counter: sequence number as 8-byte big-endian iv.extend_from_slice(&sequence_number.to_be_bytes()); iv.extend_from_slice(&[0u8; 4]); // Upper 4 bytes = 0 - + iv } - + /// 加密packet(参考OpenSSH cipher.c: cipher_encrypt()) pub fn encrypt_packet( &mut self, @@ -113,17 +119,17 @@ impl EncryptionContext { ) -> Result> { let key_array = <[u8; 16]>::try_from(encryption_key)?; let iv_array = <[u8; 16]>::try_from(iv)?; - + let mut cipher = Aes128Ctr::new(&key_array.into(), &iv_array.into()); - + let mut ciphertext = plaintext.to_vec(); cipher.apply_keystream(&mut ciphertext); - + self.sequence_number_stoc += 1; - + Ok(ciphertext) } - + /// 解密packet(参考OpenSSH cipher.c: cipher_decrypt()) pub fn decrypt_packet( &mut self, @@ -133,17 +139,17 @@ impl EncryptionContext { ) -> Result> { let key_array = <[u8; 16]>::try_from(encryption_key)?; let iv_array = <[u8; 16]>::try_from(iv)?; - + let mut cipher = Aes128Ctr::new(&key_array.into(), &iv_array.into()); - + let mut plaintext = ciphertext.to_vec(); cipher.apply_keystream(&mut plaintext); - + self.sequence_number_ctos += 1; - + Ok(plaintext) } - + /// 计算MAC(参考OpenSSH mac.c: mac_compute()) pub fn compute_mac( &self, @@ -152,17 +158,17 @@ impl EncryptionContext { mac_key: &[u8], ) -> Result> { // HMAC-SHA256 MAC计算(参考OpenSSH mac.c) - + let mut mac = HmacSha256::new_from_slice(mac_key)?; - + // OpenSSH MAC格式:sequence_number + data mac.update(&sequence_number.to_be_bytes()); mac.update(data); - + let result = mac.finalize(); Ok(result.into_bytes().to_vec()) } - + /// 验证MAC(参考OpenSSH mac.c: mac_check()) pub fn verify_mac( &self, @@ -172,14 +178,14 @@ impl EncryptionContext { mac_key: &[u8], ) -> Result { // HMAC验证(参考OpenSSH mac.c) - + let computed_mac = self.compute_mac(sequence_number, data, mac_key)?; - + // 防止时间攻击(使用常量时间比较) if computed_mac.len() != expected_mac.len() { return Ok(false); } - + // 简化实现:直接比较(实际应使用常量时间比较) Ok(computed_mac == expected_mac) } @@ -187,11 +193,11 @@ impl EncryptionContext { /// SSH加密packet封装(参考OpenSSH packet.c: ssh_packet_write_poll()) pub struct EncryptedPacket { - pub packet_length: u32, // 加密后packet长度 - pub padding_length: u8, // padding长度(加密后) - pub payload: Vec, // payload(加密后) - pub padding: Vec, // padding(加密后) - pub mac: Vec, // MAC(32字节,HMAC-SHA256) + pub packet_length: u32, // 加密后packet长度 + pub padding_length: u8, // padding长度(加密后) + pub payload: Vec, // payload(加密后) + pub padding: Vec, // padding(加密后) + pub mac: Vec, // MAC(32字节,HMAC-SHA256) } impl EncryptedPacket { @@ -204,82 +210,88 @@ impl EncryptedPacket { ) -> Result { let block_size = 16; let min_padding = 4; - + let payload_length = plaintext_payload.len(); - + // RFC 4253: entire plaintext packet (including 4-byte packet_length field) must be multiple of block_size // plaintext_packet = packet_length_field(4) + padding_length(1) + payload + padding // So: (4 + 1 + payload_length + padding_length) % 16 == 0 - - let base_size = 4 + 1 + payload_length; // without padding + + let base_size = 4 + 1 + payload_length; // without padding let padding_needed = (block_size - (base_size % block_size)) % block_size; - + // Ensure padding >= min_padding (RFC 4253 requirement) let padding_length: u8 = if padding_needed < min_padding { - (padding_needed + block_size) as u8 // Add one more block to meet minimum + (padding_needed + block_size) as u8 // Add one more block to meet minimum } else { padding_needed as u8 }; - + // packet_length = padding_length(1) + payload + padding let packet_length = 1 + payload_length + padding_length as usize; - - info!("Creating AES-CTR encrypted packet: payload_len={}, padding_len={}, packet_len={}", - payload_length, padding_length, packet_length); - + + info!( + "Creating AES-CTR encrypted packet: payload_len={}, padding_len={}, packet_len={}", + payload_length, padding_length, packet_length + ); + // 构建plaintext packet(packet_length + padding_length + payload + padding) let mut plaintext_packet = Vec::new(); - plaintext_packet.write_u32::(packet_length as u32)?; // plaintext packet_length - plaintext_packet.write_u8(padding_length)?; // plaintext padding_length - plaintext_packet.write_all(plaintext_payload)?; // plaintext payload - + plaintext_packet.write_u32::(packet_length as u32)?; // plaintext packet_length + plaintext_packet.write_u8(padding_length)?; // plaintext padding_length + plaintext_packet.write_all(plaintext_payload)?; // plaintext payload + let mut random_padding = vec![0u8; padding_length as usize]; use rand::RngCore; rand::thread_rng().fill_bytes(&mut random_padding); - plaintext_packet.write_all(&random_padding)?; // plaintext padding - + plaintext_packet.write_all(&random_padding)?; // plaintext padding + info!("Plaintext packet size: {} bytes", plaintext_packet.len()); - + // MtE模式:先計算MAC over plaintext,再加密 let sequence_number = if is_server_to_client { encryption_ctx.sequence_number_stoc } else { encryption_ctx.sequence_number_ctos }; - + let mac_key = if is_server_to_client { &encryption_ctx.mac_key_stoc } else { &encryption_ctx.mac_key_ctos }; - + info!("MAC calculation (MtE mode) over plaintext packet:"); info!(" sequence_number: {}", sequence_number); info!(" mac_key length: {}", mac_key.len()); info!(" plaintext_packet length: {}", plaintext_packet.len()); - + // MAC計算:HMAC(sequence_number || plaintext_packet) let mac = encryption_ctx.compute_mac(sequence_number, &plaintext_packet, mac_key)?; - + // 然後加密plaintext packet(AES-CTR加密整個packet) let cipher = if is_server_to_client { - encryption_ctx.cipher_stoc.as_mut() + encryption_ctx + .cipher_stoc + .as_mut() .ok_or_else(|| anyhow!("cipher_stoc not initialized"))? } else { - encryption_ctx.cipher_ctos.as_mut() + encryption_ctx + .cipher_ctos + .as_mut() .ok_or_else(|| anyhow!("cipher_ctos not initialized"))? }; - + let mut encrypted_packet = plaintext_packet; cipher.apply_keystream(&mut encrypted_packet); - + // 更新sequence number if is_server_to_client { encryption_ctx.sequence_number_stoc += 1; } else { encryption_ctx.sequence_number_ctos += 1; } - + Ok(Self { packet_length: packet_length as u32, padding_length, @@ -288,24 +300,27 @@ impl EncryptedPacket { mac, }) } - + /// 写入加密packet(参考OpenSSH cipher.c) /// AES-CTR模式:写入完整加密packet + MAC pub fn write(&self, stream: &mut W) -> Result<()> { - info!("Writing AES-CTR encrypted packet: total_encrypted_len={}, mac_len={}", - self.payload.len(), self.mac.len()); - + info!( + "Writing AES-CTR encrypted packet: total_encrypted_len={}, mac_len={}", + self.payload.len(), + self.mac.len() + ); + // AES-CTR: 整个packet已加密(包括packet_length),直接写入 stream.write_all(&self.payload)?; info!("Wrote encrypted packet ({} bytes)", self.payload.len()); - + // 写入MAC stream.write_all(&self.mac)?; info!("Wrote MAC ({} bytes)", self.mac.len()); - + Ok(()) } - + /// 读取加密packet(参考OpenSSH packet.c ssh_packet_read_poll2) /// OpenSSH packet.c: AES-CTR先解密第一个块,再提取packet_length /// aadlen = 0 (没有EtM或authenticated encryption), packet_length被加密 @@ -315,32 +330,42 @@ impl EncryptedPacket { is_client_to_server: bool, ) -> Result { use std::io::Read; - + info!("Reading AES-CTR encrypted packet (packet_length encrypted)"); - + // 1. 读取第一个加密块(16字节,包含加密的packet_length) let mut first_block_encrypted = [0u8; 16]; stream.read_exact(&mut first_block_encrypted)?; - - info!("Read first encrypted block (16 bytes): {:?}", &first_block_encrypted); - + + info!( + "Read first encrypted block (16 bytes): {:?}", + &first_block_encrypted + ); + // 2. 获取持久化cipher实例(counter已递增) let cipher = if is_client_to_server { - encryption_ctx.cipher_ctos.as_mut() + encryption_ctx + .cipher_ctos + .as_mut() .ok_or_else(|| anyhow!("cipher_ctos not initialized"))? } else { - encryption_ctx.cipher_stoc.as_mut() + encryption_ctx + .cipher_stoc + .as_mut() .ok_or_else(|| anyhow!("cipher_stoc not initialized"))? }; - - info!("Using cipher for decryption (is_client_to_server={})", is_client_to_server); - + + info!( + "Using cipher for decryption (is_client_to_server={})", + is_client_to_server + ); + // 3. 解密第一个块(counter自动递增) let mut first_block_decrypted = first_block_encrypted; cipher.apply_keystream(&mut first_block_decrypted); - + info!("Decrypted first block: {:?}", &first_block_decrypted); - + // 3. 从解密后的数据中提取packet_length(前4字节)和padding_length(第5字节) let packet_length = u32::from_be_bytes([ first_block_decrypted[0], @@ -349,67 +374,73 @@ impl EncryptedPacket { first_block_decrypted[3], ]); let padding_length = first_block_decrypted[4]; - - info!("Decrypted packet_length={}, padding_length={}", packet_length, padding_length); - + + info!( + "Decrypted packet_length={}, padding_length={}", + packet_length, padding_length + ); + // 4. 合理性检查 if packet_length > 35000 { info!("packet_length raw bytes: {:?}", &first_block_decrypted[..4]); return Err(anyhow!("Invalid packet_length: {}", packet_length)); } - + // 3. 计算剩余加密数据长度 // packet_length = padding_length(1) + payload + padding // 总加密数据 = packet_length(4) + packet_length = packet_length + 4 // 已读取16字节,剩余 = packet_length + 4 - 16 - let total_encrypted_size = packet_length as usize + 4; // packet_length field + content + let total_encrypted_size = packet_length as usize + 4; // packet_length field + content let remaining_encrypted_size = total_encrypted_size - 16; - - info!("Total encrypted size: {}, remaining: {}", total_encrypted_size, remaining_encrypted_size); - + + info!( + "Total encrypted size: {}, remaining: {}", + total_encrypted_size, remaining_encrypted_size + ); + // 4. 读取剩余加密数据 let mut remaining_encrypted = vec![0u8; remaining_encrypted_size]; stream.read_exact(&mut remaining_encrypted)?; - + // 5. 继续解密(使用同一个cipher) cipher.apply_keystream(&mut remaining_encrypted); - + info!("Remaining decrypted data: {:?}", &remaining_encrypted); - + // 6. 提取payload和padding // payload长度 = packet_length - padding_length - 1 let payload_length = packet_length as usize - padding_length as usize - 1; info!("Calculated payload_length: {}", payload_length); - + // 从第一块提取payload_part1(5-16字节,11字节) let payload_part1_len = std::cmp::min(payload_length, 11); let payload_part1 = &first_block_decrypted[5..5 + payload_part1_len]; - + // 从剩余数据提取payload_part2 let payload_part2_len = payload_length - payload_part1_len; let payload_part2 = &remaining_encrypted[..payload_part2_len]; - + // 合并payload let mut payload = Vec::new(); payload.extend_from_slice(payload_part1); payload.extend_from_slice(payload_part2); - + // 提取padding(从remaining_encrypted的末尾) let padding = remaining_encrypted[payload_part2_len..].to_vec(); - + // 9. 读取MAC info!("Reading MAC (32 bytes)..."); let mut mac = vec![0u8; 32]; stream.read_exact(&mut mac)?; info!("MAC read successfully"); - + // 10. 更新sequence number if is_client_to_server { encryption_ctx.sequence_number_ctos += 1; } else { encryption_ctx.sequence_number_stoc += 1; } - + Ok(Self { packet_length, padding_length, @@ -418,7 +449,7 @@ impl EncryptedPacket { mac, }) } - + /// 获取payload内容 pub fn payload(&self) -> &[u8] { &self.payload @@ -428,13 +459,13 @@ impl EncryptedPacket { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_aes256_ctr_encryption() { - let key = vec![0u8; 16]; // AES-128 key (16 bytes) + let key = vec![0u8; 16]; // AES-128 key (16 bytes) let iv = vec![0u8; 16]; let plaintext = b"Hello World"; - + let mut ctx = EncryptionContext::from_session_keys(&SessionKeys { session_id: vec![0u8; 32], encryption_key_ctos: key.clone(), @@ -444,18 +475,18 @@ mod tests { iv_ctos: iv.clone(), iv_stoc: iv.clone(), }); - + let ciphertext = ctx.encrypt_packet(plaintext, &key, &iv).unwrap(); let decrypted = ctx.decrypt_packet(&ciphertext, &key, &iv).unwrap(); - + assert_eq!(plaintext.to_vec(), decrypted); } - + #[test] fn test_hmac_sha256() { let key = vec![0u8; 32]; let data = b"test data"; - + let ctx = EncryptionContext::from_session_keys(&SessionKeys { session_id: vec![0u8; 32], encryption_key_ctos: vec![0u8; 32], @@ -465,10 +496,10 @@ mod tests { iv_ctos: vec![0u8; 16], iv_stoc: vec![0u8; 16], }); - + let mac = ctx.compute_mac(1, data, &key).unwrap(); - assert_eq!(mac.len(), 32); // HMAC-SHA256 = 32字节 - + assert_eq!(mac.len(), 32); // HMAC-SHA256 = 32字节 + // 验证MAC assert!(ctx.verify_mac(1, data, &mac, &key).unwrap()); } diff --git a/markbase-core/src/ssh_server/crypto.rs b/markbase-core/src/ssh_server/crypto.rs index 20a80b0..4aebe9a 100644 --- a/markbase-core/src/ssh_server/crypto.rs +++ b/markbase-core/src/ssh_server/crypto.rs @@ -1,16 +1,16 @@ // SSH加密模块(Phase 3:密钥交换) // 参考OpenSSH curve25519.c, kex.c -use anyhow::{Result, anyhow}; -use x25519_dalek::{EphemeralSecret, PublicKey, SharedSecret}; -use ed25519_dalek::{SigningKey, VerifyingKey, Signature, Signer}; -use sha2::{Sha256, Digest}; -use log::{info, debug}; +use anyhow::{anyhow, Result}; +use ed25519_dalek::{Signer, SigningKey}; +use log::info; use rand::rngs::OsRng; +use sha2::{Digest, Sha256}; +use x25519_dalek::{EphemeralSecret, PublicKey}; /// Curve25519密钥交换处理器(参考OpenSSH curve25519.c) pub struct Curve25519Kex { - secret: Option, // 使用Option包装(一次性使用类型) + secret: Option, // 使用Option包装(一次性使用类型) public: PublicKey, } @@ -21,34 +21,37 @@ impl Curve25519Kex { // x25519-dalek 2.0标准API:使用random_from_rng let secret = EphemeralSecret::random_from_rng(OsRng); let public = PublicKey::from(&secret); - - Self { secret: Some(secret), public } // Some包装 + + Self { + secret: Some(secret), + public, + } // Some包装 } - + /// 获取公钥(用于SSH_MSG_KEX_ECDH_INIT) pub fn public_key(&self) -> &[u8] { self.public.as_bytes() } - + /// 计算共享密钥(参考OpenSSH curve25519_shared_secret()) /// 使用&mut self(消耗模式,符合OpenSSH设计) pub fn compute_shared_secret(&mut self, client_public: &[u8]) -> Result<[u8; 32]> { if client_public.len() != 32 { return Err(anyhow!("Invalid client public key length")); } - + info!("=== X25519 Shared Secret Calculation ==="); info!("Client public key input: {:?}", client_public); info!("Server public key: {:?}", self.public.as_bytes()); - + // 参考OpenSSH:curve25519共享密钥计算 let client_public_key = PublicKey::from(<[u8; 32]>::try_from(client_public)?); - + // 使用take()取出secret(Rust标准模式) if let Some(secret) = self.secret.take() { let shared_secret = secret.diffie_hellman(&client_public_key); info!("Computed shared secret: {:?}", shared_secret.as_bytes()); - Ok(shared_secret.as_bytes().clone()) + Ok(*shared_secret.as_bytes()) } else { Err(anyhow!("Secret already used")) } @@ -71,47 +74,85 @@ impl SessionKeys { /// RFC 4253 Section 7.2: Key = HASH(K || H || X || session_id) pub fn derive( shared_secret: &[u8], - exchange_hash: &[u8], // H参数(exchange hash) - server_public_key: &[u8], - client_public_key: &[u8], - server_host_key: &[u8], + exchange_hash: &[u8], // H参数(exchange hash) + _server_public_key: &[u8], + _client_public_key: &[u8], + _server_host_key: &[u8], ) -> Result { // RFC 4253: session_id = H (第一次exchange hash) let session_id = exchange_hash.to_vec(); - + info!("SessionKeys::derive() starting"); info!(" shared_secret full (32 bytes): {:?}", shared_secret); - + // RFC 8731 Section 3.1: X25519 output is little-endian // OpenSSH sshbuf_put_bignum2_bytes() uses bytes DIRECTLY (no reversal) // Treats little-endian bytes as big-endian mpint (logical reinterpret) info!(" Using shared_secret directly (little-endian bytes as big-endian mpint)"); - info!(" shared_secret[0] = {} (>=0x80? {})", shared_secret[0], shared_secret[0] >= 0x80); + info!( + " shared_secret[0] = {} (>=0x80? {})", + shared_secret[0], + shared_secret[0] >= 0x80 + ); info!(" exchange_hash full (32 bytes): {:?}", exchange_hash); info!(" session_id full (32 bytes): {:?}", session_id); - + // RFC 4253密钥派生公式:HASH(K || H || X || session_id) // K is shared_secret encoded as mpint (using little-endian bytes directly) let shared_secret_mpint = Self::encode_mpint(shared_secret); - - info!(" shared_secret_mpint ({} bytes): {:?}", shared_secret_mpint.len(), &shared_secret_mpint[..std::cmp::min(12, shared_secret_mpint.len())]); - - let encryption_key_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'C', &session_id)?; - let encryption_key_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'D', &session_id)?; - let mac_key_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'E', &session_id)?; - let mac_key_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'F', &session_id)?; - - let iv_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'A', &session_id)?; - let iv_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'B', &session_id)?; - + + info!( + " shared_secret_mpint ({} bytes): {:?}", + shared_secret_mpint.len(), + &shared_secret_mpint[..std::cmp::min(12, shared_secret_mpint.len())] + ); + + let encryption_key_ctos = + Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'C', &session_id)?; + let encryption_key_stoc = + Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'D', &session_id)?; + let mac_key_ctos = + Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'E', &session_id)?; + let mac_key_stoc = + Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'F', &session_id)?; + + let iv_ctos = + Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'A', &session_id)?; + let iv_stoc = + Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'B', &session_id)?; + info!("Derived keys summary:"); - info!(" encryption_key_ctos ({} bytes): {:?}", encryption_key_ctos.len(), &encryption_key_ctos[..std::cmp::min(16, encryption_key_ctos.len())]); - info!(" encryption_key_stoc ({} bytes): {:?}", encryption_key_stoc.len(), &encryption_key_stoc[..std::cmp::min(16, encryption_key_stoc.len())]); - info!(" iv_ctos ({} bytes): {:?}", iv_ctos.len(), &iv_ctos[..std::cmp::min(16, iv_ctos.len())]); - info!(" iv_stoc ({} bytes): {:?}", iv_stoc.len(), &iv_stoc[..std::cmp::min(16, iv_stoc.len())]); - info!(" mac_key_ctos ({} bytes): {:?}", mac_key_ctos.len(), &mac_key_ctos[..std::cmp::min(16, mac_key_ctos.len())]); - info!(" mac_key_stoc ({} bytes): {:?}", mac_key_stoc.len(), &mac_key_stoc[..std::cmp::min(16, mac_key_stoc.len())]); - + info!( + " encryption_key_ctos ({} bytes): {:?}", + encryption_key_ctos.len(), + &encryption_key_ctos[..std::cmp::min(16, encryption_key_ctos.len())] + ); + info!( + " encryption_key_stoc ({} bytes): {:?}", + encryption_key_stoc.len(), + &encryption_key_stoc[..std::cmp::min(16, encryption_key_stoc.len())] + ); + info!( + " iv_ctos ({} bytes): {:?}", + iv_ctos.len(), + &iv_ctos[..std::cmp::min(16, iv_ctos.len())] + ); + info!( + " iv_stoc ({} bytes): {:?}", + iv_stoc.len(), + &iv_stoc[..std::cmp::min(16, iv_stoc.len())] + ); + info!( + " mac_key_ctos ({} bytes): {:?}", + mac_key_ctos.len(), + &mac_key_ctos[..std::cmp::min(16, mac_key_ctos.len())] + ); + info!( + " mac_key_stoc ({} bytes): {:?}", + mac_key_stoc.len(), + &mac_key_stoc[..std::cmp::min(16, mac_key_stoc.len())] + ); + Ok(Self { session_id, encryption_key_ctos, @@ -122,65 +163,73 @@ impl SessionKeys { iv_stoc, }) } - + /// RFC 4253密钥派生函数 /// 公式:Key = HASH(K || H || X || session_id) fn derive_key_rfc4253(K_mpint: &[u8], H: &[u8], X: char, session_id: &[u8]) -> Result> { let mut hasher = Sha256::new(); - + info!("Deriving key for X='{}'", X); - info!(" K_mpint ({} bytes): {:?}", K_mpint.len(), &K_mpint[..std::cmp::min(8, K_mpint.len())]); + info!( + " K_mpint ({} bytes): {:?}", + K_mpint.len(), + &K_mpint[..std::cmp::min(8, K_mpint.len())] + ); info!(" H ({} bytes): {:?}", H.len(), &H[..8]); - info!(" session_id ({} bytes): {:?}", session_id.len(), &session_id[..8]); - + info!( + " session_id ({} bytes): {:?}", + session_id.len(), + &session_id[..8] + ); + // RFC 4253: HASH(K || H || X || session_id) - hasher.update(K_mpint); // K (shared secret in mpint format) - hasher.update(H); // H (exchange hash) - hasher.update(&[X as u8]); // X (single character) + hasher.update(K_mpint); // K (shared secret in mpint format) + hasher.update(H); // H (exchange hash) + hasher.update([X as u8]); // X (single character) hasher.update(session_id); // session_id - + let full_hash = hasher.finalize(); - + info!(" Derived key (first 8 bytes): {:?}", &full_hash[..8]); - + // 根據key類型返回不同長度: // AES-128-CTR key/IV: 16 bytes // HMAC-SHA256 key: 32 bytes match X { - 'A' | 'B' | 'C' | 'D' => Ok(full_hash[..16].to_vec()), // IV or encryption key + 'A' | 'B' | 'C' | 'D' => Ok(full_hash[..16].to_vec()), // IV or encryption key 'E' | 'F' => Ok(full_hash.to_vec()), // MAC key (full 32 bytes) _ => Ok(full_hash[..16].to_vec()), // default } } - + /// SSH mpint编码(参考RFC 4253 Section 5) /// Curve25519 shared secret特殊处理 fn encode_mpint(bytes: &[u8]) -> Vec { // RFC 4253: mpint = uint32(length) + data // 去掉前导零,如果最高位>=0x80前面加0 - + // 去掉前导零字节(但不去掉最后一个字节即使它是0) let mut start = 0; while start < bytes.len() - 1 && bytes[start] == 0 { start += 1; } - + let data_without_leading_zeros = &bytes[start..]; - + // 构建mpint数据 let mut mpint_data = Vec::new(); - + // 如果最高位>=0x80,前面加0字节(避免负数) if data_without_leading_zeros[0] >= 0x80 { mpint_data.push(0); } mpint_data.extend_from_slice(data_without_leading_zeros); - + // 最终格式:uint32长度 + mpint数据 let mut result = Vec::new(); result.extend_from_slice(&(mpint_data.len() as u32).to_be_bytes()); result.extend_from_slice(&mpint_data); - + result } } @@ -192,45 +241,45 @@ pub struct Ed25519HostKey { impl Ed25519HostKey { /// 加载或生成主机密钥(参考OpenSSH hostfile.c) - pub fn load_or_generate(key_path: &str) -> Result { + pub fn load_or_generate(_key_path: &str) -> Result { // 简化实现:生成临时密钥(实际应从文件加载) // 参考OpenSSH ssh-keygen - + let signing_key = SigningKey::generate(&mut OsRng); - + Ok(Self { signing_key }) } - + /// 获取公钥(用于SSH_MSG_KEX_ECDH_REPLY) pub fn public_key_bytes(&self) -> Vec { // SSH Ed25519公钥格式(参考OpenSSH sshkey.c) let verifying_key = self.signing_key.verifying_key(); - + // SSH格式:ssh-ed25519 + 公钥bytes // 简化:仅返回公钥bytes(32字节) verifying_key.as_bytes().to_vec() } - + /// 签名(参考OpenSSH sshkey.c: sshkey_sign()) pub fn sign(&self, data: &[u8]) -> Result> { // OpenSSH Ed25519签名 let signature = self.signing_key.sign(data); - + // SSH签名格式(参考OpenSSH ssh-sign.c) // 简化:仅返回签名bytes(64字节) Ok(signature.to_bytes().to_vec()) } - + /// 获取完整SSH公钥格式(参考OpenSSH sshkey.c) pub fn ssh_public_key(&self) -> String { let public_bytes = self.public_key_bytes(); - + // SSH公钥格式:ssh-ed25519 // 参考OpenSSH ssh-keygen -y - - use base64::{Engine as _, engine::general_purpose}; + + use base64::{engine::general_purpose, Engine as _}; let encoded = general_purpose::STANDARD.encode(&public_bytes); - + format!("ssh-ed25519 {}", encoded) } } @@ -238,40 +287,44 @@ impl Ed25519HostKey { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_curve25519_key_generation() { let kex = Curve25519Kex::new(); assert_eq!(kex.public_key().len(), 32); } - + #[test] fn test_curve25519_shared_secret() { let mut client_kex = Curve25519Kex::new(); let mut server_kex = Curve25519Kex::new(); - + // 客户端计算共享密钥 - let client_secret = client_kex.compute_shared_secret(server_kex.public_key()).unwrap(); - + let client_secret = client_kex + .compute_shared_secret(server_kex.public_key()) + .unwrap(); + // 服务器计算共享密钥 - let server_secret = server_kex.compute_shared_secret(client_kex.public_key()).unwrap(); - + let server_secret = server_kex + .compute_shared_secret(client_kex.public_key()) + .unwrap(); + // 应该相同(Curve25519特性) assert_eq!(client_secret, server_secret); } - + #[test] fn test_ed25519_host_key() { let host_key = Ed25519HostKey::load_or_generate("test_key").unwrap(); assert_eq!(host_key.public_key_bytes().len(), 32); } - + #[test] fn test_ed25519_signature() { let host_key = Ed25519HostKey::load_or_generate("test_key").unwrap(); let data = b"test data"; - + let signature = host_key.sign(data).unwrap(); - assert_eq!(signature.len(), 64); // Ed25519签名64字节 + assert_eq!(signature.len(), 64); // Ed25519签名64字节 } } diff --git a/markbase-core/src/ssh_server/data_forwarder.rs b/markbase-core/src/ssh_server/data_forwarder.rs index 80ce700..ff1165e 100644 --- a/markbase-core/src/ssh_server/data_forwarder.rs +++ b/markbase-core/src/ssh_server/data_forwarder.rs @@ -1,13 +1,13 @@ // SSH端口转发数据传输(Phase 13.5) // 参考OpenSSH channels.c: channel_handle_data() -use anyhow::{Result, anyhow}; -use log::{info, warn, debug}; -use std::net::{TcpStream}; -use std::io::{Read, Write}; -use std::thread; -use std::sync::{Arc, Mutex}; +use anyhow::{anyhow, Result}; use byteorder::{BigEndian, WriteBytesExt}; +use log::{debug, info, warn}; +use std::io::{Read, Write}; +use std::net::TcpStream; +use std::sync::{Arc, Mutex}; +use std::thread; /// 数据转发器(Phase 13.5:双向数据传输) pub struct DataForwarder { @@ -25,29 +25,40 @@ impl DataForwarder { max_packet_size, } } - + /// 启动双向数据转发(Phase 13.5:SSH channel ↔ TCP socket) pub fn start_bidirectional_forwarding( &mut self, - ssh_stream: TcpStream, // SSH client连接(加密通道) - target_stream: TcpStream, // 目标服务连接(TCP socket) + ssh_stream: TcpStream, // SSH client连接(加密通道) + target_stream: TcpStream, // 目标服务连接(TCP socket) ) -> Result<()> { - info!("Starting bidirectional data forwarding for channel {}", self.channel_id); - + info!( + "Starting bidirectional data forwarding for channel {}", + self.channel_id + ); + // Phase 13.5: SSH channel → Target socket(SSH client数据 → 本地服务) - let ssh_to_target = self.start_ssh_to_target_forwarding(ssh_stream.try_clone()?, target_stream.try_clone()?); - + let ssh_to_target = self + .start_ssh_to_target_forwarding(ssh_stream.try_clone()?, target_stream.try_clone()?); + // Phase 13.5: Target socket → SSH channel(本地服务数据 → SSH client) let target_to_ssh = self.start_target_to_ssh_forwarding(target_stream, ssh_stream); - + // Phase 13.5: 等待两个转发线程完成 - ssh_to_target.join().map_err(|e| anyhow!("SSH to target thread error: {:?}", e))?; - target_to_ssh.join().map_err(|e| anyhow!("Target to SSH thread error: {:?}", e))?; - - info!("Bidirectional data forwarding completed for channel {}", self.channel_id); + ssh_to_target + .join() + .map_err(|e| anyhow!("SSH to target thread error: {:?}", e))?; + target_to_ssh + .join() + .map_err(|e| anyhow!("Target to SSH thread error: {:?}", e))?; + + info!( + "Bidirectional data forwarding completed for channel {}", + self.channel_id + ); Ok(()) } - + /// SSH channel → Target socket转发(Phase 13.5) fn start_ssh_to_target_forwarding( &self, @@ -57,18 +68,21 @@ impl DataForwarder { let channel_id = self.channel_id; let window_size = self.window_size.clone(); let max_packet_size = self.max_packet_size; - + thread::spawn(move || { - info!("SSH to target forwarding thread started for channel {}", channel_id); - + info!( + "SSH to target forwarding thread started for channel {}", + channel_id + ); + let mut buffer = vec![0u8; max_packet_size as usize]; - + loop { // Phase 13.5: 从SSH channel读取数据 let n = match ssh_stream.read(&mut buffer) { Ok(0) => { info!("SSH channel EOF for channel {}", channel_id); - break; // EOF + break; // EOF } Ok(n) => n, Err(e) => { @@ -76,45 +90,61 @@ impl DataForwarder { break; } }; - + // Phase 13.5: 检查window size { let window = window_size.lock().unwrap(); if *window < n as u32 { - warn!("Window size insufficient for channel {}: need {}, have {}", - channel_id, n, *window); + warn!( + "Window size insufficient for channel {}: need {}, have {}", + channel_id, n, *window + ); // Phase 13.5: 理论上应该等待SSH_MSG_CHANNEL_WINDOW_ADJUST // 简化实现:继续发送(可能会违反RFC 4254) } } - + // Phase 13.5: 写入目标socket if let Err(e) = target_stream.write_all(&buffer[..n]) { - warn!("Target socket write error for channel {}: {}", channel_id, e); + warn!( + "Target socket write error for channel {}: {}", + channel_id, e + ); break; } - + // Phase 13.5: Flush确保数据发送 if let Err(e) = target_stream.flush() { - warn!("Target socket flush error for channel {}: {}", channel_id, e); + warn!( + "Target socket flush error for channel {}: {}", + channel_id, e + ); break; } - + // Phase 13.5: 消耗window size { let mut window = window_size.lock().unwrap(); *window -= n as u32; - debug!("Window size consumed for channel {}: {} bytes, remaining {}", - channel_id, n, *window); + debug!( + "Window size consumed for channel {}: {} bytes, remaining {}", + channel_id, n, *window + ); } - - info!("Forwarded {} bytes from SSH to target for channel {}", n, channel_id); + + info!( + "Forwarded {} bytes from SSH to target for channel {}", + n, channel_id + ); } - - info!("SSH to target forwarding thread stopped for channel {}", channel_id); + + info!( + "SSH to target forwarding thread stopped for channel {}", + channel_id + ); }) } - + /// Target socket → SSH channel转发(Phase 13.5) fn start_target_to_ssh_forwarding( &self, @@ -122,18 +152,21 @@ impl DataForwarder { mut ssh_stream: TcpStream, ) -> thread::JoinHandle<()> { let channel_id = self.channel_id; - + thread::spawn(move || { - info!("Target to SSH forwarding thread started for channel {}", channel_id); - - let mut buffer = vec![0u8; 8192]; // 8KB buffer - + info!( + "Target to SSH forwarding thread started for channel {}", + channel_id + ); + + let mut buffer = vec![0u8; 8192]; // 8KB buffer + loop { // Phase 13.5: 从目标socket读取数据 let n = match target_stream.read(&mut buffer) { Ok(0) => { info!("Target socket EOF for channel {}", channel_id); - break; // EOF + break; // EOF } Ok(n) => n, Err(e) => { @@ -141,43 +174,51 @@ impl DataForwarder { break; } }; - + // Phase 13.5: 构建SSH_MSG_CHANNEL_DATA packet // 注意:实际实现需要通过EncryptedPacket加密 // 这里简化实现,直接写入SSH stream(测试用) - + // Phase 13.5: 写入SSH channel if let Err(e) = ssh_stream.write_all(&buffer[..n]) { warn!("SSH channel write error for channel {}: {}", channel_id, e); break; } - + // Phase 13.5: Flush确保数据发送 if let Err(e) = ssh_stream.flush() { warn!("SSH channel flush error for channel {}: {}", channel_id, e); break; } - - info!("Forwarded {} bytes from target to SSH for channel {}", n, channel_id); + + info!( + "Forwarded {} bytes from target to SSH for channel {}", + n, channel_id + ); } - - info!("Target to SSH forwarding thread stopped for channel {}", channel_id); + + info!( + "Target to SSH forwarding thread stopped for channel {}", + channel_id + ); }) } - + /// 获取当前window size(Phase 13.5) pub fn get_window_size(&self) -> u32 { *self.window_size.lock().unwrap() } - + /// 增加window size(Phase 13.5:SSH_MSG_CHANNEL_WINDOW_ADJUST) pub fn adjust_window_size(&self, bytes_to_add: u32) { let mut window = self.window_size.lock().unwrap(); *window += bytes_to_add; - info!("Window size adjusted for channel {}: added {} bytes, total {}", - self.channel_id, bytes_to_add, *window); + info!( + "Window size adjusted for channel {}: added {} bytes, total {}", + self.channel_id, bytes_to_add, *window + ); } - + /// 检查window size是否足够(Phase 13.5) pub fn check_window_available(&self, data_size: u32) -> bool { let window = self.window_size.lock().unwrap(); @@ -188,64 +229,64 @@ impl DataForwarder { /// SSH_MSG_CHANNEL_DATA构建(Phase 13.5) pub fn build_channel_data_packet(channel_id: u32, data: &[u8]) -> Result> { let mut packet = Vec::new(); - + // Packet type: SSH_MSG_CHANNEL_DATA (type 94) packet.write_u8(94)?; - + // Recipient channel ID packet.write_u32::(channel_id)?; - + // Data length (SSH string) packet.write_u32::(data.len() as u32)?; - + // Data content packet.write_all(data)?; - + Ok(packet) } /// SSH_MSG_CHANNEL_WINDOW_ADJUST构建(Phase 13.5) pub fn build_window_adjust_packet(channel_id: u32, bytes_to_add: u32) -> Result> { let mut packet = Vec::new(); - + // Packet type: SSH_MSG_CHANNEL_WINDOW_ADJUST (type 93) packet.write_u8(93)?; - + // Recipient channel ID packet.write_u32::(channel_id)?; - + // Bytes to add packet.write_u32::(bytes_to_add)?; - + Ok(packet) } #[cfg(test)] mod tests { use super::*; - + #[test] fn test_data_forwarder_creation() { let forwarder = DataForwarder::new(1, 2097152, 32768); assert_eq!(forwarder.channel_id, 1); assert_eq!(forwarder.get_window_size(), 2097152); } - + #[test] fn test_window_size_adjustment() { let forwarder = DataForwarder::new(1, 2097152, 32768); - + // 消耗window size forwarder.adjust_window_size(1000); assert_eq!(forwarder.get_window_size(), 2097152 + 1000); } - + #[test] fn test_build_channel_data_packet() { let data = b"Hello, SSH!"; let packet = build_channel_data_packet(1, data).unwrap(); - - assert_eq!(packet[0], 94); // SSH_MSG_CHANNEL_DATA - // 验证packet结构 + + assert_eq!(packet[0], 94); // SSH_MSG_CHANNEL_DATA + // 验证packet结构 } } diff --git a/markbase-core/src/ssh_server/kex.rs b/markbase-core/src/ssh_server/kex.rs index 9c42931..0be9335 100644 --- a/markbase-core/src/ssh_server/kex.rs +++ b/markbase-core/src/ssh_server/kex.rs @@ -1,42 +1,42 @@ // SSH密钥交换算法协商实现(Phase 2) // 参考OpenSSH kex.c: kex_send_kexinit(), kex_choose_conf() -use crate::ssh_server::packet::{SshPacket, PacketType}; -use anyhow::{Result, anyhow}; +use crate::ssh_server::packet::{PacketType, SshPacket}; +use anyhow::{anyhow, Result}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use log::{info, debug}; +use log::{debug, info}; use std::io::{Read, Write}; /// SSH算法类型(参考OpenSSH PROTOCOL定义) #[derive(Debug, Clone, Copy, PartialEq)] pub enum AlgorithmType { - KEX_ALGS = 0, // 密钥交换算法 + KEX_ALGS = 0, // 密钥交换算法 SERVER_HOST_KEY_ALGS = 1, // 服务器主机密钥算法 - ENC_ALGS_CTOS = 2, // 客户端到服务器加密算法 - ENC_ALGS_STOC = 3, // 服务器到客户端加密算法 - MAC_ALGS_CTOS = 4, // 客户端到服务器MAC算法 - MAC_ALGS_STOC = 5, // 服务器到客户端MAC算法 - COMP_ALGS_CTOS = 6, // 客户端到服务器压缩算法 - COMP_ALGS_STOC = 7, // 服务器到客户端压缩算法 - LANGS_CTOS = 8, // 客户端到服务器语言 - LANGS_STOC = 9, // 服务器到客户端语言 + ENC_ALGS_CTOS = 2, // 客户端到服务器加密算法 + ENC_ALGS_STOC = 3, // 服务器到客户端加密算法 + MAC_ALGS_CTOS = 4, // 客户端到服务器MAC算法 + MAC_ALGS_STOC = 5, // 服务器到客户端MAC算法 + COMP_ALGS_CTOS = 6, // 客户端到服务器压缩算法 + COMP_ALGS_STOC = 7, // 服务器到客户端压缩算法 + LANGS_CTOS = 8, // 客户端到服务器语言 + LANGS_STOC = 9, // 服务器到客户端语言 } /// SSH算法提议(参考OpenSSH kex.h: struct kex) #[derive(Debug, Clone)] pub struct KexProposal { - pub kex_algorithms: String, // 密钥交换算法列表 - pub server_host_key_algorithms: String, // 主机密钥算法列表 - pub encryption_algorithms_ctos: String, // 加密算法(客户端→服务器) - pub encryption_algorithms_stoc: String, // 加密算法(服务器→客户端) - pub mac_algorithms_ctos: String, // MAC算法(客户端→服务器) - pub mac_algorithms_stoc: String, // MAC算法(服务器→客户端) + pub kex_algorithms: String, // 密钥交换算法列表 + pub server_host_key_algorithms: String, // 主机密钥算法列表 + pub encryption_algorithms_ctos: String, // 加密算法(客户端→服务器) + pub encryption_algorithms_stoc: String, // 加密算法(服务器→客户端) + pub mac_algorithms_ctos: String, // MAC算法(客户端→服务器) + pub mac_algorithms_stoc: String, // MAC算法(服务器→客户端) pub compression_algorithms_ctos: String, // 压缩算法(客户端→服务器) pub compression_algorithms_stoc: String, // 压缩算法(服务器→客户端) - pub languages_ctos: String, // 语言(客户端→服务器) - pub languages_stoc: String, // 语言(服务器→客户端) - pub first_kex_packet_follows: bool, // 是否立即发送第一个KEX packet - pub reserved: u32, // 保留字段(0) + pub languages_ctos: String, // 语言(客户端→服务器) + pub languages_stoc: String, // 语言(服务器→客户端) + pub first_kex_packet_follows: bool, // 是否立即发送第一个KEX packet + pub reserved: u32, // 保留字段(0) } impl KexProposal { @@ -46,31 +46,31 @@ impl KexProposal { Self { // 密钥交换算法:优先Curve25519(推荐) + strict KEX extension kex_algorithms: "curve25519-sha256,curve25519-sha256@libssh.org,diffie-hellman-group14-sha256,ext-info-s,kex-strict-s-v00@openssh.com".to_string(), - + // 主机密钥算法:优先Ed25519 server_host_key_algorithms: "ssh-ed25519,rsa-sha2-256,rsa-sha2-512".to_string(), - + // 加密算法:AES-256-CTR(推荐) encryption_algorithms_ctos: "aes256-ctr,aes128-ctr".to_string(), encryption_algorithms_stoc: "aes256-ctr,aes128-ctr".to_string(), - + // MAC算法:HMAC-SHA256 mac_algorithms_ctos: "hmac-sha2-256,hmac-sha2-512".to_string(), mac_algorithms_stoc: "hmac-sha2-256,hmac-sha2-512".to_string(), - + // 压缩算法:none优先 compression_algorithms_ctos: "none,zlib".to_string(), compression_algorithms_stoc: "none,zlib".to_string(), - + // 语言:空 languages_ctos: "".to_string(), languages_stoc: "".to_string(), - + first_kex_packet_follows: false, reserved: 0, } } - + /// 创建客户端默认提议(用于测试) pub fn client_default() -> Self { Self { @@ -88,20 +88,20 @@ impl KexProposal { reserved: 0, } } - + /// 序列化到SSH_MSG_KEXINIT packet(参考OpenSSH kex_send_kexinit()) pub fn to_kexinit_packet(&self) -> Result { let mut payload = Vec::new(); - + // Packet type payload.write_u8(PacketType::SSH_MSG_KEXINIT as u8)?; - + // Cookie(16字节随机数,OpenSSH要求) let mut cookie = [0u8; 16]; use rand::Rng; rand::thread_rng().fill(&mut cookie); payload.write_all(&cookie)?; - + // 10个算法列表(SSH string格式:length + data) write_ssh_string(&mut payload, &self.kex_algorithms)?; write_ssh_string(&mut payload, &self.server_host_key_algorithms)?; @@ -113,29 +113,29 @@ impl KexProposal { write_ssh_string(&mut payload, &self.compression_algorithms_stoc)?; write_ssh_string(&mut payload, &self.languages_ctos)?; write_ssh_string(&mut payload, &self.languages_stoc)?; - + // first_kex_packet_follows(boolean) payload.write_u8(if self.first_kex_packet_follows { 1 } else { 0 })?; - + // reserved(u32) payload.write_u32::(self.reserved)?; - + Ok(SshPacket::new(payload)) } - + /// 从SSH_MSG_KEXINIT packet解析(参考OpenSSH kex_input_kexinit()) pub fn from_kexinit_packet(packet: &SshPacket) -> Result { - let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准) - + let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准) + // Packet type let packet_type = cursor.read_u8()?; if packet_type != PacketType::SSH_MSG_KEXINIT as u8 { return Err(anyhow!("Invalid packet type for KEXINIT")); } - + // Cookie(16字节,忽略) cursor.read_exact(&mut [0u8; 16])?; - + // 10个算法列表 let kex_algorithms = read_ssh_string(&mut cursor)?; let server_host_key_algorithms = read_ssh_string(&mut cursor)?; @@ -147,13 +147,13 @@ impl KexProposal { let compression_algorithms_stoc = read_ssh_string(&mut cursor)?; let languages_ctos = read_ssh_string(&mut cursor)?; let languages_stoc = read_ssh_string(&mut cursor)?; - + // first_kex_packet_follows let first_kex_packet_follows = cursor.read_u8()? != 0; - + // reserved let reserved = cursor.read_u32::()?; - + Ok(Self { kex_algorithms, server_host_key_algorithms, @@ -174,14 +174,14 @@ impl KexProposal { /// SSH算法协商结果(参考OpenSSH struct kex) #[derive(Debug, Clone)] pub struct KexResult { - pub kex_algorithm: String, // 选定的密钥交换算法 - pub host_key_algorithm: String, // 选定的主机密钥算法 - pub encryption_ctos: String, // 选定的加密算法(客户端→服务器) - pub encryption_stoc: String, // 选定的加密算法(服务器→客户端) - pub mac_ctos: String, // 选定的MAC算法(客户端→服务器) - pub mac_stoc: String, // 选定的MAC算法(服务器→客户端) - pub compression_ctos: String, // 选定的压缩算法(客户端→服务器) - pub compression_stoc: String, // 选定的压缩算法(服务器→客户端) + pub kex_algorithm: String, // 选定的密钥交换算法 + pub host_key_algorithm: String, // 选定的主机密钥算法 + pub encryption_ctos: String, // 选定的加密算法(客户端→服务器) + pub encryption_stoc: String, // 选定的加密算法(服务器→客户端) + pub mac_ctos: String, // 选定的MAC算法(客户端→服务器) + pub mac_stoc: String, // 选定的MAC算法(服务器→客户端) + pub compression_ctos: String, // 选定的压缩算法(客户端→服务器) + pub compression_stoc: String, // 选定的压缩算法(服务器→客户端) } /// 算法匹配逻辑(参考OpenSSH kex_choose_conf()) @@ -189,28 +189,43 @@ impl KexResult { /// 从服务器和客户端提议中选择算法(参考OpenSSH kex_choose_conf()) pub fn choose_algorithms(server: &KexProposal, client: &KexProposal) -> Result { info!("Starting algorithm negotiation"); - + // 算法匹配:优先客户端偏好(OpenSSH逻辑) // 参考OpenSSH:客户端列出的算法顺序为偏好顺序 - + // 密钥交换算法匹配 let kex_algorithm = match_algorithm(&client.kex_algorithms, &server.kex_algorithms)?; - + // 主机密钥算法匹配 - let host_key_algorithm = match_algorithm(&client.server_host_key_algorithms, &server.server_host_key_algorithms)?; - + let host_key_algorithm = match_algorithm( + &client.server_host_key_algorithms, + &server.server_host_key_algorithms, + )?; + // 加密算法匹配 - let encryption_ctos = match_algorithm(&client.encryption_algorithms_ctos, &server.encryption_algorithms_ctos)?; - let encryption_stoc = match_algorithm(&client.encryption_algorithms_stoc, &server.encryption_algorithms_stoc)?; - + let encryption_ctos = match_algorithm( + &client.encryption_algorithms_ctos, + &server.encryption_algorithms_ctos, + )?; + let encryption_stoc = match_algorithm( + &client.encryption_algorithms_stoc, + &server.encryption_algorithms_stoc, + )?; + // MAC算法匹配 let mac_ctos = match_algorithm(&client.mac_algorithms_ctos, &server.mac_algorithms_ctos)?; let mac_stoc = match_algorithm(&client.mac_algorithms_stoc, &server.mac_algorithms_stoc)?; - + // 压缩算法匹配 - let compression_ctos = match_algorithm(&client.compression_algorithms_ctos, &server.compression_algorithms_ctos)?; - let compression_stoc = match_algorithm(&client.compression_algorithms_stoc, &server.compression_algorithms_stoc)?; - + let compression_ctos = match_algorithm( + &client.compression_algorithms_ctos, + &server.compression_algorithms_ctos, + )?; + let compression_stoc = match_algorithm( + &client.compression_algorithms_stoc, + &server.compression_algorithms_stoc, + )?; + info!("Algorithm negotiation completed:"); debug!(" KEX: {}", kex_algorithm); debug!(" Host key: {}", host_key_algorithm); @@ -218,7 +233,7 @@ impl KexResult { debug!(" Encryption (S->C): {}", encryption_stoc); debug!(" MAC (C->S): {}", mac_ctos); debug!(" MAC (S->C): {}", mac_stoc); - + Ok(Self { kex_algorithm, host_key_algorithm, @@ -237,15 +252,19 @@ fn match_algorithm(client_algs: &str, server_algs: &str) -> Result { // 算法列表格式:name1,name2,name3,... let client_list: Vec<&str> = client_algs.split(',').collect(); let server_list: Vec<&str> = server_algs.split(',').collect(); - + // OpenSSH逻辑:按客户端偏好顺序匹配 for client_alg in &client_list { if server_list.contains(client_alg) { return Ok(client_alg.to_string()); } } - - Err(anyhow!("No matching algorithm found: client={}, server={}", client_algs, server_algs)) + + Err(anyhow!( + "No matching algorithm found: client={}, server={}", + client_algs, + server_algs + )) } /// SSH string写入辅助函数(length + data) @@ -266,36 +285,36 @@ fn read_ssh_string(reader: &mut R) -> Result { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_kex_proposal_creation() { let proposal = KexProposal::server_default(); assert!(proposal.kex_algorithms.contains("curve25519-sha256")); } - + #[test] fn test_kex_proposal_serialization() { let proposal = KexProposal::server_default(); let packet = proposal.to_kexinit_packet().unwrap(); assert!(packet.payload.len() > 0); } - + #[test] fn test_algorithm_matching() { let client = "curve25519-sha256,aes256-ctr"; let server = "aes256-ctr,diffie-hellman-group14-sha256"; - + let matched = match_algorithm(client, server).unwrap(); - assert_eq!(matched, "aes256-ctr"); // 按客户端顺序匹配 + assert_eq!(matched, "aes256-ctr"); // 按客户端顺序匹配 } - + #[test] fn test_kex_negotiation() { let server = KexProposal::server_default(); let client = KexProposal::client_default(); - + let result = KexResult::choose_algorithms(&server, &client).unwrap(); - assert_eq!(result.kex_algorithm, "curve25519-sha256"); // 优先Curve25519 - assert_eq!(result.encryption_ctos, "aes256-ctr"); // AES-256-CTR + assert_eq!(result.kex_algorithm, "curve25519-sha256"); // 优先Curve25519 + assert_eq!(result.encryption_ctos, "aes256-ctr"); // AES-256-CTR } } diff --git a/markbase-core/src/ssh_server/kex_complete.rs b/markbase-core/src/ssh_server/kex_complete.rs index 4f43c78..fd00dee 100644 --- a/markbase-core/src/ssh_server/kex_complete.rs +++ b/markbase-core/src/ssh_server/kex_complete.rs @@ -1,14 +1,13 @@ // SSH密钥交换完整流程(Phase 3剩余) // 参考OpenSSH kex.c: complete implementation -use crate::ssh_server::packet::{SshPacket, PacketType}; +use crate::ssh_server::crypto::SessionKeys; use crate::ssh_server::kex::{KexProposal, KexResult}; -use crate::ssh_server::crypto::{SessionKeys}; use crate::ssh_server::kex_exchange::KexExchangeHandler; -use anyhow::{Result, anyhow}; -use sha2::{Sha256, Digest}; -use byteorder::{BigEndian, WriteBytesExt}; -use log::{info, debug}; +use crate::ssh_server::packet::{PacketType, SshPacket}; +use anyhow::{anyhow, Result}; +use log::info; +use sha2::{Digest, Sha256}; /// SSH密钥交换完整状态管理(参考OpenSSH struct kex) pub struct KexState { @@ -30,7 +29,7 @@ impl KexState { kex_result: KexResult, ) -> Result { let exchange_handler = KexExchangeHandler::new(kex_result)?; - + Ok(Self { client_version, server_version, @@ -42,18 +41,18 @@ impl KexState { newkeys_sent: false, }) } - + /// 保存KEXINIT payloads(用于Exchange Hash计算) - /// + /// /// 分析OpenSSH源码后的结论: /// - kex->peer存储的是:incoming_packet剩余内容(payload fields + padding) /// - kex->my存储的是:prop2buf()结果(payload fields,不包括padding) - /// + /// /// **但exchange hash必须使用相同的I_C/I_S!** - /// + /// /// 疑问:OpenSSH如何确保client和server使用相同的padding? /// 可能答案:OpenSSH在计算exchange hash时,不包括padding? - /// + /// /// 暂时保持不包括padding(因为签名验证之前成功) pub fn save_kexinit_payloads( &mut self, @@ -63,12 +62,18 @@ impl KexState { // Only save payload (without padding) for now self.client_kexinit_payload = client_kexinit.payload.clone(); self.server_kexinit_payload = server_kexinit.payload.clone(); - + info!("Saved KEXINIT payloads (payload only, no padding)"); - info!(" client payload: {} bytes", self.client_kexinit_payload.len()); - info!(" server payload: {} bytes", self.server_kexinit_payload.len()); + info!( + " client payload: {} bytes", + self.client_kexinit_payload.len() + ); + info!( + " server payload: {} bytes", + self.server_kexinit_payload.len() + ); } - + /// 计算Exchange Hash(参考OpenSSH kex.c: kex_hash()) /// H = SHA256(V_C || V_S || I_C || I_S || K_S || K_C || K_S || shared_secret) pub fn compute_exchange_hash( @@ -80,74 +85,74 @@ impl KexState { ) -> Result> { // 参考OpenSSH kex.c: kex_hash() let mut hasher = Sha256::new(); - + // V_C: 客户端版本字符串(SSH string格式) write_ssh_string_to_hash(&mut hasher, &self.client_version)?; - + // V_S: 服务器版本字符串(SSH string格式) write_ssh_string_to_hash(&mut hasher, &self.server_version)?; - + // OpenSSH kexgex.c: "kexinit messages: fake header: len+SSH2_MSG_KEXINIT" // Remove SSH_MSG_KEXINIT type byte from payloads and prepend it in exchange hash - + let client_kexinit_without_type = &self.client_kexinit_payload[1..]; let server_kexinit_without_type = &self.server_kexinit_payload[1..]; - - hasher.update(&((client_kexinit_without_type.len() + 1) as u32).to_be_bytes()); - hasher.update(&[20]); // SSH_MSG_KEXINIT type byte + + hasher.update(((client_kexinit_without_type.len() + 1) as u32).to_be_bytes()); + hasher.update([20]); // SSH_MSG_KEXINIT type byte hasher.update(client_kexinit_without_type); - - hasher.update(&((server_kexinit_without_type.len() + 1) as u32).to_be_bytes()); - hasher.update(&[20]); // SSH_MSG_KEXINIT type byte + + hasher.update(((server_kexinit_without_type.len() + 1) as u32).to_be_bytes()); + hasher.update([20]); // SSH_MSG_KEXINIT type byte hasher.update(server_kexinit_without_type); - + // K_S: 服务器主机密钥blob(SSH string格式) hasher.update(server_host_key_blob); - + // K_C: 客户端Curve25519公钥(SSH string格式) write_ssh_bytes_to_hash(&mut hasher, client_public_key)?; - + // K_S: 服务器Curve25519公钥(SSH string格式) write_ssh_bytes_to_hash(&mut hasher, server_public_key)?; - + // K: 共享密钥(SSH mpint格式) // OpenSSH要求:去掉前导零 write_ssh_mpint_to_hash(&mut hasher, shared_secret)?; - + Ok(hasher.finalize().to_vec()) } - + /// 处理SSH_MSG_NEWKEYS(参考OpenSSH kex.c: kex_input_newkeys()) pub fn handle_newkeys(&mut self, packet: &SshPacket) -> Result<()> { info!("Processing SSH_MSG_NEWKEYS"); - + // 验证packet类型 - if packet.payload.len() < 1 { + if packet.payload.is_empty() { return Err(anyhow!("Invalid NEWKEYS packet")); } - + let packet_type = packet.payload[0]; if packet_type != PacketType::SSH_MSG_NEWKEYS as u8 { return Err(anyhow!("Invalid packet type for NEWKEYS")); } - + // 标记NEWKEYS接收完成(参考OpenSSH) self.newkeys_received = true; - + info!("SSH_MSG_NEWKEYS received, encryption channel ready"); - + Ok(()) } - + /// 发送SSH_MSG_NEWKEYS(参考OpenSSH kex.c: kex_send_newkeys()) pub fn send_newkeys() -> Result { info!("Sending SSH_MSG_NEWKEYS"); - + let payload = vec![PacketType::SSH_MSG_NEWKEYS as u8]; - + Ok(SshPacket::new(payload)) } - + /// 检查NEWKEYS完成状态(加密通道建立) pub fn is_encryption_ready(&self) -> bool { self.newkeys_received && self.newkeys_sent @@ -156,14 +161,14 @@ impl KexState { /// SSH string写入到hash(辅助函数) fn write_ssh_string_to_hash(hasher: &mut Sha256, s: &str) -> Result<()> { - hasher.update(&(s.len() as u32).to_be_bytes()); + hasher.update((s.len() as u32).to_be_bytes()); hasher.update(s.as_bytes()); Ok(()) } /// SSH bytes写入到hash(辅助函数) fn write_ssh_bytes_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> { - hasher.update(&(bytes.len() as u32).to_be_bytes()); + hasher.update((bytes.len() as u32).to_be_bytes()); hasher.update(bytes); Ok(()) } @@ -171,7 +176,7 @@ fn write_ssh_bytes_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> { /// SSH mpint写入到hash(参考OpenSSH sshbuf_put_mpint()) fn write_ssh_mpint_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> { // OpenSSH要求:去掉前导零(如果最高位为1) - let mpint_bytes = if bytes.len() > 0 && bytes[0] >= 0x80 { + let mpint_bytes = if !bytes.is_empty() && bytes[0] >= 0x80 { // 需要添加前导零(避免负数) let mut mpint = vec![0u8]; mpint.extend_from_slice(bytes); @@ -179,61 +184,67 @@ fn write_ssh_mpint_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> { } else { bytes.to_vec() }; - - hasher.update(&(mpint_bytes.len() as u32).to_be_bytes()); + + hasher.update((mpint_bytes.len() as u32).to_be_bytes()); hasher.update(&mpint_bytes); - + Ok(()) } #[cfg(test)] mod tests { use super::*; - + #[test] fn test_exchange_hash_computation() { let kex_result = KexResult::choose_algorithms( &KexProposal::server_default(), &KexProposal::client_default(), - ).unwrap(); - + ) + .unwrap(); + let mut state = KexState::new( "SSH-2.0-OpenSSH_10.2".to_string(), "SSH-2.0-MarkBaseSSH_1.0".to_string(), kex_result, - ).unwrap(); - + ) + .unwrap(); + // Set minimal KEXINIT payloads (need at least 1 byte for packet type) - state.client_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte - state.server_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte - + state.client_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte + state.server_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte + let shared_secret = vec![0u8; 32]; let host_key = vec![0u8; 32]; let client_pub = vec![0u8; 32]; let server_pub = vec![0u8; 32]; - - let hash = state.compute_exchange_hash(&shared_secret, &host_key, &client_pub, &server_pub).unwrap(); - - assert_eq!(hash.len(), 32); // SHA256输出32字节 + + let hash = state + .compute_exchange_hash(&shared_secret, &host_key, &client_pub, &server_pub) + .unwrap(); + + assert_eq!(hash.len(), 32); // SHA256输出32字节 } - + #[test] fn test_newkeys_handling() { let kex_result = KexResult::choose_algorithms( &KexProposal::server_default(), &KexProposal::client_default(), - ).unwrap(); - + ) + .unwrap(); + let mut state = KexState::new( "SSH-2.0-OpenSSH_10.2".to_string(), "SSH-2.0-MarkBaseSSH_1.0".to_string(), kex_result, - ).unwrap(); - + ) + .unwrap(); + let newkeys_packet = SshPacket::new(vec![PacketType::SSH_MSG_NEWKEYS as u8]); - + state.handle_newkeys(&newkeys_packet).unwrap(); - + assert!(state.newkeys_received); } } diff --git a/markbase-core/src/ssh_server/kex_exchange.rs b/markbase-core/src/ssh_server/kex_exchange.rs index 5474fbe..35d3cd8 100644 --- a/markbase-core/src/ssh_server/kex_exchange.rs +++ b/markbase-core/src/ssh_server/kex_exchange.rs @@ -1,14 +1,14 @@ // SSH密钥交换流程实现(Phase 3) // 参考OpenSSH kex.c: kex_input_kex_init(), kex_send_kex_reply() -use crate::ssh_server::packet::{SshPacket, PacketType}; -use crate::ssh_server::kex::{KexResult}; -use crate::ssh_server::crypto::{Curve25519Kex, SessionKeys, Ed25519HostKey}; -use anyhow::{Result, anyhow}; +use crate::ssh_server::crypto::{Curve25519Kex, Ed25519HostKey, SessionKeys}; +use crate::ssh_server::kex::KexResult; +use crate::ssh_server::packet::{PacketType, SshPacket}; +use anyhow::{anyhow, Result}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use log::{info, debug}; +use log::info; +use sha2::Digest; use std::io::{Read, Write}; -use sha2::{Sha256, Digest}; /// SSH密钥交换流程处理器(参考OpenSSH kex.c) pub struct KexExchangeHandler { @@ -18,7 +18,7 @@ pub struct KexExchangeHandler { shared_secret: Option>, client_public_key: Option>, server_public_key: Option>, - exchange_hash: Option>, // 保存exchange hash(H参数) + exchange_hash: Option>, // 保存exchange hash(H参数) client_version: Option, server_version: Option, client_kexinit_payload: Option>, @@ -30,7 +30,7 @@ impl KexExchangeHandler { pub fn new(kex_result: KexResult) -> Result { // 加载或生成服务器主机密钥 let host_key = Ed25519HostKey::load_or_generate("config/ssh_host_ed25519_key")?; - + Ok(Self { kex_algorithm: kex_result.kex_algorithm, server_kex: None, @@ -45,10 +45,10 @@ impl KexExchangeHandler { server_kexinit_payload: None, }) } - -/// 处理SSH_MSG_KEXDH_INIT(Curve25519密钥交换)(参考OpenSSH kex.c: kex_input_kex_init()) + + /// 处理SSH_MSG_KEXDH_INIT(Curve25519密钥交换)(参考OpenSSH kex.c: kex_input_kex_init()) pub fn handle_kexdh_init( - &mut self, + &mut self, packet: &SshPacket, client_version: &str, server_version: &str, @@ -56,41 +56,44 @@ impl KexExchangeHandler { server_kexinit_payload: &[u8], ) -> Result { info!("Processing SSH_MSG_KEXDH_INIT (Curve25519)"); - + let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); - + let packet_type = cursor.read_u8()?; if packet_type != PacketType::SSH_MSG_KEXDH_INIT as u8 { return Err(anyhow!("Invalid packet type for KEXDH_INIT")); } - + let key_length = cursor.read_u32::()?; if key_length != 32 { - return Err(anyhow!("Invalid Curve25519 public key length: {}", key_length)); + return Err(anyhow!( + "Invalid Curve25519 public key length: {}", + key_length + )); } - + let mut client_public_key = vec![0u8; 32]; cursor.read_exact(&mut client_public_key)?; - + self.server_kex = Some(Curve25519Kex::new()); let server_kex = self.server_kex.as_mut().unwrap(); - + let shared_secret = server_kex.compute_shared_secret(&client_public_key)?; let server_public_key = server_kex.public_key().to_vec(); - + // Save for later session key computation self.shared_secret = Some(shared_secret.to_vec()); self.client_public_key = Some(client_public_key.clone()); self.server_public_key = Some(server_public_key.clone()); - + // Save client_version, server_version, kexinit payloads for exchange hash self.client_version = Some(client_version.to_string()); self.server_version = Some(server_version.to_string()); self.client_kexinit_payload = Some(client_kexinit_payload.to_vec()); self.server_kexinit_payload = Some(server_kexinit_payload.to_vec()); - + info!("Curve25519 shared secret computed and saved"); - + // Compute exchange hash ONCE and reuse it let host_key_blob = self.build_ssh_host_key()?; let exchange_hash = self.compute_exchange_hash( @@ -103,69 +106,69 @@ impl KexExchangeHandler { client_kexinit_payload, server_kexinit_payload, )?; - + info!("Exchange hash computed:"); - info!(" shared_secret[0] = {} (>=0x80? {})", shared_secret[0], shared_secret[0] >= 0x80); + info!( + " shared_secret[0] = {} (>=0x80? {})", + shared_secret[0], + shared_secret[0] >= 0x80 + ); info!(" exchange_hash full (32 bytes): {:?}", exchange_hash); - + self.exchange_hash = Some(exchange_hash.clone()); info!("Exchange hash saved for key derivation"); - - self.build_kexdh_reply( - &exchange_hash, - &host_key_blob, - &server_public_key, - ) + + self.build_kexdh_reply(&exchange_hash, &host_key_blob, &server_public_key) } - + /// 构建SSH_MSG_KEXDH_REPLY packet(参考OpenSSH kex.c) fn build_kexdh_reply( - &self, - exchange_hash: &[u8], + &self, + exchange_hash: &[u8], host_key_blob: &[u8], server_public_key: &[u8], ) -> Result { info!("=== Building SSH_MSG_KEXDH_REPLY ==="); info!("Input server_public_key: {:?}", server_public_key); - + let mut payload = Vec::new(); - + payload.write_u8(PacketType::SSH_MSG_KEXDH_REPLY as u8)?; - + payload.write_u32::(host_key_blob.len() as u32)?; payload.write_all(host_key_blob)?; - + info!("Writing server_public_key to payload (32 bytes)"); payload.write_u32::(32)?; payload.write_all(server_public_key)?; - + let signature = self.build_exchange_signature(exchange_hash)?; payload.write_u32::(signature.len() as u32)?; payload.write_all(&signature)?; - + info!("SSH_MSG_KEXDH_REPLY payload built successfully"); Ok(SshPacket::new(payload)) } - + /// 构建SSH主机密钥blob(参考OpenSSH sshkey.c: sshkey_to_blob()) fn build_ssh_host_key(&self) -> Result> { let mut blob = Vec::new(); - + // SSH key format: key-type + public-key // 参考OpenSSH sshkey.c - + // Key type: ssh-ed25519 - blob.write_u32::(11)?; // "ssh-ed25519".len() + blob.write_u32::(11)?; // "ssh-ed25519".len() blob.write_all("ssh-ed25519".as_bytes())?; - + // Ed25519公钥(32字节) let public_key = self.host_key.public_key_bytes(); blob.write_u32::(32)?; blob.write_all(&public_key)?; - + Ok(blob) } - + /// 计算Exchange Hash(参考OpenSSH kex.c: kex_hash() RFC 4253 Section 7.2) fn compute_exchange_hash( &self, @@ -178,94 +181,147 @@ impl KexExchangeHandler { client_kexinit_payload: &[u8], server_kexinit_payload: &[u8], ) -> Result> { - use sha2::{Sha256, Digest}; - + use sha2::{Digest, Sha256}; + info!("=== EXCHANGE HASH COMPUTATION ==="); info!("V_C (client version): {:?}", client_version.as_bytes()); info!("V_C length: {}", client_version.len()); - + info!("V_S (server version): {:?}", server_version.as_bytes()); info!("V_S length: {}", server_version.len()); - - info!("I_C (client KEXINIT payload): {:?}", &client_kexinit_payload[..std::cmp::min(50, client_kexinit_payload.len())]); + + info!( + "I_C (client KEXINIT payload): {:?}", + &client_kexinit_payload[..std::cmp::min(50, client_kexinit_payload.len())] + ); info!("I_C length: {}", client_kexinit_payload.len()); - info!("I_C[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)", client_kexinit_payload[0]); - - info!("I_S (server KEXINIT payload): {:?}", &server_kexinit_payload[..std::cmp::min(50, server_kexinit_payload.len())]); + info!( + "I_C[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)", + client_kexinit_payload[0] + ); + + info!( + "I_S (server KEXINIT payload): {:?}", + &server_kexinit_payload[..std::cmp::min(50, server_kexinit_payload.len())] + ); info!("I_S length: {}", server_kexinit_payload.len()); - info!("I_S[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)", server_kexinit_payload[0]); - - info!("K_S (host key blob): {:?}", &host_key_blob[..std::cmp::min(30, host_key_blob.len())]); + info!( + "I_S[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)", + server_kexinit_payload[0] + ); + + info!( + "K_S (host key blob): {:?}", + &host_key_blob[..std::cmp::min(30, host_key_blob.len())] + ); info!("K_S length: {}", host_key_blob.len()); - - info!("Q_C (client ECDH public key): {:?}", &client_public_key[..std::cmp::min(16, client_public_key.len())]); + + info!( + "Q_C (client ECDH public key): {:?}", + &client_public_key[..std::cmp::min(16, client_public_key.len())] + ); info!("Q_C full (32 bytes): {:?}", client_public_key); info!("Q_C length: {}", client_public_key.len()); - - info!("Q_S (server ECDH public key): {:?}", &server_public_key[..std::cmp::min(16, server_public_key.len())]); + + info!( + "Q_S (server ECDH public key): {:?}", + &server_public_key[..std::cmp::min(16, server_public_key.len())] + ); info!("Q_S full (32 bytes): {:?}", server_public_key); info!("Q_S length: {}", server_public_key.len()); - + let mut hasher = Sha256::new(); - + // RFC 4253 Section 7: V_C and V_S are version strings (without \r\n based on testing) let vc_ssh_string = &(client_version.len() as u32).to_be_bytes(); hasher.update(vc_ssh_string); hasher.update(client_version.as_bytes()); - info!(" Exchange hash component V_C: len={} bytes=[{:?}] data=[{:?}]", 4+client_version.len(), vc_ssh_string, client_version.as_bytes()); - + info!( + " Exchange hash component V_C: len={} bytes=[{:?}] data=[{:?}]", + 4 + client_version.len(), + vc_ssh_string, + client_version.as_bytes() + ); + let vs_ssh_string = &(server_version.len() as u32).to_be_bytes(); hasher.update(vs_ssh_string); hasher.update(server_version.as_bytes()); - info!(" Exchange hash component V_S: len={} bytes=[{:?}] data=[{:?}]", 4+server_version.len(), vs_ssh_string, server_version.as_bytes()); - + info!( + " Exchange hash component V_S: len={} bytes=[{:?}] data=[{:?}]", + 4 + server_version.len(), + vs_ssh_string, + server_version.as_bytes() + ); + // OpenSSH kexgex.c: "kexinit messages: fake header: len+SSH2_MSG_KEXINIT" // KEXINIT payload should NOT include SSH_MSG_KEXINIT type byte // OpenSSH stores payload starting from cookie, prepends SSH_MSG_KEXINIT in exchange hash - + // Remove SSH_MSG_KEXINIT type byte from payloads (our payload includes it) let client_kexinit_without_type = &client_kexinit_payload[1..]; let server_kexinit_without_type = &server_kexinit_payload[1..]; - - info!("I_C (client KEXINIT without type byte): {} bytes (first byte should be cookie)", client_kexinit_without_type.len()); - info!("I_S (server KEXINIT without type byte): {} bytes", server_kexinit_without_type.len()); - + + info!( + "I_C (client KEXINIT without type byte): {} bytes (first byte should be cookie)", + client_kexinit_without_type.len() + ); + info!( + "I_S (server KEXINIT without type byte): {} bytes", + server_kexinit_without_type.len() + ); + // Exchange hash: uint32(len+1) + uint8(SSH_MSG_KEXINIT) + payload_without_type let ic_len_bytes = &((client_kexinit_without_type.len() + 1) as u32).to_be_bytes(); hasher.update(ic_len_bytes); - hasher.update(&[20]); // SSH_MSG_KEXINIT type byte + hasher.update([20]); // SSH_MSG_KEXINIT type byte hasher.update(client_kexinit_without_type); info!(" Exchange hash component I_C: len={} bytes=[{:?}] type=[20] payload_len={} (first 8 bytes=[{:?}])", 4+1+client_kexinit_without_type.len(), ic_len_bytes, client_kexinit_without_type.len(), &client_kexinit_without_type[..std::cmp::min(8, client_kexinit_without_type.len())]); - + let is_len_bytes = &((server_kexinit_without_type.len() + 1) as u32).to_be_bytes(); hasher.update(is_len_bytes); - hasher.update(&[20]); // SSH_MSG_KEXINIT type byte + hasher.update([20]); // SSH_MSG_KEXINIT type byte hasher.update(server_kexinit_without_type); info!(" Exchange hash component I_S: len={} bytes=[{:?}] type=[20] payload_len={} (first 8 bytes=[{:?}])", 4+1+server_kexinit_without_type.len(), is_len_bytes, server_kexinit_without_type.len(), &server_kexinit_without_type[..std::cmp::min(8, server_kexinit_without_type.len())]); - + let ks_len_bytes = &(host_key_blob.len() as u32).to_be_bytes(); hasher.update(ks_len_bytes); hasher.update(host_key_blob); - info!(" Exchange hash component K_S: len={} bytes=[{:?}] blob_len={} (full=[{:?}])", 4+host_key_blob.len(), ks_len_bytes, host_key_blob.len(), host_key_blob); - + info!( + " Exchange hash component K_S: len={} bytes=[{:?}] blob_len={} (full=[{:?}])", + 4 + host_key_blob.len(), + ks_len_bytes, + host_key_blob.len(), + host_key_blob + ); + let qc_len_bytes = &(client_public_key.len() as u32).to_be_bytes(); hasher.update(qc_len_bytes); hasher.update(client_public_key); - info!(" Exchange hash component Q_C: len={} bytes=[{:?}] key=[{:?}]", 4+client_public_key.len(), qc_len_bytes, client_public_key); - + info!( + " Exchange hash component Q_C: len={} bytes=[{:?}] key=[{:?}]", + 4 + client_public_key.len(), + qc_len_bytes, + client_public_key + ); + let qs_len_bytes = &(server_public_key.len() as u32).to_be_bytes(); hasher.update(qs_len_bytes); hasher.update(server_public_key); - info!(" Exchange hash component Q_S: len={} bytes=[{:?}] key=[{:?}]", 4+server_public_key.len(), qs_len_bytes, server_public_key); - + info!( + " Exchange hash component Q_S: len={} bytes=[{:?}] key=[{:?}]", + 4 + server_public_key.len(), + qs_len_bytes, + server_public_key + ); + info!("Exchange hash components:"); info!(" shared_secret raw full (32 bytes): {:?}", shared_secret); - + // RFC 8731 Section 3.1: X25519 output is little-endian // OpenSSH sshbuf_put_bignum2_bytes() uses bytes DIRECTLY (no reversal) // Treats little-endian bytes as big-endian mpint (logical reinterpret) info!(" Using shared_secret directly (little-endian bytes as big-endian mpint)"); - + // RFC 4253: mpint格式 = 去掉前导零 + 最高位>=0x80时前面加0 // 参考OpenSSH sshbuf_put_bignum2_bytes() let mut start = 0; @@ -273,64 +329,73 @@ impl KexExchangeHandler { start += 1; } let trimmed_shared_secret = &shared_secret[start..]; - - info!(" shared_secret after removing leading zeros ({} bytes): {:?}", trimmed_shared_secret.len(), trimmed_shared_secret); - - let mpint_shared_secret_data = if trimmed_shared_secret.len() > 0 && trimmed_shared_secret[0] >= 0x80 { - let mut mpint = vec![0u8]; - mpint.extend_from_slice(trimmed_shared_secret); - info!(" trimmed_shared_secret[0] >= 0x80, prepending 0 byte"); - mpint - } else { - trimmed_shared_secret.to_vec() - }; - - info!(" mpint_shared_secret_data ({} bytes): {:?}", mpint_shared_secret_data.len(), &mpint_shared_secret_data[..std::cmp::min(8, mpint_shared_secret_data.len())]); - + + info!( + " shared_secret after removing leading zeros ({} bytes): {:?}", + trimmed_shared_secret.len(), + trimmed_shared_secret + ); + + let mpint_shared_secret_data = + if !trimmed_shared_secret.is_empty() && trimmed_shared_secret[0] >= 0x80 { + let mut mpint = vec![0u8]; + mpint.extend_from_slice(trimmed_shared_secret); + info!(" trimmed_shared_secret[0] >= 0x80, prepending 0 byte"); + mpint + } else { + trimmed_shared_secret.to_vec() + }; + + info!( + " mpint_shared_secret_data ({} bytes): {:?}", + mpint_shared_secret_data.len(), + &mpint_shared_secret_data[..std::cmp::min(8, mpint_shared_secret_data.len())] + ); + // mpint格式 = uint32(length) + mpint_data let mpint_len_bytes = &(mpint_shared_secret_data.len() as u32).to_be_bytes(); hasher.update(mpint_len_bytes); hasher.update(&mpint_shared_secret_data); info!(" Exchange hash component K (shared secret mpint): len={} bytes=[{:?}] data_len={} (first 8 bytes=[{:?}])", 4+mpint_shared_secret_data.len(), mpint_len_bytes, mpint_shared_secret_data.len(), &mpint_shared_secret_data[..std::cmp::min(8, mpint_shared_secret_data.len())]); - + Ok(hasher.finalize().to_vec()) } - + /// 构建交换签名(参考OpenSSH ssh-sign.c) fn build_exchange_signature(&self, exchange_hash: &[u8]) -> Result> { let signature_bytes = self.host_key.sign(exchange_hash)?; - + let mut ssh_signature = Vec::new(); - + ssh_signature.write_u32::(11)?; ssh_signature.write_all("ssh-ed25519".as_bytes())?; - + ssh_signature.write_u32::(64)?; ssh_signature.write_all(&signature_bytes)?; - + Ok(ssh_signature) } - + /// 计算会话密钥(参考OpenSSH kex.c: derive_keys()) /// 使用保存的exchange_hash(H参数) pub fn compute_session_keys(&self) -> Result { if self.shared_secret.is_none() { return Err(anyhow!("No shared secret available")); } - + if self.exchange_hash.is_none() { return Err(anyhow!("No exchange hash available")); } - + let shared_secret = self.shared_secret.as_ref().unwrap(); let exchange_hash = self.exchange_hash.as_ref().unwrap(); let server_public_key = self.server_public_key.as_ref().unwrap(); let client_public_key = self.client_public_key.as_ref().unwrap(); let host_key_blob = self.build_ssh_host_key()?; - + SessionKeys::derive( shared_secret, - exchange_hash, // 使用保存的exchange hash(H参数) + exchange_hash, // 使用保存的exchange hash(H参数) server_public_key, client_public_key, &host_key_blob, @@ -342,13 +407,13 @@ impl KexExchangeHandler { mod tests { use super::*; use crate::ssh_server::kex::KexProposal; - + #[test] fn test_kex_exchange_handler_creation() { let server_proposal = KexProposal::server_default(); let client_proposal = KexProposal::client_default(); let kex_result = KexResult::choose_algorithms(&server_proposal, &client_proposal).unwrap(); - + let handler = KexExchangeHandler::new(kex_result).unwrap(); assert!(handler.host_key.public_key_bytes().len() == 32); } diff --git a/markbase-core/src/ssh_server/mod.rs b/markbase-core/src/ssh_server/mod.rs index 7c0c6e7..2df6dbf 100644 --- a/markbase-core/src/ssh_server/mod.rs +++ b/markbase-core/src/ssh_server/mod.rs @@ -1,28 +1,28 @@ // SSH服务器模块(手动实现SSH协议) // 参考OpenSSH源码实现完整的SSH/SFTP/SCP/rsync协议 -pub mod server; -pub mod packet; -pub mod version; -pub mod crypto; -pub mod kex; -pub mod kex_exchange; -pub mod kex_complete; -pub mod cipher; pub mod auth; pub mod channel; -pub mod sftp_handler; -pub mod scp_handler; +pub mod cipher; +pub mod crypto; +pub mod data_forwarder; // Phase 13.5: 数据传输模块 +pub mod kex; +pub mod kex_complete; +pub mod kex_exchange; +pub mod packet; +pub mod port_forward; // Phase 13: 端口转发模块 +pub mod port_forward_listener; // Phase 13.4: 监听线程模块 pub mod rsync_handler; -pub mod sshbuf; // Phase 15: SSH Buffer 零拷贝管理(参考OpenSSH sshbuf.c) -pub mod port_forward; // Phase 13: 端口转发模块 -pub mod ssh_security_config; // Phase 13.1: 企业级安全配置 -pub mod port_forward_listener; // Phase 13.4: 监听线程模块 -pub mod data_forwarder; // Phase 13.5: 数据传输模块 -pub mod window_manager; // Phase 13.6-13.7: Window size + Channel生命周期 +pub mod scp_handler; +pub mod server; +pub mod sftp_handler; +pub mod ssh_security_config; // Phase 13.1: 企业级安全配置 +pub mod sshbuf; // Phase 15: SSH Buffer 零拷贝管理(参考OpenSSH sshbuf.c) +pub mod version; +pub mod window_manager; // Phase 13.6-13.7: Window size + Channel生命周期 +pub use packet::{PacketType, SshPacket}; pub use server::SshServer; -pub use packet::{SshPacket, PacketType}; -pub use version::VersionExchange; -pub use ssh_security_config::SshSecurityConfig; // Phase 13.1: 导出安全配置 -pub use sshbuf::SshBuf; // Phase 15: 导出 SSH Buffer +pub use ssh_security_config::SshSecurityConfig; // Phase 13.1: 导出安全配置 +pub use sshbuf::SshBuf; +pub use version::VersionExchange; // Phase 15: 导出 SSH Buffer diff --git a/markbase-core/src/ssh_server/packet.rs b/markbase-core/src/ssh_server/packet.rs index a9ad185..9ea636e 100644 --- a/markbase-core/src/ssh_server/packet.rs +++ b/markbase-core/src/ssh_server/packet.rs @@ -1,7 +1,7 @@ // SSH Packet基础结构定义 // 参考OpenSSH packet.c: ssh_packet_read(), ssh_packet_write() -use anyhow::{Result, anyhow}; +use anyhow::{anyhow, Result}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use std::io::{Read, Write}; @@ -18,21 +18,21 @@ pub enum PacketType { SSH_MSG_EXT_INFO = 7, SSH_MSG_KEXINIT = 20, SSH_MSG_NEWKEYS = 21, - + // 密钥交换相关 SSH_MSG_KEXDH_INIT = 30, SSH_MSG_KEXDH_REPLY = 31, // 注意:Curve25519和DH使用相同的消息类型(30/31) // SSH_MSG_KEX_ECDH_INIT和SSH_MSG_KEX_ECDH_REPLY已在代码中注释 // 使用SSH_MSG_KEXDH_INIT和SSH_MSG_KEXDH_REPLY代替 - + // 认证相关 SSH_MSG_USERAUTH_REQUEST = 50, SSH_MSG_USERAUTH_FAILURE = 51, SSH_MSG_USERAUTH_SUCCESS = 52, SSH_MSG_USERAUTH_BANNER = 53, SSH_MSG_USERAUTH_PK_OK = 60, - + // Channel相关 SSH_MSG_GLOBAL_REQUEST = 80, SSH_MSG_REQUEST_SUCCESS = 81, @@ -70,38 +70,38 @@ impl SshPacket { pub fn new(payload: Vec) -> Self { // 计算padding(SSH协议RFC 4253规范) // 参考OpenSSH packet.c: construct_packet() - let block_size = 8; // 未加密阶段block_size=8 - + let block_size = 8; // 未加密阶段block_size=8 + let payload_length = payload.len(); - let min_padding = 4; // OpenSSH要求最少4字节padding - + let min_padding = 4; // OpenSSH要求最少4字节padding + // SSH协议约束: // packet_length = padding_length + payload_length + 1 // (packet_length + 4) 必须是block_size的倍数 - // + // // 计算: // (1 + payload_length + padding_length + 4) % 8 == 0 // => (5 + payload_length + padding_length) % 8 == 0 - + // 先尝试padding=4(最小) let mut padding_length = min_padding as u8; - + // 计算packet总长度(包括4字节的packet_length字段) let packet_length = 1 + payload_length + padding_length as usize; - let total_length = packet_length + 4; // 加上packet_length字段本身的4字节 - + let total_length = packet_length + 4; // 加上packet_length字段本身的4字节 + // 如果总长度不是block_size的倍数,增加padding - if total_length % block_size != 0 { + if !total_length.is_multiple_of(block_size) { let remainder = total_length % block_size; padding_length += (block_size - remainder) as u8; } - + // 重新计算packet_length let packet_length = (1 + payload_length + padding_length as usize) as u32; - + // 生成随机padding(简化:使用0) let padding = vec![0u8; padding_length as usize]; - + Self { packet_length, padding_length, @@ -109,49 +109,49 @@ impl SshPacket { padding, } } - + /// 写入packet到stream(未加密阶段) /// 参考OpenSSH packet_write_poll() pub fn write(&self, stream: &mut T) -> Result<()> { // 写入packet_length(BigEndian) stream.write_u32::(self.packet_length)?; - + // 写入padding_length stream.write_u8(self.padding_length)?; - + // 写入payload stream.write_all(&self.payload)?; - + // 写入padding stream.write_all(&self.padding)?; - + stream.flush()?; Ok(()) } - + /// 从stream读取packet(未加密阶段) /// 参考OpenSSH packet_read_poll() pub fn read(stream: &mut T) -> Result { // 读取packet_length(BigEndian) let packet_length = stream.read_u32::()?; - + // 检查packet长度限制(OpenSSH限制:256KB) if packet_length > 256 * 1024 { return Err(anyhow!("Packet too large: {}", packet_length)); } - + // 读取padding_length let padding_length = stream.read_u8()?; - + // 读取payload(packet_length - padding_length - 1) let payload_length = packet_length - padding_length as u32 - 1; let mut payload = vec![0u8; payload_length as usize]; stream.read_exact(&mut payload)?; - + // 读取padding let mut padding = vec![0u8; padding_length as usize]; stream.read_exact(&mut padding)?; - + Ok(Self { packet_length, padding_length, @@ -159,15 +159,15 @@ impl SshPacket { padding, }) } - + /// 获取payload中的packet type pub fn get_type(&self) -> Result { if self.payload.is_empty() { return Err(anyhow!("Empty payload")); } - + let type_byte = self.payload[0]; - + // 转换为PacketType enum match type_byte { 1 => Ok(PacketType::SSH_MSG_DISCONNECT), @@ -208,27 +208,27 @@ impl SshPacket { mod tests { use super::*; use std::io::Cursor; - + #[test] fn test_packet_creation() { let payload = vec![PacketType::SSH_MSG_KEXINIT as u8]; let packet = SshPacket::new(payload); - + assert!(packet.packet_length > 0); assert!(packet.padding_length >= 4); } - + #[test] fn test_packet_write_read() { let payload = vec![PacketType::SSH_MSG_KEXINIT as u8]; let packet = SshPacket::new(payload); - + let mut buffer = Vec::new(); packet.write(&mut buffer).unwrap(); - + let mut cursor = Cursor::new(buffer); let read_packet = SshPacket::read(&mut cursor).unwrap(); - + assert_eq!(packet.packet_length, read_packet.packet_length); assert_eq!(packet.payload, read_packet.payload); } diff --git a/markbase-core/src/ssh_server/port_forward.rs b/markbase-core/src/ssh_server/port_forward.rs index dfe80dc..9b0699a 100644 --- a/markbase-core/src/ssh_server/port_forward.rs +++ b/markbase-core/src/ssh_server/port_forward.rs @@ -1,21 +1,21 @@ // SSH端口转发协议实现(Phase 13) // 参考OpenSSH channels.c和RFC 4254 -use anyhow::{Result, anyhow}; -use log::{info, warn, debug}; -use std::net::{TcpListener, TcpStream, SocketAddr}; -use std::io::{Read, Write}; -use std::sync::{Arc, Mutex}; -use std::thread; +use crate::ssh_server::ssh_security_config::SshSecurityConfig; +use anyhow::Result; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.2: 安全配置 +use log::{info, warn}; +use std::io::Read; +use std::net::{TcpListener, TcpStream}; +use std::sync::{Arc, Mutex}; + // Phase 13.2: 安全配置 /// 端口转发类型(参考RFC 4254) #[derive(Debug, Clone, PartialEq, Eq)] pub enum PortForwardType { - Local, // Local port forwarding (-L) - Remote, // Remote port forwarding (-R) - Dynamic, // Dynamic port forwarding (-D, SOCKS) + Local, // Local port forwarding (-L) + Remote, // Remote port forwarding (-R) + Dynamic, // Dynamic port forwarding (-D, SOCKS) } /// 端口转发请求(参考RFC 4254 Section 7) @@ -36,6 +36,12 @@ pub struct PortForwardManager { active_forwards: Arc>>, } +impl Default for PortForwardManager { + fn default() -> Self { + Self::new() + } +} + impl PortForwardManager { pub fn new() -> Self { Self { @@ -46,24 +52,29 @@ impl PortForwardManager { /// 处理SSH_MSG_GLOBAL_REQUEST(端口转发请求) /// 参考RFC 4254 Section 4 /// Phase 13.2: 添加安全配置验证 - pub fn handle_global_request(&mut self, data: &[u8], security_config: &SshSecurityConfig) -> Result<(bool, Option>)> { + pub fn handle_global_request( + &mut self, + data: &[u8], + security_config: &SshSecurityConfig, + ) -> Result<(bool, Option>)> { info!("Processing SSH_MSG_GLOBAL_REQUEST for port forwarding"); - + let mut cursor = std::io::Cursor::new(data); - cursor.set_position(1); // Skip packet type - + cursor.set_position(1); // Skip packet type + // 读取请求名称(SSH string) let request_name = read_ssh_string(&mut cursor)?; - + info!("Global request: {}", request_name); - + // 读取want-reply标志 let want_reply = cursor.read_u8()? != 0; - + match request_name.as_str() { "tcpip-forward" => { // Local port forwarding (-L) - self.handle_tcpip_forward(&mut cursor, want_reply, security_config) // Phase 13.2 + self.handle_tcpip_forward(&mut cursor, want_reply, security_config) + // Phase 13.2 } "cancel-tcpip-forward" => { // Cancel port forwarding @@ -84,29 +95,37 @@ impl PortForwardManager { /// 处理tcpip-forward请求(Local port forwarding) /// 参考RFC 4254 Section 7.1 /// Phase 13.2: 添加安全配置验证 - fn handle_tcpip_forward(&mut self, cursor: &mut std::io::Cursor<&[u8]>, want_reply: bool, security_config: &SshSecurityConfig) -> Result<(bool, Option>)> { + fn handle_tcpip_forward( + &mut self, + cursor: &mut std::io::Cursor<&[u8]>, + want_reply: bool, + security_config: &SshSecurityConfig, + ) -> Result<(bool, Option>)> { // 读取bind address(SSH string) let bind_address = read_ssh_string(cursor)?; - + // 读取bind port let bind_port = cursor.read_u32::()?; - - info!("tcpip-forward request: bind_address={}, bind_port={}", bind_address, bind_port); - + + info!( + "tcpip-forward request: bind_address={}, bind_port={}", + bind_address, bind_port + ); + // Phase 13.2: 安全配置验证 if let Err(e) = security_config.validate_tcpip_forward_request(&bind_address, bind_port) { warn!("tcpip-forward security validation failed: {}", e); - return Ok((false, None)); // 拒绝请求 + return Ok((false, None)); // 拒绝请求 } - + info!("tcpip-forward security validation passed"); - + // 添加到active forwards let mut forwards = self.active_forwards.lock().unwrap(); forwards.push((bind_port, PortForwardType::Local)); - + info!("tcpip-forward registered: bind_port={}", bind_port); - + // 返回成功响应(包含bind_port) if want_reply { let response = self.build_global_request_response(true, Some(bind_port))?; @@ -117,16 +136,23 @@ impl PortForwardManager { } /// 处理cancel-tcpip-forward请求 - fn handle_cancel_tcpip_forward(&mut self, cursor: &mut std::io::Cursor<&[u8]>, want_reply: bool) -> Result<(bool, Option>)> { + fn handle_cancel_tcpip_forward( + &mut self, + cursor: &mut std::io::Cursor<&[u8]>, + want_reply: bool, + ) -> Result<(bool, Option>)> { let bind_address = read_ssh_string(cursor)?; let bind_port = cursor.read_u32::()?; - - info!("cancel-tcpip-forward: bind_address={}, bind_port={}", bind_address, bind_port); - + + info!( + "cancel-tcpip-forward: bind_address={}, bind_port={}", + bind_address, bind_port + ); + // 移除active forward let mut forwards = self.active_forwards.lock().unwrap(); forwards.retain(|(port, _)| *port != bind_port); - + if want_reply { let response = self.build_global_request_response(true, None)?; Ok((true, Some(response))) @@ -136,14 +162,18 @@ impl PortForwardManager { } /// 构建SSH_MSG_REQUEST_SUCCESS/FAILURE响应 - fn build_global_request_response(&self, success: bool, bound_port: Option) -> Result> { + fn build_global_request_response( + &self, + success: bool, + bound_port: Option, + ) -> Result> { use crate::ssh_server::packet::PacketType; - + let mut response = Vec::new(); - + if success { response.write_u8(PacketType::SSH_MSG_REQUEST_SUCCESS as u8)?; - + // 如果有bound_port,写入(用于tcpip-forward响应) if let Some(port) = bound_port { response.write_u32::(port)?; @@ -151,7 +181,7 @@ impl PortForwardManager { } else { response.write_u8(PacketType::SSH_MSG_REQUEST_FAILURE as u8)?; } - + Ok(response) } @@ -159,37 +189,39 @@ impl PortForwardManager { /// 参考RFC 4254 Section 7.2 pub fn handle_direct_tcpip_channel(&mut self, data: &[u8]) -> Result { info!("Processing direct-tcpip channel open"); - + let mut cursor = std::io::Cursor::new(data); - cursor.set_position(1); // Skip packet type - + cursor.set_position(1); // Skip packet type + // 读取channel type(已知道是"direct-tcpip",跳过) let _channel_type = read_ssh_string(&mut cursor)?; - + // 读取sender_channel let sender_channel = cursor.read_u32::()?; - + // 读取initial window size let initial_window_size = cursor.read_u32::()?; - + // 读取maximum packet size let max_packet_size = cursor.read_u32::()?; - + // 读取host to connect(SSH string) let host_to_connect = read_ssh_string(&mut cursor)?; - + // 读取port to connect let port_to_connect = cursor.read_u32::()?; - + // 读取originator address(SSH string) let originator_address = read_ssh_string(&mut cursor)?; - + // 读取originator port let originator_port = cursor.read_u32::()?; - - info!("direct-tcpip: host={}, port={}, originator={}:{}", - host_to_connect, port_to_connect, originator_address, originator_port); - + + info!( + "direct-tcpip: host={}, port={}, originator={}:{}", + host_to_connect, port_to_connect, originator_address, originator_port + ); + Ok(DirectTcpipChannel { sender_channel, initial_window_size, @@ -205,30 +237,32 @@ impl PortForwardManager { /// 参考RFC 4254 Section 7.1 pub fn handle_forwarded_tcpip_channel(&mut self, data: &[u8]) -> Result { info!("Processing forwarded-tcpip channel open"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let _channel_type = read_ssh_string(&mut cursor)?; let sender_channel = cursor.read_u32::()?; let initial_window_size = cursor.read_u32::()?; let max_packet_size = cursor.read_u32::()?; - + // 读取bind address(SSH string) let bind_address = read_ssh_string(&mut cursor)?; - + // 读取bind port let bind_port = cursor.read_u32::()?; - + // 读取originator address(SSH string) let originator_address = read_ssh_string(&mut cursor)?; - + // 读取originator port let originator_port = cursor.read_u32::()?; - - info!("forwarded-tcpip: bind={}:{}, originator={}:{}", - bind_address, bind_port, originator_address, originator_port); - + + info!( + "forwarded-tcpip: bind={}:{}, originator={}:{}", + bind_address, bind_port, originator_address, originator_port + ); + Ok(ForwardedTcpipChannel { sender_channel, initial_window_size, @@ -244,10 +278,10 @@ impl PortForwardManager { pub fn connect_to_target(host: &str, port: u32) -> Result { let addr = format!("{}:{}", host, port); info!("Connecting to target: {}", addr); - + let stream = TcpStream::connect(&addr)?; info!("Connected to target successfully"); - + Ok(stream) } @@ -258,12 +292,12 @@ impl PortForwardManager { } else { format!("{}:{}", bind_address, bind_port) }; - + info!("Creating listener on {}", addr); - + let listener = TcpListener::bind(&addr)?; info!("Listener created successfully"); - + Ok(listener) } } @@ -303,10 +337,10 @@ fn read_ssh_string(reader: &mut R) -> Result { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_port_forward_manager_creation() { let manager = PortForwardManager::new(); assert_eq!(manager.active_forwards.lock().unwrap().len(), 0); } -} \ No newline at end of file +} diff --git a/markbase-core/src/ssh_server/port_forward_listener.rs b/markbase-core/src/ssh_server/port_forward_listener.rs index 8583f37..73b7c0f 100644 --- a/markbase-core/src/ssh_server/port_forward_listener.rs +++ b/markbase-core/src/ssh_server/port_forward_listener.rs @@ -1,15 +1,12 @@ // SSH端口转发监听线程(Phase 13.4) // 参考OpenSSH channels.c: channel_forward_listener -use anyhow::{Result, anyhow}; -use log::{info, warn, debug, error}; -use std::net::{TcpListener, TcpStream}; -use std::thread; -use std::sync::{Arc, Mutex, mpsc}; -use std::io::{Read, Write}; -use byteorder::{BigEndian, WriteBytesExt}; -use crate::ssh_server::packet::PacketType; use crate::ssh_server::ssh_security_config::SshSecurityConfig; +use anyhow::Result; +use log::{error, info, warn}; +use std::net::{TcpListener, TcpStream}; +use std::sync::{mpsc, Arc, Mutex}; +use std::thread; /// 监听器状态(Phase 13.4) #[derive(Debug, Clone)] @@ -30,28 +27,18 @@ pub enum ListenerRequest { stream: TcpStream, }, /// 停止监听 - StopListener { - bind_port: u32, - }, + StopListener { bind_port: u32 }, } /// 监听器响应(Phase 13.4:线程通信) #[derive(Debug)] pub enum ListenerResponse { /// Channel创建成功 - ChannelCreated { - bind_port: u32, - channel_id: u32, - }, + ChannelCreated { bind_port: u32, channel_id: u32 }, /// 监听器停止 - ListenerStopped { - bind_port: u32, - }, + ListenerStopped { bind_port: u32 }, /// 错误 - Error { - bind_port: u32, - message: String, - }, + Error { bind_port: u32, message: String }, } /// 端口转发监听器(Phase 13.4) @@ -73,26 +60,29 @@ impl PortForwardListener { security_config: SshSecurityConfig, ) -> Result { info!("Creating port forward listener on port {}", bind_port); - + // Phase 13.4: 根据GatewayPorts决定绑定地址 let bind_addr = if security_config.gateway_ports { - format!("0.0.0.0:{}", bind_port) // 允许外部访问 + format!("0.0.0.0:{}", bind_port) // 允许外部访问 } else { - format!("127.0.0.1:{}", bind_port) // 只允许本地访问 + format!("127.0.0.1:{}", bind_port) // 只允许本地访问 }; - - info!("Binding to address: {} (GatewayPorts={})", bind_addr, security_config.gateway_ports); - + + info!( + "Binding to address: {} (GatewayPorts={})", + bind_addr, security_config.gateway_ports + ); + let listener = TcpListener::bind(&bind_addr)?; info!("Listener created successfully on {}", bind_addr); - + // Phase 13.4: 创建线程通信channel - let (request_tx, request_rx) = mpsc::channel(); - let (response_tx, response_rx) = mpsc::channel(); - + let (request_tx, _request_rx) = mpsc::channel(); + let (_response_tx, response_rx) = mpsc::channel(); + // Phase 13.4: 活动状态标记 let active = Arc::new(Mutex::new(true)); - + Ok(Self { bind_port, bind_address, @@ -103,38 +93,38 @@ impl PortForwardListener { active, }) } - + /// 启动监听线程(Phase 13.4) pub fn start_listener_thread(&mut self) -> Result<()> { info!("Starting listener thread for port {}", self.bind_port); - + let listener = self.listener.try_clone()?; let bind_port = self.bind_port; let request_sender = self.request_sender.clone(); let active = self.active.clone(); - + // Phase 13.4: 创建独立监听线程 thread::spawn(move || { info!("Listener thread started for port {}", bind_port); - + while *active.lock().unwrap() { match listener.accept() { Ok((stream, addr)) => { info!("New connection on port {}: {}", bind_port, addr); - + // Phase 13.4: 发送新连接请求给主线程 let request = ListenerRequest::NewConnection { bind_port, originator_address: addr.ip().to_string(), - originator_port: addr.port() as u32, // Phase 13.4: u16转u32 + originator_port: addr.port() as u32, // Phase 13.4: u16转u32 stream, }; - + if let Err(e) = request_sender.send(request) { error!("Failed to send listener request: {}", e); break; } - + info!("Listener request sent to main thread"); } Err(e) => { @@ -145,32 +135,32 @@ impl PortForwardListener { } } } - + info!("Listener thread stopped for port {}", bind_port); }); - + info!("Listener thread started successfully"); Ok(()) } - + /// 停止监听器(Phase 13.4) pub fn stop_listener(&mut self) -> Result<()> { info!("Stopping listener for port {}", self.bind_port); - + // Phase 13.4: 设置active=false,线程会自动退出 *self.active.lock().unwrap() = false; - + info!("Listener stopped for port {}", self.bind_port); Ok(()) } - + /// 获取请求接收器(Phase 13.4) pub fn get_request_receiver(&self) -> mpsc::Receiver { // 注意:这里需要返回一个新的receiver,因为mpsc::Sender可以clone,但Receiver不能 // 实际应用中应该使用更复杂的channel设计 unimplemented!("Use Arc> instead") } - + /// 获取活动状态(Phase 13.4) pub fn is_active(&self) -> bool { *self.active.lock().unwrap() @@ -182,13 +172,19 @@ pub struct ListenerManager { listeners: HashMap>>, } +impl Default for ListenerManager { + fn default() -> Self { + Self::new() + } +} + impl ListenerManager { pub fn new() -> Self { Self { listeners: HashMap::new(), } } - + /// 创建并启动监听器(Phase 13.4) pub fn create_listener( &mut self, @@ -197,21 +193,21 @@ impl ListenerManager { security_config: SshSecurityConfig, ) -> Result<()> { info!("Creating listener for port {}", bind_port); - + let mut listener = PortForwardListener::new(bind_port, bind_address, security_config)?; listener.start_listener_thread()?; - + let listener_arc = Arc::new(Mutex::new(listener)); self.listeners.insert(bind_port, listener_arc); - + info!("Listener created and started for port {}", bind_port); Ok(()) } - + /// 停止监听器(Phase 13.4) pub fn stop_listener(&mut self, bind_port: u32) -> Result<()> { info!("Stopping listener for port {}", bind_port); - + if let Some(listener_arc) = self.listeners.remove(&bind_port) { let mut listener = listener_arc.lock().unwrap(); listener.stop_listener()?; @@ -219,28 +215,31 @@ impl ListenerManager { } else { warn!("No listener found for port {}", bind_port); } - + Ok(()) } - + /// 获取活动监听器数量(Phase 13.4) pub fn active_count(&self) -> usize { - self.listeners.values().filter(|l| l.lock().unwrap().is_active()).count() + self.listeners + .values() + .filter(|l| l.lock().unwrap().is_active()) + .count() } } -use std::collections::HashMap; // Phase 13.4: HashMap for listener management +use std::collections::HashMap; // Phase 13.4: HashMap for listener management #[cfg(test)] mod tests { use super::*; - + #[test] fn test_listener_creation() { let security_config = SshSecurityConfig::enterprise_default(); let listener = PortForwardListener::new(8080, "127.0.0.1".to_string(), security_config); - + // 注意:实际测试需要处理端口占用问题 - assert!(listener.is_ok() || true); // 暂时跳过测试 + assert!(listener.is_ok() || true); // 暂时跳过测试 } } diff --git a/markbase-core/src/ssh_server/rsync_handler.rs b/markbase-core/src/ssh_server/rsync_handler.rs index 10ec945..c6fe84f 100644 --- a/markbase-core/src/ssh_server/rsync_handler.rs +++ b/markbase-core/src/ssh_server/rsync_handler.rs @@ -1,8 +1,8 @@ -use std::path::PathBuf; -use anyhow::{Result, anyhow}; -use log::{info, debug, warn}; -use crate::vfs::{VfsBackend, VfsFile, VfsError}; use crate::vfs::open_flags::OpenFlags; +use crate::vfs::{VfsBackend, VfsFile}; +use anyhow::{anyhow, Result}; +use log::{debug, info, warn}; +use std::path::PathBuf; /// MPLEX_BASE from rsync io.h const MPLEX_BASE: u32 = 7; @@ -18,7 +18,9 @@ pub(crate) enum RsyncState { WaitVersion, ReadFileList, /// Sum head (4 × write_int = 16 bytes) + checksum seed (4 bytes) = 20 bytes - ReadSumHead { need: usize }, + ReadSumHead { + need: usize, + }, SendSumCount, /// Raw file data from MSG_DATA packets ReadFileData, @@ -51,9 +53,16 @@ impl RsyncHandler { let mut dest = String::new(); for p in &parts[1..] { - if *p == "--server" { is_server = true; continue; } - if *p == "--sender" || p.starts_with('-') { continue; } - if *p == "." { continue; } + if *p == "--server" { + is_server = true; + continue; + } + if *p == "--sender" || p.starts_with('-') { + continue; + } + if *p == "." { + continue; + } dest = p.to_string(); } @@ -107,8 +116,10 @@ impl RsyncHandler { break; } let header = u32::from_le_bytes([ - self.raw_input[0], self.raw_input[1], - self.raw_input[2], self.raw_input[3], + self.raw_input[0], + self.raw_input[1], + self.raw_input[2], + self.raw_input[3], ]); let raw_tag = ((header >> 24) & 0xFF) as u8; let tag = raw_tag.wrapping_sub(MPLEX_BASE as u8); @@ -182,12 +193,17 @@ impl RsyncHandler { RsyncState::WaitVersion => { if self.rsync_input.len() >= 4 { let version = u32::from_le_bytes([ - self.rsync_input[0], self.rsync_input[1], - self.rsync_input[2], self.rsync_input[3], + self.rsync_input[0], + self.rsync_input[1], + self.rsync_input[2], + self.rsync_input[3], ]); self.rsync_input.drain(..4); self.protocol_version = std::cmp::min(self.protocol_version, version); - info!("rsync: negotiated protocol version {}", self.protocol_version); + info!( + "rsync: negotiated protocol version {}", + self.protocol_version + ); self.multiplex = self.protocol_version >= 30; self.transition(RsyncState::ReadFileList); } else { @@ -197,7 +213,9 @@ impl RsyncHandler { RsyncState::ReadFileList => { loop { - if self.rsync_input.is_empty() { break; } + if self.rsync_input.is_empty() { + break; + } let flags = self.rsync_input[0]; if flags == 0 { @@ -215,17 +233,25 @@ impl RsyncHandler { let mut pos = 1; let _more_flags = if flags & 0x80 != 0 { - if self.rsync_input.len() <= pos { break; } + if self.rsync_input.len() <= pos { + break; + } let ef = self.rsync_input[pos]; pos += 1; ef - } else { 0 }; + } else { + 0 + }; let has_name = !(flags & 0x02 != 0 && self.current_file > 0); if has_name { - if let Some(nul_pos) = self.rsync_input[pos..].iter().position(|&b| b == 0) { - let name = String::from_utf8_lossy(&self.rsync_input[pos..pos + nul_pos]).to_string(); + if let Some(nul_pos) = + self.rsync_input[pos..].iter().position(|&b| b == 0) + { + let name = + String::from_utf8_lossy(&self.rsync_input[pos..pos + nul_pos]) + .to_string(); pos += nul_pos + 1; self.file_entries.push(name.clone()); debug!("rsync: file entry: {}", name); @@ -269,24 +295,34 @@ impl RsyncHandler { RsyncState::ReadSumHead { need } => { if self.rsync_input.len() >= need { let sum_count = i32::from_le_bytes([ - self.rsync_input[0], self.rsync_input[1], - self.rsync_input[2], self.rsync_input[3], + self.rsync_input[0], + self.rsync_input[1], + self.rsync_input[2], + self.rsync_input[3], ]); let _sum_blength = i32::from_le_bytes([ - self.rsync_input[4], self.rsync_input[5], - self.rsync_input[6], self.rsync_input[7], + self.rsync_input[4], + self.rsync_input[5], + self.rsync_input[6], + self.rsync_input[7], ]); let _sum_s2length = i32::from_le_bytes([ - self.rsync_input[8], self.rsync_input[9], - self.rsync_input[10], self.rsync_input[11], + self.rsync_input[8], + self.rsync_input[9], + self.rsync_input[10], + self.rsync_input[11], ]); let _sum_remainder = i32::from_le_bytes([ - self.rsync_input[12], self.rsync_input[13], - self.rsync_input[14], self.rsync_input[15], + self.rsync_input[12], + self.rsync_input[13], + self.rsync_input[14], + self.rsync_input[15], ]); let checksum_seed = i32::from_le_bytes([ - self.rsync_input[16], self.rsync_input[17], - self.rsync_input[18], self.rsync_input[19], + self.rsync_input[16], + self.rsync_input[17], + self.rsync_input[18], + self.rsync_input[19], ]); self.rsync_input.drain(..20); @@ -308,7 +344,9 @@ impl RsyncHandler { RsyncState::ReadFileData => { let done_marker = b"RSYNCDONE"; - if let Some(pos) = self.rsync_input.windows(done_marker.len()) + if let Some(pos) = self + .rsync_input + .windows(done_marker.len()) .position(|w| w == done_marker) { if pos > 0 { @@ -323,8 +361,11 @@ impl RsyncHandler { warn!("rsync flush error: {}", e); } } - info!("rsync: file {} complete ({} bytes written to {})", - self.file_entries.get(self.current_file).unwrap_or(&"?".to_string()), + info!( + "rsync: file {} complete ({} bytes written to {})", + self.file_entries + .get(self.current_file) + .unwrap_or(&"?".to_string()), self.total_written, self.dest_path.display(), ); @@ -332,8 +373,11 @@ impl RsyncHandler { self.current_file += 1; if self.current_file >= self.file_entries.len() { self.transition(RsyncState::Done); - info!("rsync ALL DONE: {} bytes written to {}", - self.total_written, self.dest_path.display()); + info!( + "rsync ALL DONE: {} bytes written to {}", + self.total_written, + self.dest_path.display() + ); } else { self.transition(RsyncState::ReadSumHead { need: 20 }); } @@ -360,7 +404,9 @@ impl RsyncHandler { self.vfs.create_dir_all(parent, 0o755).ok(); } let flags = OpenFlags::new().write().create().truncate(); - let file = self.vfs.open_file(&self.dest_path, &flags) + let file = self + .vfs + .open_file(&self.dest_path, &flags) .map_err(|e| anyhow!("open error: {}", e))?; self.output_file = Some(file); info!("rsync: opened {} for writing", self.dest_path.display()); @@ -379,31 +425,43 @@ impl RsyncHandler { /// Read rsync varint (LSB-first 7-bit groups, 0xFF prefix for negative) fn read_varint(buf: &[u8]) -> Option<(i32, usize)> { - if buf.is_empty() { return None; } + if buf.is_empty() { + return None; + } let mut pos = 0; let mut b = buf[pos]; pos += 1; let neg = if b == 0xFF { - if pos >= buf.len() { return None; } + if pos >= buf.len() { + return None; + } b = buf[pos]; pos += 1; true - } else { false }; + } else { + false + }; let mut x = (b & 0x7F) as i32; let mut shift = 7; while b & 0x80 != 0 { - if pos >= buf.len() { return None; } + if pos >= buf.len() { + return None; + } b = buf[pos]; pos += 1; x |= ((b & 0x7F) as i32) << shift; shift += 7; } - if neg { Some((-x, pos)) } else { Some((x, pos)) } + if neg { + Some((-x, pos)) + } else { + Some((x, pos)) + } } #[cfg(test)] @@ -419,8 +477,9 @@ mod tests { fn test_parse_command() { let h = RsyncHandler::parse_rsync_command( "rsync --server -g -l -o -p -D -r -t -v --dirs . /tmp/upload.bin", - make_vfs() - ).unwrap(); + make_vfs(), + ) + .unwrap(); assert_eq!(h.dest_path, PathBuf::from("/tmp/upload.bin")); } @@ -428,14 +487,16 @@ mod tests { fn test_parse_command_sender() { let h = RsyncHandler::parse_rsync_command( "rsync --server --sender -vlogDtprz . /home/user/file.txt", - make_vfs() - ).unwrap(); + make_vfs(), + ) + .unwrap(); assert_eq!(h.dest_path, PathBuf::from("/home/user/file.txt")); } #[test] fn test_version_exchange() { - let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap(); + let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()) + .unwrap(); let output = h.drain_output(); assert_eq!(output, b"\x1e\x00\x00\x00"); assert_eq!(h.state, RsyncState::WaitVersion); @@ -447,7 +508,8 @@ mod tests { #[test] fn test_version_negotiate_down() { - let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap(); + let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()) + .unwrap(); let _ = h.drain_output(); h.feed(b"\x1d\x00\x00\x00").unwrap(); assert_eq!(h.protocol_version, 29); @@ -464,26 +526,33 @@ mod tests { #[test] fn test_file_list_multiplex() { - let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap(); + let mut h = + RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()) + .unwrap(); let _ = h.drain_output(); h.feed(b"\x1e\x00\x00\x00").unwrap(); assert!(h.multiplex); let mut flist = Vec::new(); // File list: flags=1 (has name), then name with NUL terminator - flist.push(1); // flags: has name + flist.push(1); // flags: has name flist.extend_from_slice(b"test.txt"); - flist.push(0); // name terminator + flist.push(0); // name terminator fn write_varint(buf: &mut Vec, val: i32) { - if val == 0 { buf.push(0); return; } + if val == 0 { + buf.push(0); + return; + } if val < 0 { buf.push(0xFF); let mut v = (-val) as u32; while v > 0 { let mut byte = (v & 0x7F) as u8; v >>= 7; - if v > 0 { byte |= 0x80; } + if v > 0 { + byte |= 0x80; + } buf.push(byte); } } else { @@ -491,7 +560,9 @@ mod tests { while v > 0 { let mut byte = (v & 0x7F) as u8; v >>= 7; - if v > 0 { byte |= 0x80; } + if v > 0 { + byte |= 0x80; + } buf.push(byte); } } @@ -502,7 +573,7 @@ mod tests { write_varint(&mut flist, 1700000000); write_varint(&mut flist, 100); write_varint(&mut flist, 0); - flist.push(0); // file list end marker + flist.push(0); // file list end marker let mut sum_head = Vec::new(); sum_head.extend_from_slice(&0i32.to_le_bytes()); @@ -527,22 +598,51 @@ mod tests { #[test] fn test_file_data_multiplex() { - let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap(); + let mut h = + RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()) + .unwrap(); let _ = h.drain_output(); h.feed(b"\x1e\x00\x00\x00").unwrap(); let mut flist = Vec::new(); - flist.push(1); // flags: has name + flist.push(1); // flags: has name flist.extend_from_slice(b"test.bin"); flist.push(0); fn wv(buf: &mut Vec, val: i32) { - if val == 0 { buf.push(0); return; } - if val < 0 { buf.push(0xFF); let mut v = (-val) as u32; while v > 0 { let mut byte = (v & 0x7F) as u8; v >>= 7; if v > 0 { byte |= 0x80; } buf.push(byte); } } - else { let mut v = val as u32; while v > 0 { let mut byte = (v & 0x7F) as u8; v >>= 7; if v > 0 { byte |= 0x80; } buf.push(byte); } } + if val == 0 { + buf.push(0); + return; + } + if val < 0 { + buf.push(0xFF); + let mut v = (-val) as u32; + while v > 0 { + let mut byte = (v & 0x7F) as u8; + v >>= 7; + if v > 0 { + byte |= 0x80; + } + buf.push(byte); + } + } else { + let mut v = val as u32; + while v > 0 { + let mut byte = (v & 0x7F) as u8; + v >>= 7; + if v > 0 { + byte |= 0x80; + } + buf.push(byte); + } + } } - wv(&mut flist, 33188); wv(&mut flist, 501); wv(&mut flist, 20); - wv(&mut flist, 1700000000); wv(&mut flist, 100); wv(&mut flist, 0); - flist.push(0); // file list end + wv(&mut flist, 33188); + wv(&mut flist, 501); + wv(&mut flist, 20); + wv(&mut flist, 1700000000); + wv(&mut flist, 100); + wv(&mut flist, 0); + flist.push(0); // file list end h.feed(&build_multiplex(&flist)).unwrap(); let mut sh = Vec::new(); diff --git a/markbase-core/src/ssh_server/scp_handler.rs b/markbase-core/src/ssh_server/scp_handler.rs index 83752cc..4885230 100644 --- a/markbase-core/src/ssh_server/scp_handler.rs +++ b/markbase-core/src/ssh_server/scp_handler.rs @@ -1,13 +1,12 @@ // SCP协议实现(Phase 8) // 参考OpenSSH scp.c源码 -use crate::vfs::{VfsBackend, VfsFile, VfsError, VfsStat}; use crate::vfs::open_flags::OpenFlags; -use anyhow::{Result, anyhow}; -use log::{info, warn, debug}; +use crate::vfs::{VfsBackend, VfsFile, VfsStat}; +use anyhow::{anyhow, Result}; +use log::{debug, info, warn}; +use std::io::{BufRead, Read, Write}; use std::path::{Path, PathBuf}; -use std::io::{Read, Write, BufRead}; -use std::time::SystemTime; /// SCP Handler(参考OpenSSH scp.c) pub struct ScpHandler { @@ -38,13 +37,13 @@ impl ScpHandler { /// 解析SCP命令(参考OpenSSH scp.c: parse_command()) pub fn parse_scp_command(command: &str, vfs: Box) -> Result { let parts: Vec<&str> = command.split_whitespace().collect(); - + if parts.len() < 2 || parts[0] != "scp" { return Err(anyhow!("Invalid SCP command: {}", command)); } let mut handler = ScpHandler::new(PathBuf::from("/tmp"), vfs); - + for part in &parts[1..] { match part { &"-f" => handler.mode = ScpMode::Source, @@ -71,10 +70,15 @@ impl ScpHandler { /// SCP Source Mode(scp -f,发送文件) fn handle_source_mode(&self, channel: &mut dyn ReadWrite) -> Result<()> { - info!("SCP source mode: sending files from {}", self.root_dir.display()); + info!( + "SCP source mode: sending files from {}", + self.root_dir.display() + ); let full_path = self.resolve_path(&self.root_dir.to_string_lossy())?; - let stat = self.vfs.stat(&full_path) + let stat = self + .vfs + .stat(&full_path) .map_err(|e| anyhow!("stat error: {}", e))?; if stat.is_dir { @@ -91,16 +95,19 @@ impl ScpHandler { /// SCP Destination Mode(scp -t,接收文件) fn handle_destination_mode(&mut self, channel: &mut dyn ReadWrite) -> Result<()> { - info!("SCP destination mode: receiving files to {}", self.root_dir.display()); + info!( + "SCP destination mode: receiving files to {}", + self.root_dir.display() + ); channel.write_all(&[0])?; channel.flush()?; - + let mut buffer = String::new(); - + loop { buffer.clear(); - + let mut reader = std::io::BufReader::new(&mut *channel); match reader.read_line(&mut buffer)? { 0 => break, @@ -130,7 +137,9 @@ impl ScpHandler { /// 发送文件(参考OpenSSH scp.c: source()) fn send_file(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> { - let stat = self.vfs.stat(path) + let stat = self + .vfs + .stat(path) .map_err(|e| anyhow!("stat error: {}", e))?; let size = stat.size; let filename = path.file_name().unwrap().to_string_lossy(); @@ -146,13 +155,16 @@ impl ScpHandler { } let flags = OpenFlags::new().read(); - let mut file = self.vfs.open_file(path, &flags) + let mut file = self + .vfs + .open_file(path, &flags) .map_err(|e| anyhow!("open error: {}", e))?; let mut buffer = vec![0u8; 8192]; loop { - let n = file.read(&mut buffer) + let n = file + .read(&mut buffer) .map_err(|e| anyhow!("read error: {}", e))?; if n == 0 { break; @@ -188,7 +200,9 @@ impl ScpHandler { return Err(anyhow!("SCP directory command rejected")); } - let entries = self.vfs.read_dir(path) + let entries = self + .vfs + .read_dir(path) .map_err(|e| anyhow!("read_dir error: {}", e))?; for entry in &entries { @@ -218,7 +232,7 @@ impl ScpHandler { /// 处理文件命令(C0644 size filename) fn handle_file_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> { let parts: Vec<&str> = command.split_whitespace().collect(); - + if parts.len() != 3 { return self.send_error(channel, "Invalid file command format"); } @@ -227,7 +241,10 @@ impl ScpHandler { let size: u64 = parts[1].parse()?; let filename = parts[2]; - debug!("SCP receive file: mode={}, size={}, name={}", mode_str, size, filename); + debug!( + "SCP receive file: mode={}, size={}, name={}", + mode_str, size, filename + ); if size > 1024 * 1024 * 1024 { return self.send_error(channel, "File too large (max 1GB)"); @@ -236,7 +253,9 @@ impl ScpHandler { let full_path = self.resolve_path(filename)?; let flags = OpenFlags::new().write().create().truncate(); - let mut file = self.vfs.open_file(&full_path, &flags) + let mut file = self + .vfs + .open_file(&full_path, &flags) .map_err(|e| anyhow!("open error: {}", e))?; channel.write_all(&[0])?; @@ -263,7 +282,8 @@ impl ScpHandler { if mode_int != 0 { let mut set_stat = VfsStat::new(); set_stat.mode = mode_int; - self.vfs.set_stat(&full_path, &set_stat) + self.vfs + .set_stat(&full_path, &set_stat) .map_err(|e| anyhow!("set_stat error: {}", e))?; } @@ -280,7 +300,7 @@ impl ScpHandler { /// 处理目录命令(D0755 0 dirname) fn handle_directory_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> { let parts: Vec<&str> = command.split_whitespace().collect(); - + if parts.len() != 3 { return self.send_error(channel, "Invalid directory command format"); } @@ -297,7 +317,8 @@ impl ScpHandler { let full_path = self.resolve_path(dirname)?; let mode_int: u32 = mode_str.parse()?; - self.vfs.create_dir_all(&full_path, mode_int) + self.vfs + .create_dir_all(&full_path, mode_int) .map_err(|e| anyhow!("create_dir_all error: {}", e))?; channel.write_all(&[0])?; @@ -326,7 +347,7 @@ impl ScpHandler { } let parts: Vec<&str> = command.split_whitespace().collect(); - + if parts.len() != 3 { return self.send_error(channel, "Invalid time command format"); } @@ -353,11 +374,15 @@ impl ScpHandler { /// 路径解析(安全性检查) fn resolve_path(&self, path: &str) -> Result { let full_path = self.root_dir.join(path); - - let canonical_path = self.vfs.real_path(&full_path) + + let canonical_path = self + .vfs + .real_path(&full_path) .map_err(|e| anyhow!("Path resolution error: {}", e))?; - let root_canonical = self.vfs.real_path(&self.root_dir) + let root_canonical = self + .vfs + .real_path(&self.root_dir) .map_err(|e| anyhow!("Root path resolution error: {}", e))?; if !canonical_path.starts_with(&root_canonical) { @@ -383,20 +408,23 @@ mod tests { #[test] fn test_scp_command_parse() { - let handler = ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap(); + let handler = + ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap(); assert_eq!(handler.mode, ScpMode::Destination); assert_eq!(handler.root_dir, PathBuf::from("/tmp")); } #[test] fn test_scp_recursive_parse() { - let handler = ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap(); + let handler = + ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap(); assert!(handler.recursive); } #[test] fn test_scp_source_parse() { - let handler = ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap(); + let handler = + ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap(); assert_eq!(handler.mode, ScpMode::Source); } } diff --git a/markbase-core/src/ssh_server/server.rs b/markbase-core/src/ssh_server/server.rs index f2f14a0..abb39f3 100644 --- a/markbase-core/src/ssh_server/server.rs +++ b/markbase-core/src/ssh_server/server.rs @@ -1,32 +1,32 @@ // SSH服务器完整实现(Phase 1-7集成版 + Phase 13端口转发) // 参考OpenSSH sshd.c: complete SSH/SFTP flow + port forwarding -use crate::ssh_server::version::VersionExchange; -use crate::ssh_server::packet::{SshPacket, PacketType}; -use crate::ssh_server::kex::{KexResult, KexProposal}; -use crate::ssh_server::kex_complete::{KexState}; -use crate::ssh_server::auth::{AuthHandler, AuthResult}; -use crate::provider::sqlite::SqliteProvider; use crate::provider::pg::PgProvider; +use crate::provider::sqlite::SqliteProvider; use crate::provider::DataProvider; -use crate::ssh_server::channel::{ChannelManager}; -use crate::ssh_server::cipher::{EncryptionContext, EncryptedPacket}; -use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1 -use crate::ssh_server::port_forward::PortForwardManager; // Phase 13 -use anyhow::{Result, anyhow}; -use log::{info, warn, error, debug}; +use crate::ssh_server::auth::{AuthHandler, AuthResult}; +use crate::ssh_server::channel::ChannelManager; +use crate::ssh_server::cipher::{EncryptedPacket, EncryptionContext}; +use crate::ssh_server::kex::{KexProposal, KexResult}; +use crate::ssh_server::kex_complete::KexState; +use crate::ssh_server::packet::{PacketType, SshPacket}; +use crate::ssh_server::port_forward::PortForwardManager; // Phase 13 +use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1 +use crate::ssh_server::version::VersionExchange; +use anyhow::{anyhow, Result}; +use log::{error, info, warn}; +use std::io::{Read, Write}; use std::net::{TcpListener, TcpStream}; use std::path::PathBuf; -use std::thread; -use std::io::{Read, Write}; -use std::sync::{Arc, Mutex}; // Phase 13: 端口转发线程同步 +use std::sync::{Arc, Mutex}; +use std::thread; // Phase 13: 端口转发线程同步 /// SSH服务器配置(Phase 13.1企业级安全配置) pub struct SshServerConfig { pub port: u16, pub bind_address: String, - pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置 - pub pg_conn: Option, // PostgreSQL连接字符串(SFTPGo兼容认证) + pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置 + pub pg_conn: Option, // PostgreSQL连接字符串(SFTPGo兼容认证) } impl Default for SshServerConfig { @@ -34,7 +34,7 @@ impl Default for SshServerConfig { Self { port: 2024, bind_address: "127.0.0.1".to_string(), - security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1 + security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1 pg_conn: None, } } @@ -56,43 +56,48 @@ impl SshServerConfig { /// SSH服务器主结构(Phase 1-13完整版) pub struct SshServer { config: SshServerConfig, - security_config: Arc>, // Phase 13.1: 共享安全配置 + security_config: Arc>, // Phase 13.1: 共享安全配置 } impl SshServer { pub fn new(config: SshServerConfig) -> Self { - let security_config = Arc::new(Mutex::new(config.security_config.clone())); // Phase 13.1: 先clone + let security_config = Arc::new(Mutex::new(config.security_config.clone())); // Phase 13.1: 先clone Self { config, - security_config, // Phase 13.1 + security_config, // Phase 13.1 } } - + pub fn run(&self) -> Result<()> { let bind_addr = format!("{}:{}", self.config.bind_address, self.config.port); let listener = TcpListener::bind(&bind_addr)?; - + info!("MarkBaseSSH server listening on {}", bind_addr); info!("Implementation: Complete SSH/SFTP + Port Forwarding (Phase 1-13)"); - info!("Security config: GatewayPorts={}, PermitOpen={:?}, MaxSessions={}", + info!( + "Security config: GatewayPorts={}, PermitOpen={:?}, MaxSessions={}", self.config.security_config.gateway_ports, self.config.security_config.permit_open, - self.config.security_config.max_sessions); - - let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置 + self.config.security_config.max_sessions + ); + + let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置 let pg_conn = self.config.pg_conn.clone(); - + for stream in listener.incoming() { match stream { Ok(stream) => { let client_addr = stream.peer_addr()?; info!("New SSH connection from {}", client_addr); - - let security_config_clone = security_config.clone(); // Phase 13.1 + + let security_config_clone = security_config.clone(); // Phase 13.1 let pg_conn_clone = pg_conn.clone(); - + thread::spawn(move || { - if let Err(e) = handle_connection_complete(stream, security_config_clone, pg_conn_clone) { // Phase 13.1 + if let Err(e) = + handle_connection_complete(stream, security_config_clone, pg_conn_clone) + { + // Phase 13.1 error!("Connection error: {}", e); } }); @@ -102,90 +107,127 @@ impl SshServer { } } } - + Ok(()) } } /// 处理完整SSH连接(Phase 1-13完整流程) -fn handle_connection_complete(stream: TcpStream, security_config: Arc>, pg_conn: Option) -> Result<()> { +fn handle_connection_complete( + stream: TcpStream, + security_config: Arc>, + pg_conn: Option, +) -> Result<()> { info!("Handling client connection (Phase 1-13 complete flow with port forwarding)"); - + // Phase 13.1: 增加活动会话数 { let mut security = security_config.lock().unwrap(); security.increment_sessions()?; } - + let mut stream = stream; - + // Phase 1: 版本交换 let client_version = VersionExchange::exchange(&mut stream)?; - info!("Version exchange: client={}, server=SSH-2.0-MarkBaseSSH_1.0", client_version); - + info!( + "Version exchange: client={}, server=SSH-2.0-MarkBaseSSH_1.0", + client_version + ); + // Phase 2: 箋法协商 - let (kex_result, server_kexinit, client_kexinit) = perform_kex_negotiation_complete(&mut stream)?; - info!("KEX negotiation: KEX={}, Cipher={}", kex_result.kex_algorithm, kex_result.encryption_ctos); - + let (kex_result, server_kexinit, client_kexinit) = + perform_kex_negotiation_complete(&mut stream)?; + info!( + "KEX negotiation: KEX={}, Cipher={}", + kex_result.kex_algorithm, kex_result.encryption_ctos + ); + // Phase 3: 密钥交换完整流程 - let mut encryption_ctx = perform_complete_kex_exchange(&mut stream, client_version.clone(), kex_result, server_kexinit, client_kexinit)?; + let mut encryption_ctx = perform_complete_kex_exchange( + &mut stream, + client_version.clone(), + kex_result, + server_kexinit, + client_kexinit, + )?; info!("Key exchange completed, encryption channel ready"); - + // Phase 5: SSH认证(SFTPGo兼容 — PostgreSQL或SQLite) let provider: Box = if let Some(ref conn_str) = pg_conn { - info!("Using PostgreSQL auth provider (SFTPGo-compatible): {}", conn_str); - Box::new(PgProvider::new(conn_str) - .map_err(|e| anyhow!("Failed to init PgProvider: {}", e))?) + info!( + "Using PostgreSQL auth provider (SFTPGo-compatible): {}", + conn_str + ); + Box::new( + PgProvider::new(conn_str).map_err(|e| anyhow!("Failed to init PgProvider: {}", e))?, + ) } else { info!("Using SQLite auth provider"); - Box::new(SqliteProvider::new("data/auth.sqlite") - .map_err(|e| anyhow!("Failed to init SqliteProvider: {}", e))?) + Box::new( + SqliteProvider::new("data/auth.sqlite") + .map_err(|e| anyhow!("Failed to init SqliteProvider: {}", e))?, + ) }; let mut auth_handler = AuthHandler::new(provider); let auth_user = perform_ssh_auth(&mut stream, &mut auth_handler, &mut encryption_ctx)?; info!("SSH authentication succeeded: user={}", auth_user.username); - + // Phase 6: SSH Channel管理(参考OpenSSH channel.c) let mut channel_manager = ChannelManager::new(auth_user.home_dir.clone()); - + // Phase 13: PortForwardManager初始化 let mut port_forward_manager = PortForwardManager::new(); - + // Phase 6-13: SSH服务循环(处理channel请求 + 端口转发) - let security_config_clone = security_config.clone(); // Phase 13.1: clone for service loop - handle_ssh_service_loop(&mut stream, &mut channel_manager, &mut encryption_ctx, &mut port_forward_manager, security_config_clone)?; - + let security_config_clone = security_config.clone(); // Phase 13.1: clone for service loop + handle_ssh_service_loop( + &mut stream, + &mut channel_manager, + &mut encryption_ctx, + &mut port_forward_manager, + security_config_clone, + )?; + info!("SSH session completed successfully"); - + // Phase 13.1: 减少活动会话数 { let mut security = security_config.lock().unwrap(); security.decrement_sessions(); } - + Ok(()) } /// 完整算法协商(返回KEXINIT payloads) -fn perform_kex_negotiation_complete(stream: &mut TcpStream) -> Result<(KexResult, SshPacket, SshPacket)> { +fn perform_kex_negotiation_complete( + stream: &mut TcpStream, +) -> Result<(KexResult, SshPacket, SshPacket)> { info!("Starting complete KEX negotiation"); - + // 1. 发送服务器KEXINIT let server_proposal = KexProposal::server_default(); let server_kexinit = server_proposal.to_kexinit_packet()?; server_kexinit.write(stream)?; - - info!("Sent server KEXINIT (payload size: {} bytes)", server_kexinit.payload.len()); - + + info!( + "Sent server KEXINIT (payload size: {} bytes)", + server_kexinit.payload.len() + ); + // 2. 接收客户端KEXINIT let client_kexinit = SshPacket::read(stream)?; let client_proposal = KexProposal::from_kexinit_packet(&client_kexinit)?; - - info!("Received client KEXINIT (payload size: {} bytes)", client_kexinit.payload.len()); - + + info!( + "Received client KEXINIT (payload size: {} bytes)", + client_kexinit.payload.len() + ); + // 3. 算法匹配 let kex_result = KexResult::choose_algorithms(&server_proposal, &client_proposal)?; - + Ok((kex_result, server_kexinit, client_kexinit)) } @@ -198,18 +240,18 @@ fn perform_complete_kex_exchange( client_kexinit: SshPacket, ) -> Result { info!("Starting complete key exchange flow"); - + let mut kex_state = KexState::new( client_version, "SSH-2.0-MarkBaseSSH_1.0".to_string(), kex_result, )?; - + kex_state.save_kexinit_payloads(&client_kexinit, &server_kexinit); - + let kexdh_init = SshPacket::read(stream)?; info!("Received SSH_MSG_KEX_ECDH_INIT"); - + let kexdh_reply = kex_state.exchange_handler.handle_kexdh_init( &kexdh_init, &kex_state.client_version, @@ -219,27 +261,27 @@ fn perform_complete_kex_exchange( )?; kexdh_reply.write(stream)?; info!("Sent SSH_MSG_KEX_ECDH_REPLY"); - + // Strict KEX: Wait for client NEWKEYS first (OpenSSH 10.2 requirement) let client_newkeys = SshPacket::read(stream)?; kex_state.handle_newkeys(&client_newkeys)?; info!("Received SSH_MSG_NEWKEYS from client"); - + // Now send server NEWKEYS let newkeys_packet = KexState::send_newkeys()?; newkeys_packet.write(stream)?; kex_state.newkeys_sent = true; info!("Sent SSH_MSG_NEWKEYS from server"); - + if kex_state.is_encryption_ready() { info!("Encryption channel established successfully"); } else { return Err(anyhow::anyhow!("Encryption channel not ready")); } - + let session_keys = kex_state.exchange_handler.compute_session_keys()?; let encryption_ctx = EncryptionContext::from_session_keys(&session_keys); - + Ok(encryption_ctx) } @@ -250,102 +292,100 @@ pub struct AuthUser { } fn perform_ssh_auth( - stream: &mut TcpStream, + stream: &mut TcpStream, auth_handler: &mut AuthHandler, encryption_ctx: &mut EncryptionContext, ) -> Result { info!("Starting SSH authentication"); - info!("Encryption context: key_ctos_len={}, key_stoc_len={}, iv_ctos_len={}, iv_stoc_len={}", + info!( + "Encryption context: key_ctos_len={}, key_stoc_len={}, iv_ctos_len={}, iv_stoc_len={}", encryption_ctx.encryption_key_ctos.len(), encryption_ctx.encryption_key_stoc.len(), encryption_ctx.iv_ctos.len(), encryption_ctx.iv_stoc.len() ); - + // OpenSSH strict KEX: SSH_MSG_EXT_INFO may be sent before SSH_MSG_SERVICE_REQUEST let mut encrypted_request = EncryptedPacket::read(stream, encryption_ctx, true)?; let payload = encrypted_request.payload(); - + if payload[0] == PacketType::SSH_MSG_EXT_INFO as u8 { info!("Received SSH_MSG_EXT_INFO, reading next packet"); encrypted_request = EncryptedPacket::read(stream, encryption_ctx, true)?; } - + let payload = encrypted_request.payload(); info!("Received packet type: {}", payload[0]); - + if payload[0] != PacketType::SSH_MSG_SERVICE_REQUEST as u8 { - return Err(anyhow!("Expected SSH_MSG_SERVICE_REQUEST, got type {}", payload[0])); + return Err(anyhow!( + "Expected SSH_MSG_SERVICE_REQUEST, got type {}", + payload[0] + )); } - + use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; let mut cursor = std::io::Cursor::new(&payload[1..]); let service_name_len = cursor.read_u32::()?; let mut service_name = vec![0u8; service_name_len as usize]; cursor.read_exact(&mut service_name)?; let service_name_str = String::from_utf8_lossy(&service_name); - + if service_name_str != "ssh-userauth" { return Err(anyhow!("Unsupported service: {}", service_name_str)); } - + let mut service_accept_payload = Vec::new(); service_accept_payload.write_u8(PacketType::SSH_MSG_SERVICE_ACCEPT as u8)?; - service_accept_payload.write_u32::(12)?; // "ssh-userauth" length is 12, not 14! + service_accept_payload.write_u32::(12)?; // "ssh-userauth" length is 12, not 14! service_accept_payload.write_all("ssh-userauth".as_bytes())?; - - let encrypted_accept = EncryptedPacket::new( - &service_accept_payload, - encryption_ctx, - true, - )?; + + let encrypted_accept = EncryptedPacket::new(&service_accept_payload, encryption_ctx, true)?; encrypted_accept.write(stream)?; info!("Sent encrypted SSH_MSG_SERVICE_ACCEPT"); - + let session_id = encryption_ctx.session_id.clone(); - + loop { - let auth_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; // Reading from client, use cipher_ctos + let auth_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; // Reading from client, use cipher_ctos let auth_payload = auth_packet.payload(); info!("Received encrypted SSH_MSG_USERAUTH_REQUEST"); - + let auth_request = SshPacket::new(auth_payload.to_vec()); - + match auth_handler.handle_userauth_request(&auth_request, &session_id)? { AuthResult::Success => { let success_payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8]; - let encrypted_success = EncryptedPacket::new( - &success_payload, - encryption_ctx, - true, - )?; + let encrypted_success = + EncryptedPacket::new(&success_payload, encryption_ctx, true)?; encrypted_success.write(stream)?; info!("Sent encrypted SSH_MSG_USERAUTH_SUCCESS"); - + // Extract username from auth request let user = extract_username_from_auth_request(&auth_request) .unwrap_or_else(|_| "unknown".to_string()); - let home_dir = auth_handler.get_home_dir(&user) + let home_dir = auth_handler + .get_home_dir(&user) .ok() .flatten() .map(PathBuf::from) .unwrap_or_else(|| PathBuf::from("/Users/accusys/markbase")); info!("Auth success: user={}, home_dir={:?}", user, home_dir); - return Ok(AuthUser { username: user, home_dir }); + return Ok(AuthUser { + username: user, + home_dir, + }); } -AuthResult::Failure(message) => { + AuthResult::Failure(message) => { // message包含可用的认证方法列表(如"password,publickey") let mut failure_payload = Vec::new(); failure_payload.write_u8(PacketType::SSH_MSG_USERAUTH_FAILURE as u8)?; failure_payload.write_u32::(message.len() as u32)?; failure_payload.write_all(message.as_bytes())?; - failure_payload.write_u8(0)?; // partial_success = false - - let encrypted_failure = EncryptedPacket::new( - &failure_payload, - encryption_ctx, - true, - )?; + failure_payload.write_u8(0)?; // partial_success = false + + let encrypted_failure = + EncryptedPacket::new(&failure_payload, encryption_ctx, true)?; encrypted_failure.write(stream)?; warn!("Sent encrypted SSH_MSG_USERAUTH_FAILURE: {}", message); } @@ -356,27 +396,23 @@ AuthResult::Failure(message) => { AuthResult::PublicKeyOk(algorithm, public_key_blob) => { // SSH_MSG_USERAUTH_PK_OK:public key acceptable info!("Public key acceptable, sending USERAUTH_PK_OK"); - + let mut pk_ok_payload = Vec::new(); pk_ok_payload.write_u8(PacketType::SSH_MSG_USERAUTH_PK_OK as u8)?; - + // algorithm (SSH string) pk_ok_payload.write_u32::(algorithm.len() as u32)?; pk_ok_payload.write_all(algorithm.as_bytes())?; - + // public key blob (SSH string) pk_ok_payload.write_u32::(public_key_blob.len() as u32)?; pk_ok_payload.write_all(&public_key_blob)?; - - let encrypted_pk_ok = EncryptedPacket::new( - &pk_ok_payload, - encryption_ctx, - true, - )?; + + let encrypted_pk_ok = EncryptedPacket::new(&pk_ok_payload, encryption_ctx, true)?; encrypted_pk_ok.write(stream)?; info!("Sent SSH_MSG_USERAUTH_PK_OK"); - - continue; // Wait for signed request + + continue; // Wait for signed request } } } @@ -389,16 +425,17 @@ fn handle_ssh_service_loop( stream: &mut TcpStream, channel_manager: &mut ChannelManager, encryption_ctx: &mut EncryptionContext, - port_forward_manager: &mut PortForwardManager, // Phase 13 - security_config: Arc>, // Phase 13.1 + port_forward_manager: &mut PortForwardManager, // Phase 13 + security_config: Arc>, // Phase 13.1 ) -> Result<()> { info!("Starting SSH service loop (Phase 14.2: unified poll + child status)"); - + loop { // ⭐⭐⭐⭐⭐ Phase 14.2: 统一poll + child状态检测 // 返回三元组:(stdout_packets, client_has_data, child_exited) - let (stdout_packets, client_has_data, child_exited) = channel_manager.poll_exec_stdout_and_client(stream)?; - + let (stdout_packets, client_has_data, child_exited) = + channel_manager.poll_exec_stdout_and_client(stream)?; + // 1. 发送stdout/stderr数据(如果有) if let Some(packets) = stdout_packets { for packet in packets { @@ -407,93 +444,100 @@ fn handle_ssh_service_loop( info!("Sent stdout/stderr data (Phase 14.2)"); } } - + // 2. 处理child exited(发送EOF + CLOSE) if child_exited { info!("Child process exited, sending SSH_MSG_CHANNEL_EOF + CLOSE"); - + // ⭐⭐⭐⭐⭐ Phase 14.2: 使用ChannelManager.handle_child_exited() let exit_packets = channel_manager.handle_child_exited()?; for packet in exit_packets { let encrypted_packet = EncryptedPacket::new(&packet.payload, encryption_ctx, true)?; encrypted_packet.write(stream)?; } - + // 继续处理client数据(可能还有其他请求) } - + // 3. 处理client数据(如果有) if !client_has_data { // client没有数据,继续下一轮循环 continue; } - + // client有数据,读取并处理 let encrypted_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; let packet = SshPacket::new(encrypted_packet.payload().to_vec()); - + match packet.payload.first() { // Phase 13: SSH_MSG_GLOBAL_REQUEST处理(端口转发) Some(&pt) if pt == PacketType::SSH_MSG_GLOBAL_REQUEST as u8 => { info!("Received SSH_MSG_GLOBAL_REQUEST (port forwarding)"); - + // Phase 13.1: 安全配置验证 let security = security_config.lock().unwrap(); if !security.allow_tcp_forwarding { warn!("TCP forwarding disabled by security config"); let failure_packet = vec![PacketType::SSH_MSG_REQUEST_FAILURE as u8]; - let encrypted_failure = EncryptedPacket::new(&failure_packet, encryption_ctx, true)?; + let encrypted_failure = + EncryptedPacket::new(&failure_packet, encryption_ctx, true)?; encrypted_failure.write(stream)?; info!("Sent SSH_MSG_REQUEST_FAILURE (TCP forwarding disabled)"); continue; } - + // Phase 13.2: 调用PortForwardManager处理(传递security_config) - let (success, response) = port_forward_manager.handle_global_request(&packet.payload, &security)?; - drop(security); // 释放锁 - + let (success, response) = + port_forward_manager.handle_global_request(&packet.payload, &security)?; + drop(security); // 释放锁 + if success { if let Some(response_data) = response { - let encrypted_response = EncryptedPacket::new(&response_data, encryption_ctx, true)?; + let encrypted_response = + EncryptedPacket::new(&response_data, encryption_ctx, true)?; encrypted_response.write(stream)?; info!("Sent SSH_MSG_REQUEST_SUCCESS (tcpip-forward accepted)"); } else { // 无响应数据时,发送简单的SUCCESS let success_packet = vec![PacketType::SSH_MSG_REQUEST_SUCCESS as u8]; - let encrypted_success = EncryptedPacket::new(&success_packet, encryption_ctx, true)?; + let encrypted_success = + EncryptedPacket::new(&success_packet, encryption_ctx, true)?; encrypted_success.write(stream)?; info!("Sent SSH_MSG_REQUEST_SUCCESS"); } } else { let failure_packet = vec![PacketType::SSH_MSG_REQUEST_FAILURE as u8]; - let encrypted_failure = EncryptedPacket::new(&failure_packet, encryption_ctx, true)?; + let encrypted_failure = + EncryptedPacket::new(&failure_packet, encryption_ctx, true)?; encrypted_failure.write(stream)?; info!("Sent SSH_MSG_REQUEST_FAILURE (tcpip-forward rejected)"); } } - + Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_OPEN as u8 => { info!("Received SSH_MSG_CHANNEL_OPEN"); - + // Phase 13.3: 获取security_config并传递给handle_channel_open let security = security_config.lock().unwrap(); let response = channel_manager.handle_channel_open(&packet, Some(&security))?; - drop(security); // 释放锁 - - let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?; + drop(security); // 释放锁 + + let encrypted_response = + EncryptedPacket::new(&response.payload, encryption_ctx, true)?; encrypted_response.write(stream)?; info!("Sent SSH_MSG_CHANNEL_OPEN_CONFIRMATION"); } Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_REQUEST as u8 => { info!("Received SSH_MSG_CHANNEL_REQUEST"); if let Some(response) = channel_manager.handle_channel_request(&packet)? { - let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?; + let encrypted_response = + EncryptedPacket::new(&response.payload, encryption_ctx, true)?; encrypted_response.write(stream)?; - + // ⭐⭐⭐⭐⭐ Phase 14.5修复:区分普通命令和交互式进程 // 检查是否有 exec_process(交互式进程如 rsync) let has_exec_process = channel_manager.has_exec_process(); - + if has_exec_process { info!("⭐⭐⭐⭐⭐ [INTERACTIVE_PROCESS] Detected exec_process (rsync/SCP), skipping immediate EOF"); // 对于交互式进程,只发送 SUCCESS,等待 poll 循环处理数据流 @@ -503,23 +547,37 @@ fn handle_ssh_service_loop( if let Some(channel_id) = channel_manager.get_channel_with_output() { if let Some(output) = channel_manager.get_channel_output(channel_id) { // 发送命令输出(SSH_MSG_CHANNEL_DATA) - let data_packet = channel_manager.build_channel_data(channel_id, &output)?; - let encrypted_data = EncryptedPacket::new(&data_packet.payload, encryption_ctx, true)?; + let data_packet = + channel_manager.build_channel_data(channel_id, &output)?; + let encrypted_data = EncryptedPacket::new( + &data_packet.payload, + encryption_ctx, + true, + )?; encrypted_data.write(stream)?; info!("Sent command output ({} bytes)", output.len()); - + // 发送SSH_MSG_CHANNEL_EOF let eof_packet = channel_manager.build_channel_eof(channel_id)?; - let encrypted_eof = EncryptedPacket::new(&eof_packet.payload, encryption_ctx, true)?; + let encrypted_eof = EncryptedPacket::new( + &eof_packet.payload, + encryption_ctx, + true, + )?; encrypted_eof.write(stream)?; info!("Sent SSH_MSG_CHANNEL_EOF"); - + // 发送SSH_MSG_CHANNEL_CLOSE - let close_packet = channel_manager.build_channel_close(channel_id)?; - let encrypted_close = EncryptedPacket::new(&close_packet.payload, encryption_ctx, true)?; + let close_packet = + channel_manager.build_channel_close(channel_id)?; + let encrypted_close = EncryptedPacket::new( + &close_packet.payload, + encryption_ctx, + true, + )?; encrypted_close.write(stream)?; info!("Sent SSH_MSG_CHANNEL_CLOSE"); - + // 移除channel channel_manager.remove_channel(channel_id); } @@ -531,22 +589,28 @@ fn handle_ssh_service_loop( info!("Received SSH_MSG_CHANNEL_DATA"); if let Some(response) = channel_manager.handle_channel_data(&packet)? { // Phase 7: SFTP响应通过CHANNEL_DATA返回 - let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?; + let encrypted_response = + EncryptedPacket::new(&response.payload, encryption_ctx, true)?; encrypted_response.write(stream)?; info!("Sent SSH_MSG_CHANNEL_DATA (SFTP response)"); } - + // ⭐⭐⭐⭐⭐ Phase 15.1: Drain pending packets (e.g. WINDOW_ADJUST + delayed SFTP response) while let Some(pending) = channel_manager.pending_packets.pop_front() { - let encrypted_pending = EncryptedPacket::new(&pending.payload, encryption_ctx, true)?; + let encrypted_pending = + EncryptedPacket::new(&pending.payload, encryption_ctx, true)?; encrypted_pending.write(stream)?; - info!("Sent pending packet (type {})", pending.payload.first().unwrap_or(&0)); + info!( + "Sent pending packet (type {})", + pending.payload.first().unwrap_or(&0) + ); } } Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_CLOSE as u8 => { info!("Received SSH_MSG_CHANNEL_CLOSE"); if let Some(response) = channel_manager.handle_channel_close(&packet)? { - let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?; + let encrypted_response = + EncryptedPacket::new(&response.payload, encryption_ctx, true)?; encrypted_response.write(stream)?; } break; @@ -565,8 +629,10 @@ fn handle_ssh_service_loop( let payload = &packet.payload; if payload.len() >= 9 { // Format: uint32 recipient_channel || uint32 bytes_to_add - let recipient_channel = u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]); - let bytes_to_add = u32::from_be_bytes([payload[5], payload[6], payload[7], payload[8]]); + let recipient_channel = + u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]); + let bytes_to_add = + u32::from_be_bytes([payload[5], payload[6], payload[7], payload[8]]); channel_manager.adjust_remote_window(recipient_channel, bytes_to_add); } } @@ -575,12 +641,14 @@ fn handle_ssh_service_loop( } } } - + Ok(()) } /// 从SSH_MSG_USERAUTH_REQUEST payload中提取用户名 -fn extract_username_from_auth_request(packet: &crate::ssh_server::packet::SshPacket) -> Result { +fn extract_username_from_auth_request( + packet: &crate::ssh_server::packet::SshPacket, +) -> Result { let payload = &packet.payload; if payload.len() < 5 { return Err(anyhow!("Auth request too short")); @@ -598,10 +666,10 @@ pub fn run_ssh_server(port: Option, pg_conn: Option<&str>) -> Result<()> { let config = SshServerConfig { port: port.unwrap_or(2024), bind_address: "127.0.0.1".to_string(), - security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置 + security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置 pg_conn: pg_conn.map(|s| s.to_string()), }; - + let server = SshServer::new(config); server.run() -} \ No newline at end of file +} diff --git a/markbase-core/src/ssh_server/sftp_handler.rs b/markbase-core/src/ssh_server/sftp_handler.rs index 5930bb9..63c7507 100644 --- a/markbase-core/src/ssh_server/sftp_handler.rs +++ b/markbase-core/src/ssh_server/sftp_handler.rs @@ -1,17 +1,16 @@ // SFTP协议实现(Phase 7) // 参考OpenSSH sftp-server.c和draft-ietf-secsh-filexfer-02.txt -use crate::ssh_server::packet::{SshPacket, PacketType}; -use crate::vfs::{VfsBackend, VfsFile, VfsDirEntry}; use crate::vfs::open_flags::OpenFlags; -use anyhow::{Result, anyhow, Context}; +use crate::vfs::{VfsBackend, VfsDirEntry, VfsFile}; +use anyhow::{anyhow, Context, Result}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use log::{info, warn, debug}; -use std::path::{Path, PathBuf}; +use log::{debug, info, warn}; use std::fs; use std::io::{SeekFrom, Write}; -use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::MetadataExt; +use std::os::unix::fs::PermissionsExt; +use std::path::PathBuf; /// SFTP packet类型(参考draft-ietf-secsh-filexfer-02.txt) #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -156,27 +155,31 @@ impl SftpAttrs { pub fn from_metadata(metadata: &fs::Metadata) -> Self { let mut attrs = Self::new(); - + attrs.flags = SftpAttrFlags::SSH_FILEXFER_ATTR_SIZE | SftpAttrFlags::SSH_FILEXFER_ATTR_UIDGID | SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS | SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME; - + attrs.size = Some(metadata.len()); attrs.permissions = Some(metadata.permissions().mode()); attrs.uid = Some(metadata.uid()); attrs.gid = Some(metadata.gid()); - + if let Ok(atime) = metadata.accessed() { - attrs.atime = atime.duration_since(std::time::UNIX_EPOCH) - .ok().map(|d| d.as_secs() as u32); + attrs.atime = atime + .duration_since(std::time::UNIX_EPOCH) + .ok() + .map(|d| d.as_secs() as u32); } - + if let Ok(mtime) = metadata.modified() { - attrs.mtime = mtime.duration_since(std::time::UNIX_EPOCH) - .ok().map(|d| d.as_secs() as u32); + attrs.mtime = mtime + .duration_since(std::time::UNIX_EPOCH) + .ok() + .map(|d| d.as_secs() as u32); } - + attrs } @@ -210,59 +213,71 @@ impl SftpAttrs { self.permissions.unwrap_or(0), self.atime, self.mtime, ); - + let mut buffer = Vec::new(); - - buffer.write_u32::(self.flags) + + buffer + .write_u32::(self.flags) .with_context(|| "serialize attrs flags")?; - + if self.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_SIZE != 0 { if let Some(size) = self.size { - buffer.write_u64::(size) + buffer + .write_u64::(size) .with_context(|| "serialize attrs size")?; } } - + if self.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_UIDGID != 0 { if let (Some(uid), Some(gid)) = (self.uid, self.gid) { - buffer.write_u32::(uid) + buffer + .write_u32::(uid) .with_context(|| "serialize attrs uid")?; - buffer.write_u32::(gid) + buffer + .write_u32::(gid) .with_context(|| "serialize attrs gid")?; } } - + if self.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0 { if let Some(permissions) = self.permissions { - buffer.write_u32::(permissions) + buffer + .write_u32::(permissions) .with_context(|| "serialize attrs perms")?; } } - + if self.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0 { if let (Some(atime), Some(mtime)) = (self.atime, self.mtime) { - buffer.write_u32::(atime) + buffer + .write_u32::(atime) .with_context(|| "serialize attrs atime")?; - buffer.write_u32::(mtime) + buffer + .write_u32::(mtime) .with_context(|| "serialize attrs mtime")?; } } - + if self.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_EXTENDED != 0 { - buffer.write_u32::(self.extended.len() as u32) + buffer + .write_u32::(self.extended.len() as u32) .with_context(|| "serialize attrs ext count")?; for (name, value) in &self.extended { - buffer.write_u32::(name.len() as u32) + buffer + .write_u32::(name.len() as u32) .with_context(|| "serialize attrs ext name len")?; - buffer.write_all(name.as_bytes()) + buffer + .write_all(name.as_bytes()) .with_context(|| "serialize attrs ext name")?; - buffer.write_u32::(value.len() as u32) + buffer + .write_u32::(value.len() as u32) .with_context(|| "serialize attrs ext value len")?; - buffer.write_all(value.as_bytes()) + buffer + .write_all(value.as_bytes()) .with_context(|| "serialize attrs ext value")?; } } - + Ok(buffer) } } @@ -289,7 +304,7 @@ pub struct SftpHandler { next_handle_id: u32, handles: std::collections::HashMap, // ⭐⭐⭐⭐⭐ Phase 4: 添加 client maxpack 限制(参考OpenSSH sftp-server.c) - maxpacket: u32, // 来自 SSH_MSG_CHANNEL_OPEN_CONFIRMATION 的 maximum_packet_size + maxpacket: u32, // 来自 SSH_MSG_CHANNEL_OPEN_CONFIRMATION 的 maximum_packet_size /// 限制绝对路径也在 root_dir 之下(chroot 模式) restrict_absolute: bool, } @@ -325,11 +340,11 @@ impl SftpHandler { if data.is_empty() { return Err(anyhow!("Empty SFTP request")); } - + let packet_type = SftpPacketType::try_from(data[0])?; - + info!("Processing SFTP request: {:?}", packet_type); - + match packet_type { SftpPacketType::SSH_FXP_INIT => self.handle_init(data), SftpPacketType::SSH_FXP_OPEN => self.handle_open(data), @@ -361,13 +376,13 @@ impl SftpHandler { /// 处理SSH_FXP_INIT(参考OpenSSH sftp-server.c: process_init()) fn handle_init(&self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_INIT"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let version = cursor.read_u32::()?; info!("Client SFTP version: {}", version); - + let response = self.build_version_response(3)?; Ok(response) } @@ -375,29 +390,36 @@ impl SftpHandler { /// 处理SSH_FXP_OPEN(参考OpenSSH sftp-server.c: process_open()) fn handle_open(&mut self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_OPEN"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let path = read_sftp_string(&mut cursor)?; let pflags = cursor.read_u32::()?; let _attrs = read_sftp_attrs(&mut cursor)?; - - info!("SSH_FXP_OPEN: id={}, path={}, pflags={:#x}", id, path, pflags); - + + info!( + "SSH_FXP_OPEN: id={}, path={}, pflags={:#x}", + id, path, pflags + ); + let full_path = self.resolve_path(&path)?; let flags = OpenFlags::from_sftp_pflags(pflags); - + match self.vfs.open_file(&full_path, &flags) { Ok(file) => { if self.handles.len() >= Self::MAX_HANDLES { warn!("SSH_FXP_OPEN: handle limit reached ({})", Self::MAX_HANDLES); - return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Handle limit reached"); + return self.build_status_response( + id, + SftpStatus::SSH_FX_FAILURE, + "Handle limit reached", + ); } let handle_id = self.next_handle_id; self.next_handle_id += 1; - + let handle = SftpHandle { id: handle_id, path: full_path, @@ -405,30 +427,33 @@ impl SftpHandler { file: Some(file), dir_entries: None, }; - + self.handles.insert(handle_id, handle); - + self.build_handle_response(id, &handle_id.to_be_bytes()) } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } /// 处理SSH_FXP_CLOSE(参考OpenSSH sftp-server.c: process_close()) fn handle_close(&mut self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_CLOSE"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let handle_bytes = read_sftp_string_bytes(&mut cursor)?; - let handle_id = u32::from_be_bytes([handle_bytes[0], handle_bytes[1], handle_bytes[2], handle_bytes[3]]); - + let handle_id = u32::from_be_bytes([ + handle_bytes[0], + handle_bytes[1], + handle_bytes[2], + handle_bytes[3], + ]); + info!("SSH_FXP_CLOSE: id={}, handle={}", id, handle_id); - + if self.handles.remove(&handle_id).is_some() { self.build_status_response(id, SftpStatus::SSH_FX_OK, "File closed") } else { @@ -439,39 +464,48 @@ impl SftpHandler { /// 处理SSH_FXP_READ(参考OpenSSH sftp-server.c: process_read()) fn handle_read(&mut self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_READ"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let handle_bytes = read_sftp_string_bytes(&mut cursor)?; - let handle_id = u32::from_be_bytes([handle_bytes[0], handle_bytes[1], handle_bytes[2], handle_bytes[3]]); + let handle_id = u32::from_be_bytes([ + handle_bytes[0], + handle_bytes[1], + handle_bytes[2], + handle_bytes[3], + ]); let offset = cursor.read_u64::()?; let length = cursor.read_u32::()?; - - info!("SSH_FXP_READ: id={}, handle={}, offset={}, length={}", id, handle_id, offset, length); - + + info!( + "SSH_FXP_READ: id={}, handle={}, offset={}, length={}", + id, handle_id, offset, length + ); + if let Some(handle) = self.handles.get_mut(&handle_id) { if let Some(ref mut file) = handle.file { - file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?; - - let max_data_size = std::cmp::min(self.maxpacket.saturating_sub(1024), Self::MAX_XFER_SIZE); + file.seek(SeekFrom::Start(offset)) + .map_err(|e| anyhow!("Seek error: {}", e))?; + + let max_data_size = + std::cmp::min(self.maxpacket.saturating_sub(1024), Self::MAX_XFER_SIZE); let actual_length = std::cmp::min(length, max_data_size); - - info!("SSH_FXP_READ limited: requested={}, actual={}", length, actual_length); - + + info!( + "SSH_FXP_READ limited: requested={}, actual={}", + length, actual_length + ); + let mut buffer = vec![0u8; actual_length as usize]; match file.read(&mut buffer) { - Ok(0) => { - self.build_status_response(id, SftpStatus::SSH_FX_EOF, "End of file") - } + Ok(0) => self.build_status_response(id, SftpStatus::SSH_FX_EOF, "End of file"), Ok(n) => { buffer.truncate(n); self.build_data_response(id, &buffer) } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } else { self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Not a file handle") @@ -484,36 +518,49 @@ impl SftpHandler { /// 处理SSH_FXP_WRITE(参考OpenSSH sftp-server.c: process_write()) fn handle_write(&mut self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_WRITE"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let handle_bytes = read_sftp_string_bytes(&mut cursor)?; - let handle_id = u32::from_be_bytes([handle_bytes[0], handle_bytes[1], handle_bytes[2], handle_bytes[3]]); + let handle_id = u32::from_be_bytes([ + handle_bytes[0], + handle_bytes[1], + handle_bytes[2], + handle_bytes[3], + ]); let offset = cursor.read_u64::()?; let write_data = read_sftp_string_bytes(&mut cursor)?; - - info!("SSH_FXP_WRITE: id={}, handle={}, offset={}, length={}", id, handle_id, offset, write_data.len()); - - if write_data.len() > 0 { + + info!( + "SSH_FXP_WRITE: id={}, handle={}, offset={}, length={}", + id, + handle_id, + offset, + write_data.len() + ); + + if !write_data.is_empty() { let preview_len = std::cmp::min(20, write_data.len()); let preview = &write_data[0..preview_len]; - debug!("SSH_FXP_WRITE data preview (first {} bytes): {:?}", preview_len, preview); + debug!( + "SSH_FXP_WRITE data preview (first {} bytes): {:?}", + preview_len, preview + ); } - + if let Some(handle) = self.handles.get_mut(&handle_id) { if let Some(ref mut file) = handle.file { - file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?; - + file.seek(SeekFrom::Start(offset)) + .map_err(|e| anyhow!("Seek error: {}", e))?; + match file.write_all(&write_data) { Ok(_) => { file.flush().ok(); self.build_status_response(id, SftpStatus::SSH_FX_OK, "Write successful") } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } else { self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Not a file handle") @@ -526,41 +573,44 @@ impl SftpHandler { /// 处理SSH_FXP_LSTAT(参考OpenSSH sftp-server.c: process_lstat()) fn handle_lstat(&self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_LSTAT"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let path = read_sftp_string(&mut cursor)?; - + info!("SSH_FXP_LSTAT: id={}, path={}", id, path); - + let full_path = self.resolve_path(&path)?; - + match self.vfs.lstat(&full_path) { Ok(stat) => { let attrs = SftpAttrs::from_vfs_stat(&stat); self.build_attrs_response(id, &attrs) } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } /// 处理SSH_FXP_FSTAT(参考OpenSSH sftp-server.c: process_fstat()) fn handle_fstat(&mut self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_FSTAT"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let handle_bytes = read_sftp_string_bytes(&mut cursor)?; - let handle_id = u32::from_be_bytes([handle_bytes[0], handle_bytes[1], handle_bytes[2], handle_bytes[3]]); - + let handle_id = u32::from_be_bytes([ + handle_bytes[0], + handle_bytes[1], + handle_bytes[2], + handle_bytes[3], + ]); + info!("SSH_FXP_FSTAT: id={}, handle={}", id, handle_id); - + if let Some(handle) = self.handles.get_mut(&handle_id) { if let Some(ref mut file) = handle.file { match file.stat() { @@ -568,9 +618,7 @@ impl SftpHandler { let attrs = SftpAttrs::from_vfs_stat(&stat); self.build_attrs_response(id, &attrs) } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } else { match self.vfs.stat(&handle.path) { @@ -578,9 +626,7 @@ impl SftpHandler { let attrs = SftpAttrs::from_vfs_stat(&stat); self.build_attrs_response(id, &attrs) } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } } else { @@ -591,26 +637,33 @@ impl SftpHandler { /// 处理SSH_FXP_OPENDIR(参考OpenSSH sftp-server.c: process_opendir()) fn handle_opendir(&mut self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_OPENDIR"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let path = read_sftp_string(&mut cursor)?; - + info!("SSH_FXP_OPENDIR: id={}, path={}", id, path); - + let full_path = self.resolve_path(&path)?; - + match self.vfs.read_dir(&full_path) { Ok(entries) => { if self.handles.len() >= Self::MAX_HANDLES { - warn!("SSH_FXP_OPENDIR: handle limit reached ({})", Self::MAX_HANDLES); - return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Handle limit reached"); + warn!( + "SSH_FXP_OPENDIR: handle limit reached ({})", + Self::MAX_HANDLES + ); + return self.build_status_response( + id, + SftpStatus::SSH_FX_FAILURE, + "Handle limit reached", + ); } let handle_id = self.next_handle_id; self.next_handle_id += 1; - + let handle = SftpHandle { id: handle_id, path: full_path, @@ -618,30 +671,33 @@ impl SftpHandler { file: None, dir_entries: Some(entries), }; - + self.handles.insert(handle_id, handle); - + self.build_handle_response(id, &handle_id.to_be_bytes()) } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } /// 处理SSH_FXP_READDIR(参考OpenSSH sftp-server.c: process_readdir()) fn handle_readdir(&mut self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_READDIR"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let handle_bytes = read_sftp_string_bytes(&mut cursor)?; - let handle_id = u32::from_be_bytes([handle_bytes[0], handle_bytes[1], handle_bytes[2], handle_bytes[3]]); - + let handle_id = u32::from_be_bytes([ + handle_bytes[0], + handle_bytes[1], + handle_bytes[2], + handle_bytes[3], + ]); + info!("SSH_FXP_READDIR: id={}, handle={}", id, handle_id); - + if let Some(handle) = self.handles.get_mut(&handle_id) { if handle.handle_type == SftpHandleType::Directory { if let Some(ref mut dir_entries) = handle.dir_entries { @@ -655,11 +711,15 @@ impl SftpHandler { (entry.name, attrs) }) .collect(); - + self.build_name_response(id, entries) } } else { - self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "No directory entries") + self.build_status_response( + id, + SftpStatus::SSH_FX_FAILURE, + "No directory entries", + ) } } else { self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Not a directory handle") @@ -672,196 +732,187 @@ impl SftpHandler { /// 处理SSH_FXP_REMOVE(参考OpenSSH sftp-server.c: process_remove()) fn handle_remove(&self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_REMOVE"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let path = read_sftp_string(&mut cursor)?; - + info!("SSH_FXP_REMOVE: id={}, path={}", id, path); - + let full_path = self.resolve_path(&path)?; - + match self.vfs.remove_file(&full_path) { - Ok(_) => { - self.build_status_response(id, SftpStatus::SSH_FX_OK, "File removed") - } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Ok(_) => self.build_status_response(id, SftpStatus::SSH_FX_OK, "File removed"), + Err(e) => self.build_status_from_vfs_error(id, &e), } } /// 处理SSH_FXP_MKDIR(参考OpenSSH sftp-server.c: process_mkdir()) fn handle_mkdir(&self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_MKDIR"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let path = read_sftp_string(&mut cursor)?; let _attrs = read_sftp_attrs(&mut cursor)?; - + info!("SSH_FXP_MKDIR: id={}, path={}", id, path); - + let full_path = self.resolve_path(&path)?; - + match self.vfs.create_dir(&full_path, 0o755) { - Ok(_) => { - self.build_status_response(id, SftpStatus::SSH_FX_OK, "Directory created") - } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Ok(_) => self.build_status_response(id, SftpStatus::SSH_FX_OK, "Directory created"), + Err(e) => self.build_status_from_vfs_error(id, &e), } } /// 处理SSH_FXP_RMDIR(参考OpenSSH sftp-server.c: process_rmdir()) fn handle_rmdir(&self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_RMDIR"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let path = read_sftp_string(&mut cursor)?; - + info!("SSH_FXP_RMDIR: id={}, path={}", id, path); - + let full_path = self.resolve_path(&path)?; - + match self.vfs.remove_dir(&full_path) { - Ok(_) => { - self.build_status_response(id, SftpStatus::SSH_FX_OK, "Directory removed") - } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Ok(_) => self.build_status_response(id, SftpStatus::SSH_FX_OK, "Directory removed"), + Err(e) => self.build_status_from_vfs_error(id, &e), } } /// 处理SSH_FXP_REALPATH(参考OpenSSH sftp-server.c: process_realpath()) fn handle_realpath(&self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_REALPATH"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let path = read_sftp_string(&mut cursor)?; - + info!("SSH_FXP_REALPATH: id={}, path={}", id, path); - + let full_path = self.resolve_path(&path)?; - - let name_attrs_vec = vec![( - full_path.to_string_lossy().to_string(), - SftpAttrs::new(), - )]; - + + let name_attrs_vec = vec![(full_path.to_string_lossy().to_string(), SftpAttrs::new())]; + self.build_name_response(id, name_attrs_vec) } /// 处理SSH_FXP_STAT(参考OpenSSH sftp-server.c: process_stat()) fn handle_stat(&self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_STAT"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let path = read_sftp_string(&mut cursor)?; - + info!("SSH_FXP_STAT: id={}, path={}", id, path); - + let full_path = self.resolve_path(&path)?; - + match self.vfs.stat(&full_path) { Ok(stat) => { let attrs = SftpAttrs::from_vfs_stat(&stat); self.build_attrs_response(id, &attrs) } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } /// 处理SSH_FXP_RENAME(参考OpenSSH sftp-server.c: process_rename()) fn handle_rename(&self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_RENAME"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let old_path = read_sftp_string(&mut cursor)?; let new_path = read_sftp_string(&mut cursor)?; - - info!("SSH_FXP_RENAME: id={}, old={}, new={}", id, old_path, new_path); - + + info!( + "SSH_FXP_RENAME: id={}, old={}, new={}", + id, old_path, new_path + ); + let old_full_path = self.resolve_path(&old_path)?; let new_full_path = self.resolve_path(&new_path)?; - + match self.vfs.rename(&old_full_path, &new_full_path) { - Ok(_) => { - self.build_status_response(id, SftpStatus::SSH_FX_OK, "Rename successful") - } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Ok(_) => self.build_status_response(id, SftpStatus::SSH_FX_OK, "Rename successful"), + Err(e) => self.build_status_from_vfs_error(id, &e), } } /// 处理SSH_FXP_SETSTAT(参考OpenSSH sftp-server.c: process_setstat()) fn handle_setstat(&self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_SETSTAT"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let path = read_sftp_string(&mut cursor)?; let _attrs = read_sftp_attrs(&mut cursor)?; - + info!("SSH_FXP_SETSTAT: id={}, path={}", id, path); - + self.build_status_response(id, SftpStatus::SSH_FX_OK, "Setstat successful") } /// 处理SSH_FXP_FSETSTAT(参考OpenSSH sftp-server.c: process_fsetstat()) fn handle_fsetstat(&mut self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_FSETSTAT"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let handle_bytes = read_sftp_string_bytes(&mut cursor)?; - let handle_id = u32::from_be_bytes([handle_bytes[0], handle_bytes[1], handle_bytes[2], handle_bytes[3]]); + let handle_id = u32::from_be_bytes([ + handle_bytes[0], + handle_bytes[1], + handle_bytes[2], + handle_bytes[3], + ]); let attrs = read_sftp_attrs(&mut cursor)?; - - info!("SSH_FXP_FSETSTAT: id={}, handle={}, attrs.flags={}", id, handle_id, attrs.flags); - + + info!( + "SSH_FXP_FSETSTAT: id={}, handle={}, attrs.flags={}", + id, handle_id, attrs.flags + ); + let handle = self.handles.get_mut(&handle_id); if handle.is_none() { return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid handle"); } - + let handle = handle.unwrap(); if handle.handle_type != SftpHandleType::File { return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Not a file handle"); } - + let path = handle.path.clone(); - + if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_SIZE != 0 { if let Some(size) = attrs.size { info!("FSETSTAT: setting file size to {}", size); if let Some(ref mut file) = handle.file { - file.set_len(size).map_err(|e| anyhow!("set_len error: {}", e))?; + file.set_len(size) + .map_err(|e| anyhow!("set_len error: {}", e))?; } else { let flags = OpenFlags::new().write(); if let Ok(mut f) = self.vfs.open_file(&path, &flags) { @@ -870,7 +921,7 @@ impl SftpHandler { } } } - + if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0 || attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0 { @@ -884,114 +935,97 @@ impl SftpHandler { } if attrs.flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0 { if let (Some(atime), Some(mtime)) = (attrs.atime, attrs.mtime) { - vfs_stat.atime = std::time::UNIX_EPOCH + std::time::Duration::from_secs(atime as u64); - vfs_stat.mtime = std::time::UNIX_EPOCH + std::time::Duration::from_secs(mtime as u64); + vfs_stat.atime = + std::time::UNIX_EPOCH + std::time::Duration::from_secs(atime as u64); + vfs_stat.mtime = + std::time::UNIX_EPOCH + std::time::Duration::from_secs(mtime as u64); } } self.vfs.set_stat(&path, &vfs_stat).ok(); } - + self.build_status_response(id, SftpStatus::SSH_FX_OK, "Fsetstat successful") } /// 处理SSH_FXP_READLINK(Phase 10:参考OpenSSH sftp-server.c: process_readlink()) fn handle_readlink(&self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_READLINK"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let path = read_sftp_string(&mut cursor)?; - + info!("SSH_FXP_READLINK: id={}, path={}", id, path); - + let full_path = self.resolve_path(&path)?; - + match self.vfs.read_link(&full_path) { Ok(link_target) => { let target = link_target.to_string_lossy().to_string(); self.build_name_response(id, vec![(target, SftpAttrs::default())]) } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } /// 处理SSH_FXP_SYMLINK(Phase 10:参考OpenSSH sftp-server.c: process_symlink()) fn handle_symlink(&self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_SYMLINK"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let linkpath = read_sftp_string(&mut cursor)?; let targetpath = read_sftp_string(&mut cursor)?; - - info!("SSH_FXP_SYMLINK: id={}, link={}, target={}", id, linkpath, targetpath); - + + info!( + "SSH_FXP_SYMLINK: id={}, link={}, target={}", + id, linkpath, targetpath + ); + let full_linkpath = self.resolve_path(&linkpath)?; let full_targetpath = self.resolve_path(&targetpath)?; - + match self.vfs.create_symlink(&full_targetpath, &full_linkpath) { - Ok(_) => { - self.build_status_response(id, SftpStatus::SSH_FX_OK, "Symlink created") - } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Ok(_) => self.build_status_response(id, SftpStatus::SSH_FX_OK, "Symlink created"), + Err(e) => self.build_status_from_vfs_error(id, &e), } } /// 处理SSH_FXP_EXTENDED(Phase 10:参考OpenSSH sftp-server.c: process_extended()) fn handle_extended(&mut self, data: &[u8]) -> Result> { info!("Processing SSH_FXP_EXTENDED"); - + let mut cursor = std::io::Cursor::new(data); cursor.set_position(1); - + let id = cursor.read_u32::()?; let extension_name = read_sftp_string(&mut cursor)?; - + info!("SSH_FXP_EXTENDED: id={}, extension={}", id, extension_name); - + // 支持常见的SFTP扩展 match extension_name.as_str() { - "statvfs@openssh.com" => { - self.handle_statvfs(&mut cursor, id) - } - "fstatvfs@openssh.com" => { - self.handle_fstatvfs(&mut cursor, id) - } - "hardlink@openssh.com" => { - self.handle_hardlink(&mut cursor, id) - } - "posix-rename@openssh.com" => { - self.handle_posix_rename(&mut cursor, id) - } - "md5-hash@openssh.com" => { - self.handle_md5_hash(&mut cursor, id) - } - "sha256-hash@openssh.com" => { - self.handle_sha256_hash(&mut cursor, id) - } - "sha384-hash@openssh.com" => { - self.handle_sha384_hash(&mut cursor, id) - } - "sha512-hash@openssh.com" => { - self.handle_sha512_hash(&mut cursor, id) - } - "check-file@openssh.com" => { - self.handle_check_file(&mut cursor, id) - } - "copy-data@openssh.com" => { - self.handle_copy_data(&mut cursor, id) - } + "statvfs@openssh.com" => self.handle_statvfs(&mut cursor, id), + "fstatvfs@openssh.com" => self.handle_fstatvfs(&mut cursor, id), + "hardlink@openssh.com" => self.handle_hardlink(&mut cursor, id), + "posix-rename@openssh.com" => self.handle_posix_rename(&mut cursor, id), + "md5-hash@openssh.com" => self.handle_md5_hash(&mut cursor, id), + "sha256-hash@openssh.com" => self.handle_sha256_hash(&mut cursor, id), + "sha384-hash@openssh.com" => self.handle_sha384_hash(&mut cursor, id), + "sha512-hash@openssh.com" => self.handle_sha512_hash(&mut cursor, id), + "check-file@openssh.com" => self.handle_check_file(&mut cursor, id), + "copy-data@openssh.com" => self.handle_copy_data(&mut cursor, id), _ => { warn!("Unsupported SFTP extension: {}", extension_name); - self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, &format!("Unsupported extension: {}", extension_name)) + self.build_status_response( + id, + SftpStatus::SSH_FX_FAILURE, + &format!("Unsupported extension: {}", extension_name), + ) } } } @@ -1000,15 +1034,15 @@ impl SftpHandler { fn handle_statvfs(&self, cursor: &mut std::io::Cursor<&[u8]>, id: u32) -> Result> { let path = read_sftp_string(cursor)?; info!("statvfs: path={}", path); - + let full_path = self.resolve_path(&path)?; - + match self.vfs.stat(&full_path) { Ok(_) => { let mut response = Vec::new(); response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; response.write_u32::(id)?; - + response.write_u64::(4096)?; response.write_u64::(4096)?; response.write_u64::(1000000)?; @@ -1020,29 +1054,32 @@ impl SftpHandler { response.write_u64::(0)?; response.write_u64::(0)?; response.write_u64::(255)?; - + self.wrap_sftp_packet(&response) } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } /// 处理fstatvfs@openssh.com扩展(文件句柄统计) fn handle_fstatvfs(&mut self, cursor: &mut std::io::Cursor<&[u8]>, id: u32) -> Result> { let handle_bytes = read_sftp_string_bytes(cursor)?; - let handle_id = u32::from_be_bytes([handle_bytes[0], handle_bytes[1], handle_bytes[2], handle_bytes[3]]); - + let handle_id = u32::from_be_bytes([ + handle_bytes[0], + handle_bytes[1], + handle_bytes[2], + handle_bytes[3], + ]); + info!("fstatvfs: handle={}", handle_id); - + // 简化实现:返回与statvfs相同的结果 #[cfg(unix)] { let mut response = Vec::new(); response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; response.write_u32::(id)?; - + response.write_u64::(4096)?; response.write_u64::(4096)?; response.write_u64::(1000000)?; @@ -1054,31 +1091,31 @@ impl SftpHandler { response.write_u64::(0)?; response.write_u64::(0)?; response.write_u64::(255)?; - + self.wrap_sftp_packet(&response) } - + #[cfg(not(unix))] - self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "fstatvfs not supported on non-Unix systems") + self.build_status_response( + id, + SftpStatus::SSH_FX_FAILURE, + "fstatvfs not supported on non-Unix systems", + ) } /// 处理hardlink@openssh.com扩展(创建硬链接) fn handle_hardlink(&self, cursor: &mut std::io::Cursor<&[u8]>, id: u32) -> Result> { let oldpath = read_sftp_string(cursor)?; let newpath = read_sftp_string(cursor)?; - + info!("hardlink: old={}, new={}", oldpath, newpath); - + let full_oldpath = self.resolve_path(&oldpath)?; let full_newpath = self.resolve_path(&newpath)?; - + match self.vfs.hard_link(&full_oldpath, &full_newpath) { - Ok(_) => { - self.build_status_response(id, SftpStatus::SSH_FX_OK, "Hardlink created") - } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Ok(_) => self.build_status_response(id, SftpStatus::SSH_FX_OK, "Hardlink created"), + Err(e) => self.build_status_from_vfs_error(id, &e), } } @@ -1086,19 +1123,17 @@ impl SftpHandler { fn handle_posix_rename(&self, cursor: &mut std::io::Cursor<&[u8]>, id: u32) -> Result> { let oldpath = read_sftp_string(cursor)?; let newpath = read_sftp_string(cursor)?; - + info!("posix-rename: old={}, new={}", oldpath, newpath); - + let full_oldpath = self.resolve_path(&oldpath)?; let full_newpath = self.resolve_path(&newpath)?; - + match self.vfs.rename(&full_oldpath, &full_newpath) { Ok(_) => { self.build_status_response(id, SftpStatus::SSH_FX_OK, "Posix rename successful") } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } @@ -1107,42 +1142,48 @@ impl SftpHandler { let path = read_sftp_string(cursor)?; let offset = cursor.read_u64::()?; let length = cursor.read_u64::()?; - - info!("md5-hash: path={}, offset={}, length={}", path, offset, length); - + + info!( + "md5-hash: path={}, offset={}, length={}", + path, offset, length + ); + let actual_length = std::cmp::min(length, Self::MAX_HASH_SIZE); if actual_length < length { - warn!("md5-hash: length reduced from {} to {} (MAX_HASH_SIZE)", length, actual_length); + warn!( + "md5-hash: length reduced from {} to {} (MAX_HASH_SIZE)", + length, actual_length + ); } - + let full_path = self.resolve_path(&path)?; - + let flags = OpenFlags::new().read(); match self.vfs.open_file(&full_path, &flags) { Ok(mut file) => { - file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?; - + file.seek(SeekFrom::Start(offset)) + .map_err(|e| anyhow!("Seek error: {}", e))?; + let mut buffer = vec![0u8; actual_length as usize]; - file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?; - + file.read_exact(&mut buffer) + .map_err(|e| anyhow!("Read error: {}", e))?; + let hash = md5::compute(&buffer); let hash_hex = format!("{:x}", hash); - + let mut response = Vec::new(); response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; response.write_u32::(id)?; - + response.write_u32::(4)?; response.write_all("md5".as_bytes())?; - + response.write_u32::(hash_hex.len() as u32)?; response.write_all(hash_hex.as_bytes())?; - + self.wrap_sftp_packet(&response) } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } @@ -1151,45 +1192,51 @@ impl SftpHandler { let path = read_sftp_string(cursor)?; let offset = cursor.read_u64::()?; let length = cursor.read_u64::()?; - - info!("sha256-hash: path={}, offset={}, length={}", path, offset, length); - + + info!( + "sha256-hash: path={}, offset={}, length={}", + path, offset, length + ); + let actual_length = std::cmp::min(length, Self::MAX_HASH_SIZE); if actual_length < length { - warn!("sha256-hash: length reduced from {} to {} (MAX_HASH_SIZE)", length, actual_length); + warn!( + "sha256-hash: length reduced from {} to {} (MAX_HASH_SIZE)", + length, actual_length + ); } - + let full_path = self.resolve_path(&path)?; - + let flags = OpenFlags::new().read(); match self.vfs.open_file(&full_path, &flags) { Ok(mut file) => { - file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?; - + file.seek(SeekFrom::Start(offset)) + .map_err(|e| anyhow!("Seek error: {}", e))?; + let mut buffer = vec![0u8; actual_length as usize]; - file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?; - - use sha2::{Sha256, Digest}; + file.read_exact(&mut buffer) + .map_err(|e| anyhow!("Read error: {}", e))?; + + use sha2::{Digest, Sha256}; let mut hasher = Sha256::new(); hasher.update(&buffer); let hash = hasher.finalize(); let hash_hex = format!("{:x}", hash); - + let mut response = Vec::new(); response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; response.write_u32::(id)?; - + response.write_u32::(6)?; response.write_all("sha256".as_bytes())?; - + response.write_u32::(hash_hex.len() as u32)?; response.write_all(hash_hex.as_bytes())?; - + self.wrap_sftp_packet(&response) } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } @@ -1198,45 +1245,51 @@ impl SftpHandler { let path = read_sftp_string(cursor)?; let offset = cursor.read_u64::()?; let length = cursor.read_u64::()?; - - info!("sha384-hash: path={}, offset={}, length={}", path, offset, length); - + + info!( + "sha384-hash: path={}, offset={}, length={}", + path, offset, length + ); + let actual_length = std::cmp::min(length, Self::MAX_HASH_SIZE); if actual_length < length { - warn!("sha384-hash: length reduced from {} to {} (MAX_HASH_SIZE)", length, actual_length); + warn!( + "sha384-hash: length reduced from {} to {} (MAX_HASH_SIZE)", + length, actual_length + ); } - + let full_path = self.resolve_path(&path)?; - + let flags = OpenFlags::new().read(); match self.vfs.open_file(&full_path, &flags) { Ok(mut file) => { - file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?; - + file.seek(SeekFrom::Start(offset)) + .map_err(|e| anyhow!("Seek error: {}", e))?; + let mut buffer = vec![0u8; actual_length as usize]; - file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?; - - use sha2::{Sha384, Digest}; + file.read_exact(&mut buffer) + .map_err(|e| anyhow!("Read error: {}", e))?; + + use sha2::{Digest, Sha384}; let mut hasher = Sha384::new(); hasher.update(&buffer); let hash = hasher.finalize(); let hash_hex = format!("{:x}", hash); - + let mut response = Vec::new(); response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; response.write_u32::(id)?; - + response.write_u32::(6)?; response.write_all("sha384".as_bytes())?; - + response.write_u32::(hash_hex.len() as u32)?; response.write_all(hash_hex.as_bytes())?; - + self.wrap_sftp_packet(&response) } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } @@ -1245,45 +1298,51 @@ impl SftpHandler { let path = read_sftp_string(cursor)?; let offset = cursor.read_u64::()?; let length = cursor.read_u64::()?; - - info!("sha512-hash: path={}, offset={}, length={}", path, offset, length); - + + info!( + "sha512-hash: path={}, offset={}, length={}", + path, offset, length + ); + let actual_length = std::cmp::min(length, Self::MAX_HASH_SIZE); if actual_length < length { - warn!("sha512-hash: length reduced from {} to {} (MAX_HASH_SIZE)", length, actual_length); + warn!( + "sha512-hash: length reduced from {} to {} (MAX_HASH_SIZE)", + length, actual_length + ); } - + let full_path = self.resolve_path(&path)?; - + let flags = OpenFlags::new().read(); match self.vfs.open_file(&full_path, &flags) { Ok(mut file) => { - file.seek(SeekFrom::Start(offset)).map_err(|e| anyhow!("Seek error: {}", e))?; - + file.seek(SeekFrom::Start(offset)) + .map_err(|e| anyhow!("Seek error: {}", e))?; + let mut buffer = vec![0u8; actual_length as usize]; - file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?; - - use sha2::{Sha512, Digest}; + file.read_exact(&mut buffer) + .map_err(|e| anyhow!("Read error: {}", e))?; + + use sha2::{Digest, Sha512}; let mut hasher = Sha512::new(); hasher.update(&buffer); let hash = hasher.finalize(); let hash_hex = format!("{:x}", hash); - + let mut response = Vec::new(); response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; response.write_u32::(id)?; - + response.write_u32::(6)?; response.write_all("sha512".as_bytes())?; - + response.write_u32::(hash_hex.len() as u32)?; response.write_all(hash_hex.as_bytes())?; - + self.wrap_sftp_packet(&response) } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } @@ -1291,65 +1350,88 @@ impl SftpHandler { fn handle_check_file(&self, cursor: &mut std::io::Cursor<&[u8]>, id: u32) -> Result> { let path = read_sftp_string(cursor)?; let _check_flags = cursor.read_u32::()?; - + info!("check-file: path={}", path); - + let full_path = self.resolve_path(&path)?; - + match self.vfs.stat(&full_path) { Ok(stat) => { let mut response = Vec::new(); response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; response.write_u32::(id)?; - + response.write_u32::(1)?; - + let msg = format!("File exists, size: {}", stat.size); response.write_u32::(msg.len() as u32)?; response.write_all(msg.as_bytes())?; - + self.wrap_sftp_packet(&response) } - Err(e) => { - self.build_status_from_vfs_error(id, &e) - } + Err(e) => self.build_status_from_vfs_error(id, &e), } } /// 处理copy-data@openssh.com扩展(Phase 12:服务器端复制) - fn handle_copy_data(&mut self, cursor: &mut std::io::Cursor<&[u8]>, id: u32) -> Result> { + fn handle_copy_data( + &mut self, + cursor: &mut std::io::Cursor<&[u8]>, + id: u32, + ) -> Result> { let read_handle_bytes = read_sftp_string_bytes(cursor)?; let read_offset = cursor.read_u64::()?; let read_length = cursor.read_u64::()?; let write_handle_bytes = read_sftp_string_bytes(cursor)?; let write_offset = cursor.read_u64::()?; - + info!("copy-data: read_handle={:?}, read_offset={}, read_length={}, write_handle={:?}, write_offset={}", read_handle_bytes, read_offset, read_length, write_handle_bytes, write_offset); - + let actual_length = std::cmp::min(read_length, Self::MAX_XFER_SIZE as u64); if actual_length < read_length { - warn!("copy-data: length reduced from {} to {} (MAX_XFER_SIZE)", read_length, actual_length); + warn!( + "copy-data: length reduced from {} to {} (MAX_XFER_SIZE)", + read_length, actual_length + ); } - - let read_handle_id = u32::from_be_bytes([read_handle_bytes[0], read_handle_bytes[1], read_handle_bytes[2], read_handle_bytes[3]]); - let write_handle_id = u32::from_be_bytes([write_handle_bytes[0], write_handle_bytes[1], write_handle_bytes[2], write_handle_bytes[3]]); - + + let read_handle_id = u32::from_be_bytes([ + read_handle_bytes[0], + read_handle_bytes[1], + read_handle_bytes[2], + read_handle_bytes[3], + ]); + let write_handle_id = u32::from_be_bytes([ + write_handle_bytes[0], + write_handle_bytes[1], + write_handle_bytes[2], + write_handle_bytes[3], + ]); + let read_path = if let Some(read_handle) = self.handles.get(&read_handle_id) { read_handle.path.clone() } else { - return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid read handle"); + return self.build_status_response( + id, + SftpStatus::SSH_FX_FAILURE, + "Invalid read handle", + ); }; - + let write_path = if let Some(write_handle) = self.handles.get(&write_handle_id) { write_handle.path.clone() } else { - return self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid write handle"); + return self.build_status_response( + id, + SftpStatus::SSH_FX_FAILURE, + "Invalid write handle", + ); }; - + let read_flags = OpenFlags::new().read(); let write_flags = OpenFlags::new().write(); - + let mut read_file = match self.vfs.open_file(&read_path, &read_flags) { Ok(f) => f, Err(e) => return self.build_status_from_vfs_error(id, &e), @@ -1358,32 +1440,40 @@ impl SftpHandler { Ok(f) => f, Err(e) => return self.build_status_from_vfs_error(id, &e), }; - - read_file.seek(SeekFrom::Start(read_offset)).map_err(|e| anyhow!("Seek error: {}", e))?; + + read_file + .seek(SeekFrom::Start(read_offset)) + .map_err(|e| anyhow!("Seek error: {}", e))?; let mut buffer = vec![0u8; actual_length as usize]; - read_file.read_exact(&mut buffer).map_err(|e| anyhow!("Read error: {}", e))?; - - write_file.seek(SeekFrom::Start(write_offset)).map_err(|e| anyhow!("Seek error: {}", e))?; - write_file.write_all(&buffer).map_err(|e| anyhow!("Write error: {}", e))?; + read_file + .read_exact(&mut buffer) + .map_err(|e| anyhow!("Read error: {}", e))?; + + write_file + .seek(SeekFrom::Start(write_offset)) + .map_err(|e| anyhow!("Seek error: {}", e))?; + write_file + .write_all(&buffer) + .map_err(|e| anyhow!("Write error: {}", e))?; write_file.flush().ok(); - + let mut response = Vec::new(); response.write_u8(SftpPacketType::SSH_FXP_EXTENDED_REPLY as u8)?; response.write_u32::(id)?; response.write_u64::(actual_length)?; - + self.wrap_sftp_packet(&response) } /// 解析路径(安全性检查,参考OpenSSH sftp-server.c: path_resolve()) - /// + /// /// 安全策略: /// - 相对路径:始终限制在 root_dir 之下 /// - 绝对路径(restrict_absolute=false):允许(用户依赖文件系统权限) /// - 绝对路径(restrict_absolute=true):限制在 root_dir 之下(chroot 模式) fn resolve_path(&self, path: &str) -> Result { info!("resolve_path: input={}, root_dir={:?}", path, self.root_dir); - + let full_path = if path.is_empty() || path == "." { self.root_dir.clone() } else if path.starts_with('/') { @@ -1391,39 +1481,52 @@ impl SftpHandler { } else { self.root_dir.join(path) }; - + info!("resolve_path: full_path={:?}", full_path); - + let is_absolute = path.starts_with('/'); - + // 检查路径遍历:对相对路径始终执行,对绝对路径仅在 restrict_absolute 模式下执行 let need_check = !is_absolute || self.restrict_absolute; - + if need_check { if full_path.exists() { - let canonical = full_path.canonicalize() + let canonical = full_path + .canonicalize() .map_err(|e| anyhow!("Path canonicalize error: {}", e))?; if !canonical.starts_with(&self.root_dir) { - return Err(anyhow!("Path traversal: {:?} not under {:?}", canonical, self.root_dir)); + return Err(anyhow!( + "Path traversal: {:?} not under {:?}", + canonical, + self.root_dir + )); } Ok(canonical) } else { // Pre-resolve parent directory for non-existent paths if let Some(parent) = full_path.parent() { if parent.exists() { - let canonical_parent = parent.canonicalize() + let canonical_parent = parent + .canonicalize() .map_err(|e| anyhow!("Parent canonicalize error: {}", e))?; - let resolved = canonical_parent.join( - full_path.file_name().unwrap_or_default() - ); + let resolved = + canonical_parent.join(full_path.file_name().unwrap_or_default()); if !resolved.starts_with(&self.root_dir) { - return Err(anyhow!("Path traversal: {:?} not under {:?}", resolved, self.root_dir)); + return Err(anyhow!( + "Path traversal: {:?} not under {:?}", + resolved, + self.root_dir + )); } return Ok(resolved); } } if !full_path.starts_with(&self.root_dir) { - return Err(anyhow!("Path traversal: {:?} not under {:?}", full_path, self.root_dir)); + return Err(anyhow!( + "Path traversal: {:?} not under {:?}", + full_path, + self.root_dir + )); } Ok(full_path) } @@ -1437,7 +1540,7 @@ impl SftpHandler { } } -/// 构建SSH_FXP_VERSION响应,包含扩展声明(参考OpenSSH sftp-server.c: process_init()) + /// 构建SSH_FXP_VERSION响应,包含扩展声明(参考OpenSSH sftp-server.c: process_init()) /// /// SFTP协议格式(draft-ietf-secsh-filexfer-02): /// uint32 length @@ -1451,10 +1554,10 @@ impl SftpHandler { /// Client reads with sshbuf_get_cstring() which expects \0 at end. fn build_version_response(&self, version: u32) -> Result> { let mut buffer = Vec::new(); - + buffer.write_u8(SftpPacketType::SSH_FXP_VERSION as u8)?; buffer.write_u32::(version)?; - + // 扩展声明 — OpenSSH sftp-server.c: process_init() style, NO count field let extensions: &[(&str, &str)] = &[ ("posix-rename@openssh.com", "1"), @@ -1477,38 +1580,38 @@ impl SftpHandler { buffer.write_all(data.as_bytes())?; buffer.write_u8(0)?; } - + self.wrap_sftp_packet(&buffer) } - + /// 构建SSH_FXP_STATUS响应(参考OpenSSH sftp-server.c) fn build_status_response(&self, id: u32, status: SftpStatus, message: &str) -> Result> { let mut buffer = Vec::new(); - + buffer.write_u8(SftpPacketType::SSH_FXP_STATUS as u8)?; buffer.write_u32::(id)?; buffer.write_u32::(status as u32)?; - + buffer.write_u32::(message.len() as u32)?; buffer.write_all(message.as_bytes())?; - + buffer.write_u32::(0)?; - + self.wrap_sftp_packet(&buffer) } - + /// 构建SSH_FXP_HANDLE响应(参考OpenSSH sftp-server.c) fn build_handle_response(&self, id: u32, handle: &[u8]) -> Result> { let mut buffer = Vec::new(); - + buffer.write_u8(SftpPacketType::SSH_FXP_HANDLE as u8)?; buffer.write_u32::(id)?; buffer.write_u32::(handle.len() as u32)?; buffer.write_all(handle)?; - + self.wrap_sftp_packet(&buffer) } - + /// Phase 7: 包装SFTP packet为SSH string格式(uint32(length) + packet_type + payload) fn wrap_sftp_packet(&self, packet_data: &[u8]) -> Result> { let mut response = Vec::new(); @@ -1516,50 +1619,50 @@ impl SftpHandler { response.write_all(packet_data)?; Ok(response) } - + /// 构建SSH_FXP_DATA响应(参考OpenSSH sftp-server.c) fn build_data_response(&self, id: u32, data: &[u8]) -> Result> { let mut buffer = Vec::new(); - + buffer.write_u8(SftpPacketType::SSH_FXP_DATA as u8)?; buffer.write_u32::(id)?; - + buffer.write_u32::(data.len() as u32)?; buffer.write_all(data)?; - + self.wrap_sftp_packet(&buffer) } - + /// 构建SSH_FXP_NAME响应(参考OpenSSH sftp-server.c) fn build_name_response(&self, id: u32, entries: Vec<(String, SftpAttrs)>) -> Result> { let mut buffer = Vec::new(); - + buffer.write_u8(SftpPacketType::SSH_FXP_NAME as u8)?; buffer.write_u32::(id)?; buffer.write_u32::(entries.len() as u32)?; - + for (name, attrs) in entries { buffer.write_u32::(name.len() as u32)?; buffer.write_all(name.as_bytes())?; - + let long_name = name.clone(); buffer.write_u32::(long_name.len() as u32)?; buffer.write_all(long_name.as_bytes())?; - + buffer.write_all(&attrs.serialize()?)?; } - + self.wrap_sftp_packet(&buffer) } - + /// 构建SSH_FXP_ATTRS响应(参考OpenSSH sftp-server.c) fn build_attrs_response(&self, id: u32, attrs: &SftpAttrs) -> Result> { let mut buffer = Vec::new(); - + buffer.write_u8(SftpPacketType::SSH_FXP_ATTRS as u8)?; buffer.write_u32::(id)?; buffer.write_all(&attrs.serialize()?)?; - + self.wrap_sftp_packet(&buffer) } @@ -1575,7 +1678,7 @@ impl SftpHandler { _ => SftpStatus::SSH_FX_FAILURE, } } - + /// 根据 Error 构建状态响应(自动映射错误类型) fn build_status_from_io_error(&self, id: u32, err: &std::io::Error) -> Result> { let status = Self::map_io_error_kind(err); @@ -1623,25 +1726,25 @@ fn read_sftp_attrs(reader: &mut R) -> Result { let flags = reader.read_u32::()?; let mut attrs = SftpAttrs::new(); attrs.flags = flags; - + if flags & SftpAttrFlags::SSH_FILEXFER_ATTR_SIZE != 0 { attrs.size = Some(reader.read_u64::()?); } - + if flags & SftpAttrFlags::SSH_FILEXFER_ATTR_UIDGID != 0 { attrs.uid = Some(reader.read_u32::()?); attrs.gid = Some(reader.read_u32::()?); } - + if flags & SftpAttrFlags::SSH_FILEXFER_ATTR_PERMISSIONS != 0 { attrs.permissions = Some(reader.read_u32::()?); } - + if flags & SftpAttrFlags::SSH_FILEXFER_ATTR_ACMODTIME != 0 { attrs.atime = Some(reader.read_u32::()?); attrs.mtime = Some(reader.read_u32::()?); } - + if flags & SftpAttrFlags::SSH_FILEXFER_ATTR_EXTENDED != 0 { let count = reader.read_u32::()?; for _ in 0..count { @@ -1650,7 +1753,7 @@ fn read_sftp_attrs(reader: &mut R) -> Result { attrs.extended.push((name, value)); } } - + Ok(attrs) } @@ -1660,57 +1763,77 @@ mod tests { use crate::vfs::local_fs::LocalFs; use std::fs::File; use tempfile::TempDir; - + fn make_handler(root_dir: PathBuf) -> SftpHandler { SftpHandler::new(root_dir, Box::new(LocalFs::new()), 32768) } - + #[test] fn test_sftp_packet_type_conversion() { - assert_eq!(SftpPacketType::try_from(1).unwrap(), SftpPacketType::SSH_FXP_INIT); - assert_eq!(SftpPacketType::try_from(2).unwrap(), SftpPacketType::SSH_FXP_VERSION); - assert_eq!(SftpPacketType::try_from(3).unwrap(), SftpPacketType::SSH_FXP_OPEN); + assert_eq!( + SftpPacketType::try_from(1).unwrap(), + SftpPacketType::SSH_FXP_INIT + ); + assert_eq!( + SftpPacketType::try_from(2).unwrap(), + SftpPacketType::SSH_FXP_VERSION + ); + assert_eq!( + SftpPacketType::try_from(3).unwrap(), + SftpPacketType::SSH_FXP_OPEN + ); } - + #[test] fn test_sftp_handler_creation() { let temp_dir = TempDir::new().unwrap(); let handler = make_handler(temp_dir.path().to_path_buf()); assert_eq!(handler.next_handle_id, 0); } - + #[test] fn test_sftp_attrs_from_metadata() { let temp_dir = TempDir::new().unwrap(); let file_path = temp_dir.path().join("test.txt"); File::create(&file_path).unwrap(); - + let metadata = fs::metadata(&file_path).unwrap(); let attrs = SftpAttrs::from_metadata(&metadata); - + assert!(attrs.size.is_some()); assert!(attrs.permissions.is_some()); } - + #[test] fn test_sftp_handle_init() { let temp_dir = TempDir::new().unwrap(); let mut handler = make_handler(temp_dir.path().to_path_buf()); - + // SSH_FXP_INIT packet format: type(1) + version(4) // Version 3 in big-endian: [0, 0, 0, 3] let init_packet = vec![SftpPacketType::SSH_FXP_INIT as u8, 0, 0, 0, 3]; let response = handler.handle_request(&init_packet).unwrap(); - + // Response format: length(4) + type(1) + version(4) + extensions // The actual SSH_FXP_VERSION is at byte 4 (after length prefix) - assert!(response.len() >= 5, "Response should have length prefix + type"); - + assert!( + response.len() >= 5, + "Response should have length prefix + type" + ); + // Read length prefix let length = u32::from_be_bytes([response[0], response[1], response[2], response[3]]); - assert_eq!(length as usize + 4, response.len(), "Length should match packet size"); - + assert_eq!( + length as usize + 4, + response.len(), + "Length should match packet size" + ); + // Packet type should be SSH_FXP_VERSION (2) at byte 4 - assert_eq!(response[4], SftpPacketType::SSH_FXP_VERSION as u8, "Packet type should be SSH_FXP_VERSION"); + assert_eq!( + response[4], + SftpPacketType::SSH_FXP_VERSION as u8, + "Packet type should be SSH_FXP_VERSION" + ); } -} \ No newline at end of file +} diff --git a/markbase-core/src/ssh_server/ssh_security_config.rs b/markbase-core/src/ssh_server/ssh_security_config.rs index be47ccd..e4eba94 100644 --- a/markbase-core/src/ssh_server/ssh_security_config.rs +++ b/markbase-core/src/ssh_server/ssh_security_config.rs @@ -1,7 +1,7 @@ // SSH企业级安全配置(Phase 13.1) // 参考OpenSSH sshd_config安全配置 -use anyhow::{Result, anyhow}; +use anyhow::{anyhow, Result}; use log::{info, warn}; use std::fs; use std::path::Path; @@ -14,25 +14,25 @@ pub struct SshSecurityConfig { /// false: 只绑定127.0.0.1(安全) /// true: 允许绑定0.0.0.0(危险) pub gateway_ports: bool, - + /// PermitOpen白名单 /// ["localhost:3000", "localhost:4000", "localhost:*"] /// 空数组表示允许所有目标(不安全) pub permit_open: Vec, - + /// AllowTcpForwarding配置 /// true: 允许TCP转发 /// false: 禁止所有TCP转发 pub allow_tcp_forwarding: bool, - + /// MaxSessions限制 /// 最大会话数,防止资源耗尽 pub max_sessions: u32, - + /// ConnectTimeout超时(秒) /// 连接超时设置,防止悬挂连接 pub connect_timeout: u64, - + /// 活动会话数(运行时状态) pub active_sessions: u32, } @@ -42,110 +42,125 @@ impl SshSecurityConfig { /// 参考:OpenSSH企业级生产环境配置 pub fn enterprise_default() -> Self { Self { - gateway_ports: false, // 安全:只绑定127.0.0.1 - permit_open: vec!["localhost:*".to_string()], // 限制转发目标(白名单) - allow_tcp_forwarding: true, // 允许TCP转发 - max_sessions: 10, // 最多10个会话 - connect_timeout: 30, // 30秒超时 - active_sessions: 0, // 运行时状态 + gateway_ports: false, // 安全:只绑定127.0.0.1 + permit_open: vec!["localhost:*".to_string()], // 限制转发目标(白名单) + allow_tcp_forwarding: true, // 允许TCP转发 + max_sessions: 10, // 最多10个会话 + connect_timeout: 30, // 30秒超时 + active_sessions: 0, // 运行时状态 } } - + /// 开发环境默认配置(宽松) pub fn development_default() -> Self { Self { - gateway_ports: true, // 开发:允许0.0.0.0 - permit_open: vec![], // 开发:允许所有目标 + gateway_ports: true, // 开发:允许0.0.0.0 + permit_open: vec![], // 开发:允许所有目标 allow_tcp_forwarding: true, - max_sessions: 20, // 开发:更多会话 - connect_timeout: 60, // 开发:更长超时 + max_sessions: 20, // 开发:更多会话 + connect_timeout: 60, // 开发:更长超时 active_sessions: 0, } } - + /// 从JSON配置文件加载 pub fn load_from_file(path: &str) -> Result { if !Path::new(path).exists() { info!("SSH security config file not found, using enterprise default"); return Ok(Self::enterprise_default()); } - + let config_str = fs::read_to_string(path)?; let config: serde_json::Value = serde_json::from_str(&config_str)?; - - let security = config.get("ssh_server") + + let security = config + .get("ssh_server") .and_then(|s| s.get("security")) .ok_or_else(|| anyhow!("Invalid config structure"))?; - + Ok(Self { - gateway_ports: security.get("gateway_ports") + gateway_ports: security + .get("gateway_ports") .and_then(|v| v.as_bool()) .unwrap_or(false), - permit_open: security.get("permit_open") + permit_open: security + .get("permit_open") .and_then(|v| v.as_array()) - .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }) .unwrap_or_else(|| vec!["localhost:*".to_string()]), - allow_tcp_forwarding: security.get("allow_tcp_forwarding") + allow_tcp_forwarding: security + .get("allow_tcp_forwarding") .and_then(|v| v.as_bool()) .unwrap_or(true), - max_sessions: security.get("max_sessions") + max_sessions: security + .get("max_sessions") .and_then(|v| v.as_u64()) .map(|v| v as u32) .unwrap_or(10), - connect_timeout: security.get("connect_timeout") + connect_timeout: security + .get("connect_timeout") .and_then(|v| v.as_u64()) .unwrap_or(30), active_sessions: 0, }) } - + /// 验证tcpip-forward请求(安全检查) /// 参考OpenSSH auth2.c: ssh_forwarding_check() - pub fn validate_tcpip_forward_request( - &self, - bind_address: &str, - bind_port: u32, - ) -> Result<()> { - info!("Validating tcpip-forward request: bind_address={}, bind_port={}", bind_address, bind_port); - + pub fn validate_tcpip_forward_request(&self, bind_address: &str, bind_port: u32) -> Result<()> { + info!( + "Validating tcpip-forward request: bind_address={}, bind_port={}", + bind_address, bind_port + ); + // 1. AllowTcpForwarding检查 if !self.allow_tcp_forwarding { warn!("TCP forwarding disabled by security config"); return Err(anyhow!("TCP forwarding disabled by AllowTcpForwarding=no")); } - + // 2. GatewayPorts检查 if !self.gateway_ports { // 只允许绑定到127.0.0.1或localhost - if bind_address != "127.0.0.1" && bind_address != "localhost" && bind_address != "" { - warn!("GatewayPorts disabled, bind_address {} not allowed", bind_address); + if bind_address != "127.0.0.1" && bind_address != "localhost" && !bind_address.is_empty() { + warn!( + "GatewayPorts disabled, bind_address {} not allowed", + bind_address + ); return Err(anyhow!("GatewayPorts=no, only 127.0.0.1 allowed")); } info!("GatewayPorts check passed: bind_address={}", bind_address); } - + // 3. MaxSessions检查 if self.active_sessions >= self.max_sessions { - warn!("Max sessions limit reached: {} >= {}", self.active_sessions, self.max_sessions); + warn!( + "Max sessions limit reached: {} >= {}", + self.active_sessions, self.max_sessions + ); return Err(anyhow!("Max sessions limit reached: {}", self.max_sessions)); } - + // 4. 特权端口检查(防止<1024) if bind_port < 1024 { warn!("Cannot bind to privileged port: {}", bind_port); return Err(anyhow!("Cannot bind to privileged port < 1024")); } - + // 5. 端口范围检查(防止过大端口) if bind_port > 65535 { warn!("Invalid port number: {}", bind_port); return Err(anyhow!("Invalid port number > 65535")); } - + info!("tcpip-forward request validated successfully"); Ok(()) } - + /// 验证direct-tcpip channel请求(安全检查) /// 参考OpenSSH channels.c: channel_connect_direct_tcpip() pub fn validate_direct_tcpip_channel( @@ -153,14 +168,17 @@ impl SshSecurityConfig { host_to_connect: &str, port_to_connect: u32, ) -> Result<()> { - info!("Validating direct-tcpip channel: host={}, port={}", host_to_connect, port_to_connect); - + info!( + "Validating direct-tcpip channel: host={}, port={}", + host_to_connect, port_to_connect + ); + // 1. AllowTcpForwarding检查 if !self.allow_tcp_forwarding { warn!("TCP forwarding disabled by security config"); return Err(anyhow!("TCP forwarding disabled by AllowTcpForwarding=no")); } - + // 2. PermitOpen白名单检查 if !self.permit_open.is_empty() { let target = format!("{}:{}", host_to_connect, port_to_connect); @@ -173,28 +191,34 @@ impl SshSecurityConfig { target == *pattern } }); - + if !allowed { - warn!("Target {}:{} not in PermitOpen whitelist", host_to_connect, port_to_connect); - return Err(anyhow!("Target {}:{} not in PermitOpen whitelist", - host_to_connect, port_to_connect)); + warn!( + "Target {}:{} not in PermitOpen whitelist", + host_to_connect, port_to_connect + ); + return Err(anyhow!( + "Target {}:{} not in PermitOpen whitelist", + host_to_connect, + port_to_connect + )); } info!("PermitOpen check passed: target={}", target); } else { // permit_open为空,允许所有目标(不安全,仅用于开发) info!("PermitOpen whitelist empty, allowing all targets (development mode)"); } - + // 3. 端口范围检查 - if port_to_connect < 1 || port_to_connect > 65535 { + if !(1..=65535).contains(&port_to_connect) { warn!("Invalid port number: {}", port_to_connect); return Err(anyhow!("Invalid port number: {}", port_to_connect)); } - + info!("direct-tcpip channel validated successfully"); Ok(()) } - + /// 增加活动会话数 pub fn increment_sessions(&mut self) -> Result<()> { if self.active_sessions >= self.max_sessions { @@ -204,7 +228,7 @@ impl SshSecurityConfig { info!("Active sessions: {}", self.active_sessions); Ok(()) } - + /// 减少活动会话数 pub fn decrement_sessions(&mut self) { if self.active_sessions > 0 { @@ -217,56 +241,76 @@ impl SshSecurityConfig { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_enterprise_default_config() { let config = SshSecurityConfig::enterprise_default(); - + assert_eq!(config.gateway_ports, false); assert_eq!(config.permit_open, vec!["localhost:*".to_string()]); assert_eq!(config.allow_tcp_forwarding, true); assert_eq!(config.max_sessions, 10); assert_eq!(config.connect_timeout, 30); } - + #[test] fn test_validate_tcpip_forward_request() { let config = SshSecurityConfig::enterprise_default(); - + // 正常请求应该通过 - assert!(config.validate_tcpip_forward_request("127.0.0.1", 8080).is_ok()); - assert!(config.validate_tcpip_forward_request("localhost", 8080).is_ok()); - + assert!(config + .validate_tcpip_forward_request("127.0.0.1", 8080) + .is_ok()); + assert!(config + .validate_tcpip_forward_request("localhost", 8080) + .is_ok()); + // GatewayPorts=false时,0.0.0.0应该被拒绝 - assert!(config.validate_tcpip_forward_request("0.0.0.0", 8080).is_err()); - + assert!(config + .validate_tcpip_forward_request("0.0.0.0", 8080) + .is_err()); + // 特权端口应该被拒绝 - assert!(config.validate_tcpip_forward_request("127.0.0.1", 80).is_err()); + assert!(config + .validate_tcpip_forward_request("127.0.0.1", 80) + .is_err()); } - + #[test] fn test_validate_direct_tcpip_channel() { let config = SshSecurityConfig::enterprise_default(); - + // localhost:*应该通过(通配符匹配) - assert!(config.validate_direct_tcpip_channel("localhost", 3000).is_ok()); - assert!(config.validate_direct_tcpip_channel("localhost", 4000).is_ok()); - + assert!(config + .validate_direct_tcpip_channel("localhost", 3000) + .is_ok()); + assert!(config + .validate_direct_tcpip_channel("localhost", 4000) + .is_ok()); + // 其他host应该被拒绝 - assert!(config.validate_direct_tcpip_channel("192.168.1.100", 3000).is_err()); - assert!(config.validate_direct_tcpip_channel("example.com", 80).is_err()); + assert!(config + .validate_direct_tcpip_channel("192.168.1.100", 3000) + .is_err()); + assert!(config + .validate_direct_tcpip_channel("example.com", 80) + .is_err()); } - + #[test] fn test_development_default_config() { let config = SshSecurityConfig::development_default(); - + assert_eq!(config.gateway_ports, true); assert_eq!(config.permit_open.len(), 0); // 空数组表示允许所有 assert_eq!(config.max_sessions, 20); - + // 开发配置应该允许所有请求 - assert!(config.validate_tcpip_forward_request("0.0.0.0", 8080).is_ok()); - assert!(config.validate_direct_tcpip_channel("example.com", 80).is_ok()); + assert!(config + .validate_tcpip_forward_request("0.0.0.0", 8080) + .is_ok()); + assert!(config + .validate_direct_tcpip_channel("example.com", 80) + .is_ok()); } } diff --git a/markbase-core/src/ssh_server/sshbuf.rs b/markbase-core/src/ssh_server/sshbuf.rs index 2015d36..386c4a3 100644 --- a/markbase-core/src/ssh_server/sshbuf.rs +++ b/markbase-core/src/ssh_server/sshbuf.rs @@ -1,11 +1,11 @@ // SSH Buffer 零拷贝实现(参考 OpenSSH sshbuf.c) // 提供高效的 buffer 管理,消除临时 buffer -use anyhow::{Result, anyhow}; +use anyhow::{anyhow, Result}; use std::io::{Read, Write}; /// SSH Buffer(参考 OpenSSH struct sshbuf) -/// +/// /// OpenSSH 实现: /// ```c /// struct sshbuf { @@ -16,10 +16,10 @@ use std::io::{Read, Write}; /// }; /// ``` pub struct SshBuf { - data: Vec, // Data buffer (对应 OpenSSH buf->d) - off: usize, // Offset (对应 OpenSSH buf->off) - size: usize, // Size (对应 OpenSSH buf->size) - max_size: usize, // Maximum size (对应 OpenSSH buf->max_size) + data: Vec, // Data buffer (对应 OpenSSH buf->d) + off: usize, // Offset (对应 OpenSSH buf->off) + size: usize, // Size (对应 OpenSSH buf->size) + max_size: usize, // Maximum size (对应 OpenSSH buf->max_size) } impl SshBuf { @@ -32,7 +32,7 @@ impl SshBuf { max_size: 128 * 1024 * 1024, // 128MB (OpenSSH SSHBUF_SIZE_MAX) } } - + /// 创建指定大小的 SSH Buffer pub fn with_capacity(capacity: usize) -> Self { Self { @@ -42,7 +42,7 @@ impl SshBuf { max_size: 128 * 1024 * 1024, } } - + /// 设置最大大小 pub fn set_max_size(&mut self, max_size: usize) -> Result<()> { if max_size > 128 * 1024 * 1024 { @@ -51,47 +51,47 @@ impl SshBuf { self.max_size = max_size; Ok(()) } - + /// 获取 buffer 长度(对应 OpenSSH sshbuf_len) - /// + /// /// OpenSSH: `sshbuf_len = buf->size - buf->off` pub fn len(&self) -> usize { self.size - self.off } - + /// 检查 buffer 是否为空 pub fn is_empty(&self) -> bool { self.len() == 0 } - + /// 获取可用空间(对应 OpenSSH sshbuf_avail) - /// + /// /// OpenSSH: `sshbuf_avail = buf->max_size - buf->size` pub fn avail(&self) -> usize { self.max_size - self.size } - + /// 获取可变指针(对应 OpenSSH sshbuf_mutable_ptr) - /// + /// /// OpenSSH 实现: /// ```c /// u_char *sshbuf_mutable_ptr(const struct sshbuf *buf) { /// return buf->d + buf->off; /// } /// ``` - /// + /// /// Rust 实现:返回 `&mut [u8]` slice(零拷贝) pub fn mutable_ptr(&mut self) -> &mut [u8] { &mut self.data[self.off..self.size] } - + /// 获取不可变指针(对应 OpenSSH sshbuf_ptr) pub fn ptr(&self) -> &[u8] { &self.data[self.off..self.size] } - + /// 预分配空间(对应 OpenSSH sshbuf_reserve) - /// + /// /// OpenSSH 实现: /// ```c /// int sshbuf_reserve(struct sshbuf *buf, size_t len, u_char **dpp) { @@ -104,31 +104,31 @@ impl SshBuf { /// return 0; /// } /// ``` - /// + /// /// Rust 实现:返回 `&mut [u8]` slice(零拷贝,可直接 write) pub fn reserve(&mut self, len: usize) -> Result<&mut [u8]> { if len > self.avail() { return Err(anyhow!("no buffer space (avail={})", self.avail())); } - + // 预分配空间 let current_size = self.size; let new_size = current_size + len; - + // 确保 Vec 有足够容量 if new_size > self.data.len() { self.data.resize(new_size, 0); } - + // 更新 size self.size = new_size; - + // 返回新空间的 slice(零拷贝) Ok(&mut self.data[current_size..new_size]) } - + /// 消费数据(对应 OpenSSH sshbuf_consume) - /// + /// /// OpenSSH 实现: /// ```c /// int sshbuf_consume(struct sshbuf *buf, size_t len) { @@ -140,29 +140,33 @@ impl SshBuf { /// return 0; /// } /// ``` - /// + /// /// Rust 实现:移动偏移量(零拷贝,不实际删除数据) pub fn consume(&mut self, len: usize) -> Result<()> { if len > self.len() { - return Err(anyhow!("message incomplete (len={}, consume={})", self.len(), len)); + return Err(anyhow!( + "message incomplete (len={}, consume={})", + self.len(), + len + )); } - + self.off += len; - + // 如果 buffer 空,重置 if self.off == self.size { self.off = 0; self.size = 0; - + // OpenSSH: pack buffer(移除已消费的数据) // Rust: 我们保留 Vec,但重置指针 } - + Ok(()) } - + /// 从末尾消费数据(对应 OpenSSH sshbuf_consume_end) - /// + /// /// OpenSSH 实现: /// ```c /// int sshbuf_consume_end(struct sshbuf *buf, size_t len) { @@ -174,13 +178,13 @@ impl SshBuf { if len > self.len() { return Err(anyhow!("message incomplete")); } - + self.size -= len; Ok(()) } - + /// 直接从 fd read 到 buffer(对应 OpenSSH sshbuf_read) - /// + /// /// OpenSSH 实现: /// ```c /// int sshbuf_read(int fd, struct sshbuf *buf, size_t maxlen, size_t *rlen) { @@ -195,71 +199,75 @@ impl SshBuf { /// return 0; /// } /// ``` - /// + /// /// Rust 实现:零拷贝,直接 read 到 buffer pub fn read_from(&mut self, reader: &mut R, maxlen: usize) -> Result { // 1. reserve 空间 let space = self.reserve(maxlen)?; - + // 2. 直接 read 到 buffer(零拷贝) let n = reader.read(space)?; - + // 3. 调整大小(移除未使用的空间) if maxlen > n { self.consume_end(maxlen - n)?; } - + Ok(n) } - + /// 直接从 buffer write 到 fd(对应 OpenSSH channel_handle_wfd) - /// + /// /// OpenSSH 实现: /// ```c /// buf = sshbuf_mutable_ptr(c->output); // 获取指针 /// len = write(c->wfd, buf, dlen); // 直接 write /// sshbuf_consume(c->output, len); // 消费已写入的数据 /// ``` - /// + /// /// Rust 实现:零拷贝,直接 write 从 buffer pub fn write_to(&mut self, writer: &mut W) -> Result { if self.is_empty() { return Ok(0); } - + // 1. 获取数据指针(零拷贝) let data = self.ptr(); - + // 2. 直接 write(零拷贝) let n = writer.write(data)?; - + // 3. 消费已写入的数据(零拷贝,只移动偏移) self.consume(n)?; - + Ok(n) } - + /// 添加数据(对应 OpenSSH sshbuf_put) - /// + /// /// 用于不需要零拷贝的场景 pub fn put(&mut self, data: &[u8]) -> Result<()> { let space = self.reserve(data.len())?; space.copy_from_slice(data); Ok(()) } - + /// 清空 buffer pub fn reset(&mut self) { self.off = 0; self.size = 0; // OpenSSH: 保留 Vec,只重置指针 } - + /// Debug: 打印 buffer 状态 pub fn debug_info(&self) -> String { format!( "SshBuf: off={}, size={}, len={}, alloc={}, max_size={}", - self.off, self.size, self.len(), self.data.len(), self.max_size + self.off, + self.size, + self.len(), + self.data.len(), + self.max_size ) } } @@ -274,11 +282,11 @@ impl Default for SshBuf { mod tests { use super::*; use std::io::Cursor; - + #[test] fn test_sshbuf_basic() { let mut buf = SshBuf::new(); - + // Test reserve - write into reserved space { let space = buf.reserve(10).unwrap(); @@ -286,57 +294,57 @@ mod tests { space[0] = 1; space[1] = 2; } // space dropped, buf accessible - + // Verify buffer length after reserve assert_eq!(buf.len(), 10); let ptr = buf.mutable_ptr(); assert_eq!(ptr[0], 1); assert_eq!(ptr[1], 2); - + // Test consume buf.consume(2).unwrap(); assert_eq!(buf.len(), 8); assert_eq!(buf.off, 2); } - + #[test] fn test_sshbuf_zero_copy_read() { let mut buf = SshBuf::with_capacity(100); let mut reader = Cursor::new("hello world"); - + // 零拷贝 read let n = buf.read_from(&mut reader, 20).unwrap(); assert_eq!(n, 11); // "hello world" length assert_eq!(buf.len(), 11); - + // 检查数据 let data = buf.ptr(); assert_eq!(data, "hello world".as_bytes()); } - + #[test] fn test_sshbuf_zero_copy_write() { let mut buf = SshBuf::new(); buf.put("hello world".as_bytes()).unwrap(); - + let mut writer = Vec::new(); - + // 零拷贝 write let n = buf.write_to(&mut writer).unwrap(); assert_eq!(n, 11); assert_eq!(buf.len(), 0); // 已消费 - + // 检查数据 assert_eq!(writer, "hello world".as_bytes()); } - + #[test] fn test_sshbuf_max_size() { let mut buf = SshBuf::new(); buf.set_max_size(1000).unwrap(); - + // 尝试 reserve 超过 max_size let result = buf.reserve(2000); assert!(result.is_err()); } -} \ No newline at end of file +} diff --git a/markbase-core/src/ssh_server/version.rs b/markbase-core/src/ssh_server/version.rs index 5adfe9c..f7d8e5a 100644 --- a/markbase-core/src/ssh_server/version.rs +++ b/markbase-core/src/ssh_server/version.rs @@ -2,8 +2,8 @@ // 参考OpenSSH sshd.c: ssh_exchange_identification() use anyhow::Result; +use log::{debug, info}; use std::io::{Read, Write}; -use log::{info, debug}; /// SSH版本字符串 pub const SSH_VERSION: &str = "SSH-2.0-MarkBaseSSH_1.0"; @@ -15,93 +15,96 @@ impl VersionExchange { /// 执行版本交换(服务器端) pub fn exchange(stream: &mut T) -> Result { info!("Starting SSH version exchange"); - + // 1. 发送服务器版本 Self::send_version(stream)?; - + // 2. 接收客户端版本 let client_version = Self::receive_version(stream)?; - - info!("Version exchange completed: server={}, client={}", SSH_VERSION, client_version); + + info!( + "Version exchange completed: server={}, client={}", + SSH_VERSION, client_version + ); Ok(client_version) } - + /// 发送服务器版本(参考OpenSSH ssh_exchange_identification) fn send_version(stream: &mut T) -> Result<()> { let version_line = format!("{}\r\n", SSH_VERSION); stream.write_all(version_line.as_bytes())?; stream.flush()?; - + debug!("Sent version: {}", SSH_VERSION); Ok(()) } - + /// 接收客户端版本(参考OpenSSH ssh_exchange_identification) fn receive_version(stream: &mut T) -> Result { let mut buffer = Vec::new(); let mut byte = [0u8; 1]; - + // 读取直到遇到'\n'(参考OpenSSH实现) loop { stream.read_exact(&mut byte)?; - + // OpenSSH兼容性处理:跳过前导空行和调试信息 - if buffer.is_empty() && byte[0] == '\n' as u8 { - continue; // 跳过空行 + if buffer.is_empty() && byte[0] == b'\n' { + continue; // 跳过空行 } - + // 调试信息行(以'#'开头),跳过 - if buffer.is_empty() && byte[0] == '#' as u8 { + if buffer.is_empty() && byte[0] == b'#' { // 读取整行调试信息 - while byte[0] != '\n' as u8 { + while byte[0] != b'\n' { stream.read_exact(&mut byte)?; } buffer.clear(); continue; } - + buffer.push(byte[0]); - + // 遇到'\n'结束 - if byte[0] == '\n' as u8 { + if byte[0] == b'\n' { break; } - + // 缓冲区溢出保护(OpenSSH限制:255字节) if buffer.len() > 255 { return Err(anyhow::anyhow!("Version string too long")); } } - + // 解析版本字符串 let version_line = String::from_utf8(buffer)?; let version = version_line.trim().trim_matches('\r'); - + // 验证版本格式(SSH-2.0-*) if !version.starts_with("SSH-2.0-") { return Err(anyhow::anyhow!("Invalid SSH version: {}", version)); } - + debug!("Received version: {}", version); Ok(version.to_string()) } - + /// 解析客户端版本信息(兼容性检查) pub fn parse_client_version(version: &str) -> Result { // 格式:SSH-protoversion-softwareversion SP comments let parts: Vec<&str> = version.split_whitespace().collect(); - + let main_part = parts.first().map_or(version, |v| v); let dash_parts: Vec<&str> = main_part.split('-').collect(); - + if dash_parts.len() < 3 { return Err(anyhow::anyhow!("Invalid version format: {}", version)); } - + let proto_version = dash_parts.get(1).map_or("2.0", |v| v); let software_version = dash_parts.get(2).map_or("unknown", |v| v); let comments = parts.get(1).map(|s| s.to_string()); - + Ok(ClientVersionInfo { proto_version: proto_version.to_string(), software_version: software_version.to_string(), @@ -120,12 +123,12 @@ pub struct ClientVersionInfo { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_version_format() { assert!(SSH_VERSION.starts_with("SSH-2.0-")); } - + #[test] fn test_parse_client_version() { let version = "SSH-2.0-OpenSSH_10.2"; diff --git a/markbase-core/src/ssh_server/window_manager.rs b/markbase-core/src/ssh_server/window_manager.rs index bb4bfbe..d0c82ba 100644 --- a/markbase-core/src/ssh_server/window_manager.rs +++ b/markbase-core/src/ssh_server/window_manager.rs @@ -1,18 +1,18 @@ // SSH Window Size管理(Phase 13.6) // 参考RFC 4254 Section 5.2: Window Size Adjustment -use anyhow::{Result, anyhow}; -use log::{info, warn, debug}; -use std::sync::{Arc, Mutex}; -use byteorder::{BigEndian, WriteBytesExt}; use crate::ssh_server::packet::PacketType; +use anyhow::{anyhow, Result}; +use byteorder::{BigEndian, WriteBytesExt}; +use log::{info, warn}; +use std::sync::{Arc, Mutex}; /// Window Size管理器(Phase 13.6) pub struct WindowManager { - initial_window_size: u32, // RFC 4254: 2MB默认 + initial_window_size: u32, // RFC 4254: 2MB默认 current_window_size: Arc>, - max_packet_size: u32, // RFC 4254: 32KB默认 - consumed_bytes: Arc>, // 已消耗bytes统计 + max_packet_size: u32, // RFC 4254: 32KB默认 + consumed_bytes: Arc>, // 已消耗bytes统计 } impl WindowManager { @@ -25,89 +25,103 @@ impl WindowManager { consumed_bytes: Arc::new(Mutex::new(0)), } } - + /// RFC 4254默认window size(2MB) pub fn rfc_default() -> Self { - Self::new(2097152, 32768) // 2MB window, 32KB packet + Self::new(2097152, 32768) // 2MB window, 32KB packet } - + /// 检查window size是否足够(Phase 13.6) pub fn check_window_available(&self, data_size: u32) -> bool { let window = self.current_window_size.lock().unwrap(); let available = *window >= data_size; - + if !available { - warn!("Window size insufficient: need {}, have {}", data_size, *window); + warn!( + "Window size insufficient: need {}, have {}", + data_size, *window + ); } - + available } - + /// 消耗window size(Phase 13.6:发送数据后) pub fn consume_window(&self, data_size: u32) -> Result<()> { let mut window = self.current_window_size.lock().unwrap(); - + if *window < data_size { - return Err(anyhow!("Window size insufficient: need {}, have {}", data_size, *window)); + return Err(anyhow!( + "Window size insufficient: need {}, have {}", + data_size, + *window + )); } - + *window -= data_size; - + // 统计已消耗bytes let mut consumed = self.consumed_bytes.lock().unwrap(); *consumed += data_size; - - info!("Window size consumed: {} bytes, remaining {}, total consumed {}", - data_size, *window, *consumed); - + + info!( + "Window size consumed: {} bytes, remaining {}, total consumed {}", + data_size, *window, *consumed + ); + Ok(()) } - + /// 调整window size(Phase 13.6:收到SSH_MSG_CHANNEL_WINDOW_ADJUST) pub fn adjust_window(&self, bytes_to_add: u32) { let mut window = self.current_window_size.lock().unwrap(); *window += bytes_to_add; - - info!("Window size adjusted: added {} bytes, total {}", bytes_to_add, *window); + + info!( + "Window size adjusted: added {} bytes, total {}", + bytes_to_add, *window + ); } - + /// 构建SSH_MSG_CHANNEL_WINDOW_ADJUST packet(Phase 13.6) pub fn build_window_adjust_packet(channel_id: u32, bytes_to_add: u32) -> Result> { let mut packet = Vec::new(); - + // Packet type: SSH_MSG_CHANNEL_WINDOW_ADJUST (type 93) packet.write_u8(PacketType::SSH_MSG_CHANNEL_WINDOW_ADJUST as u8)?; - + // Recipient channel ID packet.write_u32::(channel_id)?; - + // Bytes to add packet.write_u32::(bytes_to_add)?; - - info!("Built SSH_MSG_CHANNEL_WINDOW_ADJUST for channel {}: +{} bytes", - channel_id, bytes_to_add); - + + info!( + "Built SSH_MSG_CHANNEL_WINDOW_ADJUST for channel {}: +{} bytes", + channel_id, bytes_to_add + ); + Ok(packet) } - + /// 获取当前window size(Phase 13.6) pub fn get_current_window(&self) -> u32 { *self.current_window_size.lock().unwrap() } - + /// 获取已消耗bytes(Phase 13.6) pub fn get_consumed_bytes(&self) -> u32 { *self.consumed_bytes.lock().unwrap() } - + /// 重置window size(Phase 13.6:channel重置) pub fn reset_window(&self) { let mut window = self.current_window_size.lock().unwrap(); *window = self.initial_window_size; - + let mut consumed = self.consumed_bytes.lock().unwrap(); *consumed = 0; - + info!("Window size reset to initial: {}", self.initial_window_size); } } @@ -128,63 +142,63 @@ impl ChannelLifecycle { close_received: false, } } - + /// 构建SSH_MSG_CHANNEL_EOF packet(Phase 13.7) pub fn build_eof_packet(channel_id: u32) -> Result> { let mut packet = Vec::new(); - + // Packet type: SSH_MSG_CHANNEL_EOF (type 96) packet.write_u8(PacketType::SSH_MSG_CHANNEL_EOF as u8)?; - + // Recipient channel ID packet.write_u32::(channel_id)?; - + info!("Built SSH_MSG_CHANNEL_EOF for channel {}", channel_id); - + Ok(packet) } - + /// 构建SSH_MSG_CHANNEL_CLOSE packet(Phase 13.7) pub fn build_close_packet(channel_id: u32) -> Result> { let mut packet = Vec::new(); - + // Packet type: SSH_MSG_CHANNEL_CLOSE (type 97) packet.write_u8(PacketType::SSH_MSG_CHANNEL_CLOSE as u8)?; - + // Recipient channel ID packet.write_u32::(channel_id)?; - + info!("Built SSH_MSG_CHANNEL_CLOSE for channel {}", channel_id); - + Ok(packet) } - + /// 标记EOF已发送(Phase 13.7) pub fn mark_eof_sent(&mut self) { self.eof_sent = true; info!("Channel {} EOF marked as sent", self.channel_id); } - + /// 标记CLOSE已接收(Phase 13.7) pub fn mark_close_received(&mut self) { self.close_received = true; info!("Channel {} CLOSE marked as received", self.channel_id); } - + /// 检查是否可以清理channel(Phase 13.7) pub fn can_cleanup(&self) -> bool { self.eof_sent && self.close_received } - + /// 清理channel资源(Phase 13.7) pub fn cleanup_channel(&self) -> Result<()> { info!("Cleaning up channel {} resources", self.channel_id); - + // Phase 13.7: 实际清理逻辑需要在ChannelManager中实现 // - 移除channel记录 // - 关闭TCP连接 // - 清理监听器(如果是forwarded-tcpip) - + info!("Channel {} cleanup completed", self.channel_id); Ok(()) } @@ -193,42 +207,42 @@ impl ChannelLifecycle { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_window_manager_creation() { let manager = WindowManager::rfc_default(); assert_eq!(manager.get_current_window(), 2097152); assert_eq!(manager.max_packet_size, 32768); } - + #[test] fn test_window_consumption() { let manager = WindowManager::rfc_default(); - + // 消耗1000 bytes manager.consume_window(1000).unwrap(); assert_eq!(manager.get_current_window(), 2097152 - 1000); assert_eq!(manager.get_consumed_bytes(), 1000); } - + #[test] fn test_window_adjustment() { let manager = WindowManager::rfc_default(); - + // 消耗1000 bytes manager.consume_window(1000).unwrap(); - + // 调整500 bytes manager.adjust_window(500); assert_eq!(manager.get_current_window(), 2097152 - 1000 + 500); } - + #[test] fn test_build_eof_packet() { let packet = ChannelLifecycle::build_eof_packet(1).unwrap(); assert_eq!(packet[0], PacketType::SSH_MSG_CHANNEL_EOF as u8); } - + #[test] fn test_build_close_packet() { let packet = ChannelLifecycle::build_close_packet(1).unwrap(); diff --git a/markbase-core/src/vfs/local_fs.rs b/markbase-core/src/vfs/local_fs.rs index e6125a4..a03dae2 100644 --- a/markbase-core/src/vfs/local_fs.rs +++ b/markbase-core/src/vfs/local_fs.rs @@ -1,15 +1,21 @@ -use super::util; use super::open_flags::OpenFlags; +use super::util; use super::{VfsBackend, VfsDirEntry, VfsError, VfsFile, VfsStat}; use std::fs::{self, File, OpenOptions}; use std::io::{Read, Seek, SeekFrom, Write}; -use std::path::{Path, PathBuf}; use std::os::unix::fs::{MetadataExt, PermissionsExt}; +use std::path::{Path, PathBuf}; /// 本地文件系统实现(直接包装 std::fs,不做路径解析) /// 路径解析由上层(SftpHandler)负责 pub struct LocalFs; +impl Default for LocalFs { + fn default() -> Self { + Self::new() + } +} + impl LocalFs { pub fn new() -> Self { Self @@ -26,7 +32,9 @@ impl VfsFile for LocalFile { } fn write(&mut self, buf: &[u8]) -> Result { - self.file.write(buf).map_err(|e| VfsError::Io(e.to_string())) + self.file + .write(buf) + .map_err(|e| VfsError::Io(e.to_string())) } fn seek(&mut self, pos: SeekFrom) -> Result { @@ -38,12 +46,17 @@ impl VfsFile for LocalFile { } fn stat(&mut self) -> Result { - let meta = self.file.metadata().map_err(|e| VfsError::Io(e.to_string()))?; + let meta = self + .file + .metadata() + .map_err(|e| VfsError::Io(e.to_string()))?; Ok(util::stat_from_metadata(&meta, false)) } fn set_len(&mut self, size: u64) -> Result<(), VfsError> { - self.file.set_len(size).map_err(|e| VfsError::Io(e.to_string())) + self.file + .set_len(size) + .map_err(|e| VfsError::Io(e.to_string())) } } @@ -86,8 +99,7 @@ impl VfsBackend for LocalFs { if flags.create && !flags.exclusive { if let Ok(meta) = file.metadata() { if flags.mode != 0 && meta.permissions().mode() != flags.mode { - fs::set_permissions(path, std::fs::Permissions::from_mode(flags.mode)) - .ok(); + fs::set_permissions(path, std::fs::Permissions::from_mode(flags.mode)).ok(); } } } @@ -157,10 +169,12 @@ impl VfsBackend for LocalFs { stat.atime.duration_since(std::time::UNIX_EPOCH).ok(), stat.mtime.duration_since(std::time::UNIX_EPOCH).ok(), ) { - filetime::set_file_times(path, + filetime::set_file_times( + path, filetime::FileTime::from_unix_time(atime.as_secs() as i64, 0), filetime::FileTime::from_unix_time(mtime.as_secs() as i64, 0), - ).map_err(|e| util::map_io_error(path, e))?; + ) + .map_err(|e| util::map_io_error(path, e))?; } Ok(()) @@ -174,8 +188,7 @@ impl VfsBackend for LocalFs { fn create_symlink(&self, target: &Path, link: &Path) -> Result<(), VfsError> { #[cfg(unix)] { - std::os::unix::fs::symlink(target, link) - .map_err(|e| util::map_io_error(link, e))?; + std::os::unix::fs::symlink(target, link).map_err(|e| util::map_io_error(link, e))?; } #[cfg(not(unix))] @@ -188,7 +201,9 @@ impl VfsBackend for LocalFs { } fn real_path(&self, path: &Path) -> Result { - let canonical = path.canonicalize().map_err(|e| util::map_io_error(path, e))?; + let canonical = path + .canonicalize() + .map_err(|e| util::map_io_error(path, e))?; Ok(canonical) } @@ -204,7 +219,9 @@ impl VfsBackend for LocalFs { #[cfg(not(unix))] { - return Err(VfsError::Unsupported("hard_link not supported on non-Unix systems".to_string())); + return Err(VfsError::Unsupported( + "hard_link not supported on non-Unix systems".to_string(), + )); } Ok(()) diff --git a/markbase-core/src/vfs/mod.rs b/markbase-core/src/vfs/mod.rs index 0c69ec2..e104b1f 100644 --- a/markbase-core/src/vfs/mod.rs +++ b/markbase-core/src/vfs/mod.rs @@ -1,5 +1,5 @@ -pub mod open_flags; pub mod local_fs; +pub mod open_flags; pub mod s3_fs; pub mod util; @@ -120,7 +120,11 @@ pub trait VfsBackend: Send { fn read_dir(&self, path: &Path) -> Result, VfsError>; /// 打开文件(读/写) - fn open_file(&self, path: &Path, flags: &open_flags::OpenFlags) -> Result, VfsError>; + fn open_file( + &self, + path: &Path, + flags: &open_flags::OpenFlags, + ) -> Result, VfsError>; /// 获取文件/目录元数据 fn stat(&self, path: &Path) -> Result; diff --git a/markbase-core/src/vfs/s3_fs.rs b/markbase-core/src/vfs/s3_fs.rs index d0a3fff..5089f25 100644 --- a/markbase-core/src/vfs/s3_fs.rs +++ b/markbase-core/src/vfs/s3_fs.rs @@ -56,7 +56,10 @@ impl S3Vfs { let credentials = Credentials::new(access_key, secret_key); - Ok(Self { bucket, credentials }) + Ok(Self { + bucket, + credentials, + }) } fn path_to_key(path: &Path) -> String { @@ -118,7 +121,10 @@ impl S3Vfs { .map_err(|e| VfsError::Io(format!("S3 PUT failed: {}", e)))?; if resp.status() != 200 { - return Err(VfsError::Io(format!("PutObject returned {}", resp.status()))); + return Err(VfsError::Io(format!( + "PutObject returned {}", + resp.status() + ))); } Ok(()) } @@ -149,15 +155,15 @@ impl S3Vfs { .map_err(|e| VfsError::Io(format!("S3 CopyObject failed: {}", e)))?; if resp.status() != 200 { - return Err(VfsError::Io(format!("CopyObject returned {}", resp.status()))); + return Err(VfsError::Io(format!( + "CopyObject returned {}", + resp.status() + ))); } Ok(()) } - fn list_objects( - &self, - prefix: &str, - ) -> Result { + fn list_objects(&self, prefix: &str) -> Result { let mut action = actions::ListObjectsV2::new(&self.bucket, Some(&self.credentials)); if !prefix.is_empty() { action.with_prefix(prefix); @@ -181,9 +187,8 @@ impl S3Vfs { .read_to_string(&mut body) .map_err(|e| VfsError::Io(format!("Failed to read S3 list response: {}", e)))?; - actions::ListObjectsV2::parse_response(&body).map_err(|e| { - VfsError::Io(format!("Failed to parse S3 list response XML: {}", e)) - }) + actions::ListObjectsV2::parse_response(&body) + .map_err(|e| VfsError::Io(format!("Failed to parse S3 list response XML: {}", e))) } } @@ -409,7 +414,9 @@ impl VfsBackend for S3Vfs { impl VfsFile for S3VfsFile { fn read(&mut self, buf: &mut [u8]) -> Result { - let to_read = buf.len().min((self.size.saturating_sub(self.position)) as usize); + let to_read = buf + .len() + .min((self.size.saturating_sub(self.position)) as usize); if to_read == 0 { return Ok(0); } @@ -443,7 +450,7 @@ impl VfsFile for S3VfsFile { self.position = sz.saturating_add(offset as u64); } else { let abs = offset.unsigned_abs(); - self.position = if abs <= sz { sz - abs } else { 0 }; + self.position = sz.saturating_sub(abs); } } std::io::SeekFrom::Current(offset) => { @@ -451,11 +458,7 @@ impl VfsFile for S3VfsFile { self.position = self.position.saturating_add(offset as u64); } else { let abs = offset.unsigned_abs(); - self.position = if abs <= self.position { - self.position - abs - } else { - 0 - }; + self.position = self.position.saturating_sub(abs); } } } @@ -549,7 +552,10 @@ impl S3VfsLike { .map_err(|e| VfsError::Io(format!("S3 PUT failed: {}", e)))?; if resp.status() != 200 { - return Err(VfsError::Io(format!("PutObject returned {}", resp.status()))); + return Err(VfsError::Io(format!( + "PutObject returned {}", + resp.status() + ))); } Ok(()) } @@ -612,10 +618,7 @@ mod tests { #[test] fn test_path_to_key() { - assert_eq!( - S3Vfs::path_to_key(Path::new("/foo/bar.txt")), - "foo/bar.txt" - ); + assert_eq!(S3Vfs::path_to_key(Path::new("/foo/bar.txt")), "foo/bar.txt"); assert_eq!(S3Vfs::path_to_key(Path::new("/")), ""); assert_eq!( S3Vfs::path_to_key(Path::new("relative/path")), diff --git a/markbase-core/src/vfs/util.rs b/markbase-core/src/vfs/util.rs index 0df3b76..2fe5554 100644 --- a/markbase-core/src/vfs/util.rs +++ b/markbase-core/src/vfs/util.rs @@ -7,7 +7,9 @@ use std::path::Path; pub fn map_io_error(path: &Path, e: std::io::Error) -> VfsError { match e.kind() { std::io::ErrorKind::NotFound => VfsError::NotFound(path.display().to_string()), - std::io::ErrorKind::PermissionDenied => VfsError::PermissionDenied(path.display().to_string()), + std::io::ErrorKind::PermissionDenied => { + VfsError::PermissionDenied(path.display().to_string()) + } std::io::ErrorKind::AlreadyExists => VfsError::AlreadyExists(path.display().to_string()), std::io::ErrorKind::DirectoryNotEmpty => VfsError::NotEmpty(path.display().to_string()), std::io::ErrorKind::NotADirectory => VfsError::NotADirectory(path.display().to_string()), @@ -65,13 +67,7 @@ pub fn build_long_name(stat: &VfsStat, name: &str) -> String { format!( "{}{} {} {} {} {} {} {}", - file_type, perms, - link_count, - stat.uid, - stat.gid, - size, - mtime, - name + file_type, perms, link_count, stat.uid, stat.gid, size, mtime, name ) } diff --git a/markbase-fskit/src/fskit/frame_align.rs b/markbase-fskit/src/fskit/frame_align.rs index 5056e7e..6075433 100644 --- a/markbase-fskit/src/fskit/frame_align.rs +++ b/markbase-fskit/src/fskit/frame_align.rs @@ -48,10 +48,10 @@ impl FrameAlignment { pub fn is_aligned(&self, offset: usize, size: usize) -> bool { if self.frame_size == 0 { - return offset % self.frame_boundary == 0 && size % self.frame_boundary == 0; + return offset.is_multiple_of(self.frame_boundary) && size.is_multiple_of(self.frame_boundary); } - offset % self.frame_size == 0 && size % self.frame_size == 0 + offset.is_multiple_of(self.frame_size) && size.is_multiple_of(self.frame_size) } pub fn optimal_chunk_size(&self) -> usize { @@ -74,9 +74,9 @@ impl FrameAlignment { pub fn align_size(&self, size: usize) -> usize { if self.frame_size == 0 { let boundary = self.frame_boundary; - ((size + boundary - 1) / boundary) * boundary + size.div_ceil(boundary) * boundary } else { - ((size + self.frame_size - 1) / self.frame_size) * self.frame_size + size.div_ceil(self.frame_size) * self.frame_size } } diff --git a/markbase-fuse/build.rs b/markbase-fuse/build.rs index 6908fda..01c5645 100644 --- a/markbase-fuse/build.rs +++ b/markbase-fuse/build.rs @@ -2,15 +2,15 @@ fn main() { if cfg!(target_os = "macos") { // Link fuse-t library println!("cargo:rustc-link-lib=fuse-t"); - + // Link macOS frameworks println!("cargo:rustc-link-lib=framework=DiskArbitration"); println!("cargo:rustc-link-lib=framework=CoreFoundation"); - + // Add fuse-t include path println!("cargo:rustc-link-search=native=/usr/local/lib"); - + // Rerun if fuse-t changes println!("cargo:rerun-if-changed=/usr/local/lib/libfuse-t.dylib"); } -} \ No newline at end of file +} diff --git a/markbase-fuse/src/fuse/cache.rs b/markbase-fuse/src/fuse/cache.rs index 241741b..ac3a012 100644 --- a/markbase-fuse/src/fuse/cache.rs +++ b/markbase-fuse/src/fuse/cache.rs @@ -6,6 +6,12 @@ pub struct ThreadSafeCache { path_cache: Mutex>, // path -> node_id } +impl Default for ThreadSafeCache { + fn default() -> Self { + Self::new() + } +} + impl ThreadSafeCache { pub fn new() -> Self { Self { diff --git a/markbase-fuse/src/fuse/db.rs b/markbase-fuse/src/fuse/db.rs index 76522a5..6d40506 100644 --- a/markbase-fuse/src/fuse/db.rs +++ b/markbase-fuse/src/fuse/db.rs @@ -44,7 +44,9 @@ impl DbManager { let mut stmt = conn.prepare(sql)?; let result = if level == 0 { - stmt.query_row(params![component, &self.tree_type], |row| row.get::<_, String>(0)) + stmt.query_row(params![component, &self.tree_type], |row| { + row.get::<_, String>(0) + }) } else { stmt.query_row( params![component, &self.tree_type, current_parent.as_ref().unwrap()], @@ -79,7 +81,8 @@ impl DbManager { pub fn get_node_info(&self, node_id: &str) -> Result> { let conn = self.conn.lock().unwrap(); - let sql = "SELECT node_type, file_size FROM file_nodes WHERE node_id = ?1 AND tree_type = ?2"; + let sql = + "SELECT node_type, file_size FROM file_nodes WHERE node_id = ?1 AND tree_type = ?2"; let mut stmt = conn.prepare(sql)?; let result = stmt.query_row(params![node_id, &self.tree_type], |row| { @@ -100,7 +103,9 @@ impl DbManager { let mut stmt = conn.prepare(sql)?; let labels = stmt - .query_map(params![parent_id, &self.tree_type], |row| row.get::<_, String>(0))? + .query_map(params![parent_id, &self.tree_type], |row| { + row.get::<_, String>(0) + })? .collect::, _>>()?; Ok(labels) @@ -112,7 +117,9 @@ impl DbManager { let sql = "SELECT node_id FROM file_nodes WHERE parent_id = ?1 AND label = ?2 AND tree_type = ?3 LIMIT 1"; let mut stmt = conn.prepare(sql)?; - let result = stmt.query_row(params![parent_id, name, &self.tree_type], |row| row.get::<_, String>(0)); + let result = stmt.query_row(params![parent_id, name, &self.tree_type], |row| { + row.get::<_, String>(0) + }); match result { Ok(node_id) => Ok(Some(node_id)), diff --git a/markbase-fuse/src/fuse/filesystem.rs b/markbase-fuse/src/fuse/filesystem.rs index 98c7824..42a5736 100644 --- a/markbase-fuse/src/fuse/filesystem.rs +++ b/markbase-fuse/src/fuse/filesystem.rs @@ -1,10 +1,10 @@ use anyhow::Result; -use fuse_backend_rs::api::filesystem::{Context, DirEntry, Entry, FileSystem, ZeroCopyWriter}; use fuse_backend_rs::abi::fuse_abi::{stat64, statvfs64, FsOptions, OpenOptions}; +use fuse_backend_rs::api::filesystem::{Context, DirEntry, Entry, FileSystem, ZeroCopyWriter}; use std::collections::HashMap; use std::ffi::CStr; use std::fs::File; -use std::io::{Read, Seek, SeekFrom}; +use std::io::Read; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime}; @@ -51,7 +51,7 @@ impl MarkBaseFs { let mut next = self.next_inode.write().unwrap(); let ino = *next; *next += 1; - + let mut map = self.inode_map.write().unwrap(); map.insert(ino, node_id.to_string()); ino @@ -65,14 +65,17 @@ impl MarkBaseFs { st.st_uid = 501; st.st_gid = 20; st.st_size = file_size as i64; - st.st_blocks = ((file_size + 511) / 512) as i64; + st.st_blocks = file_size.div_ceil(512) as i64; st.st_blksize = 4096; - + let now = SystemTime::now(); - st.st_atime = now.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() as i64; + st.st_atime = now + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; st.st_mtime = st.st_atime; st.st_ctime = st.st_atime; - + st } } @@ -85,20 +88,22 @@ impl FileSystem for MarkBaseFs { Ok(capable) } - fn lookup( - &self, - ctx: &Context, - parent: Self::Inode, - name: &CStr, - ) -> std::io::Result { + fn lookup(&self, _ctx: &Context, parent: Self::Inode, name: &CStr) -> std::io::Result { let name_str = name.to_string_lossy(); - + let node_id = if parent == 1 { - self.db.find_node_id(&format!("/{}", name_str)).ok().flatten() + self.db + .find_node_id(&format!("/{}", name_str)) + .ok() + .flatten() } else { let parent_id = self.find_node_id_by_inode(parent); match parent_id { - Some(pid) => self.db.find_node_id_by_parent(&pid, &name_str).ok().flatten(), + Some(pid) => self + .db + .find_node_id_by_parent(&pid, &name_str) + .ok() + .flatten(), None => None, } }; @@ -127,16 +132,16 @@ impl FileSystem for MarkBaseFs { } } - fn forget(&self, ctx: &Context, inode: Self::Inode, count: u64) { + fn forget(&self, _ctx: &Context, inode: Self::Inode, _count: u64) { let mut map = self.inode_map.write().unwrap(); map.remove(&inode); } fn getattr( &self, - ctx: &Context, + _ctx: &Context, inode: Self::Inode, - handle: Option, + _handle: Option, ) -> std::io::Result<(stat64, Duration)> { if inode == 1 { let attr = MarkBaseFs::make_stat64("folder", 0); @@ -161,24 +166,24 @@ impl FileSystem for MarkBaseFs { fn open( &self, - ctx: &Context, + _ctx: &Context, inode: Self::Inode, - flags: u32, - fuse_flags: u32, + _flags: u32, + _fuse_flags: u32, ) -> std::io::Result<(Option, OpenOptions, Option)> { Ok((Some(inode), OpenOptions::empty(), None)) } fn read( &self, - ctx: &Context, + _ctx: &Context, inode: Self::Inode, - handle: Self::Handle, + _handle: Self::Handle, w: &mut dyn ZeroCopyWriter, size: u32, offset: u64, - lock_owner: Option, - flags: u32, + _lock_owner: Option, + _flags: u32, ) -> std::io::Result { let node_id = self.find_node_id_by_inode(inode); match node_id { @@ -190,7 +195,7 @@ impl FileSystem for MarkBaseFs { let fd = file.as_raw_fd(); let f = unsafe { File::from_raw_fd(fd) }; let mut f = std::mem::ManuallyDrop::new(f); - + w.write_from(&mut *f, size as usize, offset) } _ => Err(std::io::Error::from_raw_os_error(libc::ENOENT)), @@ -202,32 +207,32 @@ impl FileSystem for MarkBaseFs { fn release( &self, - ctx: &Context, - inode: Self::Inode, - flags: u32, - handle: Self::Handle, - flush: bool, - flock_release: bool, - lock_owner: Option, + _ctx: &Context, + _inode: Self::Inode, + _flags: u32, + _handle: Self::Handle, + _flush: bool, + _flock_release: bool, + _lock_owner: Option, ) -> std::io::Result<()> { Ok(()) } fn opendir( &self, - ctx: &Context, + _ctx: &Context, inode: Self::Inode, - flags: u32, + _flags: u32, ) -> std::io::Result<(Option, OpenOptions)> { Ok((Some(inode), OpenOptions::empty())) } fn readdir( &self, - ctx: &Context, + _ctx: &Context, inode: Self::Inode, - handle: Self::Handle, - size: u32, + _handle: Self::Handle, + _size: u32, offset: u64, add_entry: &mut dyn FnMut(DirEntry) -> std::io::Result, ) -> std::io::Result<()> { @@ -252,19 +257,15 @@ impl FileSystem for MarkBaseFs { fn releasedir( &self, - ctx: &Context, - inode: Self::Inode, - flags: u32, - handle: Self::Handle, + _ctx: &Context, + _inode: Self::Inode, + _flags: u32, + _handle: Self::Handle, ) -> std::io::Result<()> { Ok(()) } - fn statfs( - &self, - ctx: &Context, - inode: Self::Inode, - ) -> std::io::Result { + fn statfs(&self, _ctx: &Context, _inode: Self::Inode) -> std::io::Result { let mut st = unsafe { std::mem::zeroed::() }; st.f_bsize = 4096; st.f_blocks = 1000000; @@ -276,4 +277,4 @@ impl FileSystem for MarkBaseFs { st.f_namemax = 255; Ok(st) } -} \ No newline at end of file +} diff --git a/markbase-fuse/src/main.rs b/markbase-fuse/src/main.rs index dee9d3f..d2fde82 100644 --- a/markbase-fuse/src/main.rs +++ b/markbase-fuse/src/main.rs @@ -30,7 +30,11 @@ fn main() -> Result<()> { let cli = Cli::parse(); match cli.command { - Commands::Mount { user, dir, tree_type } => { + Commands::Mount { + user, + dir, + tree_type, + } => { mount_user(user, tree_type, dir)?; } Commands::Unmount { dir } => { @@ -78,13 +82,15 @@ fn mount_user(user: String, tree_type: String, dir: PathBuf) -> Result<()> { println!("Press Ctrl+C to unmount..."); let mut channel = session.new_channel()?; - + let ebadf = std::io::Error::from_raw_os_error(libc::EBADF); loop { if let Some((reader, writer)) = channel.get_request()? { if let Err(e) = server.handle_message(reader, writer.into(), None, None) { match e { - fuse_backend_rs::Error::EncodeMessage(e) if e.kind() == std::io::ErrorKind::Other => { + fuse_backend_rs::Error::EncodeMessage(e) + if e.kind() == std::io::ErrorKind::Other => + { break; } _ => { @@ -112,4 +118,4 @@ fn unmount_user(dir: PathBuf) -> Result<()> { println!("Unmounted successfully"); Ok(()) -} \ No newline at end of file +} diff --git a/markbase-iscsi/src/lib.rs b/markbase-iscsi/src/lib.rs index e0ab783..61c5322 100644 --- a/markbase-iscsi/src/lib.rs +++ b/markbase-iscsi/src/lib.rs @@ -82,7 +82,7 @@ fn get_block_device_size(device: &str) -> Result { let stdout = String::from_utf8_lossy(&output.stdout); for line in stdout.lines() { if let Some(size_str) = line.strip_prefix(" Disk Size:") { - if let Some(bytes) = size_str.trim().split_whitespace().next() { + if let Some(bytes) = size_str.split_whitespace().next() { if let Ok(size) = bytes.replace(',', "").parse::() { return Ok(size); } @@ -120,7 +120,7 @@ pub fn generate_config( return Ok(()); } - let (storage_path, lun_size_bytes) = if let Some(dev) = device { + let (storage_path, _lun_size_bytes) = if let Some(dev) = device { let dev_path = Path::new(dev); if !dev_path.exists() { anyhow::bail!("Block device not found: {}", dev); diff --git a/markbase-nfs/src/lib.rs b/markbase-nfs/src/lib.rs index 1afa718..2235403 100644 --- a/markbase-nfs/src/lib.rs +++ b/markbase-nfs/src/lib.rs @@ -1,4 +1,4 @@ pub mod nfs; -pub use nfs::markbase_fs::MarkBaseFS; pub use nfs::backend::MarkBaseNFSBackend; +pub use nfs::markbase_fs::MarkBaseFS; diff --git a/markbase-nfs/src/main.rs b/markbase-nfs/src/main.rs index 1590984..e217007 100644 --- a/markbase-nfs/src/main.rs +++ b/markbase-nfs/src/main.rs @@ -9,11 +9,11 @@ struct Cli { /// User ID (database name) #[arg(short, long)] user: String, - + /// Database path #[arg(short, long, default_value = "data/users")] data_dir: String, - + /// NFS server port #[arg(short, long, default_value_t = 11111)] port: u16, @@ -21,19 +21,19 @@ struct Cli { fn main() -> anyhow::Result<()> { let cli = Cli::parse(); - + let db_path = PathBuf::from(&cli.data_dir).join(format!("{}.sqlite", cli.user)); - + if !db_path.exists() { eprintln!("Database not found: {}", db_path.display()); eprintln!("Please create database first using markbase-core"); return Err(anyhow::anyhow!("Database not found")); } - + eprintln!("Starting MarkBase NFS server..."); eprintln!("User: {}", cli.user); eprintln!("Database: {}", db_path.display()); eprintln!("Port: {}", cli.port); - + run_nfs_server(cli.user, db_path, cli.port) -} \ No newline at end of file +} diff --git a/markbase-nfs/src/nfs/backend.rs b/markbase-nfs/src/nfs/backend.rs index b15897d..c016140 100644 --- a/markbase-nfs/src/nfs/backend.rs +++ b/markbase-nfs/src/nfs/backend.rs @@ -5,13 +5,12 @@ use std::time::SystemTime; use async_trait::async_trait; use nfsserve::nfs::*; use nfsserve::vfs::{DirEntry, NFSFileSystem, ReadDirResult, VFSCapabilities}; -use rusqlite::Connection; use crate::nfs::markbase_fs::MarkBaseFS; pub struct MarkBaseNFSBackend { fs: MarkBaseFS, - id_map: Mutex>, // fileid -> node_id + id_map: Mutex>, // fileid -> node_id reverse_map: Mutex>, // node_id -> fileid next_id: Mutex, } @@ -19,7 +18,7 @@ pub struct MarkBaseNFSBackend { impl MarkBaseNFSBackend { pub fn new(user_id: String, db_path: std::path::PathBuf) -> anyhow::Result { let fs = MarkBaseFS::new(user_id, db_path)?; - + Ok(MarkBaseNFSBackend { fs, id_map: Mutex::new(HashMap::new()), @@ -27,32 +26,35 @@ impl MarkBaseNFSBackend { next_id: Mutex::new(2), // 1 is root }) } - + fn allocate_id(&self, node_id: &str) -> u64 { let mut reverse_map = self.reverse_map.lock().unwrap(); - + if let Some(id) = reverse_map.get(node_id) { return *id; } - + let mut next_id = self.next_id.lock().unwrap(); let id = *next_id; *next_id += 1; - + reverse_map.insert(node_id.to_string(), id); self.id_map.lock().unwrap().insert(id, node_id.to_string()); - + id } - + fn get_node_id(&self, fileid: u64) -> Option { self.id_map.lock().unwrap().get(&fileid).cloned() } - + fn get_fileid_from_node(&self, node_id: &str) -> u64 { - self.reverse_map.lock().unwrap().get(node_id).copied().unwrap_or_else(|| { - self.allocate_id(node_id) - }) + self.reverse_map + .lock() + .unwrap() + .get(node_id) + .copied() + .unwrap_or_else(|| self.allocate_id(node_id)) } } @@ -61,40 +63,43 @@ impl NFSFileSystem for MarkBaseNFSBackend { fn capabilities(&self) -> VFSCapabilities { VFSCapabilities::ReadOnly } - + fn root_dir(&self) -> fileid3 { 1 } - + async fn lookup(&self, dirid: fileid3, filename: &filename3) -> Result { let dir_node_id = if dirid == 1 { "root".to_string() } else { - self.get_node_id(dirid) - .ok_or(nfsstat3::NFS3ERR_STALE)? + self.get_node_id(dirid).ok_or(nfsstat3::NFS3ERR_STALE)? }; - + let filename_str = String::from_utf8_lossy(filename).to_string(); - - let conn = self.fs.conn.lock().map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?; - + + let conn = self + .fs + .conn + .lock() + .map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?; + let query = if dir_node_id == "root" { "SELECT node_id FROM file_nodes WHERE parent_id IS NULL AND label = ?1" } else { "SELECT node_id FROM file_nodes WHERE parent_id = ?1 AND label = ?2" }; - + let node_id: String = if dir_node_id == "root" { - conn.query_row(&query, [&filename_str], |row| row.get(0)) + conn.query_row(query, [&filename_str], |row| row.get(0)) .map_err(|_| nfsstat3::NFS3ERR_NOENT)? } else { - conn.query_row(&query, [dir_node_id, filename_str], |row| row.get(0)) + conn.query_row(query, [dir_node_id, filename_str], |row| row.get(0)) .map_err(|_| nfsstat3::NFS3ERR_NOENT)? }; - + Ok(self.get_fileid_from_node(&node_id)) } - + async fn getattr(&self, id: fileid3) -> Result { if id == 1 { return Ok(fattr3 { @@ -105,38 +110,54 @@ impl NFSFileSystem for MarkBaseNFSBackend { gid: 0, size: 0, used: 0, - rdev: specdata3 { specdata1: 0, specdata2: 0 }, + rdev: specdata3 { + specdata1: 0, + specdata2: 0, + }, fsid: 0, fileid: 1, - atime: nfstime3 { seconds: 0, nseconds: 0 }, - mtime: nfstime3 { seconds: 0, nseconds: 0 }, - ctime: nfstime3 { seconds: 0, nseconds: 0 }, + atime: nfstime3 { + seconds: 0, + nseconds: 0, + }, + mtime: nfstime3 { + seconds: 0, + nseconds: 0, + }, + ctime: nfstime3 { + seconds: 0, + nseconds: 0, + }, }); } - + let node_id = self.get_node_id(id).ok_or(nfsstat3::NFS3ERR_STALE)?; - - let conn = self.fs.conn.lock().map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?; - + + let conn = self + .fs + .conn + .lock() + .map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?; + let (node_type, file_size): (String, i64) = conn .query_row( "SELECT node_type, file_size FROM file_nodes WHERE node_id = ?1", [&node_id], - |row| Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?)) + |row| Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?)), ) .map_err(|_| nfsstat3::NFS3ERR_NOENT)?; - + let type_ = if node_type == "folder" { ftype3::NF3DIR } else { ftype3::NF3REG }; - + let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap() .as_secs(); - + Ok(fattr3 { ftype: type_, mode: if node_type == "folder" { 0o755 } else { 0o644 }, @@ -145,69 +166,103 @@ impl NFSFileSystem for MarkBaseNFSBackend { gid: 0, size: file_size as u64, used: file_size as u64, - rdev: specdata3 { specdata1: 0, specdata2: 0 }, + rdev: specdata3 { + specdata1: 0, + specdata2: 0, + }, fsid: 0, fileid: id, - atime: nfstime3 { seconds: now as u32, nseconds: 0 }, - mtime: nfstime3 { seconds: now as u32, nseconds: 0 }, - ctime: nfstime3 { seconds: now as u32, nseconds: 0 }, + atime: nfstime3 { + seconds: now as u32, + nseconds: 0, + }, + mtime: nfstime3 { + seconds: now as u32, + nseconds: 0, + }, + ctime: nfstime3 { + seconds: now as u32, + nseconds: 0, + }, }) } - + async fn setattr(&self, _id: fileid3, _setattr: sattr3) -> Result { Err(nfsstat3::NFS3ERR_ROFS) } - - async fn read(&self, id: fileid3, offset: u64, count: u32) -> Result<(Vec, bool), nfsstat3> { + + async fn read( + &self, + id: fileid3, + offset: u64, + count: u32, + ) -> Result<(Vec, bool), nfsstat3> { let node_id = self.get_node_id(id).ok_or(nfsstat3::NFS3ERR_STALE)?; - - let conn = self.fs.conn.lock().map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?; - + + let conn = self + .fs + .conn + .lock() + .map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?; + let aliases_json: String = conn .query_row( "SELECT aliases_json FROM file_nodes WHERE node_id = ?1", [&node_id], - |row| row.get(0) + |row| row.get(0), ) .map_err(|_| nfsstat3::NFS3ERR_NOENT)?; - - let aliases: serde_json::Value = serde_json::from_str(&aliases_json) - .map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?; - + + let aliases: serde_json::Value = + serde_json::from_str(&aliases_json).map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?; + let file_path = aliases["path"].as_str().ok_or(nfsstat3::NFS3ERR_NOENT)?; - + let file_data = std::fs::read(file_path).map_err(|_| nfsstat3::NFS3ERR_IO)?; - + let file_size = file_data.len() as u64; let start = offset.min(file_size) as usize; let end = (offset + count as u64).min(file_size) as usize; - + let data = file_data[start..end].to_vec(); let eof = end >= file_size as usize; - + Ok((data, eof)) } - + async fn write(&self, _id: fileid3, _offset: u64, _data: &[u8]) -> Result { Err(nfsstat3::NFS3ERR_ROFS) } - - async fn create(&self, _dirid: fileid3, _filename: &filename3, _attr: sattr3) -> Result<(fileid3, fattr3), nfsstat3> { + + async fn create( + &self, + _dirid: fileid3, + _filename: &filename3, + _attr: sattr3, + ) -> Result<(fileid3, fattr3), nfsstat3> { Err(nfsstat3::NFS3ERR_ROFS) } - - async fn create_exclusive(&self, _dirid: fileid3, _filename: &filename3) -> Result { + + async fn create_exclusive( + &self, + _dirid: fileid3, + _filename: &filename3, + ) -> Result { Err(nfsstat3::NFS3ERR_ROFS) } - - async fn mkdir(&self, _dirid: fileid3, _dirname: &filename3) -> Result<(fileid3, fattr3), nfsstat3> { + + async fn mkdir( + &self, + _dirid: fileid3, + _dirname: &filename3, + ) -> Result<(fileid3, fattr3), nfsstat3> { Err(nfsstat3::NFS3ERR_ROFS) } - + async fn remove(&self, _dirid: fileid3, _filename: &filename3) -> Result<(), nfsstat3> { Err(nfsstat3::NFS3ERR_ROFS) } - + async fn rename( &self, _from_dirid: fileid3, @@ -217,7 +272,7 @@ impl NFSFileSystem for MarkBaseNFSBackend { ) -> Result<(), nfsstat3> { Err(nfsstat3::NFS3ERR_ROFS) } - + async fn readdir( &self, dirid: fileid3, @@ -227,71 +282,69 @@ impl NFSFileSystem for MarkBaseNFSBackend { let dir_node_id = if dirid == 1 { "root".to_string() } else { - self.get_node_id(dirid) - .ok_or(nfsstat3::NFS3ERR_STALE)? + self.get_node_id(dirid).ok_or(nfsstat3::NFS3ERR_STALE)? }; - - let conn = self.fs.conn.lock().map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?; - + + let conn = self + .fs + .conn + .lock() + .map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?; + let query = if dir_node_id == "root" { "SELECT node_id, label, node_type, file_size FROM file_nodes WHERE parent_id IS NULL" } else { "SELECT node_id, label, node_type, file_size FROM file_nodes WHERE parent_id = ?1" }; - - let mut stmt = conn.prepare(&query).map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?; - + + let mut stmt = conn + .prepare(query) + .map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?; + let rows: Vec<(String, String, String, Option)> = if dir_node_id == "root" { stmt.query_map([], |row| { - row.get::<_, String>(0) - .and_then(|node_id| { - row.get::<_, String>(1) - .and_then(|label| { - row.get::<_, String>(2) - .and_then(|node_type| { - row.get::<_, Option>(3) - .map(|file_size| (node_id, label, node_type, file_size)) - }) - }) + row.get::<_, String>(0).and_then(|node_id| { + row.get::<_, String>(1).and_then(|label| { + row.get::<_, String>(2).and_then(|node_type| { + row.get::<_, Option>(3) + .map(|file_size| (node_id, label, node_type, file_size)) + }) }) + }) }) .map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)? .collect::, _>>() .map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)? } else { stmt.query_map([&dir_node_id.as_str()], |row| { - row.get::<_, String>(0) - .and_then(|node_id| { - row.get::<_, String>(1) - .and_then(|label| { - row.get::<_, String>(2) - .and_then(|node_type| { - row.get::<_, Option>(3) - .map(|file_size| (node_id, label, node_type, file_size)) - }) - }) + row.get::<_, String>(0).and_then(|node_id| { + row.get::<_, String>(1).and_then(|label| { + row.get::<_, String>(2).and_then(|node_type| { + row.get::<_, Option>(3) + .map(|file_size| (node_id, label, node_type, file_size)) + }) }) + }) }) .map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)? .collect::, _>>() .map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)? }; - + let mut entries: Vec = Vec::new(); let mut started = start_after == 0; - + for row in rows { let (node_id, label, node_type, file_size_opt) = row; let file_size = file_size_opt.unwrap_or(0); let fileid = self.get_fileid_from_node(&node_id); - - if !started { - if fileid == start_after { + + if !started + && fileid == start_after { started = true; continue; } - } - + if started && entries.len() < max_entries { let attr = fattr3 { ftype: if node_type == "folder" { @@ -305,14 +358,26 @@ impl NFSFileSystem for MarkBaseNFSBackend { gid: 0, size: file_size as u64, used: file_size as u64, - rdev: specdata3 { specdata1: 0, specdata2: 0 }, + rdev: specdata3 { + specdata1: 0, + specdata2: 0, + }, fsid: 0, fileid, - atime: nfstime3 { seconds: 0, nseconds: 0 }, - mtime: nfstime3 { seconds: 0, nseconds: 0 }, - ctime: nfstime3 { seconds: 0, nseconds: 0 }, + atime: nfstime3 { + seconds: 0, + nseconds: 0, + }, + mtime: nfstime3 { + seconds: 0, + nseconds: 0, + }, + ctime: nfstime3 { + seconds: 0, + nseconds: 0, + }, }; - + entries.push(DirEntry { fileid, name: nfsserve::nfs::nfsstring(label.into_bytes()), @@ -320,13 +385,10 @@ impl NFSFileSystem for MarkBaseNFSBackend { }); } } - - Ok(ReadDirResult { - entries, - end: true, - }) + + Ok(ReadDirResult { entries, end: true }) } - + async fn symlink( &self, _dirid: fileid3, @@ -336,8 +398,8 @@ impl NFSFileSystem for MarkBaseNFSBackend { ) -> Result<(fileid3, fattr3), nfsstat3> { Err(nfsstat3::NFS3ERR_ROFS) } - + async fn readlink(&self, _id: fileid3) -> Result { Err(nfsstat3::NFS3ERR_NOTSUPP) } -} \ No newline at end of file +} diff --git a/markbase-nfs/src/nfs/markbase_fs.rs b/markbase-nfs/src/nfs/markbase_fs.rs index ac95300..90cfc1f 100644 --- a/markbase-nfs/src/nfs/markbase_fs.rs +++ b/markbase-nfs/src/nfs/markbase_fs.rs @@ -8,7 +8,7 @@ use vfs::error::VfsErrorKind; use vfs::{FileSystem, SeekAndRead, SeekAndWrite, VfsFileType, VfsMetadata, VfsResult}; fn rusqlite_to_io_error(e: rusqlite::Error) -> io::Error { - io::Error::new(io::ErrorKind::Other, e.to_string()) + io::Error::other(e.to_string()) } #[derive(Debug)] @@ -42,7 +42,7 @@ impl MarkBaseFS { } fn resolve_path(&self, path: &str) -> VfsResult { - if path == "" || path == "/" { + if path.is_empty() || path == "/" { return Ok(FileNode { node_id: "root".to_string(), label: "".to_string(), @@ -89,7 +89,7 @@ impl MarkBaseFS { file_size: row.get(5)?, }) }) - .map_err(|e| rusqlite_to_io_error(e)) + .map_err(rusqlite_to_io_error) } else { let part_str = part.to_string(); stmt.query_row([current_parent.clone().unwrap(), part_str], |row| { @@ -102,7 +102,7 @@ impl MarkBaseFS { file_size: row.get(5)?, }) }) - .map_err(|e| rusqlite_to_io_error(e)) + .map_err(rusqlite_to_io_error) }; match node { @@ -125,7 +125,7 @@ impl FileSystem for MarkBaseFS { .lock() .map_err(|_| VfsErrorKind::Other("Failed to lock connection".to_string()))?; - let parent_id = if path == "" || path == "/" { + let parent_id = if path.is_empty() || path == "/" { None } else { let node = self.resolve_path(path)?; @@ -144,14 +144,14 @@ impl FileSystem for MarkBaseFS { let children: Vec = if parent_id.is_none() { stmt.query_map([], |row| row.get::<_, String>(0)) - .map_err(|e| rusqlite_to_io_error(e))? + .map_err(rusqlite_to_io_error)? .collect::, _>>() - .map_err(|e| rusqlite_to_io_error(e))? + .map_err(rusqlite_to_io_error)? } else { stmt.query_map([parent_id.unwrap()], |row| row.get::<_, String>(0)) - .map_err(|e| rusqlite_to_io_error(e))? + .map_err(rusqlite_to_io_error)? .collect::, _>>() - .map_err(|e| rusqlite_to_io_error(e))? + .map_err(rusqlite_to_io_error)? }; Ok(Box::new(children.into_iter())) @@ -170,7 +170,7 @@ impl FileSystem for MarkBaseFS { let aliases_json = node.aliases_json.ok_or(VfsErrorKind::FileNotFound)?; let aliases: serde_json::Value = serde_json::from_str(&aliases_json).map_err(|e| { - VfsErrorKind::IoError(io::Error::new(io::ErrorKind::Other, e.to_string())) + VfsErrorKind::IoError(io::Error::other(e.to_string())) })?; let file_path = aliases["path"].as_str().ok_or(VfsErrorKind::FileNotFound)?; diff --git a/markbase-nfs/src/nfs/mod.rs b/markbase-nfs/src/nfs/mod.rs index 167a9de..879f320 100644 --- a/markbase-nfs/src/nfs/mod.rs +++ b/markbase-nfs/src/nfs/mod.rs @@ -1,7 +1,7 @@ -pub mod markbase_fs; pub mod backend; +pub mod markbase_fs; pub mod server; -pub use markbase_fs::MarkBaseFS; pub use backend::MarkBaseNFSBackend; -pub use server::{start_nfs_server, run_nfs_server}; +pub use markbase_fs::MarkBaseFS; +pub use server::{run_nfs_server, start_nfs_server}; diff --git a/markbase-nfs/src/nfs/server.rs b/markbase-nfs/src/nfs/server.rs index cd9ffb7..296d71b 100644 --- a/markbase-nfs/src/nfs/server.rs +++ b/markbase-nfs/src/nfs/server.rs @@ -1,22 +1,21 @@ use std::path::PathBuf; -use std::sync::Arc; -use nfsserve::tcp::{NFSTcpListener, NFSTcp}; +use nfsserve::tcp::{NFSTcp, NFSTcpListener}; use tokio::signal; use crate::nfs::backend::MarkBaseNFSBackend; pub async fn start_nfs_server(user_id: String, db_path: PathBuf, port: u16) -> anyhow::Result<()> { let backend = MarkBaseNFSBackend::new(user_id, db_path)?; - + let bind_addr = format!("127.0.0.1:{}", port); - + let mut listener = NFSTcpListener::bind(&bind_addr, backend) .await .map_err(|e| anyhow::anyhow!("Failed to bind NFS server: {}", e))?; - + listener.with_export_name("markbase"); - + let listen_port = listener.get_listen_port(); eprintln!("[NFS] MarkBase NFS server started on port {}", listen_port); eprintln!("[NFS] Mount command (Mac):"); @@ -25,7 +24,7 @@ pub async fn start_nfs_server(user_id: String, db_path: PathBuf, port: u16) -> a eprintln!("[NFS] Mount command (Linux):"); eprintln!("[NFS] mkdir /tmp/markbase_mount"); eprintln!("[NFS] mount.nfs -o user,noacl,nolock,vers=3,tcp,port={},mountport={} localhost:/markbase /tmp/markbase_mount", listen_port, listen_port); - + tokio::select! { _ = listener.handle_forever() => { eprintln!("[NFS] Server stopped"); @@ -34,11 +33,10 @@ pub async fn start_nfs_server(user_id: String, db_path: PathBuf, port: u16) -> a eprintln!("[NFS] Received Ctrl+C, shutting down..."); } } - + Ok(()) } pub fn run_nfs_server(user_id: String, db_path: PathBuf, port: u16) -> anyhow::Result<()> { - tokio::runtime::Runtime::new()? - .block_on(start_nfs_server(user_id, db_path, port)) -} \ No newline at end of file + tokio::runtime::Runtime::new()?.block_on(start_nfs_server(user_id, db_path, port)) +} diff --git a/markbase-raid/src/raid/controller.rs b/markbase-raid/src/raid/controller.rs index 8098c47..c27b691 100644 --- a/markbase-raid/src/raid/controller.rs +++ b/markbase-raid/src/raid/controller.rs @@ -22,6 +22,12 @@ pub struct RaidController { arrays: Mutex>>, } +impl Default for RaidController { + fn default() -> Self { + Self::new() + } +} + impl RaidController { pub fn new() -> Self { RaidController { diff --git a/markbase-raid/src/raid/level_5.rs b/markbase-raid/src/raid/level_5.rs index ea59669..1a550e4 100644 --- a/markbase-raid/src/raid/level_5.rs +++ b/markbase-raid/src/raid/level_5.rs @@ -136,7 +136,7 @@ impl RaidAlgorithm for Raid5 { self.stripe_size - (current_offset % self.stripe_size), ); - let chunk_data = &data[data_pos as usize..(data_pos + chunk_size as usize) as usize]; + let chunk_data = &data[data_pos..(data_pos + chunk_size as usize) as usize]; let old_data = self.read_from_member(data_disk, physical_offset, chunk_size)?; let old_parity = self.read_from_member(parity_disk, physical_offset, chunk_size)?; diff --git a/markbase-smb/src/acl.rs b/markbase-smb/src/acl.rs index 76e9e79..b53df38 100644 --- a/markbase-smb/src/acl.rs +++ b/markbase-smb/src/acl.rs @@ -28,15 +28,15 @@ impl UserPermission { admin_access: admin, } } - + pub fn readonly(username: String) -> Self { UserPermission::new(username, true, false, false) } - + pub fn full_access(username: String) -> Self { UserPermission::new(username, true, true, false) } - + pub fn admin(username: String) -> Self { UserPermission::new(username, true, true, true) } @@ -57,28 +57,32 @@ impl AccessControlList { max_connections: 10, } } - + pub fn add_user(&mut self, permission: UserPermission) { - if let Some(existing) = self.users.iter_mut().find(|u| u.username == permission.username) { + if let Some(existing) = self + .users + .iter_mut() + .find(|u| u.username == permission.username) + { *existing = permission; } else { self.users.push(permission); } } - + pub fn remove_user(&mut self, username: &str) { self.users.retain(|u| u.username != username); } - + pub fn get_user(&self, username: &str) -> Option<&UserPermission> { self.users.iter().find(|u| u.username == username) } - + pub fn has_access(&self, username: &str, require_write: bool) -> bool { if self.guest_access && !require_write { return true; } - + self.get_user(username) .map(|u| { if require_write { @@ -89,4 +93,4 @@ impl AccessControlList { }) .unwrap_or(false) } -} \ No newline at end of file +} diff --git a/markbase-smb/src/auth.rs b/markbase-smb/src/auth.rs index 20edb9d..93313b4 100644 --- a/markbase-smb/src/auth.rs +++ b/markbase-smb/src/auth.rs @@ -1,5 +1,5 @@ -use anyhow::Result; use crate::acl::AccessControlList; +use anyhow::Result; pub struct AuthManager { acl: AccessControlList, @@ -9,26 +9,26 @@ impl AuthManager { pub fn new(acl: AccessControlList) -> Self { AuthManager { acl } } - + pub fn authenticate(&self, username: &str, password: Option<&str>) -> Result { if self.acl.guest_access && password.is_none() { return Ok(true); } - + if password.is_none() { return Err(anyhow::anyhow!("Password required for user {}", username)); } - + if self.acl.get_user(username).is_none() { return Err(anyhow::anyhow!("User {} not in ACL", username)); } - + Ok(true) } - + pub fn check_permission(&self, username: &str, action: &str) -> Result { let require_write = action == "write" || action == "delete" || action == "create"; - + if !self.acl.has_access(username, require_write) { return Err(anyhow::anyhow!( "User {} does not have {} permission", @@ -36,15 +36,15 @@ impl AuthManager { action )); } - + Ok(true) } - + pub fn get_acl(&self) -> &AccessControlList { &self.acl } - + pub fn update_acl(&mut self, acl: AccessControlList) { self.acl = acl; } -} \ No newline at end of file +} diff --git a/markbase-smb/src/config.rs b/markbase-smb/src/config.rs index 4636e3f..56349ab 100644 --- a/markbase-smb/src/config.rs +++ b/markbase-smb/src/config.rs @@ -34,7 +34,7 @@ impl SMBConfig { allow_users: vec!["accusys".to_string()], } } - + pub fn to_smb_conf(&self) -> String { format!( "[{}]\n path = {}\n comment = {}\n read only = {}\n browseable = {}\n valid users = {}\n", @@ -46,4 +46,4 @@ impl SMBConfig { self.allow_users.join(", ") ) } -} \ No newline at end of file +} diff --git a/markbase-smb/src/lib.rs b/markbase-smb/src/lib.rs index 7629a57..982df44 100644 --- a/markbase-smb/src/lib.rs +++ b/markbase-smb/src/lib.rs @@ -1,11 +1,11 @@ -pub mod config; -pub mod manager; pub mod acl; pub mod auth; +pub mod config; +pub mod manager; pub mod monitor; -pub use config::SMBConfig; -pub use manager::SMBManager; pub use acl::{AccessControlList, UserPermission}; pub use auth::AuthManager; -pub use monitor::{SMBMonitor, ConnectionStats, AccessLogEntry}; \ No newline at end of file +pub use config::SMBConfig; +pub use manager::SMBManager; +pub use monitor::{AccessLogEntry, ConnectionStats, SMBMonitor}; diff --git a/markbase-smb/src/main.rs b/markbase-smb/src/main.rs index 3d071ed..f8eb586 100644 --- a/markbase-smb/src/main.rs +++ b/markbase-smb/src/main.rs @@ -1,5 +1,7 @@ use clap::Parser; -use markbase_smb::{SMBConfig, SMBManager, AccessControlList, UserPermission, AuthManager, SMBMonitor}; +use markbase_smb::{ + AccessControlList, AuthManager, SMBConfig, SMBManager, SMBMonitor, UserPermission, +}; #[derive(Parser)] #[command(name = "markbase-smb")] @@ -16,34 +18,34 @@ enum Commands { /// Share name #[arg(short, long, default_value = "markbase")] name: String, - + /// Path to share #[arg(short, long, default_value = "/Users/accusys/momentry/var/sftpgo/data")] path: String, }, - + /// Remove SMB share Remove { /// Share name #[arg(short, long)] name: String, }, - + /// List existing SMB shares List, - + /// Show SMB status Status, - + /// Manage user permissions User { #[command(subcommand)] action: UserCommands, }, - + /// Show monitoring stats Stats, - + /// Show access logs Logs { /// Number of log entries to show @@ -58,24 +60,24 @@ enum UserCommands { Add { #[arg(short, long)] username: String, - + #[arg(short, long, default_value = "readonly")] permission: String, }, - + /// Remove user permission Remove { #[arg(short, long)] username: String, }, - + /// List all user permissions List, } fn main() -> anyhow::Result<()> { let cli = Cli::parse(); - + match cli.command { Commands::Create { name, path } => { let config = SMBConfig::new(name, path); @@ -104,32 +106,33 @@ fn main() -> anyhow::Result<()> { let status = manager.status()?; println!("{}", serde_json::to_string_pretty(&status)?); } - Commands::User { action } => { - match action { - UserCommands::Add { username, permission } => { - let acl = AccessControlList::new(); - let perm = match permission.as_str() { - "readonly" => UserPermission::readonly(username), - "full" => UserPermission::full_access(username), - "admin" => UserPermission::admin(username), - _ => UserPermission::readonly(username), - }; - - println!("User permission configuration:"); - println!("{}", serde_json::to_string_pretty(&perm)?); - println!("\nTo apply, update system SMB configuration with this user."); - } - UserCommands::Remove { username } => { - println!("Removing user '{}' from ACL", username); - println!("To apply, update system SMB configuration."); - } - UserCommands::List => { - let acl = AccessControlList::new(); - println!("Default ACL configuration:"); - println!("{}", serde_json::to_string_pretty(&acl)?); - } + Commands::User { action } => match action { + UserCommands::Add { + username, + permission, + } => { + let acl = AccessControlList::new(); + let perm = match permission.as_str() { + "readonly" => UserPermission::readonly(username), + "full" => UserPermission::full_access(username), + "admin" => UserPermission::admin(username), + _ => UserPermission::readonly(username), + }; + + println!("User permission configuration:"); + println!("{}", serde_json::to_string_pretty(&perm)?); + println!("\nTo apply, update system SMB configuration with this user."); } - } + UserCommands::Remove { username } => { + println!("Removing user '{}' from ACL", username); + println!("To apply, update system SMB configuration."); + } + UserCommands::List => { + let acl = AccessControlList::new(); + println!("Default ACL configuration:"); + println!("{}", serde_json::to_string_pretty(&acl)?); + } + }, Commands::Stats => { let monitor = SMBMonitor::new(); let stats = monitor.get_stats(); @@ -149,6 +152,6 @@ fn main() -> anyhow::Result<()> { } } } - + Ok(()) -} \ No newline at end of file +} diff --git a/markbase-smb/src/manager.rs b/markbase-smb/src/manager.rs index 3856f76..0ca0421 100644 --- a/markbase-smb/src/manager.rs +++ b/markbase-smb/src/manager.rs @@ -12,72 +12,68 @@ impl SMBManager { pub fn new(config: SMBConfig) -> Self { SMBManager { config } } - + pub fn check_smb_service() -> Result { - let output = Command::new("sharing") - .arg("-l") - .output()?; - + let output = Command::new("sharing").arg("-l").output()?; + let status = String::from_utf8_lossy(&output.stdout); Ok(!status.contains("No share point records")) } - + pub fn create_share(&self) -> Result<()> { let path = Path::new(&self.config.path); if !path.exists() { return Err(anyhow::anyhow!("Path does not exist: {}", self.config.path)); } - - eprintln!("Creating SMB share '{}' for path: {}", self.config.share_name, self.config.path); - + + eprintln!( + "Creating SMB share '{}' for path: {}", + self.config.share_name, self.config.path + ); + let smb_conf_content = self.config.to_smb_conf(); eprintln!("Generated smb.conf section:\n{}", smb_conf_content); - + eprintln!("\nTo enable SMB sharing, run:"); - eprintln!("sudo sharing -a \"{}\" -S \"{}\"", self.config.path, self.config.share_name); - + eprintln!( + "sudo sharing -a \"{}\" -S \"{}\"", + self.config.path, self.config.share_name + ); + Ok(()) } - + pub fn remove_share(&self) -> Result<()> { eprintln!("Removing SMB share '{}'...", self.config.share_name); - + eprintln!("To remove SMB sharing, run:"); eprintln!("sudo sharing -r \"{}\"", self.config.share_name); - + Ok(()) } - + pub fn list_shares() -> Result> { - let output = Command::new("sharing") - .arg("-l") - .output()?; - + let output = Command::new("sharing").arg("-l").output()?; + let status = String::from_utf8_lossy(&output.stdout); - + if status.contains("No share point records") { return Ok(vec![]); } - + let shares: Vec = status .lines() .filter(|line| line.contains("name:")) - .map(|line| { - line.split(":") - .nth(1) - .unwrap_or("") - .trim() - .to_string() - }) + .map(|line| line.split(":").nth(1).unwrap_or("").trim().to_string()) .collect(); - + Ok(shares) } - + pub fn status(&self) -> Result { let service_running = Self::check_smb_service()?; let shares = Self::list_shares()?; - + Ok(serde_json::json!({ "service_running": service_running, "share_name": self.config.share_name, @@ -86,4 +82,4 @@ impl SMBManager { "config": self.config, })) } -} \ No newline at end of file +} diff --git a/markbase-smb/src/monitor.rs b/markbase-smb/src/monitor.rs index d285063..4655b90 100644 --- a/markbase-smb/src/monitor.rs +++ b/markbase-smb/src/monitor.rs @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize}; use std::time::{Duration, SystemTime}; #[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Default)] pub struct ConnectionStats { pub total_connections: u64, pub active_connections: u32, @@ -12,19 +13,6 @@ pub struct ConnectionStats { pub uptime_seconds: u64, } -impl Default for ConnectionStats { - fn default() -> Self { - ConnectionStats { - total_connections: 0, - active_connections: 0, - read_operations: 0, - write_operations: 0, - errors: 0, - bytes_transferred: 0, - uptime_seconds: 0, - } - } -} #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AccessLogEntry { @@ -38,12 +26,19 @@ pub struct AccessLogEntry { } impl AccessLogEntry { - pub fn new(username: String, action: String, path: String, success: bool, bytes: u64, duration: Duration) -> Self { + pub fn new( + username: String, + action: String, + path: String, + success: bool, + bytes: u64, + duration: Duration, + ) -> Self { let timestamp = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap() .as_secs(); - + AccessLogEntry { timestamp: timestamp.to_string(), username, @@ -62,6 +57,12 @@ pub struct SMBMonitor { start_time: SystemTime, } +impl Default for SMBMonitor { + fn default() -> Self { + Self::new() + } +} + impl SMBMonitor { pub fn new() -> Self { SMBMonitor { @@ -70,10 +71,10 @@ impl SMBMonitor { start_time: SystemTime::now(), } } - + pub fn log_access(&mut self, entry: AccessLogEntry) { self.logs.push(entry.clone()); - + if entry.success { if entry.action == "read" { self.stats.read_operations += 1; @@ -85,28 +86,28 @@ impl SMBMonitor { self.stats.errors += 1; } } - + pub fn connection_opened(&mut self) { self.stats.total_connections += 1; self.stats.active_connections += 1; } - + pub fn connection_closed(&mut self) { self.stats.active_connections -= 1; } - + pub fn get_stats(&self) -> ConnectionStats { let uptime = SystemTime::now() .duration_since(self.start_time) .unwrap() .as_secs(); - + let mut stats = self.stats.clone(); stats.uptime_seconds = uptime; stats } - + pub fn get_logs(&self, limit: usize) -> Vec { self.logs.iter().rev().take(limit).cloned().collect() } -} \ No newline at end of file +} diff --git a/markbase-webdav/src/webdav/lock_manager.rs b/markbase-webdav/src/webdav/lock_manager.rs index d220ea9..f7b6cb9 100644 --- a/markbase-webdav/src/webdav/lock_manager.rs +++ b/markbase-webdav/src/webdav/lock_manager.rs @@ -146,7 +146,7 @@ impl DavLockSystem for LockManager { let path_owned = path.clone(); let token = format!("urn:uuid:{}", Uuid::new_v4()); let principal_str = principal.map(|s| s.to_string()); - let owner_clone = owner.map(|e| e.clone()); + let owner_clone = owner.cloned(); let owner_xml = owner.and_then(|e| { let mut buf = Vec::new(); e.write(&mut buf).ok()?; @@ -175,7 +175,7 @@ impl DavLockSystem for LockManager { token: String::new(), path: Box::new(path_owned.clone()), principal: principal_str.clone(), - owner: owner_clone.map(|e| Box::new(e)), + owner: owner_clone.map(Box::new), timeout_at: None, timeout, shared, @@ -229,7 +229,7 @@ impl DavLockSystem for LockManager { token, path: Box::new(path_owned.clone()), principal: principal_str, - owner: owner_clone.map(|e| Box::new(e)), + owner: owner_clone.map(Box::new), timeout_at: timeout_at .map(|t| SystemTime::UNIX_EPOCH + Duration::from_secs(t as u64)), timeout, @@ -400,31 +400,29 @@ impl DavLockSystem for LockManager { deep: false, })?; - for lock_result in locks { - if let Ok(lock) = lock_result { - if tokens.contains(&lock.token) { - continue; - } + for lock in locks.flatten() { + if tokens.contains(&lock.token) { + continue; + } - if ignore_principal { - continue; - } + if ignore_principal { + continue; + } - if let Some(ref lock_principal) = lock.principal { - if let Some(ref check_principal) = principal_str { - if lock_principal == check_principal { - continue; - } + if let Some(ref lock_principal) = lock.principal { + if let Some(ref check_principal) = principal_str { + if lock_principal == check_principal { + continue; } } + } - if deep && lock.deep { - return Err(lock); - } + if deep && lock.deep { + return Err(lock); + } - if !deep { - return Err(lock); - } + if !deep { + return Err(lock); } } diff --git a/rsync_50m.bin b/rsync_50m.bin new file mode 100644 index 0000000..8b7c055 Binary files /dev/null and b/rsync_50m.bin differ diff --git a/rust-iscsi-initiator/src/connection/mod.rs b/rust-iscsi-initiator/src/connection/mod.rs index 9313071..bb7b0ed 100644 --- a/rust-iscsi-initiator/src/connection/mod.rs +++ b/rust-iscsi-initiator/src/connection/mod.rs @@ -77,7 +77,7 @@ impl IscsiConnection { } // Parse login parameters - if response.data.len() > 0 { + if !response.data.is_empty() { let params = String::from_utf8_lossy(&response.data); log::info!("Login response: {}", params); } diff --git a/rust-iscsi-initiator/src/discovery/mod.rs b/rust-iscsi-initiator/src/discovery/mod.rs index 35e0c6e..e24c986 100644 --- a/rust-iscsi-initiator/src/discovery/mod.rs +++ b/rust-iscsi-initiator/src/discovery/mod.rs @@ -33,7 +33,7 @@ impl Discovery { let response = conn.recv_pdu().await?; // Parse SendTargets response - if response.data.len() > 0 { + if !response.data.is_empty() { let targets_str = String::from_utf8_lossy(&response.data); let targets = targets_str .lines() diff --git a/rust-iscsi-initiator/src/lib.rs b/rust-iscsi-initiator/src/lib.rs index a817fd2..55c0e17 100644 --- a/rust-iscsi-initiator/src/lib.rs +++ b/rust-iscsi-initiator/src/lib.rs @@ -35,6 +35,12 @@ pub struct Initiator { connections: Vec, } +impl Default for Initiator { + fn default() -> Self { + Self::new() + } +} + impl Initiator { pub fn new() -> Self { Self { diff --git a/tests/s3_vfs_test.rs b/tests/s3_vfs_test.rs new file mode 100644 index 0000000..26b0300 --- /dev/null +++ b/tests/s3_vfs_test.rs @@ -0,0 +1,336 @@ +use std::path::Path; +use std::time::Duration; + +/// Helper to create an S3Vfs instance pointing to local MinIO. +fn make_s3_vfs() -> markbase_core::vfs::s3_fs::S3Vfs { + markbase_core::vfs::s3_fs::S3Vfs::new( + "http://127.0.0.1:9000", + "us-east-1", + "test-bucket", + "minioadmin", + "minioadmin", + ) + .expect("Failed to create S3Vfs") +} + +/// Helper to check if MinIO is reachable (skip test if not). +fn minio_reachable() -> bool { + use std::net::TcpStream; + TcpStream::connect_timeout( + &"127.0.0.1:9000".parse().unwrap(), + Duration::from_secs(2), + ) + .is_ok() +} + +fn cleanup(vfs: &impl markbase_core::vfs::VfsBackend, path: &Path) { + if vfs.exists(path) { + let _ = vfs.remove_file(path); + } + // Try parent dirs + let mut p = path.parent(); + while let Some(dir) = p { + if dir == Path::new("/") { + break; + } + let dk = format!("{}/", dir.display()); + let _ = vfs.remove_dir(Path::new(&dk)); + p = dir.parent(); + } +} + +#[test] +fn test_create_dir() { + if !minio_reachable() { + eprintln!("MinIO not reachable, skipping test_create_dir"); + return; + } + let vfs = make_s3_vfs(); + let dir = Path::new("/test-create-dir/"); + + if vfs.exists(dir) { + vfs.remove_dir(dir).ok(); + } + + assert!(!vfs.exists(dir)); + vfs.create_dir(dir, 0o755).expect("create_dir failed"); + assert!(vfs.exists(dir)); + + let stat = vfs.stat(dir).expect("stat failed"); + assert!(stat.is_dir); + + vfs.remove_dir(dir).expect("remove_dir failed"); + assert!(!vfs.exists(dir)); +} + +#[test] +fn test_create_dir_all() { + if !minio_reachable() { + eprintln!("MinIO not reachable, skipping test_create_dir_all"); + return; + } + let vfs = make_s3_vfs(); + let nested = Path::new("/test-nested/a/b/c/"); + + // Clean up from previous runs + for d in [Path::new("/test-nested/a/b/c/"), Path::new("/test-nested/a/b/"), Path::new("/test-nested/a/"), Path::new("/test-nested/")] { + if vfs.exists(d) { + vfs.remove_dir(d).ok(); + } + } + + vfs.create_dir_all(nested, 0o755).expect("create_dir_all failed"); + + assert!(vfs.exists(Path::new("/test-nested/"))); + assert!(vfs.exists(Path::new("/test-nested/a/"))); + assert!(vfs.exists(Path::new("/test-nested/a/b/"))); + assert!(vfs.exists(Path::new("/test-nested/a/b/c/"))); + + for d in [Path::new("/test-nested/a/b/c/"), Path::new("/test-nested/a/b/"), Path::new("/test-nested/a/"), Path::new("/test-nested/")] { + vfs.remove_dir(d).expect("cleanup remove_dir failed"); + } +} + +#[test] +fn test_write_and_read_file() { + if !minio_reachable() { + eprintln!("MinIO not reachable, skipping test_write_and_read_file"); + return; + } + let vfs = make_s3_vfs(); + let file_path = Path::new("/test-write-read.txt"); + cleanup(&vfs, file_path); + + use markbase_core::vfs::open_flags::OpenFlags; + + // Write + let mut f = vfs + .open_file(file_path, &OpenFlags::write()) + .expect("open for write failed"); + let content = b"Hello S3Vfs! This is a test."; + f.write_all(content).expect("write_all failed"); + f.flush().expect("flush failed"); + + // Stat + let stat = vfs.stat(file_path).expect("stat failed"); + assert_eq!(stat.size, content.len() as u64); + assert!(!stat.is_dir); + + // Read + let mut f = vfs + .open_file(file_path, &OpenFlags::read()) + .expect("open for read failed"); + let mut buf = vec![0u8; content.len()]; + f.read_exact(&mut buf).expect("read_exact failed"); + assert_eq!(buf, content); + + // Exists + assert!(vfs.exists(file_path)); + + // Cleanup + vfs.remove_file(file_path).expect("remove_file failed"); + assert!(!vfs.exists(file_path)); +} + +#[test] +fn test_read_dir() { + if !minio_reachable() { + eprintln!("MinIO not reachable, skipping test_read_dir"); + return; + } + let vfs = make_s3_vfs(); + let dir = Path::new("/test-readdir/"); + let file1 = Path::new("/test-readdir/file1.txt"); + let file2 = Path::new("/test-readdir/file2.txt"); + let subdir = Path::new("/test-readdir/sub/"); + + // Clean + for p in &[file1, file2] { + cleanup(&vfs, p); + } + if vfs.exists(subdir) { + vfs.remove_dir(subdir).ok(); + } + if vfs.exists(dir) { + vfs.remove_dir(dir).ok(); + } + + // Create structure + vfs.create_dir(dir, 0o755).expect("create_dir failed"); + + let content = b"data"; + let flags = markbase_core::vfs::open_flags::OpenFlags::write(); + for p in &[file1, file2] { + let mut f = vfs.open_file(p, &flags).expect("open for write failed"); + f.write_all(content).expect("write_all failed"); + f.flush().expect("flush failed"); + } + vfs.create_dir(subdir, 0o755).expect("create subdir failed"); + + // Read directory + let entries = vfs.read_dir(dir).expect("read_dir failed"); + let names: Vec<&str> = entries.iter().map(|e| e.name.as_str()).collect(); + assert!(names.contains(&"file1.txt"), "missing file1.txt: {:?}", names); + assert!(names.contains(&"file2.txt"), "missing file2.txt: {:?}", names); + assert!(names.contains(&"sub"), "missing sub: {:?}", names); + + // Check entry types + for entry in &entries { + if entry.name == "file1.txt" || entry.name == "file2.txt" { + assert!(!entry.stat.is_dir, "{} should be a file", entry.name); + } + if entry.name == "sub" { + assert!(entry.stat.is_dir, "{} should be a dir", entry.name); + } + } + + // Cleanup + vfs.remove_file(file1).ok(); + vfs.remove_file(file2).ok(); + vfs.remove_dir(subdir).ok(); + vfs.remove_dir(dir).ok(); +} + +#[test] +fn test_rename() { + if !minio_reachable() { + eprintln!("MinIO not reachable, skipping test_rename"); + return; + } + let vfs = make_s3_vfs(); + let from = Path::new("/test-rename-src.txt"); + let to = Path::new("/test-rename-dst.txt"); + cleanup(&vfs, from); + cleanup(&vfs, to); + + let content = b"rename test data"; + let flags = markbase_core::vfs::open_flags::OpenFlags::write(); + { + let mut f = vfs.open_file(from, &flags).expect("open for write"); + f.write_all(content).expect("write_all"); + f.flush().expect("flush"); + } + assert!(vfs.exists(from)); + + vfs.rename(from, to).expect("rename failed"); + assert!(!vfs.exists(from)); + assert!(vfs.exists(to)); + + // Verify content + let mut f = vfs + .open_file(to, &markbase_core::vfs::open_flags::OpenFlags::read()) + .expect("open for read"); + let mut buf = vec![0u8; content.len()]; + f.read_exact(&mut buf).expect("read_exact"); + assert_eq!(buf, content); + + vfs.remove_file(to).ok(); +} + +#[test] +fn test_hard_link() { + if !minio_reachable() { + eprintln!("MinIO not reachable, skipping test_hard_link"); + return; + } + let vfs = make_s3_vfs(); + let original = Path::new("/test-link-orig.txt"); + let link = Path::new("/test-link-copy.txt"); + cleanup(&vfs, original); + cleanup(&vfs, link); + + let content = b"hard link test"; + let flags = markbase_core::vfs::open_flags::OpenFlags::write(); + { + let mut f = vfs.open_file(original, &flags).expect("open for write"); + f.write_all(content).expect("write_all"); + f.flush().expect("flush"); + } + + vfs.hard_link(original, link).expect("hard_link failed"); + assert!(vfs.exists(original)); + assert!(vfs.exists(link)); + + // Verify link content + let mut f = vfs + .open_file(link, &markbase_core::vfs::open_flags::OpenFlags::read()) + .expect("open for read"); + let mut buf = vec![0u8; content.len()]; + f.read_exact(&mut buf).expect("read_exact"); + assert_eq!(buf, content); + + vfs.remove_file(original).ok(); + vfs.remove_file(link).ok(); +} + +#[test] +fn test_remove_dir_not_empty() { + if !minio_reachable() { + eprintln!("MinIO not reachable, skipping test_remove_dir_not_empty"); + return; + } + let vfs = make_s3_vfs(); + let dir = Path::new("/test-nonempty/"); + let file = Path::new("/test-nonempty/file.txt"); + cleanup(&vfs, file); + if vfs.exists(dir) { + vfs.remove_dir(dir).ok(); + } + + vfs.create_dir(dir, 0o755).expect("create_dir"); + let flags = markbase_core::vfs::open_flags::OpenFlags::write(); + { + let mut f = vfs.open_file(file, &flags).expect("open for write"); + f.write_all(b"x").expect("write"); + f.flush().expect("flush"); + } + + match vfs.remove_dir(dir) { + Err(markbase_core::vfs::VfsError::NotEmpty(_)) => { /* expected */ } + other => panic!("Expected NotEmpty, got {:?}", other), + } + + vfs.remove_file(file).ok(); + vfs.remove_dir(dir).ok(); +} + +#[test] +fn test_real_path() { + if !minio_reachable() { + eprintln!("MinIO not reachable, skipping test_real_path"); + return; + } + let vfs = make_s3_vfs(); + let dir = Path::new("/test-realpath/"); + let file = Path::new("/test-realpath/foo.txt"); + + cleanup(&vfs, file); + if vfs.exists(dir) { + vfs.remove_dir(dir).ok(); + } + + vfs.create_dir(dir, 0o755).expect("create_dir"); + let flags = markbase_core::vfs::open_flags::OpenFlags::write(); + { + let mut f = vfs.open_file(file, &flags).expect("open for write"); + f.write_all(b"test").expect("write"); + f.flush().expect("flush"); + } + + let rp = vfs.real_path(dir).expect("real_path dir"); + assert!( + rp.to_string_lossy().ends_with("test-realpath"), + "real_path dir: {:?}", + rp + ); + + let rp = vfs.real_path(file).expect("real_path file"); + assert!( + rp.to_string_lossy().ends_with("test-realpath/foo.txt"), + "real_path file: {:?}", + rp + ); + + vfs.remove_file(file).ok(); + vfs.remove_dir(dir).ok(); +} diff --git a/tests/ssh_full_integration.sh b/tests/ssh_full_integration.sh new file mode 100755 index 0000000..aa5de0d --- /dev/null +++ b/tests/ssh_full_integration.sh @@ -0,0 +1,145 @@ +#!/bin/bash +# SSH/SFTP/SCP/rsync 完整整合測試 - 自動化執行腳本 +# 用法: bash tests/ssh_full_integration.sh + +set -e + +SERVER_PORT=2024 +SSH_OPTS="-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null" +RSYNC_RSH="ssh -p $SERVER_PORT $SSH_OPTS" +PASS=0 +FAIL=0 + +log() { echo -e "\n\033[36m=== $1 ===\033[0m"; } +pass() { echo -e " \033[32m✅ $1\033[0m"; PASS=$((PASS+1)); } +fail() { echo -e " \033[31m❌ $1\033[0m"; FAIL=$((FAIL+1)); } + +cleanup() { + rm -f /tmp/sftp_test_*.bin /tmp/scp_*.bin /tmp/rsync_*.bin /tmp/batch_*.bin + rm -rf /tmp/rsync_dir* /tmp/scp_dir* /tmp/parallel_* +} + +# ── 建立測試檔案 ── +log "建立測試檔案" +dd if=/dev/urandom of=/tmp/sftp_test_2m.bin bs=1M count=2 2>/dev/null +dd if=/dev/urandom of=/tmp/sftp_test_5m.bin bs=1M count=5 2>/dev/null +SRC_MD5=$(md5 /tmp/sftp_test_5m.bin | awk '{print $NF}') +echo " 5MB source MD5: $SRC_MD5" + +# ── 1. SSH 連線 ── +log "1. SSH Basic Connection" +if timeout 10 ssh -p $SERVER_PORT $SSH_OPTS demo@127.0.0.1 'echo "SSH_OK"' 2>/dev/null | grep -q SSH_OK; then + pass "SSH connection + auth" +else + fail "SSH connection" +fi + +# ── 2. SFTP 操作 ── +log "2. SFTP Upload + Download" +timeout 60 sftp -P $SERVER_PORT $SSH_OPTS demo@127.0.0.1 << EOF 2>/dev/null >/dev/null +put /tmp/sftp_test_5m.bin sftp_5m.bin +get sftp_5m.bin /tmp/sftp_test_5m_dl.bin +rm sftp_5m.bin +bye +EOF + +DL_MD5=$(md5 /tmp/sftp_test_5m_dl.bin 2>/dev/null | awk '{print $NF}') +if [ "$SRC_MD5" = "$DL_MD5" ]; then + pass "SFTP 5MB (MD5: $SRC_MD5)" +else + fail "SFTP 5MB MD5 mismatch: src=$SRC_MD5 dl=$DL_MD5" +fi + +# ── 3. SFTP Enterprise Errors ── +log "3. SFTP Error Handling" +ERRORS=$(timeout 10 sftp -P $SERVER_PORT $SSH_OPTS demo@127.0.0.1 << 'EOF' 2>&1 +mkdir /etc/forbidden +stat /nonexistent +bye +EOF +) +if echo "$ERRORS" | grep -iq "permission denied"; then + pass "SFTP Permission denied" +else + fail "SFTP Permission denied missing" +fi +if echo "$ERRORS" | grep -iq "no such file"; then + pass "SFTP No such file" +else + fail "SFTP No such file missing" +fi + +# ── 4. SFTP Batch Files ── +log "4. SFTP Batch Transfer" +for i in 1 2 3 4 5; do + dd if=/dev/urandom of=/tmp/batch_${i}.bin bs=1K count=64 2>/dev/null +done + +timeout 60 sftp -P $SERVER_PORT $SSH_OPTS demo@127.0.0.1 << EOF 2>/dev/null >/dev/null +lcd /tmp +put batch_1.bin batch_2.bin batch_3.bin batch_4.bin batch_5.bin +get batch_1.bin batch_2.bin batch_3.bin batch_4.bin batch_5.bin +rm batch_1.bin batch_2.bin batch_3.bin batch_4.bin batch_5.bin +bye +EOF + +ALL_OK=true +for i in 1 2 3 4 5; do + S=$(md5 /tmp/batch_${i}.bin | awk '{print $NF}') + D=$(md5 /tmp/batch_${i}.bin 2>/dev/null | awk '{print $NF}') + [ "$S" != "$D" ] && ALL_OK=false +done +$ALL_OK && pass "SFTP Batch (5 files)" || fail "SFTP Batch MD5 mismatch" + +# ── 5. SCP ── +log "5. SCP Transfer" +timeout 30 scp -P $SERVER_PORT $SSH_OPTS /tmp/sftp_test_5m.bin demo@127.0.0.1:scp_test.bin 2>/dev/null +timeout 30 scp -P $SERVER_PORT $SSH_OPTS demo@127.0.0.1:scp_test.bin /tmp/scp_dl.bin 2>/dev/null + +SCP_MD5=$(md5 /tmp/scp_dl.bin 2>/dev/null | awk '{print $NF}') +if [ "$SRC_MD5" = "$SCP_MD5" ]; then + pass "SCP 5MB (MD5: $SCP_MD5)" +else + fail "SCP MD5 mismatch" +fi + +# ── 6. rsync ── +log "6. rsync Transfer" +timeout 30 rsync -avz --rsh="$RSYNC_RSH" /tmp/sftp_test_5m.bin demo@127.0.0.1:rsync_test.bin 2>/dev/null >/dev/null +timeout 30 rsync -avz --rsh="$RSYNC_RSH" demo@127.0.0.1:rsync_test.bin /tmp/rsync_dl.bin 2>/dev/null >/dev/null + +RSYNC_MD5=$(md5 /tmp/rsync_dl.bin 2>/dev/null | awk '{print $NF}') +if [ "$SRC_MD5" = "$RSYNC_MD5" ]; then + pass "rsync 5MB (MD5: $RSYNC_MD5)" +else + fail "rsync MD5 mismatch" +fi + +# ── 7. Delta rsync ── +log "7. rsync Delta Transfer" +echo "delta" >> /tmp/sftp_test_5m.bin +timeout 30 rsync -avz --rsh="$RSYNC_RSH" /tmp/sftp_test_5m.bin demo@127.0.0.1:rsync_test.bin 2>&1 | grep -q "speedup" && \ + pass "rsync delta (small change)" || fail "rsync delta" + +# ── 8. Shell Commands ── +log "8. SSH Shell Commands" +OUT=$(timeout 10 ssh -p $SERVER_PORT $SSH_OPTS demo@127.0.0.1 'echo hello; whoami' 2>/dev/null) +echo "$OUT" | grep -q hello && echo "$OUT" | grep -q demo && \ + pass "SSH shell commands" || fail "SSH shell commands" + +# ── 9. Stress: 10 consecutive connections ── +log "9. Stress Test (10 connections)" +CONN_OK=0 +for i in $(seq 1 10); do + timeout 10 ssh -p $SERVER_PORT $SSH_OPTS demo@127.0.0.1 'echo OK' 2>/dev/null | grep -q OK && CONN_OK=$((CONN_OK+1)) +done +[ "$CONN_OK" -eq 10 ] && pass "10/10 connections OK" || fail "$CONN_OK/10 connections OK" + +# ── Summary ── +log "SUMMARY" +echo -e " \033[32mPassed: $PASS\033[0m" +echo -e " \033[31mFailed: $FAIL\033[0m" +[ "$FAIL" -eq 0 ] && echo -e "\033[32m\n ✅ ALL TESTS PASSED\033[0m" || echo -e "\033[31m\n ❌ SOME TESTS FAILED\033[0m" + +cleanup +exit $FAIL