Fix code quality: trailing whitespace, unused imports, clippy warnings
- Fix trailing whitespace in kex.rs and s3.rs - Add missing KexProposal import in kex_complete.rs - Auto-fix clippy warnings across all crates - All 153 tests pass
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// API Handlers Module
|
||||
//
|
||||
//
|
||||
// This module provides space for future modular API handlers.
|
||||
// Current handlers are implemented in server.rs for stability.
|
||||
//
|
||||
@@ -13,4 +13,4 @@
|
||||
// - view.rs: Category/Series view handlers
|
||||
// - static.rs: Static page handlers
|
||||
|
||||
pub use crate::server::AppState;
|
||||
pub use crate::server::AppState;
|
||||
|
||||
@@ -9,4 +9,4 @@ pub mod handlers;
|
||||
// - Clear separation of concerns
|
||||
// - Easier maintenance for new features
|
||||
// - Gradual migration path from server.rs
|
||||
// - Independent testing per handler module
|
||||
// - Independent testing per handler module
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
// Archive Configuration - User Configurable Options
|
||||
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
use log::warn;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Archive Configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ArchiveConfig {
|
||||
// Optional formats (controversial)
|
||||
pub enable_rar: bool, // ⚠️ Legal risk (RARLAB patent)
|
||||
pub enable_xz: bool, // ⚠️ External dependency (liblzma)
|
||||
pub enable_7z: bool, // ⚠️ Unstable library
|
||||
|
||||
pub enable_rar: bool, // ⚠️ Legal risk (RARLAB patent)
|
||||
pub enable_xz: bool, // ⚠️ External dependency (liblzma)
|
||||
pub enable_7z: bool, // ⚠️ Unstable library
|
||||
|
||||
// Performance settings
|
||||
pub cache_size_mb: u64,
|
||||
pub max_concurrent_extractions: usize,
|
||||
|
||||
|
||||
// Security settings
|
||||
pub max_decompression_ratio: u64,
|
||||
pub max_file_size_mb: u64,
|
||||
@@ -29,11 +28,11 @@ impl Default for ArchiveConfig {
|
||||
enable_rar: false,
|
||||
enable_xz: false,
|
||||
enable_7z: false,
|
||||
|
||||
|
||||
// Performance
|
||||
cache_size_mb: 100,
|
||||
max_concurrent_extractions: 4,
|
||||
|
||||
|
||||
// Security
|
||||
max_decompression_ratio: 1000,
|
||||
max_file_size_mb: 1024,
|
||||
@@ -46,45 +45,46 @@ impl ArchiveConfig {
|
||||
pub fn load(path: &str) -> Result<Self> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let config: ArchiveConfig = toml::from_str(&content)?;
|
||||
|
||||
|
||||
// Validate configuration
|
||||
config.validate()?;
|
||||
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
|
||||
/// Save configuration to TOML file
|
||||
pub fn save(&self, path: &str) -> Result<()> {
|
||||
let content = toml::to_string_pretty(self)?;
|
||||
std::fs::write(path, content)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// Validate configuration
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
if self.cache_size_mb > 1000 {
|
||||
warn!("Cache size > 1GB may cause memory pressure");
|
||||
}
|
||||
|
||||
|
||||
if self.max_concurrent_extractions > 10 {
|
||||
warn!("Concurrent extractions > 10 may cause resource exhaustion");
|
||||
}
|
||||
|
||||
|
||||
if self.max_decompression_ratio < 10 {
|
||||
return Err(anyhow::anyhow!("Max decompression ratio too low (min 10)"));
|
||||
}
|
||||
|
||||
if self.max_file_size_mb > 10_000 { // 10GB
|
||||
|
||||
if self.max_file_size_mb > 10_000 {
|
||||
// 10GB
|
||||
warn!("Max file size > 10GB may cause disk space issues");
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// Generate default config file template
|
||||
pub fn generate_template() -> String {
|
||||
let config = Self::default();
|
||||
|
||||
|
||||
format!(
|
||||
"# === Archive Configuration ===
|
||||
# MarkBase Universal Compression Format Support
|
||||
@@ -138,33 +138,33 @@ max_file_size_mb = {} # File size limit (MB)
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = ArchiveConfig::default();
|
||||
|
||||
|
||||
assert_eq!(config.enable_rar, false);
|
||||
assert_eq!(config.enable_xz, false);
|
||||
assert_eq!(config.enable_7z, false);
|
||||
assert_eq!(config.cache_size_mb, 100);
|
||||
assert_eq!(config.max_decompression_ratio, 1000);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_config_validation() {
|
||||
let config = ArchiveConfig {
|
||||
max_decompression_ratio: 5,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
||||
assert!(config.validate().is_err());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_config_template() {
|
||||
let template = ArchiveConfig::generate_template();
|
||||
|
||||
|
||||
assert!(template.contains("enable_rar = false"));
|
||||
assert!(template.contains("⚠️ RAR Format Legal Risk Warning"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
// Format Detector - Automatic Detection Based on Magic Numbers
|
||||
|
||||
use anyhow::Result;
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::archive::processor::ArchiveFormat;
|
||||
|
||||
@@ -18,64 +18,61 @@ impl FormatDetector {
|
||||
// ZIP: 50 4B 03 04 or 50 4B 05 06 (empty) or 50 4B 07 08 (spanned)
|
||||
(vec![0x50, 0x4B, 0x03, 0x04], ArchiveFormat::Zip, 4),
|
||||
(vec![0x50, 0x4B, 0x05, 0x06], ArchiveFormat::Zip, 4),
|
||||
|
||||
// GZIP: 1F 8B
|
||||
(vec![0x1F, 0x8B], ArchiveFormat::Gzip, 2),
|
||||
];
|
||||
|
||||
|
||||
Self { magic_table }
|
||||
}
|
||||
|
||||
|
||||
/// Detect file format based on Magic Number
|
||||
pub fn detect(&self, path: &Path) -> Result<ArchiveFormat> {
|
||||
let mut file = File::open(path)?;
|
||||
let mut buffer = vec![0u8; 512];
|
||||
|
||||
|
||||
let bytes_read = file.read(&mut buffer)?;
|
||||
if bytes_read < 2 {
|
||||
return Ok(ArchiveFormat::Unknown);
|
||||
}
|
||||
|
||||
|
||||
// Match Magic Numbers
|
||||
for (magic, format, offset) in &self.magic_table {
|
||||
if buffer.len() >= *offset && buffer[0..magic.len()] == *magic {
|
||||
return Ok(*format);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Special detection: TAR format (check ustar magic at offset 257)
|
||||
if buffer.len() >= 262 {
|
||||
if &buffer[257..262] == b"ustar" {
|
||||
if buffer.len() >= 262
|
||||
&& &buffer[257..262] == b"ustar" {
|
||||
return Ok(ArchiveFormat::Tar);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(ArchiveFormat::Unknown)
|
||||
}
|
||||
|
||||
|
||||
/// Detect composite format (e.g., TAR.GZ)
|
||||
pub fn detect_composite(&self, path: &Path) -> Result<ArchiveFormat> {
|
||||
let format = self.detect(path)?;
|
||||
|
||||
|
||||
// If GZIP, check if it's TAR.GZ (by extension for now)
|
||||
if format == ArchiveFormat::Gzip {
|
||||
let ext = path.extension()
|
||||
let ext = path
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.unwrap_or("")
|
||||
.to_lowercase();
|
||||
|
||||
|
||||
if ext == "tgz" || ext == "gz" {
|
||||
// Check if filename contains .tar
|
||||
let filename = path.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let filename = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
|
||||
|
||||
if filename.contains(".tar") {
|
||||
return Ok(ArchiveFormat::TarGzip);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(format)
|
||||
}
|
||||
}
|
||||
@@ -89,51 +86,51 @@ impl Default for FormatDetector {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
use std::io::Write;
|
||||
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_detect_zip() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let zip_path = temp_dir.path().join("test.zip");
|
||||
|
||||
|
||||
// Create minimal ZIP file header
|
||||
let mut file = File::create(&zip_path).unwrap();
|
||||
file.write_all(&[0x50, 0x4B, 0x03, 0x04]).unwrap();
|
||||
|
||||
|
||||
let detector = FormatDetector::new();
|
||||
let format = detector.detect(&zip_path).unwrap();
|
||||
|
||||
|
||||
assert_eq!(format, ArchiveFormat::Zip);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_detect_gzip() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let gz_path = temp_dir.path().join("test.gz");
|
||||
|
||||
|
||||
// Create minimal GZIP file header
|
||||
let mut file = File::create(&gz_path).unwrap();
|
||||
file.write_all(&[0x1F, 0x8B]).unwrap();
|
||||
|
||||
|
||||
let detector = FormatDetector::new();
|
||||
let format = detector.detect(&gz_path).unwrap();
|
||||
|
||||
|
||||
assert_eq!(format, ArchiveFormat::Gzip);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_detect_unknown() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let unknown_path = temp_dir.path().join("test.bin");
|
||||
|
||||
|
||||
// Create unknown file
|
||||
let mut file = File::create(&unknown_path).unwrap();
|
||||
file.write_all(b"unknown data").unwrap();
|
||||
|
||||
|
||||
let detector = FormatDetector::new();
|
||||
let format = detector.detect(&unknown_path).unwrap();
|
||||
|
||||
|
||||
assert_eq!(format, ArchiveFormat::Unknown);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// Metadata Module - Archive Entry Metadata Management
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::time::SystemTime;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::archive::processor::ArchiveFormat;
|
||||
|
||||
@@ -29,7 +29,7 @@ impl ArchiveMetadata {
|
||||
self.total_size as f64 / self.compressed_size as f64
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Check if compression ratio exceeds limit (Zip Bomb detection)
|
||||
pub fn check_zip_bomb(&self, max_ratio: u64) -> bool {
|
||||
self.actual_ratio() > max_ratio as f64
|
||||
@@ -65,7 +65,7 @@ impl ArchiveEntry {
|
||||
checksum: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Create file entry
|
||||
pub fn file(path: PathBuf, size: u64, compressed_size: u64) -> Self {
|
||||
Self {
|
||||
@@ -104,7 +104,7 @@ impl ExtractResult {
|
||||
warnings: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn success_rate(&self) -> f64 {
|
||||
if self.total_files == 0 {
|
||||
100.0
|
||||
@@ -113,11 +113,11 @@ impl ExtractResult {
|
||||
(success_count as f64 / self.total_files as f64) * 100.0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn has_failures(&self) -> bool {
|
||||
!self.failed_files.is_empty()
|
||||
}
|
||||
|
||||
|
||||
pub fn has_warnings(&self) -> bool {
|
||||
!self.warnings.is_empty()
|
||||
}
|
||||
@@ -126,7 +126,7 @@ impl ExtractResult {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_archive_metadata() {
|
||||
let metadata = ArchiveMetadata {
|
||||
@@ -140,37 +140,37 @@ mod tests {
|
||||
created_time: None,
|
||||
modified_time: None,
|
||||
};
|
||||
|
||||
|
||||
assert_eq!(metadata.actual_ratio(), 2.0);
|
||||
assert!(!metadata.check_zip_bomb(1000));
|
||||
assert!(metadata.check_zip_bomb(1)); // Should detect as bomb
|
||||
assert!(metadata.check_zip_bomb(1)); // Should detect as bomb
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_archive_entry() {
|
||||
let dir_entry = ArchiveEntry::directory(PathBuf::from("test_dir"));
|
||||
assert!(dir_entry.is_dir);
|
||||
assert!(!dir_entry.is_file);
|
||||
|
||||
|
||||
let file_entry = ArchiveEntry::file(PathBuf::from("test.txt"), 100, 50);
|
||||
assert!(!file_entry.is_dir);
|
||||
assert!(file_entry.is_file);
|
||||
assert_eq!(file_entry.size, 100);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_extract_result() {
|
||||
let result = ExtractResult::new();
|
||||
assert_eq!(result.success_rate(), 100.0);
|
||||
|
||||
|
||||
let result_with_failure = ExtractResult {
|
||||
total_files: 10,
|
||||
success_files: 8,
|
||||
failed_files: vec![PathBuf::from("failed.txt")],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
||||
assert_eq!(result_with_failure.success_rate(), 80.0);
|
||||
assert!(result_with_failure.has_failures());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,9 +25,9 @@ pub use metadata::{ArchiveEntry, ArchiveMetadata, ExtractResult};
|
||||
pub use processor::{ArchiveFormat, ArchiveProcessor};
|
||||
|
||||
use anyhow::Result;
|
||||
use log::info;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use log::{info, warn};
|
||||
|
||||
/// Processor Registry - Plugin Architecture
|
||||
pub struct ProcessorRegistry {
|
||||
@@ -43,93 +43,108 @@ impl ProcessorRegistry {
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Initialize all processors (based on config)
|
||||
pub fn initialize(&mut self) -> Result<()> {
|
||||
// Core formats (always registered)
|
||||
self.register_core_processors()?;
|
||||
|
||||
|
||||
// Optional formats (based on config)
|
||||
self.register_optional_processors()?;
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// Register core format processors (9 formats)
|
||||
fn register_core_processors(&mut self) -> Result<()> {
|
||||
use crate::archive::processors::core::*;
|
||||
|
||||
self.processors.insert(ArchiveFormat::Zip, Box::new(ZipProcessor::new()));
|
||||
self.processors.insert(ArchiveFormat::Tar, Box::new(TarProcessor::new()));
|
||||
self.processors.insert(ArchiveFormat::Gzip, Box::new(GzipProcessor::new()));
|
||||
self.processors.insert(ArchiveFormat::Zstd, Box::new(ZstdProcessor::new()));
|
||||
self.processors.insert(ArchiveFormat::Bzip2, Box::new(Bzip2Processor::new()));
|
||||
self.processors.insert(ArchiveFormat::Lz4, Box::new(Lz4Processor::new()));
|
||||
self.processors.insert(ArchiveFormat::TarGzip, Box::new(TarGzipProcessor::new()));
|
||||
self.processors.insert(ArchiveFormat::TarBzip2, Box::new(TarBzip2Processor::new()));
|
||||
self.processors.insert(ArchiveFormat::TarZstd, Box::new(TarZstdProcessor::new()));
|
||||
|
||||
|
||||
self.processors
|
||||
.insert(ArchiveFormat::Zip, Box::new(ZipProcessor::new()));
|
||||
self.processors
|
||||
.insert(ArchiveFormat::Tar, Box::new(TarProcessor::new()));
|
||||
self.processors
|
||||
.insert(ArchiveFormat::Gzip, Box::new(GzipProcessor::new()));
|
||||
self.processors
|
||||
.insert(ArchiveFormat::Zstd, Box::new(ZstdProcessor::new()));
|
||||
self.processors
|
||||
.insert(ArchiveFormat::Bzip2, Box::new(Bzip2Processor::new()));
|
||||
self.processors
|
||||
.insert(ArchiveFormat::Lz4, Box::new(Lz4Processor::new()));
|
||||
self.processors
|
||||
.insert(ArchiveFormat::TarGzip, Box::new(TarGzipProcessor::new()));
|
||||
self.processors
|
||||
.insert(ArchiveFormat::TarBzip2, Box::new(TarBzip2Processor::new()));
|
||||
self.processors
|
||||
.insert(ArchiveFormat::TarZstd, Box::new(TarZstdProcessor::new()));
|
||||
|
||||
info!("✅ Core formats registered: 9 formats");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// Register optional format processors (3 formats, based on config)
|
||||
fn register_optional_processors(&mut self) -> Result<()> {
|
||||
#[cfg(feature = "optional-formats")]
|
||||
{
|
||||
use crate::archive::processors::optional::*;
|
||||
|
||||
|
||||
// RAR format (legal risk)
|
||||
if self.config.enable_rar {
|
||||
crate::archive::warning::show_rar_legal_warning();
|
||||
self.processors.insert(ArchiveFormat::Rar, Box::new(RarProcessor::new()));
|
||||
self.processors
|
||||
.insert(ArchiveFormat::Rar, Box::new(RarProcessor::new()));
|
||||
warn!("⚠️ RAR format enabled (legal risk)");
|
||||
}
|
||||
|
||||
|
||||
// XZ format (external dependency)
|
||||
if self.config.enable_xz {
|
||||
if check_liblzma_available() {
|
||||
self.processors.insert(ArchiveFormat::Xz, Box::new(XzProcessor::new()));
|
||||
self.processors
|
||||
.insert(ArchiveFormat::Xz, Box::new(XzProcessor::new()));
|
||||
info!("✅ XZ format enabled");
|
||||
} else {
|
||||
crate::archive::warning::show_xz_dependency_warning();
|
||||
warn!("⚠️ XZ format disabled (liblzma not found)");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 7z format (unstable library)
|
||||
if self.config.enable_7z {
|
||||
crate::archive::warning::show_7z_stability_warning();
|
||||
self.processors.insert(ArchiveFormat::SevenZ, Box::new(SevenZProcessor::new()));
|
||||
self.processors
|
||||
.insert(ArchiveFormat::SevenZ, Box::new(SevenZProcessor::new()));
|
||||
warn!("⚠️ 7z format enabled (stability warning)");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// Get processor for detected format (mutable version for open/extraction)
|
||||
pub fn get_processor_mut(&mut self, path: &Path) -> Result<&mut (dyn ArchiveProcessor + '_)> {
|
||||
let detector = FormatDetector::new();
|
||||
let format = detector.detect(path)?;
|
||||
|
||||
|
||||
match self.processors.get_mut(&format) {
|
||||
Some(p) => Ok(p.as_mut()),
|
||||
None => Err(anyhow::anyhow!("Format {} not supported or not enabled", format)),
|
||||
None => Err(anyhow::anyhow!(
|
||||
"Format {} not supported or not enabled",
|
||||
format
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Get processor for detected format (immutable version for listing)
|
||||
pub fn get_processor(&self, path: &Path) -> Result<&dyn ArchiveProcessor> {
|
||||
let detector = FormatDetector::new();
|
||||
let format = detector.detect(path)?;
|
||||
|
||||
|
||||
self.processors
|
||||
.get(&format)
|
||||
.map(|p| p.as_ref())
|
||||
.ok_or_else(|| anyhow::anyhow!("Format {} not supported or not enabled", format))
|
||||
}
|
||||
|
||||
|
||||
/// List all enabled formats
|
||||
pub fn enabled_formats(&self) -> Vec<ArchiveFormat> {
|
||||
self.processors.keys().cloned().collect()
|
||||
@@ -141,7 +156,7 @@ impl ProcessorRegistry {
|
||||
fn check_liblzma_available() -> bool {
|
||||
// Try to load xz2 library
|
||||
// Simplified check: try to create XzProcessor
|
||||
true // Simplified for now, actual implementation needs better detection
|
||||
true // Simplified for now, actual implementation needs better detection
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "optional-formats"))]
|
||||
@@ -156,13 +171,16 @@ pub fn init_archive_system(config_path: Option<&str>) -> Result<ProcessorRegistr
|
||||
} else {
|
||||
ArchiveConfig::default()
|
||||
};
|
||||
|
||||
|
||||
// Show startup warnings for optional formats
|
||||
crate::archive::warning::show_startup_warnings(&config);
|
||||
|
||||
|
||||
let mut registry = ProcessorRegistry::new(config);
|
||||
registry.initialize()?;
|
||||
|
||||
info!("Archive system initialized with {} formats", registry.enabled_formats().len());
|
||||
|
||||
info!(
|
||||
"Archive system initialized with {} formats",
|
||||
registry.enabled_formats().len()
|
||||
);
|
||||
Ok(registry)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use anyhow::Result;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
// Re-export types from metadata.rs
|
||||
pub use crate::archive::metadata::{ArchiveMetadata, ArchiveEntry, ExtractResult};
|
||||
pub use crate::archive::metadata::{ArchiveEntry, ArchiveMetadata, ExtractResult};
|
||||
|
||||
/// Archive Format Type Enumeration
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
|
||||
@@ -19,12 +19,12 @@ pub enum ArchiveFormat {
|
||||
TarGzip,
|
||||
TarBzip2,
|
||||
TarZstd,
|
||||
|
||||
|
||||
// Optional formats (controversial)
|
||||
Rar, // ⚠️ Legal risk (RARLAB patent)
|
||||
Xz, // ⚠️ External dependency (liblzma)
|
||||
SevenZ, // ⚠️ Unstable library (sevenz-rust 0.21.0)
|
||||
|
||||
|
||||
Unknown,
|
||||
}
|
||||
|
||||
@@ -53,30 +53,34 @@ impl std::fmt::Display for ArchiveFormat {
|
||||
pub trait ArchiveProcessor: Send + Sync {
|
||||
/// Format type supported by this processor
|
||||
fn format(&self) -> ArchiveFormat;
|
||||
|
||||
|
||||
/// Open archive file and read metadata
|
||||
fn open(&mut self, path: &Path) -> Result<ArchiveMetadata>;
|
||||
|
||||
|
||||
/// List all file entries in archive
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>>;
|
||||
|
||||
|
||||
/// Extract single file (on-demand decompression)
|
||||
fn extract_file(&mut self, entry_path: &Path, output: &mut Vec<u8>) -> Result<u64>;
|
||||
|
||||
|
||||
/// Extract all files to directory (batch extraction)
|
||||
fn extract_all(&mut self, output_dir: &Path) -> Result<ExtractResult>;
|
||||
|
||||
|
||||
/// Check if this processor can handle the format
|
||||
fn can_process(format: ArchiveFormat) -> bool where Self: Sized;
|
||||
|
||||
fn can_process(format: ArchiveFormat) -> bool
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
/// Create new processor instance
|
||||
fn new() -> Self where Self: Sized;
|
||||
fn new() -> Self
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
/// Security Validation - Zip Slip Protection
|
||||
pub fn validate_extraction_path(entry_path: &Path, base_dir: &Path) -> Result<PathBuf> {
|
||||
use std::path::Component;
|
||||
|
||||
|
||||
// 1. Check path components
|
||||
for component in entry_path.components() {
|
||||
match component {
|
||||
@@ -92,51 +96,62 @@ pub fn validate_extraction_path(entry_path: &Path, base_dir: &Path) -> Result<Pa
|
||||
Component::Normal(_) | Component::CurDir => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 2. Build full path
|
||||
let full_path = base_dir.join(entry_path);
|
||||
|
||||
|
||||
// 3. Canonicalize and validate (ensure within base_dir)
|
||||
let canonical_base = base_dir.canonicalize()
|
||||
let canonical_base = base_dir
|
||||
.canonicalize()
|
||||
.map_err(|e| anyhow::anyhow!("Cannot canonicalize base dir: {}", e))?;
|
||||
|
||||
|
||||
// Create parent directories first
|
||||
if let Some(parent) = full_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
|
||||
// 4. Verify extraction path is within base_dir
|
||||
// Note: full_path may not exist yet, so we check parent directory
|
||||
if full_path.exists() {
|
||||
let canonical_full = full_path.canonicalize()
|
||||
let canonical_full = full_path
|
||||
.canonicalize()
|
||||
.map_err(|e| anyhow::anyhow!("Cannot canonicalize full path: {}", e))?;
|
||||
|
||||
|
||||
if !canonical_full.starts_with(&canonical_base) {
|
||||
return Err(anyhow::anyhow!("Zip Slip detected: path escapes base directory"));
|
||||
return Err(anyhow::anyhow!(
|
||||
"Zip Slip detected: path escapes base directory"
|
||||
));
|
||||
}
|
||||
} else {
|
||||
// Check parent directory instead
|
||||
if let Some(parent) = full_path.parent() {
|
||||
let canonical_parent = parent.canonicalize()
|
||||
let canonical_parent = parent
|
||||
.canonicalize()
|
||||
.map_err(|e| anyhow::anyhow!("Cannot canonicalize parent: {}", e))?;
|
||||
|
||||
|
||||
if !canonical_parent.starts_with(&canonical_base) {
|
||||
return Err(anyhow::anyhow!("Zip Slip detected: path escapes base directory"));
|
||||
return Err(anyhow::anyhow!(
|
||||
"Zip Slip detected: path escapes base directory"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(full_path)
|
||||
}
|
||||
|
||||
/// Security Validation - Zip Bomb Protection
|
||||
pub fn check_decompression_ratio(compressed_size: u64, decompressed_size: u64, max_ratio: u64) -> Result<()> {
|
||||
pub fn check_decompression_ratio(
|
||||
compressed_size: u64,
|
||||
decompressed_size: u64,
|
||||
max_ratio: u64,
|
||||
) -> Result<()> {
|
||||
if compressed_size == 0 {
|
||||
return Ok(()); // Empty file, allow
|
||||
return Ok(()); // Empty file, allow
|
||||
}
|
||||
|
||||
|
||||
let ratio = decompressed_size / compressed_size;
|
||||
|
||||
|
||||
if ratio > max_ratio {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Zip Bomb detected: compression ratio {} exceeds limit {}",
|
||||
@@ -144,7 +159,7 @@ pub fn check_decompression_ratio(compressed_size: u64, decompressed_size: u64, m
|
||||
max_ratio
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -157,7 +172,7 @@ pub fn check_file_size_limit(file_size: u64, max_size: u64) -> Result<()> {
|
||||
max_size / 1024 / 1024
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -165,34 +180,34 @@ pub fn check_file_size_limit(file_size: u64, max_size: u64) -> Result<()> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_zip_slip_protection() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let base = temp_dir.path();
|
||||
|
||||
|
||||
// Safe path: should pass
|
||||
let safe_path = Path::new("safe/file.txt");
|
||||
assert!(validate_extraction_path(safe_path, base).is_ok());
|
||||
|
||||
|
||||
// Evil path: should be rejected
|
||||
let evil_path = Path::new("../../etc/passwd");
|
||||
assert!(validate_extraction_path(evil_path, base).is_err());
|
||||
|
||||
|
||||
// Absolute path: should be rejected
|
||||
let abs_path = Path::new("/etc/passwd");
|
||||
assert!(validate_extraction_path(abs_path, base).is_err());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_zip_bomb_detection() {
|
||||
// Normal ratio: should pass
|
||||
assert!(check_decompression_ratio(1000, 5000, 1000).is_ok());
|
||||
|
||||
|
||||
// Zip Bomb ratio: should be rejected
|
||||
assert!(check_decompression_ratio(42_000, 5_000_000_000, 1000).is_err());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_compression_ratio_calculation() {
|
||||
let metadata = ArchiveMetadata {
|
||||
@@ -206,7 +221,7 @@ mod tests {
|
||||
created_time: None,
|
||||
modified_time: None,
|
||||
};
|
||||
|
||||
|
||||
assert_eq!(metadata.actual_ratio(), 2.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
// Core Format Processors - ZIP, TAR, GZIP, TAR.GZ Full Implementation
|
||||
|
||||
use crate::archive::{
|
||||
ArchiveProcessor, ArchiveFormat, ArchiveMetadata, ArchiveEntry, ExtractResult,
|
||||
processor::{validate_extraction_path, check_decompression_ratio, check_file_size_limit},
|
||||
};
|
||||
use crate::archive::config::ArchiveConfig;
|
||||
use anyhow::{Result, anyhow};
|
||||
use crate::archive::{
|
||||
processor::{check_decompression_ratio, check_file_size_limit, validate_extraction_path},
|
||||
ArchiveEntry, ArchiveFormat, ArchiveMetadata, ArchiveProcessor, ExtractResult,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use log::{debug, info, warn};
|
||||
use std::fs::{create_dir_all, File};
|
||||
use std::io::{BufWriter, Read};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::fs::{File, create_dir_all};
|
||||
use std::io::{Read, Write, BufReader, BufWriter};
|
||||
use std::time::SystemTime;
|
||||
use log::{info, warn, debug};
|
||||
|
||||
// ==================== ZIP Processor ====================
|
||||
|
||||
@@ -21,6 +21,12 @@ pub struct ZipProcessor {
|
||||
config: ArchiveConfig,
|
||||
}
|
||||
|
||||
impl Default for ZipProcessor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ZipProcessor {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
@@ -29,7 +35,7 @@ impl ZipProcessor {
|
||||
config: ArchiveConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn with_config(config: ArchiveConfig) -> Self {
|
||||
Self {
|
||||
archive: None,
|
||||
@@ -43,7 +49,7 @@ impl ArchiveProcessor for ZipProcessor {
|
||||
fn format(&self) -> ArchiveFormat {
|
||||
ArchiveFormat::Zip
|
||||
}
|
||||
|
||||
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
archive: None,
|
||||
@@ -51,64 +57,72 @@ impl ArchiveProcessor for ZipProcessor {
|
||||
config: ArchiveConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn open(&mut self, path: &Path) -> Result<ArchiveMetadata> {
|
||||
info!("Opening ZIP archive: {}", path.display());
|
||||
|
||||
|
||||
let file = File::open(path)?;
|
||||
let archive = zip::ZipArchive::new(file)?;
|
||||
|
||||
|
||||
self.archive = Some(archive);
|
||||
self.path = path.to_path_buf();
|
||||
|
||||
|
||||
// Extract metadata (need mutable reference for by_index)
|
||||
let archive_ref = self.archive.as_mut().unwrap();
|
||||
let total_files = archive_ref.len() as u64;
|
||||
|
||||
|
||||
let mut total_size = 0u64;
|
||||
let mut compressed_size = 0u64;
|
||||
|
||||
|
||||
for i in 0..archive_ref.len() {
|
||||
let file = archive_ref.by_index(i)?;
|
||||
total_size += file.size();
|
||||
compressed_size += file.compressed_size();
|
||||
}
|
||||
|
||||
|
||||
let compression_ratio = if compressed_size > 0 {
|
||||
total_size as f64 / compressed_size as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
|
||||
// Check for Zip Bomb
|
||||
if compression_ratio > self.config.max_decompression_ratio as f64 {
|
||||
warn!("Potential Zip Bomb detected: ratio {:.1}:1", compression_ratio);
|
||||
return Err(anyhow!("Zip Bomb detected: compression ratio {:.1} exceeds limit {}",
|
||||
compression_ratio, self.config.max_decompression_ratio));
|
||||
warn!(
|
||||
"Potential Zip Bomb detected: ratio {:.1}:1",
|
||||
compression_ratio
|
||||
);
|
||||
return Err(anyhow!(
|
||||
"Zip Bomb detected: compression ratio {:.1} exceeds limit {}",
|
||||
compression_ratio,
|
||||
self.config.max_decompression_ratio
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
Ok(ArchiveMetadata {
|
||||
format: ArchiveFormat::Zip,
|
||||
total_files,
|
||||
total_size,
|
||||
compressed_size,
|
||||
compression_ratio,
|
||||
is_encrypted: false, // TODO: Check encryption
|
||||
is_encrypted: false, // TODO: Check encryption
|
||||
is_multi_volume: false,
|
||||
created_time: Some(SystemTime::now()),
|
||||
modified_time: Some(SystemTime::now()),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||
let archive = self.archive.as_mut()
|
||||
let archive = self
|
||||
.archive
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||
|
||||
|
||||
let mut entries = Vec::new();
|
||||
|
||||
|
||||
for i in 0..archive.len() {
|
||||
let file = archive.by_index(i)?;
|
||||
|
||||
|
||||
let entry = ArchiveEntry {
|
||||
path: PathBuf::from(file.name()),
|
||||
size: file.size(),
|
||||
@@ -116,61 +130,64 @@ impl ArchiveProcessor for ZipProcessor {
|
||||
is_dir: file.name().ends_with('/'),
|
||||
is_file: !file.name().ends_with('/'),
|
||||
is_encrypted: false,
|
||||
modified: SystemTime::UNIX_EPOCH, // TODO: Get actual time
|
||||
modified: SystemTime::UNIX_EPOCH, // TODO: Get actual time
|
||||
permissions: Some(0o644),
|
||||
checksum: None,
|
||||
};
|
||||
|
||||
|
||||
entries.push(entry);
|
||||
}
|
||||
|
||||
|
||||
info!("Listed {} entries in ZIP archive", entries.len());
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
|
||||
fn extract_file(&mut self, entry_path: &Path, output: &mut Vec<u8>) -> Result<u64> {
|
||||
let archive = self.archive.as_mut()
|
||||
let archive = self
|
||||
.archive
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||
|
||||
let entry_name = entry_path.to_str()
|
||||
|
||||
let entry_name = entry_path
|
||||
.to_str()
|
||||
.ok_or_else(|| anyhow!("Invalid entry path"))?;
|
||||
|
||||
|
||||
let mut file = archive.by_name(entry_name)?;
|
||||
|
||||
|
||||
// Check file size limit
|
||||
check_file_size_limit(file.size(), self.config.max_file_size_mb * 1024 * 1024)?;
|
||||
|
||||
|
||||
output.clear();
|
||||
output.reserve(file.size() as usize);
|
||||
|
||||
|
||||
file.read_to_end(output)?;
|
||||
|
||||
|
||||
info!("Extracted file: {} ({} bytes)", entry_name, output.len());
|
||||
Ok(output.len() as u64)
|
||||
}
|
||||
|
||||
|
||||
fn extract_all(&mut self, output_dir: &Path) -> Result<ExtractResult> {
|
||||
create_dir_all(output_dir)?;
|
||||
|
||||
|
||||
let mut result = ExtractResult::new();
|
||||
|
||||
|
||||
// Open archive if not already open
|
||||
if self.archive.is_none() {
|
||||
let file = File::open(&self.path)?;
|
||||
let archive = zip::ZipArchive::new(file)?;
|
||||
self.archive = Some(archive);
|
||||
}
|
||||
|
||||
|
||||
let archive = self.archive.as_mut().unwrap();
|
||||
result.total_files = archive.len() as u64;
|
||||
|
||||
|
||||
// Use archive iteration to extract files
|
||||
for i in 0..archive.len() {
|
||||
let mut file = archive.by_index(i)?;
|
||||
let entry_name = file.name().to_string();
|
||||
let file_size = file.size();
|
||||
let is_dir = entry_name.ends_with('/');
|
||||
|
||||
|
||||
// Zip Slip protection
|
||||
match validate_extraction_path(&PathBuf::from(&entry_name), output_dir) {
|
||||
Ok(safe_path) => {
|
||||
@@ -181,21 +198,24 @@ impl ArchiveProcessor for ZipProcessor {
|
||||
result.success_files += 1;
|
||||
} else {
|
||||
// File
|
||||
check_file_size_limit(file_size, self.config.max_file_size_mb * 1024 * 1024)?;
|
||||
|
||||
check_file_size_limit(
|
||||
file_size,
|
||||
self.config.max_file_size_mb * 1024 * 1024,
|
||||
)?;
|
||||
|
||||
if let Some(parent) = safe_path.parent() {
|
||||
create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
|
||||
// Extract file content
|
||||
let mut outfile = BufWriter::new(File::create(&safe_path)?);
|
||||
std::io::copy(&mut file, &mut outfile)?;
|
||||
|
||||
|
||||
result.success_files += 1;
|
||||
result.total_bytes += file_size;
|
||||
debug!("Extracted: {} ({} bytes)", entry_name, file_size);
|
||||
}
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Zip Slip detected: {} - {}", entry_name, e);
|
||||
result.failed_files.push(PathBuf::from(&entry_name));
|
||||
@@ -203,13 +223,17 @@ impl ArchiveProcessor for ZipProcessor {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Extracted {} files ({} bytes) to {}",
|
||||
result.success_files, result.total_bytes, output_dir.display());
|
||||
|
||||
|
||||
info!(
|
||||
"Extracted {} files ({} bytes) to {}",
|
||||
result.success_files,
|
||||
result.total_bytes,
|
||||
output_dir.display()
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
|
||||
fn can_process(format: ArchiveFormat) -> bool {
|
||||
format == ArchiveFormat::Zip
|
||||
}
|
||||
@@ -224,6 +248,12 @@ pub struct TarProcessor {
|
||||
config: ArchiveConfig,
|
||||
}
|
||||
|
||||
impl Default for TarProcessor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TarProcessor {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
@@ -232,7 +262,7 @@ impl TarProcessor {
|
||||
config: ArchiveConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn with_config(config: ArchiveConfig) -> Self {
|
||||
Self {
|
||||
path: PathBuf::new(),
|
||||
@@ -246,7 +276,7 @@ impl ArchiveProcessor for TarProcessor {
|
||||
fn format(&self) -> ArchiveFormat {
|
||||
ArchiveFormat::Tar
|
||||
}
|
||||
|
||||
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
path: PathBuf::new(),
|
||||
@@ -254,30 +284,30 @@ impl ArchiveProcessor for TarProcessor {
|
||||
config: ArchiveConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn open(&mut self, path: &Path) -> Result<ArchiveMetadata> {
|
||||
info!("Opening TAR archive: {}", path.display());
|
||||
|
||||
|
||||
self.path = path.to_path_buf();
|
||||
self.entries.clear();
|
||||
|
||||
|
||||
let file = File::open(path)?;
|
||||
let mut archive = tar::Archive::new(file);
|
||||
|
||||
|
||||
let mut total_size = 0u64;
|
||||
|
||||
|
||||
// Iterate entries to collect metadata
|
||||
for entry in archive.entries()? {
|
||||
let entry = entry?;
|
||||
let path = entry.path()?.to_path_buf();
|
||||
let size = entry.size();
|
||||
|
||||
|
||||
total_size += size;
|
||||
|
||||
|
||||
self.entries.push(ArchiveEntry {
|
||||
path,
|
||||
size,
|
||||
compressed_size: size, // TAR has no compression
|
||||
compressed_size: size, // TAR has no compression
|
||||
is_dir: entry.header().entry_type().is_dir(),
|
||||
is_file: entry.header().entry_type().is_file(),
|
||||
is_encrypted: false,
|
||||
@@ -286,78 +316,87 @@ impl ArchiveProcessor for TarProcessor {
|
||||
checksum: None,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
let total_files = self.entries.len() as u64;
|
||||
|
||||
|
||||
Ok(ArchiveMetadata {
|
||||
format: ArchiveFormat::Tar,
|
||||
total_files,
|
||||
total_size,
|
||||
compressed_size: total_size, // TAR has no compression
|
||||
compression_ratio: 1.0, // No compression
|
||||
compressed_size: total_size, // TAR has no compression
|
||||
compression_ratio: 1.0, // No compression
|
||||
is_encrypted: false,
|
||||
is_multi_volume: false,
|
||||
created_time: Some(SystemTime::now()),
|
||||
modified_time: Some(SystemTime::now()),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||
Ok(self.entries.clone())
|
||||
}
|
||||
|
||||
|
||||
fn extract_file(&mut self, entry_path: &Path, output: &mut Vec<u8>) -> Result<u64> {
|
||||
// TAR doesn't support random access, need to unpack entire archive
|
||||
// This is a limitation - for single file extraction, we unpack everything
|
||||
warn!("TAR format doesn't support random access - extracting entire archive");
|
||||
|
||||
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
self.extract_all(temp_dir.path())?;
|
||||
|
||||
|
||||
let file_path = temp_dir.path().join(entry_path);
|
||||
let mut file = File::open(&file_path)?;
|
||||
output.clear();
|
||||
file.read_to_end(output)?;
|
||||
|
||||
|
||||
Ok(output.len() as u64)
|
||||
}
|
||||
|
||||
|
||||
fn extract_all(&mut self, output_dir: &Path) -> Result<ExtractResult> {
|
||||
create_dir_all(output_dir)?;
|
||||
|
||||
|
||||
let file = File::open(&self.path)?;
|
||||
let mut archive = tar::Archive::new(file);
|
||||
|
||||
|
||||
let mut result = ExtractResult::new();
|
||||
result.total_files = self.entries.len() as u64;
|
||||
|
||||
|
||||
for entry in archive.entries()? {
|
||||
let mut entry = entry?;
|
||||
let entry_path = entry.path()?.to_path_buf();
|
||||
let entry_path_str = entry_path.display().to_string(); // Save for warning
|
||||
|
||||
let entry_path_str = entry_path.display().to_string(); // Save for warning
|
||||
|
||||
// Zip Slip protection
|
||||
match validate_extraction_path(&entry_path, output_dir) {
|
||||
Ok(safe_path) => {
|
||||
check_file_size_limit(entry.size(), self.config.max_file_size_mb * 1024 * 1024)?;
|
||||
|
||||
check_file_size_limit(
|
||||
entry.size(),
|
||||
self.config.max_file_size_mb * 1024 * 1024,
|
||||
)?;
|
||||
|
||||
entry.unpack(&safe_path)?;
|
||||
|
||||
|
||||
result.success_files += 1;
|
||||
result.total_bytes += entry.size();
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Zip Slip detected: {} - {}", entry_path_str, e);
|
||||
result.failed_files.push(entry_path);
|
||||
result.warnings.push(format!("Zip Slip: {}", entry_path_str));
|
||||
result
|
||||
.warnings
|
||||
.push(format!("Zip Slip: {}", entry_path_str));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Extracted {} TAR entries to {}", result.success_files, output_dir.display());
|
||||
|
||||
info!(
|
||||
"Extracted {} TAR entries to {}",
|
||||
result.success_files,
|
||||
output_dir.display()
|
||||
);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
|
||||
fn can_process(format: ArchiveFormat) -> bool {
|
||||
format == ArchiveFormat::Tar
|
||||
}
|
||||
@@ -372,6 +411,12 @@ pub struct GzipProcessor {
|
||||
config: ArchiveConfig,
|
||||
}
|
||||
|
||||
impl Default for GzipProcessor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl GzipProcessor {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
@@ -380,7 +425,7 @@ impl GzipProcessor {
|
||||
config: ArchiveConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn with_config(config: ArchiveConfig) -> Self {
|
||||
Self {
|
||||
path: PathBuf::new(),
|
||||
@@ -394,7 +439,7 @@ impl ArchiveProcessor for GzipProcessor {
|
||||
fn format(&self) -> ArchiveFormat {
|
||||
ArchiveFormat::Gzip
|
||||
}
|
||||
|
||||
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
path: PathBuf::new(),
|
||||
@@ -402,27 +447,31 @@ impl ArchiveProcessor for GzipProcessor {
|
||||
config: ArchiveConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn open(&mut self, path: &Path) -> Result<ArchiveMetadata> {
|
||||
info!("Opening GZIP archive: {}", path.display());
|
||||
|
||||
|
||||
self.path = path.to_path_buf();
|
||||
|
||||
|
||||
let file = File::open(path)?;
|
||||
let compressed_size = file.metadata()?.len();
|
||||
|
||||
|
||||
let mut decoder = flate2::read::GzDecoder::new(file);
|
||||
let mut buffer = Vec::new();
|
||||
decoder.read_to_end(&mut buffer)?;
|
||||
|
||||
|
||||
self.decompressed_size = buffer.len() as u64;
|
||||
|
||||
|
||||
// Check Zip Bomb
|
||||
check_decompression_ratio(compressed_size, self.decompressed_size, self.config.max_decompression_ratio)?;
|
||||
|
||||
check_decompression_ratio(
|
||||
compressed_size,
|
||||
self.decompressed_size,
|
||||
self.config.max_decompression_ratio,
|
||||
)?;
|
||||
|
||||
Ok(ArchiveMetadata {
|
||||
format: ArchiveFormat::Gzip,
|
||||
total_files: 1, // GZIP is single file
|
||||
total_files: 1, // GZIP is single file
|
||||
total_size: self.decompressed_size,
|
||||
compressed_size,
|
||||
compression_ratio: if compressed_size > 0 {
|
||||
@@ -436,58 +485,64 @@ impl ArchiveProcessor for GzipProcessor {
|
||||
modified_time: Some(SystemTime::now()),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||
// GZIP is single file - infer name from archive name
|
||||
let name = self.path.file_name()
|
||||
let name = self
|
||||
.path
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("unknown")
|
||||
.replace(".gz", "")
|
||||
.replace(".gzip", "");
|
||||
|
||||
|
||||
Ok(vec![ArchiveEntry::file(
|
||||
PathBuf::from(name),
|
||||
self.decompressed_size,
|
||||
0, // GZIP doesn't preserve compressed size per file
|
||||
0, // GZIP doesn't preserve compressed size per file
|
||||
)])
|
||||
}
|
||||
|
||||
fn extract_file(&mut self, entry_path: &Path, output: &mut Vec<u8>) -> Result<u64> {
|
||||
|
||||
fn extract_file(&mut self, _entry_path: &Path, output: &mut Vec<u8>) -> Result<u64> {
|
||||
// GZIP is single file - just decompress it
|
||||
let file = File::open(&self.path)?;
|
||||
let mut decoder = flate2::read::GzDecoder::new(file);
|
||||
|
||||
|
||||
output.clear();
|
||||
decoder.read_to_end(output)?;
|
||||
|
||||
check_file_size_limit(output.len() as u64, self.config.max_file_size_mb * 1024 * 1024)?;
|
||||
|
||||
|
||||
check_file_size_limit(
|
||||
output.len() as u64,
|
||||
self.config.max_file_size_mb * 1024 * 1024,
|
||||
)?;
|
||||
|
||||
info!("Decompressed GZIP file: {} bytes", output.len());
|
||||
Ok(output.len() as u64)
|
||||
}
|
||||
|
||||
|
||||
fn extract_all(&mut self, output_dir: &Path) -> Result<ExtractResult> {
|
||||
create_dir_all(output_dir)?;
|
||||
|
||||
|
||||
let entries = self.list_entries()?;
|
||||
let entry = entries.first()
|
||||
let entry = entries
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("No entry in GZIP archive"))?;
|
||||
|
||||
|
||||
let outpath = output_dir.join(&entry.path);
|
||||
|
||||
|
||||
// Zip Slip protection
|
||||
validate_extraction_path(&entry.path, output_dir)?;
|
||||
|
||||
|
||||
if let Some(parent) = outpath.parent() {
|
||||
create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
|
||||
let file = File::open(&self.path)?;
|
||||
let mut decoder = flate2::read::GzDecoder::new(file);
|
||||
let mut outfile = BufWriter::new(File::create(&outpath)?);
|
||||
|
||||
|
||||
std::io::copy(&mut decoder, &mut outfile)?;
|
||||
|
||||
|
||||
let result = ExtractResult {
|
||||
total_files: 1,
|
||||
total_bytes: self.decompressed_size,
|
||||
@@ -496,11 +551,11 @@ impl ArchiveProcessor for GzipProcessor {
|
||||
skipped_files: Vec::new(),
|
||||
warnings: Vec::new(),
|
||||
};
|
||||
|
||||
|
||||
info!("Decompressed GZIP to: {}", outpath.display());
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
|
||||
fn can_process(format: ArchiveFormat) -> bool {
|
||||
format == ArchiveFormat::Gzip
|
||||
}
|
||||
@@ -514,6 +569,12 @@ pub struct TarGzipProcessor {
|
||||
config: ArchiveConfig,
|
||||
}
|
||||
|
||||
impl Default for TarGzipProcessor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TarGzipProcessor {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
@@ -521,7 +582,7 @@ impl TarGzipProcessor {
|
||||
config: ArchiveConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn with_config(config: ArchiveConfig) -> Self {
|
||||
Self {
|
||||
gzip_processor: GzipProcessor::with_config(config.clone()),
|
||||
@@ -534,32 +595,33 @@ impl ArchiveProcessor for TarGzipProcessor {
|
||||
fn format(&self) -> ArchiveFormat {
|
||||
ArchiveFormat::TarGzip
|
||||
}
|
||||
|
||||
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
gzip_processor: GzipProcessor::new(),
|
||||
config: ArchiveConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn open(&mut self, path: &Path) -> Result<ArchiveMetadata> {
|
||||
info!("Opening TAR.GZ archive: {}", path.display());
|
||||
|
||||
|
||||
// Step 1: Decompress GZIP
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
self.gzip_processor.open(path)?;
|
||||
self.gzip_processor.extract_all(temp_dir.path())?;
|
||||
|
||||
|
||||
// Step 2: Open TAR
|
||||
let tar_entries = self.gzip_processor.list_entries()?;
|
||||
let tar_file = tar_entries.first()
|
||||
let tar_file = tar_entries
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("No TAR file in GZIP"))?;
|
||||
|
||||
|
||||
let tar_path = temp_dir.path().join(&tar_file.path);
|
||||
|
||||
|
||||
let mut tar_processor = TarProcessor::with_config(self.config.clone());
|
||||
let tar_metadata = tar_processor.open(&tar_path)?;
|
||||
|
||||
|
||||
Ok(ArchiveMetadata {
|
||||
format: ArchiveFormat::TarGzip,
|
||||
total_files: tar_metadata.total_files,
|
||||
@@ -576,46 +638,47 @@ impl ArchiveProcessor for TarGzipProcessor {
|
||||
modified_time: Some(SystemTime::now()),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||
// Need to implement properly - this requires decompressing first
|
||||
warn!("TAR.GZ list_entries requires full decompression - consider extract_all instead");
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
|
||||
fn extract_file(&mut self, entry_path: &Path, output: &mut Vec<u8>) -> Result<u64> {
|
||||
warn!("TAR.GZ extract_file requires full unpacking - inefficient for single file");
|
||||
|
||||
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
self.extract_all(temp_dir.path())?;
|
||||
|
||||
|
||||
let file_path = temp_dir.path().join(entry_path);
|
||||
let mut file = File::open(&file_path)?;
|
||||
output.clear();
|
||||
file.read_to_end(output)?;
|
||||
|
||||
|
||||
Ok(output.len() as u64)
|
||||
}
|
||||
|
||||
|
||||
fn extract_all(&mut self, output_dir: &Path) -> Result<ExtractResult> {
|
||||
info!("Extracting TAR.GZ to: {}", output_dir.display());
|
||||
|
||||
|
||||
// Step 1: Decompress GZIP to temp
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
self.gzip_processor.extract_all(temp_dir.path())?;
|
||||
|
||||
|
||||
// Step 2: Extract TAR
|
||||
let tar_entries = self.gzip_processor.list_entries()?;
|
||||
let tar_file = tar_entries.first()
|
||||
let tar_file = tar_entries
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("No TAR file found"))?;
|
||||
|
||||
|
||||
let tar_path = temp_dir.path().join(&tar_file.path);
|
||||
|
||||
|
||||
let mut tar_processor = TarProcessor::with_config(self.config.clone());
|
||||
tar_processor.open(&tar_path)?;
|
||||
tar_processor.extract_all(output_dir)
|
||||
}
|
||||
|
||||
|
||||
fn can_process(format: ArchiveFormat) -> bool {
|
||||
format == ArchiveFormat::TarGzip
|
||||
}
|
||||
@@ -627,73 +690,133 @@ impl ArchiveProcessor for TarGzipProcessor {
|
||||
pub struct ZstdProcessor;
|
||||
|
||||
impl ArchiveProcessor for ZstdProcessor {
|
||||
fn format(&self) -> ArchiveFormat { ArchiveFormat::Zstd }
|
||||
fn format(&self) -> ArchiveFormat {
|
||||
ArchiveFormat::Zstd
|
||||
}
|
||||
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
||||
Err(anyhow!("ZSTD processor not yet implemented"))
|
||||
}
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> { Ok(Vec::new()) }
|
||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> { Ok(0) }
|
||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> { Ok(ExtractResult::new()) }
|
||||
fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::Zstd }
|
||||
fn new() -> Self { Self }
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> {
|
||||
Ok(0)
|
||||
}
|
||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> {
|
||||
Ok(ExtractResult::new())
|
||||
}
|
||||
fn can_process(format: ArchiveFormat) -> bool {
|
||||
format == ArchiveFormat::Zstd
|
||||
}
|
||||
fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
/// BZIP2 Processor Stub (Phase 2/3)
|
||||
pub struct Bzip2Processor;
|
||||
|
||||
impl ArchiveProcessor for Bzip2Processor {
|
||||
fn format(&self) -> ArchiveFormat { ArchiveFormat::Bzip2 }
|
||||
fn format(&self) -> ArchiveFormat {
|
||||
ArchiveFormat::Bzip2
|
||||
}
|
||||
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
||||
Err(anyhow!("BZIP2 processor not yet implemented"))
|
||||
}
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> { Ok(Vec::new()) }
|
||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> { Ok(0) }
|
||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> { Ok(ExtractResult::new()) }
|
||||
fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::Bzip2 }
|
||||
fn new() -> Self { Self }
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> {
|
||||
Ok(0)
|
||||
}
|
||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> {
|
||||
Ok(ExtractResult::new())
|
||||
}
|
||||
fn can_process(format: ArchiveFormat) -> bool {
|
||||
format == ArchiveFormat::Bzip2
|
||||
}
|
||||
fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
/// LZ4 Processor Stub (Phase 2/3)
|
||||
pub struct Lz4Processor;
|
||||
|
||||
impl ArchiveProcessor for Lz4Processor {
|
||||
fn format(&self) -> ArchiveFormat { ArchiveFormat::Lz4 }
|
||||
fn format(&self) -> ArchiveFormat {
|
||||
ArchiveFormat::Lz4
|
||||
}
|
||||
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
||||
Err(anyhow!("LZ4 processor not yet implemented"))
|
||||
}
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> { Ok(Vec::new()) }
|
||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> { Ok(0) }
|
||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> { Ok(ExtractResult::new()) }
|
||||
fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::Lz4 }
|
||||
fn new() -> Self { Self }
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> {
|
||||
Ok(0)
|
||||
}
|
||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> {
|
||||
Ok(ExtractResult::new())
|
||||
}
|
||||
fn can_process(format: ArchiveFormat) -> bool {
|
||||
format == ArchiveFormat::Lz4
|
||||
}
|
||||
fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
/// TAR.BZ2 Composite Processor Stub (Phase 2/3)
|
||||
pub struct TarBzip2Processor;
|
||||
|
||||
impl ArchiveProcessor for TarBzip2Processor {
|
||||
fn format(&self) -> ArchiveFormat { ArchiveFormat::TarBzip2 }
|
||||
fn format(&self) -> ArchiveFormat {
|
||||
ArchiveFormat::TarBzip2
|
||||
}
|
||||
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
||||
Err(anyhow!("TAR.BZ2 processor not yet implemented"))
|
||||
}
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> { Ok(Vec::new()) }
|
||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> { Ok(0) }
|
||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> { Ok(ExtractResult::new()) }
|
||||
fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::TarBzip2 }
|
||||
fn new() -> Self { Self }
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> {
|
||||
Ok(0)
|
||||
}
|
||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> {
|
||||
Ok(ExtractResult::new())
|
||||
}
|
||||
fn can_process(format: ArchiveFormat) -> bool {
|
||||
format == ArchiveFormat::TarBzip2
|
||||
}
|
||||
fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
/// TAR.ZST Composite Processor Stub (Phase 2/3)
|
||||
pub struct TarZstdProcessor;
|
||||
|
||||
impl ArchiveProcessor for TarZstdProcessor {
|
||||
fn format(&self) -> ArchiveFormat { ArchiveFormat::TarZstd }
|
||||
fn format(&self) -> ArchiveFormat {
|
||||
ArchiveFormat::TarZstd
|
||||
}
|
||||
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
||||
Err(anyhow!("TAR.ZST processor not yet implemented"))
|
||||
}
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> { Ok(Vec::new()) }
|
||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> { Ok(0) }
|
||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> { Ok(ExtractResult::new()) }
|
||||
fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::TarZstd }
|
||||
fn new() -> Self { Self }
|
||||
}
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> {
|
||||
Ok(0)
|
||||
}
|
||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> {
|
||||
Ok(ExtractResult::new())
|
||||
}
|
||||
fn can_process(format: ArchiveFormat) -> bool {
|
||||
format == ArchiveFormat::TarZstd
|
||||
}
|
||||
fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
// Optional Format Processors - RAR, XZ, 7z
|
||||
// All optional formats have warnings displayed when enabled
|
||||
|
||||
use crate::archive::{ArchiveFormat, ArchiveProcessor, ArchiveMetadata, ArchiveEntry, ExtractResult};
|
||||
use crate::archive::processor::{check_decompression_ratio, validate_extraction_path};
|
||||
use crate::archive::warning;
|
||||
use crate::archive::processor::{validate_extraction_path, check_decompression_ratio};
|
||||
use anyhow::{Result, anyhow};
|
||||
use std::path::Path;
|
||||
use crate::archive::{
|
||||
ArchiveEntry, ArchiveFormat, ArchiveMetadata, ArchiveProcessor, ExtractResult,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use log::{info, warn};
|
||||
use std::fs;
|
||||
use log::{warn, info};
|
||||
use std::path::Path;
|
||||
|
||||
/// RAR Processor - Only Decompression
|
||||
/// ⚠️ Legal Warning: RARLAB patent, commercial use requires license
|
||||
@@ -28,54 +30,65 @@ impl ArchiveProcessor for RarProcessor {
|
||||
fn format(&self) -> ArchiveFormat {
|
||||
ArchiveFormat::Rar
|
||||
}
|
||||
|
||||
|
||||
fn open(&mut self, path: &Path) -> Result<ArchiveMetadata> {
|
||||
// Show legal warning when RAR is used
|
||||
warning::show_rar_legal_warning();
|
||||
|
||||
|
||||
self.archive_path = Some(path.to_path_buf());
|
||||
|
||||
|
||||
// Use unrar library to open RAR
|
||||
// Note: unrar only supports decompression, no compression
|
||||
use unrar::Archive;
|
||||
|
||||
|
||||
let archive = Archive::new(path)?;
|
||||
|
||||
|
||||
let entries: Vec<_> = archive.list()?.collect();
|
||||
let total_files = entries.len() as u64;
|
||||
|
||||
let total_size = entries.iter()
|
||||
|
||||
let total_size = entries
|
||||
.iter()
|
||||
.filter_map(|e| e.ok())
|
||||
.map(|e| e.uncompressed_size)
|
||||
.sum();
|
||||
|
||||
|
||||
let compressed_size = fs::metadata(path)?.len();
|
||||
|
||||
|
||||
Ok(ArchiveMetadata {
|
||||
format: ArchiveFormat::Rar,
|
||||
total_files,
|
||||
total_size,
|
||||
compressed_size,
|
||||
compression_ratio: if compressed_size > 0 { total_size as f64 / compressed_size as f64 } else { 0.0 },
|
||||
is_encrypted: entries.iter().any(|e| e.ok().map_or(false, |e| e.is_encrypted())),
|
||||
is_multi_volume: false, // unrar library limitation
|
||||
compression_ratio: if compressed_size > 0 {
|
||||
total_size as f64 / compressed_size as f64
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
is_encrypted: entries
|
||||
.iter()
|
||||
.any(|e| e.ok().map_or(false, |e| e.is_encrypted())),
|
||||
is_multi_volume: false, // unrar library limitation
|
||||
created_time: None,
|
||||
modified_time: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||
use unrar::Archive;
|
||||
|
||||
let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||
|
||||
let path = self
|
||||
.archive_path
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||
let archive = Archive::new(path)?;
|
||||
|
||||
let entries: Vec<ArchiveEntry> = archive.list()?
|
||||
|
||||
let entries: Vec<ArchiveEntry> = archive
|
||||
.list()?
|
||||
.filter_map(|e| e.ok())
|
||||
.map(|e| ArchiveEntry {
|
||||
path: PathBuf::from(e.filename),
|
||||
size: e.uncompressed_size,
|
||||
compressed_size: 0, // unrar doesn't provide this
|
||||
compressed_size: 0, // unrar doesn't provide this
|
||||
is_dir: e.is_directory(),
|
||||
is_file: !e.is_directory(),
|
||||
is_encrypted: e.is_encrypted(),
|
||||
@@ -83,45 +96,49 @@ impl ArchiveProcessor for RarProcessor {
|
||||
permissions: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
|
||||
fn extract_file(&self, entry_path: &Path, output: &mut Vec<u8>) -> Result<u64> {
|
||||
// RAR doesn't support random access efficiently
|
||||
// Need to extract entire archive
|
||||
warn!("RAR extract_file requires full extraction (no random access)");
|
||||
|
||||
|
||||
let entries = self.list_entries()?;
|
||||
let entry = entries.iter()
|
||||
let entry = entries
|
||||
.iter()
|
||||
.find(|e| e.path == entry_path)
|
||||
.ok_or_else(|| anyhow!("Entry not found: {}", entry_path.display()))?;
|
||||
|
||||
|
||||
// Extract to temp dir, then read
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
self.extract_all(temp_dir.path())?;
|
||||
|
||||
|
||||
let extracted_file = temp_dir.path().join(entry_path);
|
||||
let content = fs::read(&extracted_file)?;
|
||||
output.extend_from_slice(&content);
|
||||
|
||||
|
||||
Ok(content.len() as u64)
|
||||
}
|
||||
|
||||
|
||||
fn extract_all(&self, output_dir: &Path) -> Result<ExtractResult> {
|
||||
use unrar::Archive;
|
||||
use unrar::ExtractOption;
|
||||
|
||||
let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||
|
||||
|
||||
let path = self
|
||||
.archive_path
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||
|
||||
// Validate output_dir path
|
||||
validate_extraction_path(output_dir, output_dir)?;
|
||||
|
||||
|
||||
let mut result = ExtractResult::new();
|
||||
result.total_files = self.list_entries()?.len() as u64;
|
||||
|
||||
|
||||
let archive = Archive::new(path)?;
|
||||
|
||||
|
||||
for entry_result in archive.extract_all(output_dir, ExtractOption::Recurse)? {
|
||||
match entry_result {
|
||||
Ok(entry) => {
|
||||
@@ -135,10 +152,10 @@ impl ArchiveProcessor for RarProcessor {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
|
||||
fn can_process(format: ArchiveFormat) -> bool {
|
||||
format == ArchiveFormat::Rar
|
||||
}
|
||||
@@ -163,57 +180,65 @@ impl ArchiveProcessor for XzProcessor {
|
||||
fn format(&self) -> ArchiveFormat {
|
||||
ArchiveFormat::Xz
|
||||
}
|
||||
|
||||
|
||||
fn open(&mut self, path: &Path) -> Result<ArchiveMetadata> {
|
||||
// Check if liblzma is available
|
||||
if !check_liblzma_available() {
|
||||
warning::show_xz_dependency_warning();
|
||||
return Err(anyhow!("liblzma library not found, XZ format disabled"));
|
||||
}
|
||||
|
||||
|
||||
self.archive_path = Some(path.to_path_buf());
|
||||
|
||||
use xz2::read::XzDecoder;
|
||||
|
||||
use std::io::Read;
|
||||
|
||||
use xz2::read::XzDecoder;
|
||||
|
||||
let file = fs::File::open(path)?;
|
||||
let mut decoder = XzDecoder::new(file);
|
||||
|
||||
|
||||
// Read decompressed size (estimate)
|
||||
let mut buffer = Vec::new();
|
||||
decoder.read_to_end(&mut buffer)?;
|
||||
|
||||
|
||||
let decompressed_size = buffer.len() as u64;
|
||||
let compressed_size = fs::metadata(path)?.len();
|
||||
|
||||
|
||||
// Check decompression ratio
|
||||
check_decompression_ratio(compressed_size, decompressed_size, 1000)?;
|
||||
|
||||
|
||||
Ok(ArchiveMetadata {
|
||||
format: ArchiveFormat::Xz,
|
||||
total_files: 1, // XZ is single-file format
|
||||
total_files: 1, // XZ is single-file format
|
||||
total_size: decompressed_size,
|
||||
compressed_size,
|
||||
compression_ratio: if compressed_size > 0 { decompressed_size as f64 / compressed_size as f64 } else { 0.0 },
|
||||
compression_ratio: if compressed_size > 0 {
|
||||
decompressed_size as f64 / compressed_size as f64
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
is_encrypted: false,
|
||||
is_multi_volume: false,
|
||||
created_time: None,
|
||||
modified_time: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||
// XZ is single-file, infer filename from archive name
|
||||
let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||
|
||||
let filename = path.file_name()
|
||||
let path = self
|
||||
.archive_path
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||
|
||||
let filename = path
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.map(|s| s.strip_suffix(".xz").unwrap_or(s))
|
||||
.unwrap_or("output");
|
||||
|
||||
|
||||
Ok(vec![ArchiveEntry {
|
||||
path: PathBuf::from(filename),
|
||||
size: 0, // Will be determined during extraction
|
||||
size: 0, // Will be determined during extraction
|
||||
compressed_size: 0,
|
||||
is_dir: false,
|
||||
is_file: true,
|
||||
@@ -222,48 +247,54 @@ impl ArchiveProcessor for XzProcessor {
|
||||
permissions: None,
|
||||
}])
|
||||
}
|
||||
|
||||
|
||||
fn extract_file(&self, _entry_path: &Path, output: &mut Vec<u8>) -> Result<u64> {
|
||||
use xz2::read::XzDecoder;
|
||||
use std::io::Read;
|
||||
|
||||
let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||
|
||||
use xz2::read::XzDecoder;
|
||||
|
||||
let path = self
|
||||
.archive_path
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||
|
||||
let file = fs::File::open(path)?;
|
||||
let mut decoder = XzDecoder::new(file);
|
||||
|
||||
|
||||
decoder.read_to_end(output)?;
|
||||
|
||||
|
||||
Ok(output.len() as u64)
|
||||
}
|
||||
|
||||
|
||||
fn extract_all(&self, output_dir: &Path) -> Result<ExtractResult> {
|
||||
use xz2::read::XzDecoder;
|
||||
use std::io::Read;
|
||||
|
||||
let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||
|
||||
use xz2::read::XzDecoder;
|
||||
|
||||
let path = self
|
||||
.archive_path
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||
|
||||
// Infer output filename
|
||||
let entries = self.list_entries()?;
|
||||
let output_path = output_dir.join(&entries[0].path);
|
||||
|
||||
|
||||
// Validate path
|
||||
validate_extraction_path(&entries[0].path, output_dir)?;
|
||||
|
||||
|
||||
let file = fs::File::open(path)?;
|
||||
let mut decoder = XzDecoder::new(file);
|
||||
|
||||
|
||||
let mut output_file = fs::File::create(&output_path)?;
|
||||
std::io::copy(&mut decoder, &mut output_file)?;
|
||||
|
||||
|
||||
let mut result = ExtractResult::new();
|
||||
result.success_files = 1;
|
||||
result.total_files = 1;
|
||||
result.total_bytes = fs::metadata(&output_path)?.len();
|
||||
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
|
||||
fn can_process(format: ArchiveFormat) -> bool {
|
||||
format == ArchiveFormat::Xz && check_liblzma_available()
|
||||
}
|
||||
@@ -286,59 +317,61 @@ impl ArchiveProcessor for SevenZProcessor {
|
||||
fn format(&self) -> ArchiveFormat {
|
||||
ArchiveFormat::SevenZ
|
||||
}
|
||||
|
||||
|
||||
fn open(&mut self, path: &Path) -> Result<ArchiveMetadata> {
|
||||
// Show stability warning
|
||||
warning::show_7z_stability_warning();
|
||||
|
||||
|
||||
use sevenz_rust::SevenZReader;
|
||||
|
||||
|
||||
let reader = SevenZReader::new(path)?;
|
||||
|
||||
|
||||
let entries = reader.entries()?;
|
||||
let total_files = entries.len() as u64;
|
||||
|
||||
let total_size = entries.iter()
|
||||
.map(|e| e.uncompressed_size as u64)
|
||||
.sum();
|
||||
|
||||
|
||||
let total_size = entries.iter().map(|e| e.uncompressed_size as u64).sum();
|
||||
|
||||
let compressed_size = fs::metadata(path)?.len();
|
||||
|
||||
|
||||
Ok(ArchiveMetadata {
|
||||
format: ArchiveFormat::SevenZ,
|
||||
total_files,
|
||||
total_size,
|
||||
compressed_size,
|
||||
compression_ratio: if compressed_size > 0 { total_size as f64 / compressed_size as f64 } else { 0.0 },
|
||||
compression_ratio: if compressed_size > 0 {
|
||||
total_size as f64 / compressed_size as f64
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
is_encrypted: entries.iter().any(|e| e.is_encrypted),
|
||||
is_multi_volume: false,
|
||||
created_time: None,
|
||||
modified_time: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||
// Note: sevenz-rust doesn't have full entry listing yet
|
||||
// This is a stub returning empty list
|
||||
warn!("7z list_entries not fully implemented (library limitation)");
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
|
||||
fn extract_file(&self, _entry_path: &Path, _output: &mut Vec<u8>) -> Result<u64> {
|
||||
warn!("7z extract_file not implemented (library limitation)");
|
||||
Err(anyhow!("7z library doesn't support random access"))
|
||||
}
|
||||
|
||||
|
||||
fn extract_all(&self, output_dir: &Path) -> Result<ExtractResult> {
|
||||
use sevenz_rust::SevenZReader;
|
||||
|
||||
|
||||
// Note: sevenz-rust doesn't have full extraction yet
|
||||
// This is a stub
|
||||
warn!("7z extract_all limited (library under development)");
|
||||
|
||||
|
||||
Ok(ExtractResult::new())
|
||||
}
|
||||
|
||||
|
||||
fn can_process(format: ArchiveFormat) -> bool {
|
||||
format == ArchiveFormat::SevenZ
|
||||
}
|
||||
@@ -369,15 +402,21 @@ pub struct SevenZProcessor;
|
||||
|
||||
#[cfg(not(feature = "optional-formats"))]
|
||||
impl RarProcessor {
|
||||
pub fn new() -> Self { Self }
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "optional-formats"))]
|
||||
impl XzProcessor {
|
||||
pub fn new() -> Self { Self }
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "optional-formats"))]
|
||||
impl SevenZProcessor {
|
||||
pub fn new() -> Self { Self }
|
||||
}
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,31 +1,31 @@
|
||||
use crate::archive::{
|
||||
ArchiveProcessor, ArchiveFormat, ArchiveMetadata, ArchiveEntry, ExtractResult,
|
||||
processors::core::{ZipProcessor, TarProcessor, GzipProcessor, TarGzipProcessor},
|
||||
processor::{validate_extraction_path, check_decompression_ratio},
|
||||
config::ArchiveConfig,
|
||||
processor::{check_decompression_ratio, validate_extraction_path},
|
||||
processors::core::{GzipProcessor, TarGzipProcessor, TarProcessor, ZipProcessor},
|
||||
ArchiveEntry, ArchiveFormat, ArchiveMetadata, ArchiveProcessor, ExtractResult,
|
||||
};
|
||||
use tempfile::TempDir;
|
||||
use std::fs::{File, create_dir_all};
|
||||
use anyhow::Result;
|
||||
use std::fs::{create_dir_all, File};
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use anyhow::Result;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[cfg(test)]
|
||||
mod helpers {
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
|
||||
pub fn create_test_zip(path: &PathBuf, files: Vec<(&str, &[u8])>) {
|
||||
use std::io::Cursor;
|
||||
|
||||
|
||||
let mut buffer = Cursor::new(Vec::new());
|
||||
{
|
||||
let mut zip = zip::ZipWriter::new(&mut buffer);
|
||||
|
||||
|
||||
let options = zip::write::FileOptions::default()
|
||||
.compression_method(zip::CompressionMethod::Stored);
|
||||
|
||||
|
||||
for (name, content) in files {
|
||||
if name.ends_with('/') {
|
||||
zip.add_directory(name, options).unwrap();
|
||||
@@ -34,31 +34,31 @@ mod helpers {
|
||||
zip.write_all(content).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
zip.finish().unwrap();
|
||||
}
|
||||
|
||||
|
||||
let zip_data = buffer.into_inner();
|
||||
File::create(path).unwrap().write_all(&zip_data).unwrap();
|
||||
}
|
||||
|
||||
|
||||
pub fn create_test_tar(path: &PathBuf, files: Vec<(&str, &[u8])>) {
|
||||
let file = File::create(path).unwrap();
|
||||
let mut builder = tar::Builder::new(file);
|
||||
|
||||
|
||||
for (name, content) in files {
|
||||
let mut header = tar::Header::new_gnu();
|
||||
header.set_size(content.len() as u64);
|
||||
header.set_path(name);
|
||||
header.set_mode(0o644);
|
||||
header.set_cksum();
|
||||
|
||||
|
||||
builder.append_data(&mut header, name, content).unwrap();
|
||||
}
|
||||
|
||||
|
||||
builder.finish().unwrap();
|
||||
}
|
||||
|
||||
|
||||
pub fn create_test_gzip(path: &PathBuf, content: &[u8]) {
|
||||
let file = File::create(path).unwrap();
|
||||
let mut encoder = flate2::write::GzEncoder::new(file, flate2::Compression::default());
|
||||
@@ -69,74 +69,74 @@ mod helpers {
|
||||
|
||||
#[cfg(test)]
|
||||
mod core_format_tests {
|
||||
use super::*;
|
||||
use super::helpers::*;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_zip_processor_basic() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let zip_path = temp_dir.path().join("test.zip");
|
||||
create_test_zip(&zip_path, vec![("file1.txt", b"hello")]);
|
||||
|
||||
|
||||
let mut processor = ZipProcessor::new();
|
||||
let metadata = processor.open(&zip_path).unwrap();
|
||||
|
||||
|
||||
assert_eq!(metadata.format, ArchiveFormat::Zip);
|
||||
assert_eq!(metadata.total_files, 1);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_tar_processor_basic() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let tar_path = temp_dir.path().join("test.tar");
|
||||
create_test_tar(&tar_path, vec![("file1.txt", b"tar content")]);
|
||||
|
||||
|
||||
let mut processor = TarProcessor::new();
|
||||
let metadata = processor.open(&tar_path).unwrap();
|
||||
|
||||
|
||||
assert_eq!(metadata.format, ArchiveFormat::Tar);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_gzip_processor_basic() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let gz_path = temp_dir.path().join("test.gz");
|
||||
create_test_gzip(&gz_path, b"gzip content here");
|
||||
|
||||
|
||||
let mut processor = GzipProcessor::new();
|
||||
let metadata = processor.open(&gz_path).unwrap();
|
||||
|
||||
|
||||
assert_eq!(metadata.format, ArchiveFormat::Gzip);
|
||||
assert_eq!(metadata.total_files, 1);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_validate_extraction_path_safe() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let base = temp_dir.path();
|
||||
let safe_path = PathBuf::from("safe/file.txt");
|
||||
|
||||
|
||||
let result = validate_extraction_path(&safe_path, base);
|
||||
assert!(result.is_ok());
|
||||
|
||||
|
||||
let resolved = result.unwrap();
|
||||
assert!(resolved.starts_with(base));
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_validate_extraction_path_zip_slip() {
|
||||
let base = PathBuf::from("/tmp/extract");
|
||||
let evil_path = PathBuf::from("../../etc/passwd");
|
||||
|
||||
|
||||
let result = validate_extraction_path(&evil_path, &base);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_check_decompression_ratio_ok() {
|
||||
assert!(check_decompression_ratio(1000, 5000, 1000).is_ok());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_check_decompression_ratio_zip_bomb() {
|
||||
assert!(check_decompression_ratio(42_000, 5_000_000_000, 1000).is_err());
|
||||
@@ -145,39 +145,39 @@ mod core_format_tests {
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests {
|
||||
use super::*;
|
||||
use super::helpers::*;
|
||||
use super::*;
|
||||
use crate::archive::detector::FormatDetector;
|
||||
use crate::archive::ProcessorRegistry;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_format_detection_automation() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let detector = FormatDetector::new();
|
||||
|
||||
|
||||
let zip_path = temp_dir.path().join("test.zip");
|
||||
create_test_zip(&zip_path, vec![("f.txt", b"z")]);
|
||||
assert_eq!(detector.detect(&zip_path).unwrap(), ArchiveFormat::Zip);
|
||||
|
||||
|
||||
let tar_path = temp_dir.path().join("test.tar");
|
||||
create_test_tar(&tar_path, vec![("f.txt", b"t")]);
|
||||
assert_eq!(detector.detect(&tar_path).unwrap(), ArchiveFormat::Tar);
|
||||
|
||||
|
||||
let gz_path = temp_dir.path().join("test.gz");
|
||||
create_test_gzip(&gz_path, b"g");
|
||||
assert_eq!(detector.detect(&gz_path).unwrap(), ArchiveFormat::Gzip);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_processor_registry_integration() {
|
||||
let config = ArchiveConfig::default();
|
||||
let mut registry = ProcessorRegistry::new(config);
|
||||
registry.initialize().unwrap();
|
||||
|
||||
|
||||
let formats = registry.enabled_formats();
|
||||
assert!(formats.contains(&ArchiveFormat::Zip));
|
||||
assert!(formats.contains(&ArchiveFormat::Tar));
|
||||
assert!(formats.contains(&ArchiveFormat::Gzip));
|
||||
assert!(formats.contains(&ArchiveFormat::TarGzip));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,48 +4,46 @@ use std::fs;
|
||||
use std::io::Read;
|
||||
use tempfile::TempDir;
|
||||
|
||||
use crate::archive::*;
|
||||
use crate::archive::processor::check_decompression_ratio;
|
||||
use crate::archive::tests::test_helpers::*;
|
||||
use crate::archive::*;
|
||||
|
||||
#[test]
|
||||
fn test_zip_processor_full_workflow() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let zip_path = create_test_zip(&temp_dir);
|
||||
|
||||
|
||||
// Initialize processor
|
||||
let mut processor = processors::core::ZipProcessor::new();
|
||||
|
||||
|
||||
// Test open
|
||||
let metadata = processor.open(&zip_path).unwrap();
|
||||
assert_eq!(metadata.format, ArchiveFormat::Zip);
|
||||
assert_eq!(metadata.total_files, 3);
|
||||
|
||||
|
||||
// Test list_entries
|
||||
let entries = processor.list_entries().unwrap();
|
||||
assert_eq!(entries.len(), 3);
|
||||
|
||||
|
||||
// Verify entry names
|
||||
let names: Vec<&str> = entries.iter()
|
||||
.map(|e| e.path.to_str().unwrap())
|
||||
.collect();
|
||||
let names: Vec<&str> = entries.iter().map(|e| e.path.to_str().unwrap()).collect();
|
||||
assert!(names.contains(&"file1.txt"));
|
||||
assert!(names.contains(&"file2.txt"));
|
||||
assert!(names.contains(&"subdir/file3.txt"));
|
||||
|
||||
|
||||
// Test extract_all
|
||||
let extract_dir = temp_dir.path().join("extracted");
|
||||
fs::create_dir_all(&extract_dir).unwrap();
|
||||
|
||||
|
||||
let result = processor.extract_all(&extract_dir).unwrap();
|
||||
assert_eq!(result.success_files, 3);
|
||||
assert_eq!(result.failed_files.len(), 0);
|
||||
|
||||
|
||||
// Verify extracted files
|
||||
assert!(extract_dir.join("file1.txt").exists());
|
||||
assert!(extract_dir.join("file2.txt").exists());
|
||||
assert!(extract_dir.join("subdir/file3.txt").exists());
|
||||
|
||||
|
||||
// Verify content
|
||||
let content1 = fs::read_to_string(extract_dir.join("file1.txt")).unwrap();
|
||||
assert_eq!(content1, "content of file 1");
|
||||
@@ -55,24 +53,24 @@ fn test_zip_processor_full_workflow() {
|
||||
fn test_tar_processor_full_workflow() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let tar_path = create_test_tar(&temp_dir);
|
||||
|
||||
|
||||
let mut processor = processors::core::TarProcessor::new();
|
||||
|
||||
|
||||
// Test open
|
||||
let metadata = processor.open(&tar_path).unwrap();
|
||||
assert_eq!(metadata.format, ArchiveFormat::Tar);
|
||||
|
||||
|
||||
// Test list_entries
|
||||
let entries = processor.list_entries().unwrap();
|
||||
assert!(entries.len() >= 3); // TAR may include directory entries
|
||||
|
||||
assert!(entries.len() >= 3); // TAR may include directory entries
|
||||
|
||||
// Test extract_all
|
||||
let extract_dir = temp_dir.path().join("extracted_tar");
|
||||
fs::create_dir_all(&extract_dir).unwrap();
|
||||
|
||||
|
||||
let result = processor.extract_all(&extract_dir).unwrap();
|
||||
assert!(result.success_files >= 3);
|
||||
|
||||
|
||||
// Verify extracted files exist
|
||||
assert!(extract_dir.join("file1.txt").exists());
|
||||
assert!(extract_dir.join("file2.txt").exists());
|
||||
@@ -82,25 +80,25 @@ fn test_tar_processor_full_workflow() {
|
||||
fn test_gzip_processor_full_workflow() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let gz_path = create_test_gzip(&temp_dir);
|
||||
|
||||
|
||||
let mut processor = processors::core::GzipProcessor::new();
|
||||
|
||||
|
||||
// Test open
|
||||
let metadata = processor.open(&gz_path).unwrap();
|
||||
assert_eq!(metadata.format, ArchiveFormat::Gzip);
|
||||
assert_eq!(metadata.total_files, 1); // GZIP is single file
|
||||
|
||||
assert_eq!(metadata.total_files, 1); // GZIP is single file
|
||||
|
||||
// Test extract_all
|
||||
let extract_dir = temp_dir.path().join("extracted_gz");
|
||||
fs::create_dir_all(&extract_dir).unwrap();
|
||||
|
||||
|
||||
let result = processor.extract_all(&extract_dir).unwrap();
|
||||
assert_eq!(result.success_files, 1);
|
||||
|
||||
|
||||
// Verify extracted file (should strip .gz extension)
|
||||
let extracted_file = extract_dir.join("test.txt");
|
||||
assert!(extracted_file.exists());
|
||||
|
||||
|
||||
// Verify content
|
||||
let content = fs::read_to_string(&extracted_file).unwrap();
|
||||
assert_eq!(content, "test gzip content for validation");
|
||||
@@ -110,20 +108,20 @@ fn test_gzip_processor_full_workflow() {
|
||||
fn test_tar_gz_processor_workflow() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let tar_gz_path = create_test_tar_gz(&temp_dir);
|
||||
|
||||
|
||||
let mut processor = processors::core::TarGzipProcessor::new();
|
||||
|
||||
|
||||
// Test open
|
||||
let metadata = processor.open(&tar_gz_path).unwrap();
|
||||
assert_eq!(metadata.format, ArchiveFormat::TarGzip);
|
||||
|
||||
|
||||
// Test extract_all
|
||||
let extract_dir = temp_dir.path().join("extracted_tar_gz");
|
||||
fs::create_dir_all(&extract_dir).unwrap();
|
||||
|
||||
|
||||
let result = processor.extract_all(&extract_dir).unwrap();
|
||||
assert!(result.success_files >= 2);
|
||||
|
||||
|
||||
// Verify extracted TAR files
|
||||
assert!(extract_dir.join("file1.txt").exists());
|
||||
assert!(extract_dir.join("file2.txt").exists());
|
||||
@@ -132,18 +130,18 @@ fn test_tar_gz_processor_workflow() {
|
||||
#[test]
|
||||
fn test_format_detection_auto() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
|
||||
|
||||
// Test ZIP detection
|
||||
let zip_path = create_test_zip(&temp_dir);
|
||||
let detector = FormatDetector::new();
|
||||
let format = detector.detect(&zip_path).unwrap();
|
||||
assert_eq!(format, ArchiveFormat::Zip);
|
||||
|
||||
|
||||
// Test TAR detection
|
||||
let tar_path = create_test_tar(&temp_dir);
|
||||
let format = detector.detect(&tar_path).unwrap();
|
||||
assert_eq!(format, ArchiveFormat::Tar);
|
||||
|
||||
|
||||
// Test GZIP detection
|
||||
let gz_path = create_test_gzip(&temp_dir);
|
||||
let format = detector.detect(&gz_path).unwrap();
|
||||
@@ -155,12 +153,12 @@ fn test_processor_registry_core_formats() {
|
||||
let config = ArchiveConfig::default();
|
||||
let mut registry = ProcessorRegistry::new(config);
|
||||
registry.initialize().unwrap();
|
||||
|
||||
|
||||
let formats = registry.enabled_formats();
|
||||
|
||||
|
||||
// Should have 9 core formats
|
||||
assert!(formats.len() >= 4); // At least the ones we implemented
|
||||
|
||||
assert!(formats.len() >= 4); // At least the ones we implemented
|
||||
|
||||
// Verify format support
|
||||
assert!(formats.contains(&ArchiveFormat::Zip));
|
||||
assert!(formats.contains(&ArchiveFormat::Tar));
|
||||
@@ -172,20 +170,20 @@ fn test_processor_registry_core_formats() {
|
||||
fn test_zip_slip_protection() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let zip_bomb_data = create_zip_slip_test();
|
||||
|
||||
|
||||
// Write malicious ZIP to file
|
||||
let evil_zip_path = temp_dir.path().join("evil.zip");
|
||||
fs::write(&evil_zip_path, &zip_bomb_data).unwrap();
|
||||
|
||||
|
||||
let mut processor = processors::core::ZipProcessor::new();
|
||||
processor.open(&evil_zip_path).unwrap();
|
||||
|
||||
|
||||
// Attempt extraction should fail due to Zip Slip protection
|
||||
let extract_dir = temp_dir.path().join("should_fail");
|
||||
fs::create_dir_all(&extract_dir).unwrap();
|
||||
|
||||
|
||||
let result = processor.extract_all(&extract_dir);
|
||||
|
||||
|
||||
// Should either fail or have empty extracted files
|
||||
// (validate_extraction_path prevents malicious paths)
|
||||
if result.is_ok() {
|
||||
@@ -199,11 +197,11 @@ fn test_zip_slip_protection() {
|
||||
fn test_zip_bomb_detection() {
|
||||
// Test decompression ratio check
|
||||
let result = check_decompression_ratio(42_000, 5_000_000_000, 1000);
|
||||
assert!(result.is_err()); // Should detect as Zip Bomb
|
||||
|
||||
assert!(result.is_err()); // Should detect as Zip Bomb
|
||||
|
||||
// Test normal ratio
|
||||
let result = check_decompression_ratio(1000, 5000, 1000);
|
||||
assert!(result.is_ok()); // Normal ratio should pass
|
||||
assert!(result.is_ok()); // Normal ratio should pass
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -219,21 +217,21 @@ fn test_metadata_compression_ratio() {
|
||||
created_time: None,
|
||||
modified_time: None,
|
||||
};
|
||||
|
||||
assert_eq!(metadata.actual_ratio(), 5.0); // 5000/1000 = 5.0
|
||||
assert!(!metadata.check_zip_bomb(10)); // ratio 5.0 < 10, not a bomb
|
||||
assert!(metadata.check_zip_bomb(4)); // ratio 5.0 > 4, detected as bomb
|
||||
|
||||
assert_eq!(metadata.actual_ratio(), 5.0); // 5000/1000 = 5.0
|
||||
assert!(!metadata.check_zip_bomb(10)); // ratio 5.0 < 10, not a bomb
|
||||
assert!(metadata.check_zip_bomb(4)); // ratio 5.0 > 4, detected as bomb
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_validation() {
|
||||
let config = ArchiveConfig {
|
||||
max_decompression_ratio: 5, // Too low
|
||||
max_decompression_ratio: 5, // Too low
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
||||
assert!(config.validate().is_err());
|
||||
|
||||
|
||||
let valid_config = ArchiveConfig::default();
|
||||
assert!(valid_config.validate().is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,10 +7,10 @@ pub mod test_helpers;
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_module_structure() {
|
||||
// Test that all test modules exist
|
||||
assert!(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,28 +1,27 @@
|
||||
use flate2::write::GzEncoder;
|
||||
use flate2::Compression;
|
||||
use std::fs::{self, File};
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::TempDir;
|
||||
use zip::{ZipWriter, write::FileOptions, CompressionMethod};
|
||||
use flate2::write::GzEncoder;
|
||||
use flate2::Compression;
|
||||
use tar::Builder;
|
||||
use tempfile::TempDir;
|
||||
use zip::{write::FileOptions, CompressionMethod, ZipWriter};
|
||||
|
||||
pub fn create_test_zip(temp_dir: &TempDir) -> PathBuf {
|
||||
let zip_path = temp_dir.path().join("test.zip");
|
||||
let file = File::create(&zip_path).unwrap();
|
||||
let mut zip = ZipWriter::new(file);
|
||||
let options = FileOptions::default()
|
||||
.compression_method(CompressionMethod::Stored);
|
||||
|
||||
let options = FileOptions::default().compression_method(CompressionMethod::Stored);
|
||||
|
||||
zip.start_file("file1.txt", options).unwrap();
|
||||
zip.write_all(b"content of file 1").unwrap();
|
||||
|
||||
|
||||
zip.start_file("file2.txt", options).unwrap();
|
||||
zip.write_all(b"content of file 2").unwrap();
|
||||
|
||||
|
||||
zip.start_file("subdir/file3.txt", options).unwrap();
|
||||
zip.write_all(b"content of file 3 in subdir").unwrap();
|
||||
|
||||
|
||||
zip.finish().unwrap();
|
||||
zip_path
|
||||
}
|
||||
@@ -31,28 +30,38 @@ pub fn create_test_tar(temp_dir: &TempDir) -> PathBuf {
|
||||
let tar_path = temp_dir.path().join("test.tar");
|
||||
let file = File::create(&tar_path).unwrap();
|
||||
let mut builder = Builder::new(file);
|
||||
|
||||
|
||||
let mut header1 = tar::Header::new_gnu();
|
||||
header1.set_path("file1.txt").unwrap();
|
||||
header1.set_size(17);
|
||||
header1.set_mode(0o644);
|
||||
header1.set_cksum();
|
||||
builder.append_data(&mut header1, "file1.txt", &b"content of file 1"[..]).unwrap();
|
||||
|
||||
builder
|
||||
.append_data(&mut header1, "file1.txt", &b"content of file 1"[..])
|
||||
.unwrap();
|
||||
|
||||
let mut header2 = tar::Header::new_gnu();
|
||||
header2.set_path("file2.txt").unwrap();
|
||||
header2.set_size(17);
|
||||
header2.set_mode(0o644);
|
||||
header2.set_cksum();
|
||||
builder.append_data(&mut header2, "file2.txt", &b"content of file 2"[..]).unwrap();
|
||||
|
||||
builder
|
||||
.append_data(&mut header2, "file2.txt", &b"content of file 2"[..])
|
||||
.unwrap();
|
||||
|
||||
let mut header3 = tar::Header::new_gnu();
|
||||
header3.set_path("subdir/file3.txt").unwrap();
|
||||
header3.set_size(27);
|
||||
header3.set_mode(0o644);
|
||||
header3.set_cksum();
|
||||
builder.append_data(&mut header3, "subdir/file3.txt", &b"content of file 3 in subdir"[..]).unwrap();
|
||||
|
||||
builder
|
||||
.append_data(
|
||||
&mut header3,
|
||||
"subdir/file3.txt",
|
||||
&b"content of file 3 in subdir"[..],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
builder.finish().unwrap();
|
||||
tar_path
|
||||
}
|
||||
@@ -61,7 +70,9 @@ pub fn create_test_gzip(temp_dir: &TempDir) -> PathBuf {
|
||||
let gz_path = temp_dir.path().join("test.txt.gz");
|
||||
let file = File::create(&gz_path).unwrap();
|
||||
let mut encoder = GzEncoder::new(file, Compression::default());
|
||||
encoder.write_all(b"test gzip content for validation").unwrap();
|
||||
encoder
|
||||
.write_all(b"test gzip content for validation")
|
||||
.unwrap();
|
||||
encoder.finish().unwrap();
|
||||
gz_path
|
||||
}
|
||||
@@ -70,33 +81,37 @@ pub fn create_test_tar_gz(temp_dir: &TempDir) -> PathBuf {
|
||||
let tar_path = temp_dir.path().join("test.tar");
|
||||
let tar_file = File::create(&tar_path).unwrap();
|
||||
let mut builder = Builder::new(tar_file);
|
||||
|
||||
|
||||
let mut header1 = tar::Header::new_gnu();
|
||||
header1.set_path("file1.txt").unwrap();
|
||||
header1.set_size(10);
|
||||
header1.set_mode(0o644);
|
||||
header1.set_cksum();
|
||||
builder.append_data(&mut header1, "file1.txt", &b"file1 data"[..]).unwrap();
|
||||
|
||||
builder
|
||||
.append_data(&mut header1, "file1.txt", &b"file1 data"[..])
|
||||
.unwrap();
|
||||
|
||||
let mut header2 = tar::Header::new_gnu();
|
||||
header2.set_path("file2.txt").unwrap();
|
||||
header2.set_size(10);
|
||||
header2.set_mode(0o644);
|
||||
header2.set_cksum();
|
||||
builder.append_data(&mut header2, "file2.txt", &b"file2 data"[..]).unwrap();
|
||||
|
||||
builder
|
||||
.append_data(&mut header2, "file2.txt", &b"file2 data"[..])
|
||||
.unwrap();
|
||||
|
||||
builder.finish().unwrap();
|
||||
|
||||
|
||||
let tar_gz_path = temp_dir.path().join("test.tar.gz");
|
||||
let gz_file = File::create(&tar_gz_path).unwrap();
|
||||
let mut encoder = GzEncoder::new(gz_file, Compression::default());
|
||||
|
||||
|
||||
let tar_content = std::fs::read(&tar_path).unwrap();
|
||||
encoder.write_all(&tar_content).unwrap();
|
||||
encoder.finish().unwrap();
|
||||
|
||||
|
||||
std::fs::remove_file(&tar_path).unwrap();
|
||||
|
||||
|
||||
tar_gz_path
|
||||
}
|
||||
|
||||
@@ -105,13 +120,12 @@ pub fn create_zip_bomb_test() -> Vec<u8> {
|
||||
{
|
||||
let writer = std::io::Cursor::new(&mut buffer);
|
||||
let mut zip = ZipWriter::new(writer);
|
||||
|
||||
let options = FileOptions::default()
|
||||
.compression_method(CompressionMethod::Stored);
|
||||
|
||||
|
||||
let options = FileOptions::default().compression_method(CompressionMethod::Stored);
|
||||
|
||||
zip.start_file("bomb.txt", options).unwrap();
|
||||
zip.write_all(&[0u8; 100]).unwrap();
|
||||
|
||||
|
||||
zip.finish().unwrap();
|
||||
}
|
||||
buffer
|
||||
@@ -123,11 +137,11 @@ pub fn create_zip_slip_test() -> Vec<u8> {
|
||||
let writer = std::io::Cursor::new(&mut buffer);
|
||||
let mut zip = ZipWriter::new(writer);
|
||||
let options = FileOptions::default();
|
||||
|
||||
|
||||
zip.start_file("../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../../etc/passwd", options).unwrap();
|
||||
zip.write_all(b"malicious content").unwrap();
|
||||
|
||||
|
||||
zip.finish().unwrap();
|
||||
}
|
||||
buffer
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Warning System - Legal and Technical Warnings for Optional Formats
|
||||
|
||||
use log::{warn, info};
|
||||
use log::{info, warn};
|
||||
|
||||
use crate::archive::config::ArchiveConfig;
|
||||
|
||||
@@ -63,25 +63,27 @@ pub fn show_startup_warnings(config: &ArchiveConfig) {
|
||||
if config.enable_rar {
|
||||
show_rar_legal_warning();
|
||||
}
|
||||
|
||||
|
||||
if config.enable_xz {
|
||||
// Dependency check happens in ProcessorRegistry
|
||||
}
|
||||
|
||||
|
||||
if config.enable_7z {
|
||||
show_7z_stability_warning();
|
||||
}
|
||||
|
||||
|
||||
// Show summary of enabled formats
|
||||
let enabled_optional = [
|
||||
config.enable_rar,
|
||||
config.enable_xz,
|
||||
config.enable_7z,
|
||||
].iter().filter(|&x| *x).count();
|
||||
|
||||
let enabled_optional = [config.enable_rar, config.enable_xz, config.enable_7z]
|
||||
.iter()
|
||||
.filter(|&x| *x)
|
||||
.count();
|
||||
|
||||
if enabled_optional > 0 {
|
||||
info!("");
|
||||
info!("⚠️ {} optional format(s) enabled with warnings shown above", enabled_optional);
|
||||
info!(
|
||||
"⚠️ {} optional format(s) enabled with warnings shown above",
|
||||
enabled_optional
|
||||
);
|
||||
info!("Core formats (9): ZIP, TAR, GZIP, ZSTD, BZIP2, LZ4, TAR.GZ, TAR.BZ2, TAR.ZST");
|
||||
info!("");
|
||||
}
|
||||
@@ -89,8 +91,7 @@ pub fn show_startup_warnings(config: &ArchiveConfig) {
|
||||
|
||||
/// Generate user-facing legal disclaimer text
|
||||
pub fn generate_rar_legal_disclaimer() -> String {
|
||||
format!(
|
||||
"RAR FORMAT LEGAL DISCLAIMER
|
||||
"RAR FORMAT LEGAL DISCLAIMER
|
||||
|
||||
IMPORTANT WARNING:
|
||||
|
||||
@@ -136,6 +137,5 @@ CONTACT:
|
||||
Last Updated: 2026-06-10
|
||||
Version: 1.0
|
||||
Legal Consultation: [Please consult professional lawyer for commercial use]
|
||||
"
|
||||
)
|
||||
}
|
||||
".to_string()
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ impl AuditLogger {
|
||||
};
|
||||
|
||||
self.write_entry(&entry)?;
|
||||
|
||||
|
||||
log::info!(
|
||||
"Audit: {} config {} changed from '{}' to '{}' by {}",
|
||||
config_type,
|
||||
@@ -61,7 +61,7 @@ impl AuditLogger {
|
||||
new_value,
|
||||
user
|
||||
);
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ impl AuditLogger {
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
|
||||
Ok(entries[start..].to_vec())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::provider::{DataProvider, ProviderError};
|
||||
use crate::provider::DataProvider;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct User {
|
||||
@@ -71,6 +71,12 @@ pub struct AuthState {
|
||||
pub provider: Option<Arc<dyn DataProvider>>,
|
||||
}
|
||||
|
||||
impl Default for AuthState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthState {
|
||||
pub fn new() -> Self {
|
||||
let mut users = HashMap::new();
|
||||
@@ -284,7 +290,12 @@ impl AuthState {
|
||||
}
|
||||
}
|
||||
|
||||
fn login_with_provider(&self, provider: &dyn DataProvider, username: &str, password: &str) -> Option<LoginResponse> {
|
||||
fn login_with_provider(
|
||||
&self,
|
||||
provider: &dyn DataProvider,
|
||||
username: &str,
|
||||
password: &str,
|
||||
) -> Option<LoginResponse> {
|
||||
match provider.get_user(username) {
|
||||
Ok(Some(user)) => {
|
||||
if user.status != 1 {
|
||||
|
||||
@@ -118,14 +118,20 @@ fn get_series_display_name(name: &str) -> String {
|
||||
pub fn get_all_categories() -> Result<CategoriesResponse> {
|
||||
let conn = FileTree::open_user_db("accusys")?;
|
||||
let tree = FileTree::load(&conn, "accusys", "categories")?;
|
||||
|
||||
let categories: Vec<Category> = tree.nodes.iter()
|
||||
|
||||
let categories: Vec<Category> = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| n.parent_id.is_none() && n.node_type.as_str() == "folder")
|
||||
.map(|n| {
|
||||
let file_count = tree.nodes.iter()
|
||||
.filter(|f| f.parent_id == Some(n.node_id.clone()) && f.node_type.as_str() == "file")
|
||||
let file_count = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|f| {
|
||||
f.parent_id == Some(n.node_id.clone()) && f.node_type.as_str() == "file"
|
||||
})
|
||||
.count();
|
||||
|
||||
|
||||
Category {
|
||||
name: n.label.clone(),
|
||||
display_name: get_category_display_name(&n.label),
|
||||
@@ -135,11 +141,13 @@ pub fn get_all_categories() -> Result<CategoriesResponse> {
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let total_files = tree.nodes.iter()
|
||||
|
||||
let total_files = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| n.node_type.as_str() == "file")
|
||||
.count();
|
||||
|
||||
|
||||
Ok(CategoriesResponse {
|
||||
total_categories: categories.len(),
|
||||
total_files,
|
||||
@@ -150,42 +158,65 @@ pub fn get_all_categories() -> Result<CategoriesResponse> {
|
||||
pub fn get_category_detail(category_name: &str) -> Result<CategoryDetail> {
|
||||
let conn = FileTree::open_user_db("accusys")?;
|
||||
let tree = FileTree::load(&conn, "accusys", "categories")?;
|
||||
|
||||
let category_node = tree.nodes.iter()
|
||||
.find(|n| n.label == category_name && n.parent_id.is_none() && n.node_type.as_str() == "folder")
|
||||
|
||||
let category_node = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.find(|n| {
|
||||
n.label == category_name && n.parent_id.is_none() && n.node_type.as_str() == "folder"
|
||||
})
|
||||
.ok_or_else(|| anyhow::anyhow!("Category not found: {}", category_name))?;
|
||||
|
||||
let series_groups: Vec<SeriesGroup> = tree.nodes.iter()
|
||||
.filter(|n| n.parent_id == Some(category_node.node_id.clone()) && n.node_type.as_str() == "folder")
|
||||
|
||||
let series_groups: Vec<SeriesGroup> = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| {
|
||||
n.parent_id == Some(category_node.node_id.clone()) && n.node_type.as_str() == "folder"
|
||||
})
|
||||
.map(|series_node| {
|
||||
let files: Vec<CategoryFile> = tree.nodes.iter()
|
||||
.filter(|f| f.parent_id == Some(series_node.node_id.clone()) && f.node_type.as_str() == "file")
|
||||
.map(|file_node| {
|
||||
CategoryFile {
|
||||
filename: file_node.label.clone(),
|
||||
size: file_node.aliases.get("file_size_display").cloned().unwrap_or_default(),
|
||||
download_url: file_node.aliases.get("download_url").cloned().unwrap_or_default(),
|
||||
sha256: file_node.sha256.clone(),
|
||||
}
|
||||
let files: Vec<CategoryFile> = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|f| {
|
||||
f.parent_id == Some(series_node.node_id.clone())
|
||||
&& f.node_type.as_str() == "file"
|
||||
})
|
||||
.map(|file_node| CategoryFile {
|
||||
filename: file_node.label.clone(),
|
||||
size: file_node
|
||||
.aliases
|
||||
.get("file_size_display")
|
||||
.cloned()
|
||||
.unwrap_or_default(),
|
||||
download_url: file_node
|
||||
.aliases
|
||||
.get("download_url")
|
||||
.cloned()
|
||||
.unwrap_or_default(),
|
||||
sha256: file_node.sha256.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
SeriesGroup {
|
||||
series_name: series_node.label.clone(),
|
||||
files,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
let file_count = series_groups.iter().map(|g| g.files.len()).sum();
|
||||
|
||||
|
||||
Ok(CategoryDetail {
|
||||
category: Category {
|
||||
name: category_name.to_string(),
|
||||
display_name: get_category_display_name(category_name),
|
||||
file_count,
|
||||
last_updated: category_node.updated_at.clone(),
|
||||
description: category_node.aliases.get("description").cloned().unwrap_or_default(),
|
||||
description: category_node
|
||||
.aliases
|
||||
.get("description")
|
||||
.cloned()
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
series_groups,
|
||||
})
|
||||
@@ -194,25 +225,31 @@ pub fn get_category_detail(category_name: &str) -> Result<CategoryDetail> {
|
||||
pub fn get_all_series() -> Result<SeriesResponse> {
|
||||
let conn = FileTree::open_user_db("accusys")?;
|
||||
let tree = FileTree::load(&conn, "accusys", "series")?;
|
||||
|
||||
let series: Vec<Series> = tree.nodes.iter()
|
||||
|
||||
let series: Vec<Series> = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| n.parent_id.is_none() && n.node_type.as_str() == "folder")
|
||||
.map(|n| {
|
||||
let file_count = tree.nodes.iter()
|
||||
let file_count = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|f| {
|
||||
let mut current = f.parent_id.clone();
|
||||
while let Some(pid) = current {
|
||||
if pid == n.node_id {
|
||||
return f.node_type.as_str() == "file";
|
||||
}
|
||||
current = tree.nodes.iter()
|
||||
current = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.find(|p| p.node_id == pid)
|
||||
.map(|p| p.parent_id.clone()).flatten();
|
||||
.and_then(|p| p.parent_id.clone());
|
||||
}
|
||||
false
|
||||
})
|
||||
.count();
|
||||
|
||||
|
||||
Series {
|
||||
name: n.label.clone(),
|
||||
display_name: get_series_display_name(&n.label),
|
||||
@@ -223,11 +260,13 @@ pub fn get_all_series() -> Result<SeriesResponse> {
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let total_files = tree.nodes.iter()
|
||||
|
||||
let total_files = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| n.node_type.as_str() == "file")
|
||||
.count();
|
||||
|
||||
|
||||
Ok(SeriesResponse {
|
||||
total_series: series.len(),
|
||||
total_files,
|
||||
@@ -238,45 +277,63 @@ pub fn get_all_series() -> Result<SeriesResponse> {
|
||||
pub fn get_series_detail(series_name: &str) -> Result<SeriesDetail> {
|
||||
let conn = FileTree::open_user_db("accusys")?;
|
||||
let tree = FileTree::load(&conn, "accusys", "series")?;
|
||||
|
||||
let series_node = tree.nodes.iter()
|
||||
.find(|n| n.label == series_name && n.parent_id.is_none() && n.node_type.as_str() == "folder")
|
||||
|
||||
let series_node = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.find(|n| {
|
||||
n.label == series_name && n.parent_id.is_none() && n.node_type.as_str() == "folder"
|
||||
})
|
||||
.ok_or_else(|| anyhow::anyhow!("Series not found: {}", series_name))?;
|
||||
|
||||
let categories: Vec<SeriesCategory> = tree.nodes.iter()
|
||||
.filter(|n| n.parent_id == Some(series_node.node_id.clone()) && n.node_type.as_str() == "folder")
|
||||
|
||||
let categories: Vec<SeriesCategory> = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| {
|
||||
n.parent_id == Some(series_node.node_id.clone()) && n.node_type.as_str() == "folder"
|
||||
})
|
||||
.map(|category_node| {
|
||||
let files: Vec<SeriesFile> = tree.nodes.iter()
|
||||
let files: Vec<SeriesFile> = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|f| {
|
||||
let mut current = f.parent_id.clone();
|
||||
while let Some(pid) = current {
|
||||
if pid == category_node.node_id && f.node_type.as_str() == "file" {
|
||||
return true;
|
||||
}
|
||||
current = tree.nodes.iter()
|
||||
current = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.find(|p| p.node_id == pid)
|
||||
.map(|p| p.parent_id.clone()).flatten();
|
||||
.and_then(|p| p.parent_id.clone());
|
||||
}
|
||||
false
|
||||
})
|
||||
.map(|file_node| {
|
||||
SeriesFile {
|
||||
filename: file_node.label.clone(),
|
||||
size: file_node.aliases.get("file_size_display").unwrap_or(&"N/A".to_string()).clone(),
|
||||
download_url: file_node.aliases.get("download_url").unwrap_or(&"".to_string()).clone(),
|
||||
}
|
||||
.map(|file_node| SeriesFile {
|
||||
filename: file_node.label.clone(),
|
||||
size: file_node
|
||||
.aliases
|
||||
.get("file_size_display")
|
||||
.unwrap_or(&"N/A".to_string())
|
||||
.clone(),
|
||||
download_url: file_node
|
||||
.aliases
|
||||
.get("download_url")
|
||||
.unwrap_or(&"".to_string())
|
||||
.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
SeriesCategory {
|
||||
category_name: category_node.label.clone(),
|
||||
files,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
let file_count = categories.iter().map(|c| c.files.len()).sum();
|
||||
|
||||
|
||||
Ok(SeriesDetail {
|
||||
series: Series {
|
||||
name: series_name.to_string(),
|
||||
@@ -296,29 +353,39 @@ pub fn search_files(query: &str, view: &str) -> Result<SearchResponse> {
|
||||
"series" => "series",
|
||||
_ => "untitled folder",
|
||||
};
|
||||
|
||||
|
||||
let conn = FileTree::open_user_db("accusys")?;
|
||||
let tree = FileTree::load(&conn, "accusys", tree_type)?;
|
||||
|
||||
let results: Vec<SearchResult> = tree.nodes.iter()
|
||||
.filter(|n| n.node_type.as_str() == "file" && n.label.to_lowercase().contains(&query.to_lowercase()))
|
||||
|
||||
let results: Vec<SearchResult> = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| {
|
||||
n.node_type.as_str() == "file" && n.label.to_lowercase().contains(&query.to_lowercase())
|
||||
})
|
||||
.map(|file_node| {
|
||||
let parent_node = tree.nodes.iter()
|
||||
let parent_node = tree
|
||||
.nodes
|
||||
.iter()
|
||||
.find(|n| n.node_id == file_node.parent_id.clone().unwrap_or_default());
|
||||
|
||||
|
||||
SearchResult {
|
||||
category: parent_node.map(|n| n.label.clone()),
|
||||
series: parent_node.map(|n| n.label.clone()),
|
||||
filename: file_node.label.clone(),
|
||||
download_url: file_node.aliases.get("download_url").unwrap_or(&"".to_string()).clone(),
|
||||
download_url: file_node
|
||||
.aliases
|
||||
.get("download_url")
|
||||
.unwrap_or(&"".to_string())
|
||||
.clone(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
Ok(SearchResponse {
|
||||
query: query.to_string(),
|
||||
view: view.to_string(),
|
||||
total_results: results.len(),
|
||||
results,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ pub async fn handle_iscsi_command(cmd: IscsiCommand) -> anyhow::Result<()> {
|
||||
let binary = find_binary("markbase-iscsi");
|
||||
let mut cmd_process = std::process::Command::new(&binary);
|
||||
cmd_process.arg("iscsi");
|
||||
|
||||
|
||||
match cmd {
|
||||
IscsiCommand::Start {
|
||||
user,
|
||||
@@ -34,7 +34,8 @@ pub async fn handle_iscsi_command(cmd: IscsiCommand) -> anyhow::Result<()> {
|
||||
force,
|
||||
device,
|
||||
} => {
|
||||
cmd_process.arg("start")
|
||||
cmd_process
|
||||
.arg("start")
|
||||
.args(["--user", &user])
|
||||
.args(["--port", &port.to_string()])
|
||||
.args(["--lun-size", &lun_size]);
|
||||
@@ -52,7 +53,7 @@ pub async fn handle_iscsi_command(cmd: IscsiCommand) -> anyhow::Result<()> {
|
||||
cmd_process.arg("status");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let status = cmd_process.status()?;
|
||||
std::process::exit(status.code().unwrap_or(1));
|
||||
}
|
||||
@@ -61,4 +62,4 @@ fn find_binary(name: &str) -> std::path::PathBuf {
|
||||
let exe = std::env::current_exe().unwrap();
|
||||
let dir = exe.parent().unwrap();
|
||||
dir.join(name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
pub mod web;
|
||||
pub mod ssh;
|
||||
pub mod webdav;
|
||||
pub mod iscsi;
|
||||
pub mod ssh;
|
||||
pub mod tree;
|
||||
pub mod web;
|
||||
pub mod webdav;
|
||||
|
||||
use clap::Subcommand;
|
||||
|
||||
@@ -29,4 +29,4 @@ pub async fn handle_interface_command(cmd: InterfaceCommands) -> anyhow::Result<
|
||||
InterfaceCommands::Tree(c) => tree::handle_tree_command(c).await?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,4 +32,4 @@ pub async fn handle_ssh_command(cmd: SshCommand) -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::Context;
|
||||
use clap::Subcommand;
|
||||
use rusqlite::Connection;
|
||||
use anyhow::Context;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Subcommand)]
|
||||
@@ -33,12 +33,12 @@ pub enum TreeCommand {
|
||||
#[arg(short, long)]
|
||||
name: String,
|
||||
},
|
||||
|
||||
|
||||
Folder {
|
||||
#[command(subcommand)]
|
||||
action: FolderCommand,
|
||||
},
|
||||
|
||||
|
||||
Ls {
|
||||
#[arg(short, long)]
|
||||
user: String,
|
||||
@@ -47,7 +47,7 @@ pub enum TreeCommand {
|
||||
#[arg(short, long)]
|
||||
tree_type: String,
|
||||
},
|
||||
|
||||
|
||||
Cp {
|
||||
#[arg(short, long)]
|
||||
user: String,
|
||||
@@ -58,7 +58,7 @@ pub enum TreeCommand {
|
||||
#[arg(short, long)]
|
||||
tree_type: String,
|
||||
},
|
||||
|
||||
|
||||
Mv {
|
||||
#[arg(short, long)]
|
||||
user: String,
|
||||
@@ -113,44 +113,54 @@ pub enum FolderCommand {
|
||||
|
||||
pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
||||
match cmd {
|
||||
TreeCommand::Create { name, user, tree_type } => {
|
||||
TreeCommand::Create {
|
||||
name,
|
||||
user,
|
||||
tree_type,
|
||||
} => {
|
||||
let db_path = format!("data/users/{}.sqlite", user);
|
||||
let conn = Connection::open(&db_path)
|
||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||
|
||||
|
||||
let node_id = Uuid::new_v4().to_string();
|
||||
let created_at = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO file_nodes (node_id, label, node_type, tree_type, created_at, updated_at)
|
||||
VALUES (?1, ?2, 'folder', ?3, ?4, ?4)",
|
||||
rusqlite::params![node_id, name, tree_type, created_at]
|
||||
).context("Failed to create tree")?;
|
||||
|
||||
println!("✓ Tree created: {} (type: {}) for user: {}", name, tree_type, user);
|
||||
|
||||
println!(
|
||||
"✓ Tree created: {} (type: {}) for user: {}",
|
||||
name, tree_type, user
|
||||
);
|
||||
println!("✓ Node ID: {}", node_id);
|
||||
}
|
||||
TreeCommand::List { user } => {
|
||||
let db_path = format!("data/users/{}.sqlite", user);
|
||||
let conn = Connection::open(&db_path)
|
||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT DISTINCT tree_type FROM file_nodes ORDER BY tree_type"
|
||||
).context("Failed to prepare query")?;
|
||||
|
||||
let tree_types = stmt.query_map([], |row| row.get::<_, String>(0))
|
||||
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT DISTINCT tree_type FROM file_nodes ORDER BY tree_type")
|
||||
.context("Failed to prepare query")?;
|
||||
|
||||
let tree_types = stmt
|
||||
.query_map([], |row| row.get::<_, String>(0))
|
||||
.context("Failed to query tree types")?;
|
||||
|
||||
|
||||
println!("=== Trees for user: {} ===", user);
|
||||
for tree_type in tree_types {
|
||||
let tt = tree_type?;
|
||||
let count: i64 = conn.query_row(
|
||||
"SELECT COUNT(*) FROM file_nodes WHERE tree_type = ?1",
|
||||
[&tt],
|
||||
|row| row.get(0)
|
||||
).unwrap_or(0);
|
||||
|
||||
let count: i64 = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) FROM file_nodes WHERE tree_type = ?1",
|
||||
[&tt],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.unwrap_or(0);
|
||||
|
||||
println!(" {} ({} nodes)", tt, count);
|
||||
}
|
||||
}
|
||||
@@ -158,9 +168,9 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
||||
let db_path = format!("data/users/{}.sqlite", user);
|
||||
let conn = Connection::open(&db_path)
|
||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||
|
||||
|
||||
println!("Importing Markdown files to {} virtual tree...", tree_type);
|
||||
|
||||
|
||||
if tree_type == "categories" {
|
||||
crate::import_markdown::import_categories_to_db(&conn, &user, &tree_type)?;
|
||||
println!("✓ Categories imported successfully!");
|
||||
@@ -168,53 +178,66 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
||||
crate::import_markdown::import_series_to_db(&conn, &user, &tree_type)?;
|
||||
println!("✓ Series imported successfully!");
|
||||
} else {
|
||||
eprintln!("Invalid tree_type: {}. Use 'categories' or 'series'", tree_type);
|
||||
eprintln!(
|
||||
"Invalid tree_type: {}. Use 'categories' or 'series'",
|
||||
tree_type
|
||||
);
|
||||
}
|
||||
}
|
||||
TreeCommand::Delete { user, name } => {
|
||||
let db_path = format!("data/users/{}.sqlite", user);
|
||||
let conn = Connection::open(&db_path)
|
||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||
|
||||
|
||||
conn.execute(
|
||||
"DELETE FROM file_nodes WHERE label = ?1 AND node_type = 'folder'",
|
||||
[&name]
|
||||
).context("Failed to delete tree")?;
|
||||
|
||||
[&name],
|
||||
)
|
||||
.context("Failed to delete tree")?;
|
||||
|
||||
println!("✓ Tree deleted: {} for user: {}", name, user);
|
||||
}
|
||||
|
||||
|
||||
TreeCommand::Folder { action } => {
|
||||
handle_folder_command(action)?;
|
||||
}
|
||||
|
||||
TreeCommand::Ls { user, path, tree_type } => {
|
||||
|
||||
TreeCommand::Ls {
|
||||
user,
|
||||
path,
|
||||
tree_type,
|
||||
} => {
|
||||
let db_path = format!("data/users/{}.sqlite", user);
|
||||
let conn = Connection::open(&db_path)
|
||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||
|
||||
|
||||
let parent_id = find_node_id(&conn, &path, &tree_type)?;
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT label, node_type, file_size FROM file_nodes
|
||||
|
||||
let mut stmt = conn
|
||||
.prepare(
|
||||
"SELECT label, node_type, file_size FROM file_nodes
|
||||
WHERE parent_id = ?1 AND tree_type = ?2
|
||||
ORDER BY node_type DESC, label ASC"
|
||||
).context("Failed to prepare ls query")?;
|
||||
|
||||
let entries = stmt.query_map(
|
||||
rusqlite::params![parent_id, tree_type],
|
||||
|row| Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, Option<i64>>(2)?
|
||||
))
|
||||
).context("Failed to query entries")?;
|
||||
|
||||
ORDER BY node_type DESC, label ASC",
|
||||
)
|
||||
.context("Failed to prepare ls query")?;
|
||||
|
||||
let entries = stmt
|
||||
.query_map(rusqlite::params![parent_id, tree_type], |row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, Option<i64>>(2)?,
|
||||
))
|
||||
})
|
||||
.context("Failed to query entries")?;
|
||||
|
||||
println!("=== Contents of {} (tree_type: {}) ===", path, tree_type);
|
||||
for entry in entries {
|
||||
let (name, node_type, size) = entry?;
|
||||
let size_str = size.map(|s| format!("{} bytes", s)).unwrap_or_else(|| "-".to_string());
|
||||
|
||||
let size_str = size
|
||||
.map(|s| format!("{} bytes", s))
|
||||
.unwrap_or_else(|| "-".to_string());
|
||||
|
||||
if node_type == "folder" {
|
||||
println!(" 📁 {} ({})", name, size_str);
|
||||
} else {
|
||||
@@ -222,57 +245,72 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TreeCommand::Cp { user, source, target, tree_type } => {
|
||||
|
||||
TreeCommand::Cp {
|
||||
user,
|
||||
source,
|
||||
target,
|
||||
tree_type,
|
||||
} => {
|
||||
let db_path = format!("data/users/{}.sqlite", user);
|
||||
let conn = Connection::open(&db_path)
|
||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||
|
||||
|
||||
let source_id = find_node_id(&conn, &source, &tree_type)?;
|
||||
let target_parent_id = find_node_id(&conn, &target, &tree_type)?;
|
||||
|
||||
let (label, node_type, aliases_json, file_uuid, sha256, file_size) = conn.query_row(
|
||||
"SELECT label, node_type, aliases_json, file_uuid, sha256, file_size
|
||||
|
||||
let (label, node_type, aliases_json, file_uuid, sha256, file_size) = conn
|
||||
.query_row(
|
||||
"SELECT label, node_type, aliases_json, file_uuid, sha256, file_size
|
||||
FROM file_nodes WHERE node_id = ?1",
|
||||
[&source_id],
|
||||
|row| Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, String>(2)?,
|
||||
row.get::<_, Option<String>>(3)?,
|
||||
row.get::<_, Option<String>>(4)?,
|
||||
row.get::<_, Option<i64>>(5)?
|
||||
))
|
||||
).context("Failed to get source node")?;
|
||||
|
||||
[&source_id],
|
||||
|row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, String>(2)?,
|
||||
row.get::<_, Option<String>>(3)?,
|
||||
row.get::<_, Option<String>>(4)?,
|
||||
row.get::<_, Option<i64>>(5)?,
|
||||
))
|
||||
},
|
||||
)
|
||||
.context("Failed to get source node")?;
|
||||
|
||||
let new_id = Uuid::new_v4().to_string();
|
||||
let created_at = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO file_nodes
|
||||
(node_id, label, aliases_json, file_uuid, sha256, parent_id, node_type, file_size, tree_type, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?10)",
|
||||
rusqlite::params![new_id, label, aliases_json, file_uuid, sha256, target_parent_id, node_type, file_size, tree_type, created_at]
|
||||
).context("Failed to copy node")?;
|
||||
|
||||
|
||||
println!("✓ Copied {} to {} (new ID: {})", source, target, new_id);
|
||||
}
|
||||
|
||||
TreeCommand::Mv { user, source, target, tree_type } => {
|
||||
|
||||
TreeCommand::Mv {
|
||||
user,
|
||||
source,
|
||||
target,
|
||||
tree_type,
|
||||
} => {
|
||||
let db_path = format!("data/users/{}.sqlite", user);
|
||||
let conn = Connection::open(&db_path)
|
||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||
|
||||
|
||||
let source_id = find_node_id(&conn, &source, &tree_type)?;
|
||||
let target_parent_id = find_node_id(&conn, &target, &tree_type)?;
|
||||
|
||||
|
||||
let updated_at = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
|
||||
conn.execute(
|
||||
"UPDATE file_nodes SET parent_id = ?1, updated_at = ?2 WHERE node_id = ?3",
|
||||
rusqlite::params![target_parent_id, updated_at, source_id]
|
||||
).context("Failed to move node")?;
|
||||
|
||||
rusqlite::params![target_parent_id, updated_at, source_id],
|
||||
)
|
||||
.context("Failed to move node")?;
|
||||
|
||||
println!("✓ Moved {} to {}", source, target);
|
||||
}
|
||||
}
|
||||
@@ -281,104 +319,136 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
||||
|
||||
fn handle_folder_command(cmd: FolderCommand) -> anyhow::Result<()> {
|
||||
match cmd {
|
||||
FolderCommand::Create { user, path, name, tree_type } => {
|
||||
FolderCommand::Create {
|
||||
user,
|
||||
path,
|
||||
name,
|
||||
tree_type,
|
||||
} => {
|
||||
let db_path = format!("data/users/{}.sqlite", user);
|
||||
let conn = Connection::open(&db_path)
|
||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||
|
||||
let parent_id = if path == "/" || path == "" {
|
||||
|
||||
let parent_id = if path == "/" || path.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(find_node_id(&conn, &path, &tree_type)?)
|
||||
};
|
||||
|
||||
|
||||
let node_id = Uuid::new_v4().to_string();
|
||||
let created_at = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO file_nodes
|
||||
(node_id, label, parent_id, node_type, tree_type, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, 'folder', ?4, ?5, ?5)",
|
||||
rusqlite::params![node_id, name, parent_id, tree_type, created_at]
|
||||
).context("Failed to create folder")?;
|
||||
|
||||
println!("✓ Folder created: {} in {} (tree_type: {})", name, path, tree_type);
|
||||
rusqlite::params![node_id, name, parent_id, tree_type, created_at],
|
||||
)
|
||||
.context("Failed to create folder")?;
|
||||
|
||||
println!(
|
||||
"✓ Folder created: {} in {} (tree_type: {})",
|
||||
name, path, tree_type
|
||||
);
|
||||
println!("✓ Node ID: {}", node_id);
|
||||
}
|
||||
FolderCommand::Delete { user, path, name, tree_type } => {
|
||||
FolderCommand::Delete {
|
||||
user,
|
||||
path,
|
||||
name,
|
||||
tree_type,
|
||||
} => {
|
||||
let db_path = format!("data/users/{}.sqlite", user);
|
||||
let conn = Connection::open(&db_path)
|
||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||
|
||||
let folder_path = if path == "/" || path == "" {
|
||||
|
||||
let folder_path = if path == "/" || path.is_empty() {
|
||||
name.clone()
|
||||
} else {
|
||||
format!("{}/{}", path, name)
|
||||
};
|
||||
|
||||
|
||||
let folder_id = find_node_id(&conn, &folder_path, &tree_type)?;
|
||||
|
||||
|
||||
conn.execute(
|
||||
"DELETE FROM file_nodes WHERE node_id = ?1 OR parent_id = ?1",
|
||||
[&folder_id]
|
||||
).context("Failed to delete folder and children")?;
|
||||
|
||||
println!("✓ Folder deleted: {} in {} (tree_type: {})", name, path, tree_type);
|
||||
[&folder_id],
|
||||
)
|
||||
.context("Failed to delete folder and children")?;
|
||||
|
||||
println!(
|
||||
"✓ Folder deleted: {} in {} (tree_type: {})",
|
||||
name, path, tree_type
|
||||
);
|
||||
}
|
||||
FolderCommand::Rename { user, path, old_name, new_name, tree_type } => {
|
||||
FolderCommand::Rename {
|
||||
user,
|
||||
path,
|
||||
old_name,
|
||||
new_name,
|
||||
tree_type,
|
||||
} => {
|
||||
let db_path = format!("data/users/{}.sqlite", user);
|
||||
let conn = Connection::open(&db_path)
|
||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||
|
||||
let folder_path = if path == "/" || path == "" {
|
||||
|
||||
let folder_path = if path == "/" || path.is_empty() {
|
||||
old_name.clone()
|
||||
} else {
|
||||
format!("{}/{}", path, old_name)
|
||||
};
|
||||
|
||||
|
||||
let folder_id = find_node_id(&conn, &folder_path, &tree_type)?;
|
||||
|
||||
|
||||
let updated_at = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
|
||||
conn.execute(
|
||||
"UPDATE file_nodes SET label = ?1, updated_at = ?2 WHERE node_id = ?3",
|
||||
rusqlite::params![new_name, updated_at, folder_id]
|
||||
).context("Failed to rename folder")?;
|
||||
|
||||
println!("✓ Folder renamed: {} → {} in {} (tree_type: {})", old_name, new_name, path, tree_type);
|
||||
rusqlite::params![new_name, updated_at, folder_id],
|
||||
)
|
||||
.context("Failed to rename folder")?;
|
||||
|
||||
println!(
|
||||
"✓ Folder renamed: {} → {} in {} (tree_type: {})",
|
||||
old_name, new_name, path, tree_type
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn find_node_id(conn: &Connection, path: &str, tree_type: &str) -> anyhow::Result<String> {
|
||||
if path == "/" || path == "" {
|
||||
let node_id: String = conn.query_row(
|
||||
"SELECT node_id FROM file_nodes
|
||||
if path == "/" || path.is_empty() {
|
||||
let node_id: String = conn
|
||||
.query_row(
|
||||
"SELECT node_id FROM file_nodes
|
||||
WHERE parent_id IS NULL AND node_type = 'folder' AND tree_type = ?1
|
||||
LIMIT 1",
|
||||
[tree_type],
|
||||
|row| row.get(0)
|
||||
).context("Failed to find root folder")?;
|
||||
|
||||
[tree_type],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.context("Failed to find root folder")?;
|
||||
|
||||
return Ok(node_id);
|
||||
}
|
||||
|
||||
|
||||
let parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
|
||||
|
||||
|
||||
let mut current_parent: Option<String> = None;
|
||||
|
||||
|
||||
for part in parts {
|
||||
let node_id: String = conn.query_row(
|
||||
"SELECT node_id FROM file_nodes
|
||||
let node_id: String = conn
|
||||
.query_row(
|
||||
"SELECT node_id FROM file_nodes
|
||||
WHERE label = ?1 AND tree_type = ?2 AND
|
||||
(parent_id = ?3 OR (?3 IS NULL AND parent_id IS NULL))",
|
||||
rusqlite::params![part, tree_type, current_parent],
|
||||
|row| row.get(0)
|
||||
).context(format!("Failed to find node: {}", part))?;
|
||||
|
||||
rusqlite::params![part, tree_type, current_parent],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.context(format!("Failed to find node: {}", part))?;
|
||||
|
||||
current_parent = Some(node_id);
|
||||
}
|
||||
|
||||
|
||||
current_parent.context("Failed to find node ID for path")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,4 +18,4 @@ pub async fn handle_web_command(cmd: WebCommand) -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use clap::Subcommand;
|
||||
use axum::{extract::Request, response::IntoResponse, Extension};
|
||||
use clap::Subcommand;
|
||||
|
||||
#[derive(Subcommand)]
|
||||
pub enum WebdavCommand {
|
||||
@@ -28,7 +28,7 @@ pub async fn handle_webdav_command(cmd: WebdavCommand) -> anyhow::Result<()> {
|
||||
println!("User: {}", user);
|
||||
println!("Port: {}", port);
|
||||
println!("Database: {}", db_path.display());
|
||||
println!("");
|
||||
println!();
|
||||
|
||||
run_webdav_server(port, user, db_path).await?;
|
||||
}
|
||||
@@ -41,7 +41,7 @@ async fn run_webdav_server(
|
||||
user: String,
|
||||
db_path: std::path::PathBuf,
|
||||
) -> anyhow::Result<()> {
|
||||
use axum::{extract::Request, response::IntoResponse, routing::any, Extension, Router};
|
||||
use axum::{routing::any, Extension, Router};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
let webdav = markbase_webdav::webdav::MarkBaseWebDAV::new(user, db_path);
|
||||
@@ -58,7 +58,7 @@ async fn run_webdav_server(
|
||||
|
||||
println!("WebDAV server listening on http://{}", addr);
|
||||
println!("Mount point: /webdav");
|
||||
println!("");
|
||||
println!();
|
||||
println!("Press Ctrl+C to stop");
|
||||
|
||||
axum::serve(listener, app).await?;
|
||||
@@ -71,4 +71,4 @@ async fn handle_dav(
|
||||
req: Request,
|
||||
) -> impl IntoResponse {
|
||||
dav.handle(req).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::Context;
|
||||
use clap::Subcommand;
|
||||
use rusqlite::Connection;
|
||||
use anyhow::Context;
|
||||
|
||||
#[derive(Subcommand)]
|
||||
pub enum AuthCommand {
|
||||
@@ -24,29 +24,30 @@ pub fn handle_auth_command(cmd: AuthCommand) -> anyhow::Result<()> {
|
||||
match cmd {
|
||||
AuthCommand::Login { user, password } => {
|
||||
let db_path = "data/auth.sqlite";
|
||||
|
||||
|
||||
if !std::path::Path::new(db_path).exists() {
|
||||
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
||||
}
|
||||
|
||||
let conn = Connection::open(db_path)
|
||||
.context("Failed to open auth database")?;
|
||||
|
||||
let password_hash: String = conn.query_row(
|
||||
"SELECT password_hash FROM sftpgo_users WHERE username = ?",
|
||||
[&user],
|
||||
|row| row.get(0)
|
||||
).context("Failed to query password hash")?;
|
||||
|
||||
let valid = bcrypt::verify(&password, &password_hash)
|
||||
.context("Failed to verify password")?;
|
||||
|
||||
|
||||
let conn = Connection::open(db_path).context("Failed to open auth database")?;
|
||||
|
||||
let password_hash: String = conn
|
||||
.query_row(
|
||||
"SELECT password_hash FROM sftpgo_users WHERE username = ?",
|
||||
[&user],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.context("Failed to query password hash")?;
|
||||
|
||||
let valid =
|
||||
bcrypt::verify(&password, &password_hash).context("Failed to verify password")?;
|
||||
|
||||
if !valid {
|
||||
return Err(anyhow::anyhow!("Invalid password for user: {}", user));
|
||||
}
|
||||
|
||||
|
||||
let token = generate_simple_token(&user);
|
||||
|
||||
|
||||
println!("✓ Login successful for user: {}", user);
|
||||
println!("✓ Token: {}", token);
|
||||
println!("Note: This is a simple token for demonstration. Use JWT in production.");
|
||||
@@ -57,7 +58,7 @@ pub fn handle_auth_command(cmd: AuthCommand) -> anyhow::Result<()> {
|
||||
}
|
||||
AuthCommand::Verify { token } => {
|
||||
let user = verify_simple_token(&token)?;
|
||||
|
||||
|
||||
println!("✓ Token valid for user: {}", user);
|
||||
println!("Note: This is simple token verification. Use JWT in production.");
|
||||
}
|
||||
@@ -67,37 +68,38 @@ pub fn handle_auth_command(cmd: AuthCommand) -> anyhow::Result<()> {
|
||||
|
||||
fn generate_simple_token(user: &str) -> String {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
|
||||
format!("{}_{}", user, timestamp)
|
||||
}
|
||||
|
||||
fn verify_simple_token(token: &str) -> anyhow::Result<String> {
|
||||
let parts: Vec<&str> = token.split('_').collect();
|
||||
|
||||
|
||||
if parts.len() < 2 {
|
||||
return Err(anyhow::anyhow!("Invalid token format"));
|
||||
}
|
||||
|
||||
|
||||
let user = parts[0];
|
||||
let timestamp_str = parts[1];
|
||||
|
||||
let timestamp: u64 = timestamp_str.parse()
|
||||
|
||||
let timestamp: u64 = timestamp_str
|
||||
.parse()
|
||||
.context("Failed to parse token timestamp")?;
|
||||
|
||||
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
|
||||
if now - timestamp > 86400 {
|
||||
return Err(anyhow::anyhow!("Token expired (valid for 24 hours)"));
|
||||
}
|
||||
|
||||
|
||||
Ok(user.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,9 @@ pub fn handle_config_command(cmd: ConfigCommand) -> anyhow::Result<()> {
|
||||
let config_path = Path::new("config/markbase.toml");
|
||||
|
||||
if !config_path.exists() {
|
||||
println!("Configuration file not found. Run 'markbase metadata config init' first.");
|
||||
println!(
|
||||
"Configuration file not found. Run 'markbase metadata config init' first."
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -61,7 +63,9 @@ pub fn handle_config_command(cmd: ConfigCommand) -> anyhow::Result<()> {
|
||||
let config_path = Path::new("config/markbase.toml");
|
||||
|
||||
if !config_path.exists() {
|
||||
println!("Configuration file not found. Run 'markbase metadata config init' first.");
|
||||
println!(
|
||||
"Configuration file not found. Run 'markbase metadata config init' first."
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -86,7 +90,9 @@ pub fn handle_config_command(cmd: ConfigCommand) -> anyhow::Result<()> {
|
||||
let config_path = Path::new("config/markbase.toml");
|
||||
|
||||
if !config_path.exists() {
|
||||
println!("Configuration file not found. Run 'markbase metadata config init' first.");
|
||||
println!(
|
||||
"Configuration file not found. Run 'markbase metadata config init' first."
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -115,4 +121,4 @@ fn show_section(config: &crate::config::MarkBaseConfig, section: &str) {
|
||||
"logging" => println!("{}", toml::to_string_pretty(&config.logging).unwrap()),
|
||||
_ => println!("Invalid section: {}. Valid sections: server, postgresql, authentication, test, logging", section),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::Context;
|
||||
use clap::Subcommand;
|
||||
use rusqlite::Connection;
|
||||
use anyhow::Context;
|
||||
|
||||
#[derive(Subcommand)]
|
||||
pub enum DbCommand {
|
||||
@@ -34,54 +34,52 @@ pub fn handle_db_command(cmd: DbCommand) -> anyhow::Result<()> {
|
||||
match cmd {
|
||||
DbCommand::Create { user } => {
|
||||
let db_path = filetree::FileTree::user_db_path(&user);
|
||||
|
||||
|
||||
if std::path::Path::new(&db_path).exists() {
|
||||
println!("Database already exists: {}", db_path);
|
||||
println!("Use 'db status' to check database info");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
||||
println!("Creating database for user: {}", user);
|
||||
|
||||
let conn = filetree::FileTree::init_user_db(&user)
|
||||
.context("Failed to initialize database")?;
|
||||
|
||||
|
||||
let conn =
|
||||
filetree::FileTree::init_user_db(&user).context("Failed to initialize database")?;
|
||||
|
||||
println!("✓ Database created: {}", db_path);
|
||||
println!("✓ Tables initialized: file_nodes, file_registry, file_locations, tree_registry");
|
||||
|
||||
conn.close().map_err(|e| anyhow::anyhow!("Failed to close database: {:?}", e))?;
|
||||
println!(
|
||||
"✓ Tables initialized: file_nodes, file_registry, file_locations, tree_registry"
|
||||
);
|
||||
|
||||
conn.close()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to close database: {:?}", e))?;
|
||||
}
|
||||
DbCommand::Status { user } => {
|
||||
let db_path = filetree::FileTree::user_db_path(&user);
|
||||
|
||||
|
||||
if !std::path::Path::new(&db_path).exists() {
|
||||
return Err(anyhow::anyhow!("Database not found: {}", db_path));
|
||||
}
|
||||
|
||||
let conn = Connection::open(&db_path)
|
||||
.context("Failed to open database")?;
|
||||
|
||||
|
||||
let conn = Connection::open(&db_path).context("Failed to open database")?;
|
||||
|
||||
let file_size = std::fs::metadata(&db_path)?.len();
|
||||
let file_size_mb = file_size as f64 / 1024.0 / 1024.0;
|
||||
|
||||
let node_count: i64 = conn.query_row(
|
||||
"SELECT COUNT(*) FROM file_nodes",
|
||||
[],
|
||||
|row| row.get(0)
|
||||
).context("Failed to count nodes")?;
|
||||
|
||||
let file_count: i64 = conn.query_row(
|
||||
"SELECT COUNT(*) FROM file_registry",
|
||||
[],
|
||||
|row| row.get(0)
|
||||
).context("Failed to count files")?;
|
||||
|
||||
|
||||
let node_count: i64 = conn
|
||||
.query_row("SELECT COUNT(*) FROM file_nodes", [], |row| row.get(0))
|
||||
.context("Failed to count nodes")?;
|
||||
|
||||
let file_count: i64 = conn
|
||||
.query_row("SELECT COUNT(*) FROM file_registry", [], |row| row.get(0))
|
||||
.context("Failed to count files")?;
|
||||
|
||||
let tree_types: Vec<String> = {
|
||||
let mut stmt = conn.prepare("SELECT tree_type FROM tree_registry")?;
|
||||
let rows = stmt.query_map([], |row| row.get(0))?;
|
||||
rows.collect::<Result<Vec<_>, _>>()?
|
||||
};
|
||||
|
||||
|
||||
println!("=== Database Status ===");
|
||||
println!("User: {}", user);
|
||||
println!("Path: {}", db_path);
|
||||
@@ -89,21 +87,21 @@ pub fn handle_db_command(cmd: DbCommand) -> anyhow::Result<()> {
|
||||
println!("Nodes: {}", node_count);
|
||||
println!("Files: {}", file_count);
|
||||
println!("Tree Types: {:?}", tree_types);
|
||||
|
||||
conn.close().map_err(|e| anyhow::anyhow!("Failed to close database: {:?}", e))?;
|
||||
|
||||
conn.close()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to close database: {:?}", e))?;
|
||||
}
|
||||
DbCommand::Backup { user, output } => {
|
||||
let db_path = filetree::FileTree::user_db_path(&user);
|
||||
|
||||
|
||||
if !std::path::Path::new(&db_path).exists() {
|
||||
return Err(anyhow::anyhow!("Database not found: {}", db_path));
|
||||
}
|
||||
|
||||
|
||||
println!("Backing up database for user: {} to {}", user, output);
|
||||
|
||||
std::fs::copy(&db_path, &output)
|
||||
.context("Failed to backup database")?;
|
||||
|
||||
|
||||
std::fs::copy(&db_path, &output).context("Failed to backup database")?;
|
||||
|
||||
println!("✓ Database backed up to: {}", output);
|
||||
println!("✓ Backup size: {} bytes", std::fs::metadata(&output)?.len());
|
||||
}
|
||||
@@ -111,24 +109,26 @@ pub fn handle_db_command(cmd: DbCommand) -> anyhow::Result<()> {
|
||||
if !std::path::Path::new(&input).exists() {
|
||||
return Err(anyhow::anyhow!("Backup file not found: {}", input));
|
||||
}
|
||||
|
||||
|
||||
let db_path = filetree::FileTree::user_db_path(&user);
|
||||
|
||||
|
||||
if std::path::Path::new(&db_path).exists() {
|
||||
let backup_path = format!("{}.bak", db_path);
|
||||
println!("Warning: Database exists, creating backup: {}", backup_path);
|
||||
std::fs::copy(&db_path, &backup_path)
|
||||
.context("Failed to create backup before restore")?;
|
||||
}
|
||||
|
||||
|
||||
println!("Restoring database for user: {} from {}", user, input);
|
||||
|
||||
std::fs::copy(&input, &db_path)
|
||||
.context("Failed to restore database")?;
|
||||
|
||||
|
||||
std::fs::copy(&input, &db_path).context("Failed to restore database")?;
|
||||
|
||||
println!("✓ Database restored from: {}", input);
|
||||
println!("✓ Database size: {} bytes", std::fs::metadata(&db_path)?.len());
|
||||
println!(
|
||||
"✓ Database size: {} bytes",
|
||||
std::fs::metadata(&db_path)?.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
pub mod config;
|
||||
pub mod user;
|
||||
pub mod db;
|
||||
pub mod auth;
|
||||
pub mod config;
|
||||
pub mod db;
|
||||
pub mod user;
|
||||
|
||||
use clap::Subcommand;
|
||||
|
||||
@@ -25,4 +25,4 @@ pub async fn handle_metadata_command(cmd: MetadataCommands) -> anyhow::Result<()
|
||||
MetadataCommands::Auth(c) => auth::handle_auth_command(c)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::Context;
|
||||
use clap::Subcommand;
|
||||
use rusqlite::Connection;
|
||||
use anyhow::Context;
|
||||
|
||||
#[derive(Subcommand)]
|
||||
pub enum UserCommand {
|
||||
@@ -18,7 +18,7 @@ pub enum UserCommand {
|
||||
#[arg(short, long)]
|
||||
name: String,
|
||||
},
|
||||
#[command(name = "user-delete")]
|
||||
#[command(name = "user-delete")]
|
||||
Delete {
|
||||
#[arg(short, long)]
|
||||
name: String,
|
||||
@@ -29,58 +29,60 @@ pub fn handle_user_command(cmd: UserCommand) -> anyhow::Result<()> {
|
||||
match cmd {
|
||||
UserCommand::Create { name, password } => {
|
||||
let db_path = "data/auth.sqlite";
|
||||
|
||||
|
||||
if !std::path::Path::new(db_path).exists() {
|
||||
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
||||
}
|
||||
|
||||
let conn = Connection::open(db_path)
|
||||
.context("Failed to open auth database")?;
|
||||
|
||||
let exists: i64 = conn.query_row(
|
||||
"SELECT COUNT(*) FROM sftpgo_users WHERE username = ?",
|
||||
[&name],
|
||||
|row| row.get(0)
|
||||
).context("Failed to check user existence")?;
|
||||
|
||||
|
||||
let conn = Connection::open(db_path).context("Failed to open auth database")?;
|
||||
|
||||
let exists: i64 = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) FROM sftpgo_users WHERE username = ?",
|
||||
[&name],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.context("Failed to check user existence")?;
|
||||
|
||||
if exists > 0 {
|
||||
return Err(anyhow::anyhow!("User already exists: {}", name));
|
||||
}
|
||||
|
||||
let password_hash = bcrypt::hash(&password, bcrypt::DEFAULT_COST)
|
||||
.context("Failed to hash password")?;
|
||||
|
||||
|
||||
let password_hash =
|
||||
bcrypt::hash(&password, bcrypt::DEFAULT_COST).context("Failed to hash password")?;
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO sftpgo_users (username, password_hash, role, created_at) VALUES (?, ?, 'user', datetime('now'))",
|
||||
rusqlite::params![name, password_hash]
|
||||
).context("Failed to create user")?;
|
||||
|
||||
|
||||
println!("✓ User created: {}", name);
|
||||
println!("✓ Role: user");
|
||||
println!("✓ Password hashed with bcrypt");
|
||||
}
|
||||
UserCommand::List => {
|
||||
let db_path = "data/auth.sqlite";
|
||||
|
||||
|
||||
if !std::path::Path::new(db_path).exists() {
|
||||
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
||||
}
|
||||
|
||||
let conn = Connection::open(db_path)
|
||||
.context("Failed to open auth database")?;
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT username, role, created_at FROM sftpgo_users ORDER BY username"
|
||||
).context("Failed to prepare query")?;
|
||||
|
||||
let users = stmt.query_map([], |row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, String>(2)?,
|
||||
))
|
||||
}).context("Failed to query users")?;
|
||||
|
||||
|
||||
let conn = Connection::open(db_path).context("Failed to open auth database")?;
|
||||
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT username, role, created_at FROM sftpgo_users ORDER BY username")
|
||||
.context("Failed to prepare query")?;
|
||||
|
||||
let users = stmt
|
||||
.query_map([], |row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, String>(2)?,
|
||||
))
|
||||
})
|
||||
.context("Failed to query users")?;
|
||||
|
||||
println!("=== Users List ===");
|
||||
let mut count = 0;
|
||||
for user in users {
|
||||
@@ -88,7 +90,7 @@ pub fn handle_user_command(cmd: UserCommand) -> anyhow::Result<()> {
|
||||
println!(" {} (role: {}, created: {})", name, role, created_at);
|
||||
count += 1;
|
||||
}
|
||||
|
||||
|
||||
if count == 0 {
|
||||
println!("No users found");
|
||||
} else {
|
||||
@@ -97,24 +99,27 @@ pub fn handle_user_command(cmd: UserCommand) -> anyhow::Result<()> {
|
||||
}
|
||||
UserCommand::Show { name } => {
|
||||
let db_path = "data/auth.sqlite";
|
||||
|
||||
|
||||
if !std::path::Path::new(db_path).exists() {
|
||||
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
||||
}
|
||||
|
||||
let conn = Connection::open(db_path)
|
||||
.context("Failed to open auth database")?;
|
||||
|
||||
let user = conn.query_row(
|
||||
"SELECT username, role, created_at FROM sftpgo_users WHERE username = ?",
|
||||
[&name],
|
||||
|row| Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, String>(2)?,
|
||||
))
|
||||
).context("Failed to query user")?;
|
||||
|
||||
|
||||
let conn = Connection::open(db_path).context("Failed to open auth database")?;
|
||||
|
||||
let user = conn
|
||||
.query_row(
|
||||
"SELECT username, role, created_at FROM sftpgo_users WHERE username = ?",
|
||||
[&name],
|
||||
|row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, String>(2)?,
|
||||
))
|
||||
},
|
||||
)
|
||||
.context("Failed to query user")?;
|
||||
|
||||
let (username, role, created_at) = user;
|
||||
println!("=== User Details ===");
|
||||
println!("Username: {}", username);
|
||||
@@ -123,31 +128,30 @@ pub fn handle_user_command(cmd: UserCommand) -> anyhow::Result<()> {
|
||||
}
|
||||
UserCommand::Delete { name } => {
|
||||
let db_path = "data/auth.sqlite";
|
||||
|
||||
|
||||
if !std::path::Path::new(db_path).exists() {
|
||||
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
||||
}
|
||||
|
||||
let conn = Connection::open(db_path)
|
||||
.context("Failed to open auth database")?;
|
||||
|
||||
let exists: i64 = conn.query_row(
|
||||
"SELECT COUNT(*) FROM sftpgo_users WHERE username = ?",
|
||||
[&name],
|
||||
|row| row.get(0)
|
||||
).context("Failed to check user existence")?;
|
||||
|
||||
|
||||
let conn = Connection::open(db_path).context("Failed to open auth database")?;
|
||||
|
||||
let exists: i64 = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) FROM sftpgo_users WHERE username = ?",
|
||||
[&name],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.context("Failed to check user existence")?;
|
||||
|
||||
if exists == 0 {
|
||||
return Err(anyhow::anyhow!("User not found: {}", name));
|
||||
}
|
||||
|
||||
conn.execute(
|
||||
"DELETE FROM sftpgo_users WHERE username = ?",
|
||||
[&name]
|
||||
).context("Failed to delete user")?;
|
||||
|
||||
|
||||
conn.execute("DELETE FROM sftpgo_users WHERE username = ?", [&name])
|
||||
.context("Failed to delete user")?;
|
||||
|
||||
println!("✓ User deleted: {}", name);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,4 +22,4 @@ pub enum Commands {
|
||||
Storage(storage::StorageCommands),
|
||||
#[command(flatten)]
|
||||
Tools(tools::ToolsCommands),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,56 +21,56 @@ pub fn handle_archive_command(cmd: ArchiveCommand) -> anyhow::Result<()> {
|
||||
match cmd {
|
||||
ArchiveCommand::Decompress { file, output } => {
|
||||
use crate::archive::{ArchiveConfig, ProcessorRegistry};
|
||||
|
||||
|
||||
println!("Decompressing {} to {}", file, output);
|
||||
|
||||
|
||||
let archive_path = Path::new(&file);
|
||||
if !archive_path.exists() {
|
||||
return Err(anyhow::anyhow!("Archive file not found: {}", file));
|
||||
}
|
||||
|
||||
|
||||
let config = ArchiveConfig::default();
|
||||
let mut registry = ProcessorRegistry::new(config);
|
||||
registry.initialize()?;
|
||||
|
||||
|
||||
let output_path = Path::new(&output);
|
||||
std::fs::create_dir_all(output_path)?;
|
||||
|
||||
|
||||
let processor = registry.get_processor_mut(archive_path)?;
|
||||
let result = processor.extract_all(output_path)?;
|
||||
|
||||
|
||||
println!("✓ Archive decompressed to: {}", output);
|
||||
println!("✓ Files extracted: {}", result.success_files);
|
||||
println!("✓ Total size: {} bytes", result.total_bytes);
|
||||
}
|
||||
ArchiveCommand::List { file } => {
|
||||
use crate::archive::{ArchiveConfig, ProcessorRegistry};
|
||||
|
||||
|
||||
println!("Listing contents of {}", file);
|
||||
|
||||
|
||||
let archive_path = Path::new(&file);
|
||||
if !archive_path.exists() {
|
||||
return Err(anyhow::anyhow!("Archive file not found: {}", file));
|
||||
}
|
||||
|
||||
|
||||
let config = ArchiveConfig::default();
|
||||
let mut registry = ProcessorRegistry::new(config);
|
||||
registry.initialize()?;
|
||||
|
||||
|
||||
let processor = registry.get_processor_mut(archive_path)?;
|
||||
let metadata = processor.open(archive_path)?;
|
||||
let entries = processor.list_entries()?;
|
||||
|
||||
|
||||
println!("=== Archive Contents ===");
|
||||
println!("Format: {}", metadata.format);
|
||||
println!("Total files: {}", metadata.total_files);
|
||||
println!("Total size: {} bytes", metadata.total_size);
|
||||
println!("");
|
||||
|
||||
println!();
|
||||
|
||||
for entry in entries {
|
||||
println!(" {} ({} bytes)", entry.path.display(), entry.size);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,4 +17,4 @@ pub fn handle_hash_command(cmd: HashCommand) -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
pub mod scan;
|
||||
pub mod hash;
|
||||
pub mod archive;
|
||||
pub mod sync;
|
||||
pub mod hash;
|
||||
pub mod mount;
|
||||
pub mod scan;
|
||||
pub mod sync;
|
||||
|
||||
use clap::Subcommand;
|
||||
|
||||
@@ -29,4 +29,4 @@ pub async fn handle_storage_command(cmd: StorageCommands) -> anyhow::Result<()>
|
||||
StorageCommands::Mount(c) => mount::handle_mount_command(c)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,21 +22,25 @@ pub enum MountCommand {
|
||||
|
||||
pub fn handle_mount_command(cmd: MountCommand) -> anyhow::Result<()> {
|
||||
match cmd {
|
||||
MountCommand::Attach { type_, server, path } => {
|
||||
MountCommand::Attach {
|
||||
type_,
|
||||
server,
|
||||
path,
|
||||
} => {
|
||||
use std::process::Command;
|
||||
|
||||
|
||||
println!("Mounting {} from {} to {}", type_, server, path);
|
||||
|
||||
|
||||
if type_ == "nfs" {
|
||||
let mount_point = std::path::Path::new(&path);
|
||||
std::fs::create_dir_all(mount_point)?;
|
||||
|
||||
|
||||
let nfs_path = format!("{}:{}", server, path);
|
||||
|
||||
|
||||
let status = Command::new("mount")
|
||||
.args(["-t", "nfs", &nfs_path, &path])
|
||||
.status()?;
|
||||
|
||||
|
||||
if status.success() {
|
||||
println!("✓ NFS mounted: {} to {}", nfs_path, path);
|
||||
} else {
|
||||
@@ -45,31 +49,32 @@ pub fn handle_mount_command(cmd: MountCommand) -> anyhow::Result<()> {
|
||||
} else if type_ == "smb" {
|
||||
let mount_point = std::path::Path::new(&path);
|
||||
std::fs::create_dir_all(mount_point)?;
|
||||
|
||||
|
||||
let smb_path = format!("//{}", server);
|
||||
|
||||
|
||||
let status = Command::new("mount")
|
||||
.args(["-t", "smbfs", &smb_path, &path])
|
||||
.status()?;
|
||||
|
||||
|
||||
if status.success() {
|
||||
println!("✓ SMB mounted: {} to {}", smb_path, path);
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("SMB mount failed"));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Unknown mount type: {}. Use 'nfs' or 'smb'", type_));
|
||||
return Err(anyhow::anyhow!(
|
||||
"Unknown mount type: {}. Use 'nfs' or 'smb'",
|
||||
type_
|
||||
));
|
||||
}
|
||||
}
|
||||
MountCommand::Detach { path } => {
|
||||
use std::process::Command;
|
||||
|
||||
|
||||
println!("Unmounting {}", path);
|
||||
|
||||
let status = Command::new("umount")
|
||||
.arg(&path)
|
||||
.status()?;
|
||||
|
||||
|
||||
let status = Command::new("umount").arg(&path).status()?;
|
||||
|
||||
if status.success() {
|
||||
println!("✓ Unmounted: {}", path);
|
||||
} else {
|
||||
@@ -78,14 +83,13 @@ pub fn handle_mount_command(cmd: MountCommand) -> anyhow::Result<()> {
|
||||
}
|
||||
MountCommand::List => {
|
||||
use std::process::Command;
|
||||
|
||||
|
||||
println!("Listing mounted storage");
|
||||
|
||||
let output = Command::new("mount")
|
||||
.output()?;
|
||||
|
||||
|
||||
let output = Command::new("mount").output()?;
|
||||
|
||||
let mounts = String::from_utf8_lossy(&output.stdout);
|
||||
|
||||
|
||||
println!("=== Mounted Filesystems ===");
|
||||
for line in mounts.lines() {
|
||||
if line.contains("nfs") || line.contains("smbfs") || line.contains("fuse") {
|
||||
@@ -95,4 +99,4 @@ pub fn handle_mount_command(cmd: MountCommand) -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,4 +31,4 @@ pub fn handle_scan_command(cmd: ScanCommand) -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,27 +17,31 @@ pub enum SyncCommand {
|
||||
|
||||
pub fn handle_sync_command(cmd: SyncCommand) -> anyhow::Result<()> {
|
||||
match cmd {
|
||||
SyncCommand::Start { source, target, mode } => {
|
||||
SyncCommand::Start {
|
||||
source,
|
||||
target,
|
||||
mode,
|
||||
} => {
|
||||
use std::path::Path;
|
||||
|
||||
|
||||
println!("Syncing {} to {} (mode: {})", source, target, mode);
|
||||
|
||||
|
||||
let source_path = Path::new(&source);
|
||||
let target_path = Path::new(&target);
|
||||
|
||||
|
||||
if !source_path.exists() {
|
||||
return Err(anyhow::anyhow!("Source path not found: {}", source));
|
||||
}
|
||||
|
||||
|
||||
if mode == "mirror" {
|
||||
std::fs::create_dir_all(target_path)?;
|
||||
|
||||
|
||||
let entries = std::fs::read_dir(source_path)?;
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
let target_file = target_path.join(entry.file_name());
|
||||
|
||||
|
||||
if path.is_file() {
|
||||
std::fs::copy(&path, &target_file)?;
|
||||
println!(" Copied: {:?}", entry.file_name());
|
||||
@@ -46,7 +50,7 @@ pub fn handle_sync_command(cmd: SyncCommand) -> anyhow::Result<()> {
|
||||
println!(" Created directory: {:?}", entry.file_name());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
println!("✓ Sync completed (mirror mode)");
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Unknown sync mode: {}. Use 'mirror'", mode));
|
||||
@@ -59,4 +63,4 @@ pub fn handle_sync_command(cmd: SyncCommand) -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,4 +17,4 @@ pub async fn handle_tools_command(cmd: ToolsCommands) -> anyhow::Result<()> {
|
||||
ToolsCommands::Test(c) => test::handle_test_command(c)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,4 +23,4 @@ pub fn handle_render_command(cmd: RenderCommand) -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,36 +16,39 @@ pub enum TestCommand {
|
||||
|
||||
pub fn handle_test_command(cmd: TestCommand) -> anyhow::Result<()> {
|
||||
match cmd {
|
||||
TestCommand::Bcrypt { password, verify_hash } => {
|
||||
TestCommand::Bcrypt {
|
||||
password,
|
||||
verify_hash,
|
||||
} => {
|
||||
use bcrypt::{hash, verify, DEFAULT_COST};
|
||||
|
||||
println!("=== bcrypt Hash Test ===");
|
||||
println!("Password: {}", password);
|
||||
println!("");
|
||||
println!();
|
||||
|
||||
let new_hash = hash(&password, DEFAULT_COST)?;
|
||||
println!("Generated hash:");
|
||||
println!("{}", new_hash);
|
||||
println!("");
|
||||
println!();
|
||||
|
||||
if let Some(hash_to_verify) = verify_hash {
|
||||
println!("Verifying hash: {}", hash_to_verify);
|
||||
let valid = verify(&password, &hash_to_verify)?;
|
||||
println!("Valid: {}", valid);
|
||||
println!("");
|
||||
println!();
|
||||
}
|
||||
|
||||
let db_hash = "$2b$10$ha5wU.mOi8fHLJCfun860u2cfVopa04jwe/q82IKOwqp5uG70qsH6";
|
||||
println!("Database hash: {}", db_hash);
|
||||
let valid = verify(&password, db_hash)?;
|
||||
println!("Database hash valid for '{}': {}", password, valid);
|
||||
println!("");
|
||||
println!();
|
||||
|
||||
if !valid {
|
||||
println!("❌ Database hash is incorrect!");
|
||||
println!("Update SQL:");
|
||||
println!("UPDATE sftpgo_users SET password_hash = '{}' WHERE username IN ('testuser', 'demo', 'warren', 'momentry');", new_hash);
|
||||
println!("");
|
||||
println!();
|
||||
println!("Execute:");
|
||||
println!("sqlite3 data/auth.sqlite \"UPDATE sftpgo_users SET password_hash = '{}' WHERE username IN ('testuser', 'demo', 'warren', 'momentry');\"", new_hash);
|
||||
} else {
|
||||
@@ -58,4 +61,4 @@ pub fn handle_test_command(cmd: TestCommand) -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ pub use web::*;
|
||||
|
||||
/// Unified application configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Default)]
|
||||
pub struct AppConfig {
|
||||
#[serde(default)]
|
||||
pub web: WebSection,
|
||||
@@ -154,13 +155,19 @@ impl AppConfig {
|
||||
self.web.host = v;
|
||||
}
|
||||
if let Ok(v) = std::env::var("MB_WEB_PORT") {
|
||||
if let Ok(p) = v.parse() { self.web.port = p; }
|
||||
if let Ok(p) = v.parse() {
|
||||
self.web.port = p;
|
||||
}
|
||||
}
|
||||
if let Ok(v) = std::env::var("MB_SSH_PORT") {
|
||||
if let Ok(p) = v.parse() { self.ssh.port = p; }
|
||||
if let Ok(p) = v.parse() {
|
||||
self.ssh.port = p;
|
||||
}
|
||||
}
|
||||
if let Ok(v) = std::env::var("MB_SFTP_PORT") {
|
||||
if let Ok(p) = v.parse() { self.sftp.port = p; }
|
||||
if let Ok(p) = v.parse() {
|
||||
self.sftp.port = p;
|
||||
}
|
||||
}
|
||||
if let Ok(v) = std::env::var("MB_S3_ENABLED") {
|
||||
self.s3.enabled = v == "true" || v == "1";
|
||||
@@ -172,16 +179,6 @@ impl AppConfig {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
web: WebSection::default(),
|
||||
s3: S3Section::default(),
|
||||
sftp: SftpSection::default(),
|
||||
ssh: SshSection::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
@@ -323,11 +323,15 @@ impl MarkBaseConfig {
|
||||
}
|
||||
|
||||
if self.authentication.default_user.is_empty() {
|
||||
return Err(anyhow::anyhow!("authentication.default_user cannot be empty"));
|
||||
return Err(anyhow::anyhow!(
|
||||
"authentication.default_user cannot be empty"
|
||||
));
|
||||
}
|
||||
|
||||
if self.authentication.default_password.is_empty() {
|
||||
return Err(anyhow::anyhow!("authentication.default_password cannot be empty"));
|
||||
return Err(anyhow::anyhow!(
|
||||
"authentication.default_password cannot be empty"
|
||||
));
|
||||
}
|
||||
|
||||
if self.authentication.max_sessions_per_user == 0 {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use rusqlite::{Connection, params};
|
||||
use rusqlite::{params, Connection};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
@@ -46,10 +46,10 @@ impl DownloadDb {
|
||||
Self::init_tables(&conn)?;
|
||||
conn
|
||||
};
|
||||
|
||||
|
||||
Ok(DownloadDb { conn })
|
||||
}
|
||||
|
||||
|
||||
fn init_tables(conn: &Connection) -> Result<()> {
|
||||
conn.execute_batch(
|
||||
"CREATE TABLE IF NOT EXISTS products (
|
||||
@@ -74,63 +74,70 @@ impl DownloadDb {
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_product_files_product_id ON product_files(product_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_products_series ON products(series);
|
||||
"
|
||||
",
|
||||
)?;
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn create_product(&mut self, product_name: &str, series: &str, description: Option<&str>) -> Result<i64> {
|
||||
|
||||
pub fn create_product(
|
||||
&mut self,
|
||||
product_name: &str,
|
||||
series: &str,
|
||||
description: Option<&str>,
|
||||
) -> Result<i64> {
|
||||
let now = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
|
||||
|
||||
|
||||
self.conn.execute(
|
||||
"INSERT INTO products (product_name, series, description, created_at)
|
||||
VALUES (?1, ?2, ?3, ?4)",
|
||||
params![product_name, series, description, now],
|
||||
)?;
|
||||
|
||||
|
||||
Ok(self.conn.last_insert_rowid())
|
||||
}
|
||||
|
||||
|
||||
pub fn get_all_products(&self) -> Result<Vec<Product>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT id, product_name, series, description, created_at FROM products ORDER BY series, product_name"
|
||||
)?;
|
||||
|
||||
let products = stmt.query_map([], |row| {
|
||||
Ok(Product {
|
||||
id: row.get(0)?,
|
||||
product_name: row.get(1)?,
|
||||
series: row.get(2)?,
|
||||
description: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})?
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
|
||||
let products = stmt
|
||||
.query_map([], |row| {
|
||||
Ok(Product {
|
||||
id: row.get(0)?,
|
||||
product_name: row.get(1)?,
|
||||
series: row.get(2)?,
|
||||
description: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})?
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(products)
|
||||
}
|
||||
|
||||
|
||||
pub fn get_products_by_series(&self, series: &str) -> Result<Vec<Product>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT id, product_name, series, description, created_at FROM products
|
||||
WHERE series = ?1 ORDER BY product_name"
|
||||
WHERE series = ?1 ORDER BY product_name",
|
||||
)?;
|
||||
|
||||
let products = stmt.query_map([series], |row| {
|
||||
Ok(Product {
|
||||
id: row.get(0)?,
|
||||
product_name: row.get(1)?,
|
||||
series: row.get(2)?,
|
||||
description: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})?
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
|
||||
let products = stmt
|
||||
.query_map([series], |row| {
|
||||
Ok(Product {
|
||||
id: row.get(0)?,
|
||||
product_name: row.get(1)?,
|
||||
series: row.get(2)?,
|
||||
description: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
})?
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(products)
|
||||
}
|
||||
|
||||
|
||||
pub fn get_series_stats(&self) -> Result<Vec<SeriesStats>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT
|
||||
@@ -141,106 +148,118 @@ impl DownloadDb {
|
||||
FROM products p
|
||||
LEFT JOIN product_files pf ON p.id = pf.product_id
|
||||
GROUP BY p.series
|
||||
ORDER BY p.series"
|
||||
ORDER BY p.series",
|
||||
)?;
|
||||
|
||||
let stats = stmt.query_map([], |row| {
|
||||
Ok(SeriesStats {
|
||||
series: row.get(0)?,
|
||||
product_count: row.get(1)?,
|
||||
file_count: row.get(2)?,
|
||||
total_size: row.get::<_, i64>(3)? as u64,
|
||||
})
|
||||
})?
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
|
||||
let stats = stmt
|
||||
.query_map([], |row| {
|
||||
Ok(SeriesStats {
|
||||
series: row.get(0)?,
|
||||
product_count: row.get(1)?,
|
||||
file_count: row.get(2)?,
|
||||
total_size: row.get::<_, i64>(3)? as u64,
|
||||
})
|
||||
})?
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(stats)
|
||||
}
|
||||
|
||||
pub fn add_file_to_product(&mut self, product_id: i64, file_path: &str, file_name: &str, file_size: u64, file_hash: Option<&str>) -> Result<i64> {
|
||||
|
||||
pub fn add_file_to_product(
|
||||
&mut self,
|
||||
product_id: i64,
|
||||
file_path: &str,
|
||||
file_name: &str,
|
||||
file_size: u64,
|
||||
file_hash: Option<&str>,
|
||||
) -> Result<i64> {
|
||||
let now = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
|
||||
|
||||
|
||||
self.conn.execute(
|
||||
"INSERT INTO product_files (product_id, file_path, file_name, file_size, file_hash, uploaded_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||
params![product_id, file_path, file_name, file_size as i64, file_hash, now],
|
||||
)?;
|
||||
|
||||
|
||||
Ok(self.conn.last_insert_rowid())
|
||||
}
|
||||
|
||||
|
||||
pub fn get_files_by_product(&self, product_id: i64) -> Result<Vec<ProductFile>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT id, product_id, file_path, file_name, file_size, file_hash, download_count, uploaded_at
|
||||
FROM product_files WHERE product_id = ?1 ORDER BY file_name"
|
||||
)?;
|
||||
|
||||
let files = stmt.query_map([product_id], |row| {
|
||||
Ok(ProductFile {
|
||||
id: row.get(0)?,
|
||||
product_id: row.get(1)?,
|
||||
file_path: row.get(2)?,
|
||||
file_name: row.get(3)?,
|
||||
file_size: row.get::<_, i64>(4)? as u64,
|
||||
file_hash: row.get(5)?,
|
||||
download_count: row.get(6)?,
|
||||
uploaded_at: row.get(7)?,
|
||||
})
|
||||
})?
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
|
||||
let files = stmt
|
||||
.query_map([product_id], |row| {
|
||||
Ok(ProductFile {
|
||||
id: row.get(0)?,
|
||||
product_id: row.get(1)?,
|
||||
file_path: row.get(2)?,
|
||||
file_name: row.get(3)?,
|
||||
file_size: row.get::<_, i64>(4)? as u64,
|
||||
file_hash: row.get(5)?,
|
||||
download_count: row.get(6)?,
|
||||
uploaded_at: row.get(7)?,
|
||||
})
|
||||
})?
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
|
||||
pub fn increment_download_count(&mut self, file_id: i64) -> Result<()> {
|
||||
self.conn.execute(
|
||||
"UPDATE product_files SET download_count = download_count + 1 WHERE id = ?1",
|
||||
params![file_id],
|
||||
)?;
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
pub fn get_all_files(&self) -> Result<Vec<ProductFile>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT id, product_id, file_path, file_name, file_size, file_hash, download_count, uploaded_at
|
||||
FROM product_files ORDER BY uploaded_at DESC"
|
||||
)?;
|
||||
|
||||
let files = stmt.query_map([], |row| {
|
||||
Ok(ProductFile {
|
||||
id: row.get(0)?,
|
||||
product_id: row.get(1)?,
|
||||
file_path: row.get(2)?,
|
||||
file_name: row.get(3)?,
|
||||
file_size: row.get::<_, i64>(4)? as u64,
|
||||
file_hash: row.get(5)?,
|
||||
download_count: row.get(6)?,
|
||||
uploaded_at: row.get(7)?,
|
||||
})
|
||||
})?
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
|
||||
let files = stmt
|
||||
.query_map([], |row| {
|
||||
Ok(ProductFile {
|
||||
id: row.get(0)?,
|
||||
product_id: row.get(1)?,
|
||||
file_path: row.get(2)?,
|
||||
file_name: row.get(3)?,
|
||||
file_size: row.get::<_, i64>(4)? as u64,
|
||||
file_hash: row.get(5)?,
|
||||
download_count: row.get(6)?,
|
||||
uploaded_at: row.get(7)?,
|
||||
})
|
||||
})?
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
|
||||
pub fn delete_product_with_files(&mut self, product_id: i64) -> Result<(i64, i64)> {
|
||||
// 先删除关联的文件映射
|
||||
self.conn.execute(
|
||||
"DELETE FROM product_files WHERE product_id = ?1",
|
||||
params![product_id],
|
||||
)?;
|
||||
|
||||
|
||||
let deleted_files = self.conn.last_insert_rowid();
|
||||
|
||||
|
||||
// 再删除产品记录
|
||||
self.conn.execute(
|
||||
"DELETE FROM products WHERE id = ?1",
|
||||
params![product_id],
|
||||
)?;
|
||||
|
||||
let deleted_product = if self.conn.last_insert_rowid() > 0 { 1 } else { 0 };
|
||||
|
||||
self.conn
|
||||
.execute("DELETE FROM products WHERE id = ?1", params![product_id])?;
|
||||
|
||||
let deleted_product = if self.conn.last_insert_rowid() > 0 {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
Ok((deleted_files, deleted_product))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,15 +7,19 @@ use axum::{
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
|
||||
use crate::server::AppState;
|
||||
use crate::download::db::DownloadDb;
|
||||
use crate::server::AppState;
|
||||
|
||||
pub async fn download_file(
|
||||
Path(file_id): Path<i64>,
|
||||
State(state): State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
||||
|
||||
let db_path = format!(
|
||||
"{}{}",
|
||||
state.db_dir.replace("users", "downloads"),
|
||||
"/products.sqlite"
|
||||
);
|
||||
|
||||
match DownloadDb::new(&db_path) {
|
||||
Ok(mut db) => {
|
||||
// 获取文件信息
|
||||
@@ -24,48 +28,65 @@ pub async fn download_file(
|
||||
if files.is_empty() {
|
||||
return (StatusCode::NOT_FOUND, "File not found").into_response();
|
||||
}
|
||||
|
||||
|
||||
let file_info = &files[0];
|
||||
|
||||
|
||||
// 更新下载统计
|
||||
db.increment_download_count(file_info.id).ok();
|
||||
|
||||
|
||||
// 构建文件路径(使用配置的db_dir)
|
||||
let base_path = state.db_dir.replace("users", "Downloads");
|
||||
let file_path = std::path::Path::new(&base_path).join(&file_info.file_path);
|
||||
|
||||
|
||||
if !file_path.exists() {
|
||||
return (StatusCode::NOT_FOUND, "File not found on disk").into_response();
|
||||
}
|
||||
|
||||
|
||||
// 读取文件内容
|
||||
match File::open(&file_path) {
|
||||
Ok(mut file) => {
|
||||
let mut buffer = Vec::new();
|
||||
match file.read_to_end(&mut buffer) {
|
||||
Ok(_) => {
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/octet-stream")
|
||||
.header(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", file_info.file_name)
|
||||
)
|
||||
.header("X-File-Hash", file_info.file_hash.clone().unwrap_or_default())
|
||||
.header("X-File-Size", file_info.file_size)
|
||||
.body(buffer.into())
|
||||
.unwrap()
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error reading file: {}", e)).into_response()
|
||||
Ok(_) => Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/octet-stream")
|
||||
.header(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", file_info.file_name),
|
||||
)
|
||||
.header(
|
||||
"X-File-Hash",
|
||||
file_info.file_hash.clone().unwrap_or_default(),
|
||||
)
|
||||
.header("X-File-Size", file_info.file_size)
|
||||
.body(buffer.into())
|
||||
.unwrap(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error reading file: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error opening file: {}", e)).into_response()
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error opening file: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {}", e)).into_response()
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Database error: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {}", e)).into_response()
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Database error: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,69 +101,84 @@ pub async fn download_file_by_path(
|
||||
// User files are in Downloads/user_id/
|
||||
format!("/Users/accusys/Downloads/{}", user_id)
|
||||
};
|
||||
|
||||
|
||||
let full_path = std::path::Path::new(&base_path).join(&file_path);
|
||||
|
||||
|
||||
if !full_path.exists() {
|
||||
return (StatusCode::NOT_FOUND, "File not found").into_response();
|
||||
}
|
||||
|
||||
let filename = file_path.split('/').last().unwrap_or("unknown");
|
||||
|
||||
|
||||
let filename = file_path.split('/').next_back().unwrap_or("unknown");
|
||||
|
||||
match File::open(&full_path) {
|
||||
Ok(mut file) => {
|
||||
let mut buffer = Vec::new();
|
||||
match file.read_to_end(&mut buffer) {
|
||||
Ok(_) => {
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/octet-stream")
|
||||
.header(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", filename)
|
||||
)
|
||||
.body(buffer.into())
|
||||
.unwrap()
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error reading file: {}", e)).into_response()
|
||||
Ok(_) => Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/octet-stream")
|
||||
.header(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", filename),
|
||||
)
|
||||
.body(buffer.into())
|
||||
.unwrap(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error reading file: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error opening file: {}", e)).into_response()
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error opening file: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_download_stats(
|
||||
State(state): State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
||||
|
||||
pub async fn get_download_stats(State(state): State<AppState>) -> impl IntoResponse {
|
||||
let db_path = format!(
|
||||
"{}{}",
|
||||
state.db_dir.replace("users", "downloads"),
|
||||
"/products.sqlite"
|
||||
);
|
||||
|
||||
match DownloadDb::new(&db_path) {
|
||||
Ok(db) => {
|
||||
match db.get_all_files() {
|
||||
Ok(files) => {
|
||||
let total_downloads: i64 = files.iter().map(|f| f.download_count).sum();
|
||||
let top_files: Vec<_> = files.iter()
|
||||
.filter(|f| f.download_count > 0)
|
||||
.take(10)
|
||||
.map(|f| serde_json::json!({
|
||||
Ok(db) => match db.get_all_files() {
|
||||
Ok(files) => {
|
||||
let total_downloads: i64 = files.iter().map(|f| f.download_count).sum();
|
||||
let top_files: Vec<_> = files
|
||||
.iter()
|
||||
.filter(|f| f.download_count > 0)
|
||||
.take(10)
|
||||
.map(|f| {
|
||||
serde_json::json!({
|
||||
"file_name": f.file_name,
|
||||
"download_count": f.download_count
|
||||
}))
|
||||
.collect();
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({
|
||||
"total_files": files.len(),
|
||||
"total_downloads": total_downloads,
|
||||
"top_files": top_files
|
||||
}))
|
||||
)
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({
|
||||
"total_files": files.len(),
|
||||
"total_downloads": total_downloads,
|
||||
"top_files": top_files
|
||||
})),
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": e.to_string()})),
|
||||
),
|
||||
},
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": e.to_string()})),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,35 +187,41 @@ pub async fn download_product_file(
|
||||
) -> impl IntoResponse {
|
||||
let base_path = format!("/Users/accusys/markbase/data/downloads/{}/", product_series);
|
||||
let full_path = std::path::Path::new(&base_path).join(&file_path);
|
||||
|
||||
|
||||
if !full_path.exists() {
|
||||
return (StatusCode::NOT_FOUND, "File not found").into_response();
|
||||
}
|
||||
|
||||
|
||||
if full_path.is_dir() {
|
||||
return (StatusCode::BAD_REQUEST, "Path is a directory, not a file").into_response();
|
||||
}
|
||||
|
||||
let filename = file_path.split('/').last().unwrap_or("unknown");
|
||||
|
||||
|
||||
let filename = file_path.split('/').next_back().unwrap_or("unknown");
|
||||
|
||||
match File::open(&full_path) {
|
||||
Ok(mut file) => {
|
||||
let mut buffer = Vec::new();
|
||||
match file.read_to_end(&mut buffer) {
|
||||
Ok(_) => {
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/octet-stream")
|
||||
.header(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", filename)
|
||||
)
|
||||
.body(buffer.into())
|
||||
.unwrap()
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error reading file: {}", e)).into_response()
|
||||
Ok(_) => Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/octet-stream")
|
||||
.header(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", filename),
|
||||
)
|
||||
.body(buffer.into())
|
||||
.unwrap(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error reading file: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error opening file: {}", e)).into_response()
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error opening file: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,27 +1,22 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{Html, IntoResponse, Json},
|
||||
extract::Path,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Json},
|
||||
};
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::server::AppState;
|
||||
use crate::download::storage;
|
||||
|
||||
pub async fn list_uploaded_files(
|
||||
Path(user_id): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
pub async fn list_uploaded_files(Path(user_id): Path<String>) -> impl IntoResponse {
|
||||
let file_list = storage::scan_uploaded_files(&user_id);
|
||||
(StatusCode::OK, Json(file_list))
|
||||
}
|
||||
|
||||
pub async fn get_file_info(
|
||||
Path((user_id, filename)): Path<(String, String)>,
|
||||
) -> impl IntoResponse {
|
||||
pub async fn get_file_info(Path((user_id, filename)): Path<(String, String)>) -> impl IntoResponse {
|
||||
let base_path = format!("/Users/accusys/Downloads/{}", user_id);
|
||||
let file_path = PathBuf::from(&base_path).join(&filename);
|
||||
|
||||
|
||||
if !file_path.exists() {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
@@ -29,7 +24,7 @@ pub async fn get_file_info(
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
|
||||
let metadata = std::fs::metadata(&file_path).unwrap();
|
||||
let file_size = metadata.len();
|
||||
let file_hash = if file_size > 0 {
|
||||
@@ -37,7 +32,7 @@ pub async fn get_file_info(
|
||||
} else {
|
||||
Some("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string())
|
||||
};
|
||||
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
@@ -49,4 +44,4 @@ pub async fn get_file_info(
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
pub mod models;
|
||||
pub mod db;
|
||||
pub mod handlers;
|
||||
pub mod storage;
|
||||
pub mod product_handlers;
|
||||
pub mod download_handler;
|
||||
pub mod handlers;
|
||||
pub mod models;
|
||||
pub mod product_handlers;
|
||||
pub mod storage;
|
||||
|
||||
pub use models::*;
|
||||
pub use db::{DownloadDb, Product, ProductFile, SeriesStats};
|
||||
pub use download_handler::*;
|
||||
pub use handlers::*;
|
||||
pub use models::*;
|
||||
pub use product_handlers::*;
|
||||
pub use download_handler::*;
|
||||
@@ -39,4 +39,4 @@ pub struct DownloadStats {
|
||||
pub total_files: i64,
|
||||
pub total_downloads: i64,
|
||||
pub series_stats: Vec<ProductSeries>,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,33 +1,42 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::{Json, IntoResponse},
|
||||
response::{IntoResponse, Json},
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::download::db::DownloadDb;
|
||||
use crate::server::AppState;
|
||||
use crate::download::db::{DownloadDb, Product, ProductFile, SeriesStats};
|
||||
|
||||
pub async fn list_all_products(
|
||||
State(state): State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
||||
|
||||
pub async fn list_all_products(State(state): State<AppState>) -> impl IntoResponse {
|
||||
let db_path = format!(
|
||||
"{}{}",
|
||||
state.db_dir.replace("users", "downloads"),
|
||||
"/products.sqlite"
|
||||
);
|
||||
|
||||
match DownloadDb::new(&db_path) {
|
||||
Ok(db) => {
|
||||
match db.get_all_products() {
|
||||
Ok(products) => (StatusCode::OK, Json(json!({
|
||||
Ok(db) => match db.get_all_products() {
|
||||
Ok(products) => (
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"products": products,
|
||||
"total": products.len()
|
||||
}))),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
||||
})),
|
||||
),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": e.to_string()
|
||||
}))),
|
||||
}
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
||||
"error": e.to_string()
|
||||
}))),
|
||||
})),
|
||||
),
|
||||
},
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": e.to_string()
|
||||
})),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,47 +44,67 @@ pub async fn list_products_by_series(
|
||||
Path(series): Path<String>,
|
||||
State(state): State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
||||
|
||||
let db_path = format!(
|
||||
"{}{}",
|
||||
state.db_dir.replace("users", "downloads"),
|
||||
"/products.sqlite"
|
||||
);
|
||||
|
||||
match DownloadDb::new(&db_path) {
|
||||
Ok(db) => {
|
||||
match db.get_products_by_series(&series) {
|
||||
Ok(products) => (StatusCode::OK, Json(json!({
|
||||
Ok(db) => match db.get_products_by_series(&series) {
|
||||
Ok(products) => (
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"series": series,
|
||||
"products": products,
|
||||
"total": products.len()
|
||||
}))),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
||||
})),
|
||||
),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": e.to_string()
|
||||
}))),
|
||||
}
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
||||
"error": e.to_string()
|
||||
}))),
|
||||
})),
|
||||
),
|
||||
},
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": e.to_string()
|
||||
})),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_series_stats(
|
||||
State(state): State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
||||
|
||||
pub async fn get_series_stats(State(state): State<AppState>) -> impl IntoResponse {
|
||||
let db_path = format!(
|
||||
"{}{}",
|
||||
state.db_dir.replace("users", "downloads"),
|
||||
"/products.sqlite"
|
||||
);
|
||||
|
||||
match DownloadDb::new(&db_path) {
|
||||
Ok(db) => {
|
||||
match db.get_series_stats() {
|
||||
Ok(stats) => (StatusCode::OK, Json(json!({
|
||||
Ok(db) => match db.get_series_stats() {
|
||||
Ok(stats) => (
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"series_stats": stats,
|
||||
"total_series": stats.len()
|
||||
}))),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
||||
})),
|
||||
),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": e.to_string()
|
||||
}))),
|
||||
}
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
||||
"error": e.to_string()
|
||||
}))),
|
||||
})),
|
||||
),
|
||||
},
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": e.to_string()
|
||||
})),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,25 +112,36 @@ pub async fn get_product_files(
|
||||
Path(product_id): Path<i64>,
|
||||
State(state): State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
||||
|
||||
let db_path = format!(
|
||||
"{}{}",
|
||||
state.db_dir.replace("users", "downloads"),
|
||||
"/products.sqlite"
|
||||
);
|
||||
|
||||
match DownloadDb::new(&db_path) {
|
||||
Ok(db) => {
|
||||
match db.get_files_by_product(product_id) {
|
||||
Ok(files) => (StatusCode::OK, Json(json!({
|
||||
Ok(db) => match db.get_files_by_product(product_id) {
|
||||
Ok(files) => (
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"product_id": product_id,
|
||||
"files": files,
|
||||
"total_files": files.len(),
|
||||
"total_size": files.iter().map(|f| f.file_size).sum::<u64>()
|
||||
}))),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
||||
})),
|
||||
),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": e.to_string()
|
||||
}))),
|
||||
}
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
||||
"error": e.to_string()
|
||||
}))),
|
||||
})),
|
||||
),
|
||||
},
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": e.to_string()
|
||||
})),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,29 +149,40 @@ pub async fn create_product_handler(
|
||||
State(state): State<AppState>,
|
||||
Json(payload): Json<serde_json::Value>,
|
||||
) -> impl IntoResponse {
|
||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
||||
|
||||
let db_path = format!(
|
||||
"{}{}",
|
||||
state.db_dir.replace("users", "downloads"),
|
||||
"/products.sqlite"
|
||||
);
|
||||
|
||||
let product_name = payload["product_name"].as_str().unwrap_or("");
|
||||
let series = payload["series"].as_str().unwrap_or("");
|
||||
let description = payload["description"].as_str();
|
||||
|
||||
|
||||
match DownloadDb::new(&db_path) {
|
||||
Ok(mut db) => {
|
||||
match db.create_product(product_name, series, description) {
|
||||
Ok(product_id) => (StatusCode::OK, Json(json!({
|
||||
Ok(mut db) => match db.create_product(product_name, series, description) {
|
||||
Ok(product_id) => (
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"ok": true,
|
||||
"product_id": product_id,
|
||||
"product_name": product_name,
|
||||
"series": series
|
||||
}))),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
||||
})),
|
||||
),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": e.to_string()
|
||||
}))),
|
||||
}
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
||||
"error": e.to_string()
|
||||
}))),
|
||||
})),
|
||||
),
|
||||
},
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": e.to_string()
|
||||
})),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,48 +191,62 @@ pub async fn assign_files_to_product(
|
||||
State(state): State<AppState>,
|
||||
Json(payload): Json<serde_json::Value>,
|
||||
) -> impl IntoResponse {
|
||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
||||
|
||||
let db_path = format!(
|
||||
"{}{}",
|
||||
state.db_dir.replace("users", "downloads"),
|
||||
"/products.sqlite"
|
||||
);
|
||||
|
||||
let files_vec = payload["files"].as_array().cloned().unwrap_or_default();
|
||||
let files = files_vec.as_slice();
|
||||
|
||||
|
||||
match DownloadDb::new(&db_path) {
|
||||
Ok(mut db) => {
|
||||
let mut assigned_count = 0;
|
||||
let mut errors = vec![];
|
||||
|
||||
|
||||
for file in files {
|
||||
let file_path = file["file_path"].as_str().unwrap_or("");
|
||||
let file_name = file["file_name"].as_str().unwrap_or("");
|
||||
let file_size = file["file_size"].as_u64().unwrap_or(0);
|
||||
let file_hash = file["file_hash"].as_str();
|
||||
|
||||
match db.add_file_to_product(product_id, file_path, file_name, file_size, file_hash) {
|
||||
|
||||
match db.add_file_to_product(product_id, file_path, file_name, file_size, file_hash)
|
||||
{
|
||||
Ok(_) => assigned_count += 1,
|
||||
Err(e) => {
|
||||
errors.push(format!("Failed to assign {}: {}", file_path, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if errors.is_empty() {
|
||||
(StatusCode::OK, Json(json!({
|
||||
"ok": true,
|
||||
"product_id": product_id,
|
||||
"assigned_count": assigned_count
|
||||
})))
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"ok": true,
|
||||
"product_id": product_id,
|
||||
"assigned_count": assigned_count
|
||||
})),
|
||||
)
|
||||
} else {
|
||||
(StatusCode::PARTIAL_CONTENT, Json(json!({
|
||||
"ok": true,
|
||||
"product_id": product_id,
|
||||
"assigned_count": assigned_count,
|
||||
"errors": errors
|
||||
})))
|
||||
(
|
||||
StatusCode::PARTIAL_CONTENT,
|
||||
Json(json!({
|
||||
"ok": true,
|
||||
"product_id": product_id,
|
||||
"assigned_count": assigned_count,
|
||||
"errors": errors
|
||||
})),
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
||||
"error": e.to_string()
|
||||
}))),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": e.to_string()
|
||||
})),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,24 +254,35 @@ pub async fn delete_product(
|
||||
Path(product_id): Path<i64>,
|
||||
State(state): State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
||||
|
||||
let db_path = format!(
|
||||
"{}{}",
|
||||
state.db_dir.replace("users", "downloads"),
|
||||
"/products.sqlite"
|
||||
);
|
||||
|
||||
match DownloadDb::new(&db_path) {
|
||||
Ok(mut db) => {
|
||||
match db.delete_product_with_files(product_id) {
|
||||
Ok((deleted_files, deleted_product)) => (StatusCode::OK, Json(json!({
|
||||
Ok(mut db) => match db.delete_product_with_files(product_id) {
|
||||
Ok((deleted_files, deleted_product)) => (
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"ok": true,
|
||||
"product_id": product_id,
|
||||
"deleted_files": deleted_files,
|
||||
"deleted_product": deleted_product
|
||||
}))),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
||||
})),
|
||||
),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": e.to_string()
|
||||
}))),
|
||||
}
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
||||
"error": e.to_string()
|
||||
}))),
|
||||
})),
|
||||
),
|
||||
},
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": e.to_string()
|
||||
})),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FileInfo {
|
||||
@@ -23,17 +23,17 @@ pub struct FileListResponse {
|
||||
pub fn scan_uploaded_files(user_id: &str) -> FileListResponse {
|
||||
let base_path = format!("/Users/accusys/Downloads/{}", user_id);
|
||||
let path = Path::new(&base_path);
|
||||
|
||||
|
||||
let mut files = Vec::new();
|
||||
let mut total_size = 0u64;
|
||||
|
||||
|
||||
if path.exists() {
|
||||
scan_directory_recursive(path, path, &mut files, &mut total_size);
|
||||
}
|
||||
|
||||
|
||||
FileListResponse {
|
||||
user_id: user_id.to_string(),
|
||||
base_path: base_path,
|
||||
base_path,
|
||||
total_files: files.len(),
|
||||
total_size,
|
||||
files,
|
||||
@@ -49,24 +49,25 @@ fn scan_directory_recursive(
|
||||
if let Ok(entries) = std::fs::read_dir(current) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
|
||||
|
||||
if path.is_file() {
|
||||
let filename = path.file_name()
|
||||
let filename = path
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
let file_size = entry.metadata()
|
||||
.map(|m| m.len())
|
||||
.unwrap_or(0);
|
||||
|
||||
let relative_path = path.strip_prefix(base)
|
||||
|
||||
let file_size = entry.metadata().map(|m| m.len()).unwrap_or(0);
|
||||
|
||||
let relative_path = path
|
||||
.strip_prefix(base)
|
||||
.ok()
|
||||
.and_then(|p| p.to_str())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| filename.clone());
|
||||
|
||||
let upload_time = entry.metadata()
|
||||
|
||||
let upload_time = entry
|
||||
.metadata()
|
||||
.ok()
|
||||
.and_then(|m| m.modified().ok())
|
||||
.and_then(|t| {
|
||||
@@ -75,13 +76,16 @@ fn scan_directory_recursive(
|
||||
.map(|dt| dt.format("%Y-%m-%dT%H:%M:%SZ").to_string())
|
||||
})
|
||||
.unwrap_or_else(|| chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string());
|
||||
|
||||
|
||||
let file_hash = if file_size > 0 {
|
||||
compute_file_hash(&path).ok()
|
||||
} else {
|
||||
Some("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string())
|
||||
Some(
|
||||
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
|
||||
.to_string(),
|
||||
)
|
||||
};
|
||||
|
||||
|
||||
files.push(FileInfo {
|
||||
filename,
|
||||
file_size,
|
||||
@@ -90,7 +94,7 @@ fn scan_directory_recursive(
|
||||
relative_path,
|
||||
upload_time,
|
||||
});
|
||||
|
||||
|
||||
*total_size += file_size;
|
||||
} else if path.is_dir() {
|
||||
scan_directory_recursive(base, &path, files, total_size);
|
||||
@@ -102,11 +106,11 @@ fn scan_directory_recursive(
|
||||
pub fn compute_file_hash(path: &Path) -> Result<String, std::io::Error> {
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::io::Read;
|
||||
|
||||
|
||||
let mut file = std::fs::File::open(path)?;
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buffer = [0u8; 8192];
|
||||
|
||||
|
||||
loop {
|
||||
let bytes_read = file.read(&mut buffer)?;
|
||||
if bytes_read == 0 {
|
||||
@@ -114,6 +118,6 @@ pub fn compute_file_hash(path: &Path) -> Result<String, std::io::Error> {
|
||||
}
|
||||
hasher.update(&buffer[..bytes_read]);
|
||||
}
|
||||
|
||||
|
||||
Ok(format!("{:x}", hasher.finalize()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use pulldown_cmark::{Parser, Event, Tag, HeadingLevel, TagEnd};
|
||||
use regex::Regex;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MarkdownFile {
|
||||
@@ -39,17 +37,21 @@ pub struct SeriesMarkdown {
|
||||
pub fn parse_category_markdown(content: &str) -> Result<CategoryMarkdown> {
|
||||
let mut category = String::new();
|
||||
let mut sections: Vec<CategorySection> = Vec::new();
|
||||
|
||||
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
let mut current_product = String::new();
|
||||
let mut current_files: Vec<MarkdownFile> = Vec::new();
|
||||
let mut pending_file: Option<(String, String)> = None;
|
||||
|
||||
|
||||
for i in 0..lines.len() {
|
||||
let line = lines[i].trim();
|
||||
|
||||
|
||||
if line.contains("**Category**:") {
|
||||
category = line.replace("**Category**:", "").replace("**", "").trim().to_string();
|
||||
category = line
|
||||
.replace("**Category**:", "")
|
||||
.replace("**", "")
|
||||
.trim()
|
||||
.to_string();
|
||||
} else if line.starts_with("## ") {
|
||||
if !current_product.is_empty() && !current_files.is_empty() {
|
||||
sections.push(CategorySection {
|
||||
@@ -72,13 +74,17 @@ pub fn parse_category_markdown(content: &str) -> Result<CategoryMarkdown> {
|
||||
current_files.push(MarkdownFile {
|
||||
filename,
|
||||
size: Some(size),
|
||||
download_url: line.trim_start_matches('`').trim_end_matches('`').trim().to_string(),
|
||||
download_url: line
|
||||
.trim_start_matches('`')
|
||||
.trim_end_matches('`')
|
||||
.trim()
|
||||
.to_string(),
|
||||
});
|
||||
pending_file = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if !current_product.is_empty() && !current_files.is_empty() {
|
||||
sections.push(CategorySection {
|
||||
product: current_product.clone(),
|
||||
@@ -92,17 +98,21 @@ pub fn parse_category_markdown(content: &str) -> Result<CategoryMarkdown> {
|
||||
pub fn parse_series_markdown(content: &str) -> Result<SeriesMarkdown> {
|
||||
let mut series = String::new();
|
||||
let mut sections: Vec<SeriesSection> = Vec::new();
|
||||
|
||||
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
let mut current_category = String::new();
|
||||
let mut current_files: Vec<MarkdownFile> = Vec::new();
|
||||
let mut pending_file: Option<(String, String)> = None;
|
||||
|
||||
|
||||
for i in 0..lines.len() {
|
||||
let line = lines[i].trim();
|
||||
|
||||
|
||||
if line.starts_with("# ") && line.contains("Download Links") {
|
||||
series = line.replace("# ", "").replace(" Download Links", "").trim().to_string();
|
||||
series = line
|
||||
.replace("# ", "")
|
||||
.replace(" Download Links", "")
|
||||
.trim()
|
||||
.to_string();
|
||||
} else if line.starts_with("## ") {
|
||||
if !current_category.is_empty() && !current_files.is_empty() {
|
||||
sections.push(SeriesSection {
|
||||
@@ -125,13 +135,17 @@ pub fn parse_series_markdown(content: &str) -> Result<SeriesMarkdown> {
|
||||
current_files.push(MarkdownFile {
|
||||
filename,
|
||||
size: Some(size),
|
||||
download_url: line.trim_start_matches('`').trim_end_matches('`').trim().to_string(),
|
||||
download_url: line
|
||||
.trim_start_matches('`')
|
||||
.trim_end_matches('`')
|
||||
.trim()
|
||||
.to_string(),
|
||||
});
|
||||
pending_file = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if !current_category.is_empty() && !current_files.is_empty() {
|
||||
sections.push(SeriesSection {
|
||||
category: current_category.clone(),
|
||||
@@ -144,65 +158,77 @@ pub fn parse_series_markdown(content: &str) -> Result<SeriesMarkdown> {
|
||||
|
||||
pub fn read_category_files(dir: &Path) -> Result<Vec<(String, String)>> {
|
||||
let mut files = Vec::new();
|
||||
|
||||
|
||||
for entry in fs::read_dir(dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.extension().map_or(false, |ext| ext == "md") && path.file_name() != Some(std::ffi::OsStr::new("README.md")) {
|
||||
|
||||
if path.extension().is_some_and(|ext| ext == "md")
|
||||
&& path.file_name() != Some(std::ffi::OsStr::new("README.md"))
|
||||
{
|
||||
let filename = path.file_name().unwrap().to_string_lossy().to_string();
|
||||
let content = fs::read_to_string(&path)?;
|
||||
files.push((filename, content));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
pub fn read_series_files(dir: &Path) -> Result<Vec<(String, String)>> {
|
||||
let mut files = Vec::new();
|
||||
|
||||
|
||||
for entry in fs::read_dir(dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.extension().map_or(false, |ext| ext == "md") && path.file_name() != Some(std::ffi::OsStr::new("README.md")) {
|
||||
|
||||
if path.extension().is_some_and(|ext| ext == "md")
|
||||
&& path.file_name() != Some(std::ffi::OsStr::new("README.md"))
|
||||
{
|
||||
let filename = path.file_name().unwrap().to_string_lossy().to_string();
|
||||
let content = fs::read_to_string(&path)?;
|
||||
files.push((filename, content));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
pub fn import_categories_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type: &str) -> Result<()> {
|
||||
pub fn import_categories_to_db(
|
||||
conn: &rusqlite::Connection,
|
||||
user_id: &str,
|
||||
tree_type: &str,
|
||||
) -> Result<()> {
|
||||
use crate::FileTree;
|
||||
use filetree::node::{FileNode, Aliases, NodeType};
|
||||
use uuid::Uuid;
|
||||
use filetree::node::{Aliases, FileNode, NodeType};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
let category_dir = Path::new("/Users/accusys/markbase/data/downloads/by_category");
|
||||
let files = read_category_files(category_dir)?;
|
||||
|
||||
|
||||
println!("Found {} Markdown files", files.len());
|
||||
|
||||
|
||||
let mut tree = FileTree::load(conn, user_id, tree_type)?;
|
||||
|
||||
|
||||
for (_filename, content) in files {
|
||||
let parsed = parse_category_markdown(&content)?;
|
||||
|
||||
println!("Parsed category: '{}', sections: {}", parsed.category, parsed.sections.len());
|
||||
|
||||
|
||||
println!(
|
||||
"Parsed category: '{}', sections: {}",
|
||||
parsed.category,
|
||||
parsed.sections.len()
|
||||
);
|
||||
|
||||
if parsed.category.is_empty() {
|
||||
println!("Warning: category is empty, skipping");
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
let category_node_id = Uuid::new_v4().to_string();
|
||||
let mut aliases_map = HashMap::new();
|
||||
aliases_map.insert("category_type".to_string(), "category".to_string());
|
||||
|
||||
|
||||
let category_node = FileNode {
|
||||
node_id: category_node_id.clone(),
|
||||
label: parsed.category.clone(),
|
||||
@@ -221,20 +247,27 @@ pub fn import_categories_to_db(conn: &rusqlite::Connection, user_id: &str, tree_
|
||||
updated_at: chrono::Utc::now().to_rfc3339(),
|
||||
sort_order: 0,
|
||||
};
|
||||
|
||||
println!("Inserting category node: {} (id: {})", category_node.label, category_node_id);
|
||||
|
||||
|
||||
println!(
|
||||
"Inserting category node: {} (id: {})",
|
||||
category_node.label, category_node_id
|
||||
);
|
||||
|
||||
tree.insert_node(conn, &category_node)?;
|
||||
|
||||
|
||||
println!("Category node inserted successfully");
|
||||
|
||||
|
||||
for section in parsed.sections {
|
||||
println!("Processing section: {} with {} files", section.product, section.files.len());
|
||||
|
||||
println!(
|
||||
"Processing section: {} with {} files",
|
||||
section.product,
|
||||
section.files.len()
|
||||
);
|
||||
|
||||
let product_node_id = Uuid::new_v4().to_string();
|
||||
let mut aliases_map = HashMap::new();
|
||||
aliases_map.insert("product".to_string(), section.product.clone());
|
||||
|
||||
|
||||
let product_node = FileNode {
|
||||
node_id: product_node_id.clone(),
|
||||
label: section.product.clone(),
|
||||
@@ -253,15 +286,18 @@ pub fn import_categories_to_db(conn: &rusqlite::Connection, user_id: &str, tree_
|
||||
updated_at: chrono::Utc::now().to_rfc3339(),
|
||||
sort_order: 0,
|
||||
};
|
||||
|
||||
|
||||
tree.insert_node(conn, &product_node)?;
|
||||
|
||||
|
||||
for file in section.files {
|
||||
let file_node_id = Uuid::new_v4().to_string();
|
||||
let mut aliases_map = HashMap::new();
|
||||
aliases_map.insert("download_url".to_string(), file.download_url.clone());
|
||||
aliases_map.insert("file_size_display".to_string(), file.size.clone().unwrap_or_else(|| "Unknown".to_string()));
|
||||
|
||||
aliases_map.insert(
|
||||
"file_size_display".to_string(),
|
||||
file.size.clone().unwrap_or_else(|| "Unknown".to_string()),
|
||||
);
|
||||
|
||||
let file_node = FileNode {
|
||||
node_id: file_node_id.clone(),
|
||||
label: file.filename.clone(),
|
||||
@@ -280,42 +316,50 @@ pub fn import_categories_to_db(conn: &rusqlite::Connection, user_id: &str, tree_
|
||||
updated_at: chrono::Utc::now().to_rfc3339(),
|
||||
sort_order: 0,
|
||||
};
|
||||
|
||||
|
||||
tree.insert_node(conn, &file_node)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn import_series_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type: &str) -> Result<()> {
|
||||
pub fn import_series_to_db(
|
||||
conn: &rusqlite::Connection,
|
||||
user_id: &str,
|
||||
tree_type: &str,
|
||||
) -> Result<()> {
|
||||
use crate::FileTree;
|
||||
use filetree::node::{FileNode, Aliases, NodeType};
|
||||
use uuid::Uuid;
|
||||
use filetree::node::{Aliases, FileNode, NodeType};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
let series_dir = Path::new("/Users/accusys/markbase/data/downloads/by_series");
|
||||
let files = read_series_files(series_dir)?;
|
||||
|
||||
|
||||
println!("Found {} Markdown files for series", files.len());
|
||||
|
||||
|
||||
let mut tree = FileTree::load(conn, user_id, tree_type)?;
|
||||
|
||||
|
||||
for (_filename, content) in files {
|
||||
let parsed = parse_series_markdown(&content)?;
|
||||
|
||||
println!("Parsed series: '{}', sections: {}", parsed.series, parsed.sections.len());
|
||||
|
||||
|
||||
println!(
|
||||
"Parsed series: '{}', sections: {}",
|
||||
parsed.series,
|
||||
parsed.sections.len()
|
||||
);
|
||||
|
||||
if parsed.series.is_empty() {
|
||||
println!("Warning: series is empty, skipping");
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
let series_node_id = Uuid::new_v4().to_string();
|
||||
let mut aliases_map = HashMap::new();
|
||||
aliases_map.insert("series_type".to_string(), "series".to_string());
|
||||
|
||||
|
||||
let series_node = FileNode {
|
||||
node_id: series_node_id.clone(),
|
||||
label: parsed.series.clone(),
|
||||
@@ -334,18 +378,22 @@ pub fn import_series_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type
|
||||
updated_at: chrono::Utc::now().to_rfc3339(),
|
||||
sort_order: 0,
|
||||
};
|
||||
|
||||
|
||||
tree.insert_node(conn, &series_node)?;
|
||||
|
||||
|
||||
println!("Series node inserted successfully");
|
||||
|
||||
|
||||
for section in parsed.sections {
|
||||
println!("Processing section: {} with {} files", section.category, section.files.len());
|
||||
|
||||
println!(
|
||||
"Processing section: {} with {} files",
|
||||
section.category,
|
||||
section.files.len()
|
||||
);
|
||||
|
||||
let category_node_id = Uuid::new_v4().to_string();
|
||||
let mut aliases_map = HashMap::new();
|
||||
aliases_map.insert("category".to_string(), section.category.clone());
|
||||
|
||||
|
||||
let category_node = FileNode {
|
||||
node_id: category_node_id.clone(),
|
||||
label: section.category.clone(),
|
||||
@@ -364,15 +412,18 @@ pub fn import_series_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type
|
||||
updated_at: chrono::Utc::now().to_rfc3339(),
|
||||
sort_order: 0,
|
||||
};
|
||||
|
||||
|
||||
tree.insert_node(conn, &category_node)?;
|
||||
|
||||
|
||||
for file in section.files {
|
||||
let file_node_id = Uuid::new_v4().to_string();
|
||||
let mut aliases_map = HashMap::new();
|
||||
aliases_map.insert("download_url".to_string(), file.download_url.clone());
|
||||
aliases_map.insert("file_size_display".to_string(), file.size.clone().unwrap_or_else(|| "Unknown".to_string()));
|
||||
|
||||
aliases_map.insert(
|
||||
"file_size_display".to_string(),
|
||||
file.size.clone().unwrap_or_else(|| "Unknown".to_string()),
|
||||
);
|
||||
|
||||
let file_node = FileNode {
|
||||
node_id: file_node_id.clone(),
|
||||
label: file.filename.clone(),
|
||||
@@ -391,12 +442,12 @@ pub fn import_series_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type
|
||||
updated_at: chrono::Utc::now().to_rfc3339(),
|
||||
sort_order: 0,
|
||||
};
|
||||
|
||||
|
||||
tree.insert_node(conn, &file_node)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -418,8 +469,8 @@ mod tests {
|
||||
```https://download.accusys.ddns.net/api/v2/download/products/ExaSAN-DAS/C1M_C2M/User%20Guide/C2M-QIG20170906.zip
|
||||
```
|
||||
"#;
|
||||
|
||||
|
||||
let result = parse_category_markdown(content).unwrap();
|
||||
assert_eq!(result.category, "GUI");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
pub mod audio;
|
||||
pub mod auth;
|
||||
pub mod audit;
|
||||
pub mod cli;
|
||||
pub mod api;
|
||||
pub mod archive; // Archive Module - Universal Compression Format Support (Phase 1-3完成)
|
||||
pub mod audio;
|
||||
pub mod audit;
|
||||
pub mod auth;
|
||||
pub mod category_view;
|
||||
pub mod cli;
|
||||
pub mod command;
|
||||
pub mod config;
|
||||
pub mod download;
|
||||
pub mod import_markdown;
|
||||
pub mod pg_client;
|
||||
pub mod render;
|
||||
pub mod rsync;
|
||||
@@ -14,20 +17,17 @@ pub mod s3_auth;
|
||||
pub mod s3_config;
|
||||
pub mod s3_xml;
|
||||
pub mod scan;
|
||||
pub mod server;
|
||||
pub mod archive; // Archive Module - Universal Compression Format Support (Phase 1-3完成)
|
||||
pub mod category_view;
|
||||
pub mod import_markdown; // Category View Module - 双视图管理(Phase 1)
|
||||
// pub mod sftp; // ⚠️ russh版本(已禁用)
|
||||
// pub mod ssh2_server; // ssh2服务器(已禁用)
|
||||
// pub mod ssh2_mod; // ssh2辅助模块(已禁用)
|
||||
pub mod ssh_server; // SSH服务器(Phase 1-9完成,正在修复编译错误)⭐⭐⭐⭐⭐
|
||||
pub mod server; // Category View Module - 双视图管理(Phase 1)
|
||||
// pub mod sftp; // ⚠️ russh版本(已禁用)
|
||||
// pub mod ssh2_server; // ssh2服务器(已禁用)
|
||||
// pub mod ssh2_mod; // ssh2辅助模块(已禁用)
|
||||
pub mod provider; // DataProvider抽象层(Phase 5)
|
||||
pub mod ssh_server; // SSH服务器(Phase 1-9完成,正在修复编译错误)⭐⭐⭐⭐⭐
|
||||
pub mod sync;
|
||||
pub mod provider; // DataProvider抽象层(Phase 5)
|
||||
pub mod vfs; // VFS抽象层(Phase 1-6重构计划)
|
||||
pub mod vfs; // VFS抽象层(Phase 1-6重构计划)
|
||||
|
||||
#[cfg(test)]
|
||||
mod security_audit; // Security Audit Module - Phase 9
|
||||
mod security_audit; // Security Audit Module - Phase 9
|
||||
|
||||
// Re-export from external filetree crate
|
||||
pub use filetree::node::FileNode;
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use markbase_core::cli::Cli;
|
||||
use clap::Parser;
|
||||
use markbase_core::cli::Cli;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
env_logger::Builder::from_default_env()
|
||||
.filter_level(log::LevelFilter::Info)
|
||||
.init();
|
||||
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
match cli.command {
|
||||
@@ -25,4 +25,4 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,12 @@ pub struct PgClient {
|
||||
database: String,
|
||||
}
|
||||
|
||||
impl Default for PgClient {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PgClient {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
pub mod sqlite;
|
||||
pub mod pg;
|
||||
pub mod sqlite;
|
||||
|
||||
pub use sqlite::SqliteProvider;
|
||||
pub use pg::PgProvider;
|
||||
pub use sqlite::SqliteProvider;
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
@@ -57,7 +57,10 @@ pub trait DataProvider: Send + Sync {
|
||||
|
||||
/// 检查用户是否存在且启用
|
||||
fn user_exists(&self, username: &str) -> Result<bool, ProviderError> {
|
||||
Ok(self.get_user(username)?.map(|u| u.status == 1).unwrap_or(false))
|
||||
Ok(self
|
||||
.get_user(username)?
|
||||
.map(|u| u.status == 1)
|
||||
.unwrap_or(false))
|
||||
}
|
||||
|
||||
/// 获取用户的公开密钥列表(OpenSSH authorized_keys格式)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::path::PathBuf;
|
||||
use postgres::{Client, NoTls};
|
||||
use bcrypt::verify;
|
||||
use super::{DataProvider, ProviderError, User};
|
||||
use bcrypt::verify;
|
||||
use postgres::{Client, NoTls};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// PostgreSQL 数据提供者(兼容 SFTPGo 的 users 表)
|
||||
pub struct PgProvider {
|
||||
@@ -13,7 +13,9 @@ impl PgProvider {
|
||||
///
|
||||
/// 连接字符串格式:host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026
|
||||
pub fn new(conn_str: &str) -> Result<Self, ProviderError> {
|
||||
Ok(Self { conn_str: conn_str.to_string() })
|
||||
Ok(Self {
|
||||
conn_str: conn_str.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_params(
|
||||
@@ -40,18 +42,22 @@ impl DataProvider for PgProvider {
|
||||
fn get_user(&self, username: &str) -> Result<Option<User>, ProviderError> {
|
||||
let mut conn = self.open_conn()?;
|
||||
|
||||
let result = conn.query_opt(
|
||||
"SELECT username, password, home_dir, permissions, uid, gid, status
|
||||
let result = conn
|
||||
.query_opt(
|
||||
"SELECT username, password, home_dir, permissions, uid, gid, status
|
||||
FROM users WHERE username = $1 AND status = 1",
|
||||
&[&username],
|
||||
).map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
|
||||
&[&username],
|
||||
)
|
||||
.map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
|
||||
|
||||
match result {
|
||||
Some(row) => Ok(Some(User {
|
||||
username: row.get(0),
|
||||
password_hash: row.get::<_, Option<String>>(1).unwrap_or_default(),
|
||||
home_dir: PathBuf::from(row.get::<_, String>(2)),
|
||||
permissions: row.get::<_, Option<String>>(3).unwrap_or_else(|| "*".to_string()),
|
||||
permissions: row
|
||||
.get::<_, Option<String>>(3)
|
||||
.unwrap_or_else(|| "*".to_string()),
|
||||
uid: row.get::<_, i64>(4) as u32,
|
||||
gid: row.get::<_, i64>(5) as u32,
|
||||
status: row.get(6),
|
||||
@@ -75,24 +81,31 @@ impl DataProvider for PgProvider {
|
||||
}
|
||||
|
||||
fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
|
||||
Ok(self.get_user(username)?.map(|u| u.home_dir.to_string_lossy().to_string()))
|
||||
Ok(self
|
||||
.get_user(username)?
|
||||
.map(|u| u.home_dir.to_string_lossy().to_string()))
|
||||
}
|
||||
|
||||
fn get_public_keys(&self, username: &str) -> Result<Vec<String>, ProviderError> {
|
||||
let mut conn = self.open_conn()?;
|
||||
let result = conn.query_opt(
|
||||
"SELECT public_keys FROM users WHERE username = $1 AND status = 1",
|
||||
&[&username],
|
||||
).map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
|
||||
let result = conn
|
||||
.query_opt(
|
||||
"SELECT public_keys FROM users WHERE username = $1 AND status = 1",
|
||||
&[&username],
|
||||
)
|
||||
.map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
|
||||
|
||||
match result {
|
||||
Some(row) => {
|
||||
let json_str: Option<String> = row.get(0);
|
||||
match json_str {
|
||||
Some(s) if !s.is_empty() => {
|
||||
let keys: Vec<serde_json::Value> = serde_json::from_str(&s)
|
||||
.map_err(|e| ProviderError::Internal(format!("JSON parse error: {}", e)))?;
|
||||
Ok(keys.iter()
|
||||
let keys: Vec<serde_json::Value> =
|
||||
serde_json::from_str(&s).map_err(|e| {
|
||||
ProviderError::Internal(format!("JSON parse error: {}", e))
|
||||
})?;
|
||||
Ok(keys
|
||||
.iter()
|
||||
.filter_map(|v| v.get("public_key")?.as_str().map(|s| s.to_string()))
|
||||
.collect())
|
||||
}
|
||||
@@ -112,7 +125,7 @@ mod tests {
|
||||
fn test_pg_provider_connection() {
|
||||
// 仅当 SFTPGo PostgreSQL 可用时运行
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||
);
|
||||
assert!(provider.is_ok(), "Should connect to SFTPGo PostgreSQL");
|
||||
}
|
||||
@@ -120,8 +133,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_pg_get_user_demo() {
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
).unwrap();
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||
)
|
||||
.unwrap();
|
||||
let user = provider.get_user("demo").unwrap();
|
||||
assert!(user.is_some(), "Demo user should exist");
|
||||
assert_eq!(user.unwrap().username, "demo");
|
||||
@@ -130,8 +144,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_pg_get_user_momentry() {
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
).unwrap();
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||
)
|
||||
.unwrap();
|
||||
let user = provider.get_user("momentry").unwrap();
|
||||
assert!(user.is_some(), "Momentry user should exist");
|
||||
}
|
||||
@@ -139,8 +154,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_pg_get_user_warren() {
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
).unwrap();
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||
)
|
||||
.unwrap();
|
||||
let user = provider.get_user("warren").unwrap();
|
||||
assert!(user.is_some(), "Warren user should exist");
|
||||
}
|
||||
@@ -148,8 +164,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_pg_check_password_demo() {
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
).unwrap();
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||
)
|
||||
.unwrap();
|
||||
let valid = provider.check_password("demo", "demo123").unwrap();
|
||||
assert!(valid, "Password should be valid");
|
||||
}
|
||||
@@ -157,8 +174,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_pg_check_password_invalid() {
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
).unwrap();
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||
)
|
||||
.unwrap();
|
||||
let valid = provider.check_password("demo", "wrong").unwrap();
|
||||
assert!(!valid, "Wrong password should fail");
|
||||
}
|
||||
@@ -166,8 +184,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_pg_get_home_dir() {
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
).unwrap();
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||
)
|
||||
.unwrap();
|
||||
let dir = provider.get_home_dir("demo").unwrap();
|
||||
assert!(dir.is_some());
|
||||
assert!(dir.unwrap().contains("momentry"));
|
||||
@@ -176,8 +195,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_pg_nonexistent_user() {
|
||||
let provider = PgProvider::new(
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
||||
).unwrap();
|
||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||
)
|
||||
.unwrap();
|
||||
let user = provider.get_user("__nonexistent__").unwrap();
|
||||
assert!(user.is_none());
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::path::PathBuf;
|
||||
use rusqlite::{Connection, params};
|
||||
use bcrypt::verify;
|
||||
use super::{DataProvider, ProviderError, User};
|
||||
use bcrypt::verify;
|
||||
use rusqlite::{params, Connection};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// SQLite 数据提供者
|
||||
pub struct SqliteProvider {
|
||||
@@ -13,7 +13,8 @@ impl SqliteProvider {
|
||||
let path = PathBuf::from(db_path);
|
||||
if !path.exists() {
|
||||
return Err(ProviderError::NotFound(format!(
|
||||
"Database not found: {}", db_path
|
||||
"Database not found: {}",
|
||||
db_path
|
||||
)));
|
||||
}
|
||||
Ok(Self { db_path: path })
|
||||
@@ -50,7 +51,8 @@ impl DataProvider for SqliteProvider {
|
||||
Ok(user) => Ok(Some(user)),
|
||||
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
|
||||
Err(e) => Err(ProviderError::Internal(format!(
|
||||
"Database query error: {}", e
|
||||
"Database query error: {}",
|
||||
e
|
||||
))),
|
||||
}
|
||||
}
|
||||
@@ -66,7 +68,9 @@ impl DataProvider for SqliteProvider {
|
||||
}
|
||||
|
||||
fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
|
||||
Ok(self.get_user(username)?.map(|u| u.home_dir.to_string_lossy().to_string()))
|
||||
Ok(self
|
||||
.get_user(username)?
|
||||
.map(|u| u.home_dir.to_string_lossy().to_string()))
|
||||
}
|
||||
|
||||
fn get_public_keys(&self, username: &str) -> Result<Vec<String>, ProviderError> {
|
||||
@@ -98,7 +102,10 @@ mod tests {
|
||||
}
|
||||
|
||||
fn get_test_db_path() -> String {
|
||||
format!("{}/../data/auth.sqlite", std::env::var("CARGO_MANIFEST_DIR").unwrap())
|
||||
format!(
|
||||
"{}/../data/auth.sqlite",
|
||||
std::env::var("CARGO_MANIFEST_DIR").unwrap()
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use anyhow::Result;
|
||||
use md5::compute;
|
||||
|
||||
pub struct RollingChecksum {
|
||||
|
||||
@@ -50,6 +50,12 @@ pub struct DecompressionStream {
|
||||
decompressor: Decompress,
|
||||
}
|
||||
|
||||
impl Default for DecompressionStream {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl DecompressionStream {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use crate::rsync::checksum::{compute_block_checksums, BlockChecksum};
|
||||
use crate::rsync::compress::{CompressionStream, DecompressionStream};
|
||||
use crate::rsync::delta::{DeltaAlgorithm, DeltaInstruction};
|
||||
use crate::rsync::protocol::{RsyncCommand, RsyncProtocol};
|
||||
use crate::rsync::protocol::RsyncCommand;
|
||||
use crate::rsync::RsyncConfig;
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -162,6 +162,12 @@ pub struct RsyncHandshake {
|
||||
negotiated_version: u32,
|
||||
}
|
||||
|
||||
impl Default for RsyncHandshake {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl RsyncHandshake {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
use filetree::{FileTree, node::{FileNode, Aliases}};
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Path, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Json},
|
||||
};
|
||||
use filetree::{
|
||||
node::FileNode,
|
||||
FileTree,
|
||||
};
|
||||
use futures_util::StreamExt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio_util::io::ReaderStream;
|
||||
|
||||
@@ -41,10 +43,10 @@ pub async fn list_buckets(State(state): State<crate::server::AppState>) -> impl
|
||||
|
||||
pub async fn list_objects(
|
||||
Path(bucket): Path<String>,
|
||||
State(state): State<crate::server::AppState>,
|
||||
State(_state): State<crate::server::AppState>,
|
||||
) -> impl IntoResponse {
|
||||
println!("S3 List Objects: bucket={}", bucket);
|
||||
|
||||
|
||||
let conn = match FileTree::open_user_db(&bucket) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
@@ -70,20 +72,20 @@ pub async fn list_objects(
|
||||
"Key": build_s3_key(&tree, n),
|
||||
"LastModified": n.registered_at.clone().unwrap_or_default(),
|
||||
"ETag": n.sha256.clone().unwrap_or_default(),
|
||||
"Size": n.file_size.clone().unwrap_or(0),
|
||||
"Size": n.file_size.unwrap_or(0),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
println!("Listed {} objects for bucket {}", objects.len(), bucket);
|
||||
|
||||
|
||||
let (headers, xml_body) = crate::s3_xml::list_objects_xml(&bucket, &objects);
|
||||
(StatusCode::OK, headers, xml_body).into_response()
|
||||
}
|
||||
|
||||
pub async fn get_object(
|
||||
Path((bucket, key)): Path<(String, String)>,
|
||||
State(state): State<crate::server::AppState>,
|
||||
State(_state): State<crate::server::AppState>,
|
||||
headers: HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
println!("S3 GET Object: bucket={}, key={}", bucket, key);
|
||||
@@ -119,7 +121,7 @@ pub async fn get_object(
|
||||
);
|
||||
|
||||
let file_uuid = node.file_uuid.clone().unwrap_or_default();
|
||||
let file_size = node.file_size.clone().unwrap_or(0);
|
||||
let file_size = node.file_size.unwrap_or(0);
|
||||
let sha256 = node.sha256.clone().unwrap_or_default();
|
||||
|
||||
let real_path = get_real_file_path(&conn, &file_uuid);
|
||||
@@ -233,7 +235,7 @@ pub async fn put_object(
|
||||
|
||||
let sha256_hash_clone = sha256_hash.clone();
|
||||
let file_path_clone = file_path.clone();
|
||||
let label = key.split('/').last().unwrap_or(&key).to_string();
|
||||
let label = key.split('/').next_back().unwrap_or(&key).to_string();
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||
let conn = match FileTree::open_user_db(&bucket) {
|
||||
@@ -246,7 +248,7 @@ pub async fn put_object(
|
||||
|row| row.get::<_, i32>(0),
|
||||
)
|
||||
.unwrap_or(0) > 0;
|
||||
|
||||
|
||||
if !has_tables {
|
||||
// Initialize tables if not exist
|
||||
c.execute_batch(filetree::CREATE_TABLES)?;
|
||||
@@ -298,7 +300,7 @@ pub async fn put_object(
|
||||
|
||||
pub async fn head_object(
|
||||
Path((bucket, key)): Path<(String, String)>,
|
||||
State(state): State<crate::server::AppState>,
|
||||
State(_state): State<crate::server::AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let conn = match FileTree::open_user_db(&bucket) {
|
||||
Ok(c) => c,
|
||||
@@ -323,7 +325,7 @@ pub async fn head_object(
|
||||
"ETag",
|
||||
node.sha256.clone().unwrap_or_default().parse().unwrap(),
|
||||
);
|
||||
headers.insert("Content-Length", node.file_size.clone().unwrap_or(0).into());
|
||||
headers.insert("Content-Length", node.file_size.unwrap_or(0).into());
|
||||
|
||||
(StatusCode::OK, headers)
|
||||
}
|
||||
@@ -438,7 +440,7 @@ fn find_node_by_s3_key(tree: &FileTree, key: &str) -> Option<FileNode> {
|
||||
}
|
||||
|
||||
// 方法2:通过filename直接匹配(fallback)
|
||||
let filename = key.split('/').last().unwrap_or(key);
|
||||
let filename = key.split('/').next_back().unwrap_or(key);
|
||||
tree.nodes
|
||||
.iter()
|
||||
.filter(|n| n.node_type == filetree::node::NodeType::File)
|
||||
@@ -501,7 +503,7 @@ async fn handle_range_request(
|
||||
}
|
||||
|
||||
// 使用take限制读取长度
|
||||
let limited_file = file.take(content_length as u64);
|
||||
let limited_file = file.take(content_length);
|
||||
let stream = ReaderStream::new(limited_file);
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
@@ -535,11 +537,7 @@ fn parse_range_header(range: &str, file_size: i64) -> Option<(u64, u64)> {
|
||||
let (start, end) = if parts[0].is_empty() {
|
||||
// "bytes=-N"格式:最后N字节
|
||||
let suffix_length = parts[1].parse::<u64>().ok()?;
|
||||
let start = if suffix_length > file_size as u64 {
|
||||
0
|
||||
} else {
|
||||
file_size as u64 - suffix_length
|
||||
};
|
||||
let start = (file_size as u64).saturating_sub(suffix_length);
|
||||
(start, file_size as u64 - 1)
|
||||
} else if parts[1].is_empty() {
|
||||
// "bytes=N-"格式:从N到结尾
|
||||
|
||||
@@ -8,11 +8,11 @@ type HmacSha256 = Hmac<Sha256>;
|
||||
pub fn verify_signature(headers: HeaderMap, method: &str, path: &str) -> bool {
|
||||
// Load S3 config and check require_auth flag
|
||||
let config = crate::s3_config::S3Config::load_default().unwrap_or_default();
|
||||
|
||||
|
||||
// Merge environment variables (allows override via MB_S3_REQUIRE_AUTH)
|
||||
let mut config = config;
|
||||
config.merge_env();
|
||||
|
||||
|
||||
if !config.s3.require_auth {
|
||||
// Development mode: allow access without authentication
|
||||
return true;
|
||||
@@ -127,7 +127,7 @@ fn calculate_signature(
|
||||
headers: HeaderMap,
|
||||
method: &str,
|
||||
path: &str,
|
||||
access_key: &str,
|
||||
_access_key: &str,
|
||||
secret_key: &str,
|
||||
region: &str,
|
||||
service: &str,
|
||||
@@ -143,9 +143,9 @@ fn calculate_signature(
|
||||
let signing_key = calculate_signing_key(secret_key, date, region, service);
|
||||
|
||||
// 4. Calculate Signature
|
||||
let signature = hmac_sha256_hex(&signing_key, &string_to_sign);
|
||||
|
||||
|
||||
signature
|
||||
hmac_sha256_hex(&signing_key, &string_to_sign)
|
||||
}
|
||||
|
||||
fn create_canonical_request(headers: HeaderMap, method: &str, path: &str) -> String {
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Default)]
|
||||
pub struct S3Config {
|
||||
#[serde(default)]
|
||||
pub s3: S3Section,
|
||||
@@ -40,6 +41,7 @@ pub struct KeysSection {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Default)]
|
||||
pub struct BucketsSection {
|
||||
#[serde(default)]
|
||||
pub mappings: std::collections::HashMap<String, String>,
|
||||
@@ -96,16 +98,6 @@ fn admin_permissions() -> Vec<String> {
|
||||
]
|
||||
}
|
||||
|
||||
impl Default for S3Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
s3: S3Section::default(),
|
||||
keys: KeysSection::default(),
|
||||
buckets: BucketsSection::default(),
|
||||
permissions: PermissionsSection::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for S3Section {
|
||||
fn default() -> Self {
|
||||
@@ -129,13 +121,6 @@ impl Default for KeysSection {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BucketsSection {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mappings: std::collections::HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PermissionsSection {
|
||||
fn default() -> Self {
|
||||
@@ -169,9 +154,9 @@ impl S3Config {
|
||||
Self::load("config/s3.toml")
|
||||
}
|
||||
|
||||
pub fn save(&self, path: &str) -> Result<()> {
|
||||
pub fn save(&self, path: &str) -> Result<()> {
|
||||
let config_path = PathBuf::from(path);
|
||||
|
||||
|
||||
// Create backup before saving
|
||||
if config_path.exists() {
|
||||
let backup_path = config_path.with_extension("toml.bak");
|
||||
@@ -179,13 +164,13 @@ pub fn save(&self, path: &str) -> Result<()> {
|
||||
.with_context(|| format!("Failed to create backup: {}", backup_path.display()))?;
|
||||
log::info!("S3 config backup created: {}", backup_path.display());
|
||||
}
|
||||
|
||||
let content = toml::to_string_pretty(self)
|
||||
.with_context(|| "Failed to serialize S3 config")?;
|
||||
|
||||
|
||||
let content =
|
||||
toml::to_string_pretty(self).with_context(|| "Failed to serialize S3 config")?;
|
||||
|
||||
std::fs::write(&config_path, content)
|
||||
.with_context(|| format!("Failed to write S3 config: {}", path))?;
|
||||
|
||||
|
||||
log::info!("S3 config saved to: {}", path);
|
||||
Ok(())
|
||||
}
|
||||
@@ -255,10 +240,16 @@ pub fn save(&self, path: &str) -> Result<()> {
|
||||
|
||||
// Validate permission format
|
||||
let valid_permissions = [
|
||||
"GetObject", "PutObject", "DeleteObject", "ListBucket",
|
||||
"HeadObject", "ListAllMyBuckets", "CreateBucket", "DeleteBucket"
|
||||
"GetObject",
|
||||
"PutObject",
|
||||
"DeleteObject",
|
||||
"ListBucket",
|
||||
"HeadObject",
|
||||
"ListAllMyBuckets",
|
||||
"CreateBucket",
|
||||
"DeleteBucket",
|
||||
];
|
||||
|
||||
|
||||
for perm in &self.permissions.default_permissions {
|
||||
if !valid_permissions.contains(&perm.as_str()) {
|
||||
return Err(anyhow::anyhow!(
|
||||
@@ -289,18 +280,18 @@ pub fn save(&self, path: &str) -> Result<()> {
|
||||
"s3.region" => Some(self.s3.region.clone()),
|
||||
"s3.service" => Some(self.s3.service.clone()),
|
||||
"s3.require_auth" => Some(self.s3.require_auth.to_string()),
|
||||
|
||||
|
||||
"keys.default_access_key" => Some(self.keys.default_access_key.clone()),
|
||||
"keys.default_secret_key" => Some(self.keys.default_secret_key.clone()),
|
||||
"keys.keys_db_path" => Some(self.keys.keys_db_path.clone()),
|
||||
|
||||
"permissions.default_permissions" => {
|
||||
Some(serde_json::to_string(&self.permissions.default_permissions).unwrap_or_default())
|
||||
}
|
||||
|
||||
"permissions.default_permissions" => Some(
|
||||
serde_json::to_string(&self.permissions.default_permissions).unwrap_or_default(),
|
||||
),
|
||||
"permissions.admin_permissions" => {
|
||||
Some(serde_json::to_string(&self.permissions.admin_permissions).unwrap_or_default())
|
||||
}
|
||||
|
||||
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -312,11 +303,11 @@ pub fn save(&self, path: &str) -> Result<()> {
|
||||
"s3.region" => self.s3.region = value.to_string(),
|
||||
"s3.service" => self.s3.service = value.to_string(),
|
||||
"s3.require_auth" => self.s3.require_auth = value.parse()?,
|
||||
|
||||
|
||||
"keys.default_access_key" => self.keys.default_access_key = value.to_string(),
|
||||
"keys.default_secret_key" => self.keys.default_secret_key = value.to_string(),
|
||||
"keys.keys_db_path" => self.keys.keys_db_path = value.to_string(),
|
||||
|
||||
|
||||
"permissions.default_permissions" => {
|
||||
self.permissions.default_permissions = serde_json::from_str(value)
|
||||
.with_context(|| "Failed to parse permissions array")?;
|
||||
@@ -325,7 +316,7 @@ pub fn save(&self, path: &str) -> Result<()> {
|
||||
self.permissions.admin_permissions = serde_json::from_str(value)
|
||||
.with_context(|| "Failed to parse admin permissions array")?;
|
||||
}
|
||||
|
||||
|
||||
_ => return Err(anyhow::anyhow!("Invalid S3 config key: {}", key)),
|
||||
}
|
||||
Ok(())
|
||||
@@ -340,15 +331,15 @@ mod tests {
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = S3Config::default();
|
||||
|
||||
|
||||
assert_eq!(config.s3.enabled, true);
|
||||
assert_eq!(config.s3.require_auth, false);
|
||||
assert_eq!(config.s3.endpoint, "http://localhost:11438/s3");
|
||||
assert_eq!(config.s3.region, "us-east-1");
|
||||
|
||||
|
||||
assert_eq!(config.keys.default_access_key, "markbase_access_key_001");
|
||||
assert_eq!(config.keys.default_secret_key, "markbase_secret_key_xyz123");
|
||||
|
||||
|
||||
assert_eq!(config.permissions.default_permissions.len(), 3);
|
||||
assert_eq!(config.permissions.admin_permissions.len(), 5);
|
||||
}
|
||||
@@ -357,9 +348,9 @@ mod tests {
|
||||
fn test_load_missing_config() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let config_path = temp_dir.path().join("missing.toml");
|
||||
|
||||
|
||||
let config = S3Config::load(&config_path.to_string_lossy()).unwrap();
|
||||
|
||||
|
||||
assert_eq!(config.s3.enabled, true);
|
||||
assert_eq!(config.s3.require_auth, false);
|
||||
}
|
||||
@@ -368,13 +359,13 @@ mod tests {
|
||||
fn test_merge_env() {
|
||||
std::env::set_var("MB_S3_REQUIRE_AUTH", "true");
|
||||
std::env::set_var("MB_S3_ENDPOINT", "http://custom.endpoint");
|
||||
|
||||
|
||||
let mut config = S3Config::default();
|
||||
config.merge_env();
|
||||
|
||||
|
||||
assert_eq!(config.s3.require_auth, true);
|
||||
assert_eq!(config.s3.endpoint, "http://custom.endpoint");
|
||||
|
||||
|
||||
std::env::remove_var("MB_S3_REQUIRE_AUTH");
|
||||
std::env::remove_var("MB_S3_ENDPOINT");
|
||||
}
|
||||
@@ -383,7 +374,7 @@ mod tests {
|
||||
fn test_validate() {
|
||||
let config = S3Config::default();
|
||||
assert!(config.validate().is_ok());
|
||||
|
||||
|
||||
let mut invalid_config = S3Config::default();
|
||||
invalid_config.s3.endpoint = "".to_string();
|
||||
assert!(invalid_config.validate().is_err());
|
||||
@@ -392,14 +383,17 @@ mod tests {
|
||||
#[test]
|
||||
fn test_get_set() {
|
||||
let mut config = S3Config::default();
|
||||
|
||||
|
||||
assert_eq!(config.get("s3.enabled"), Some("true".to_string()));
|
||||
assert_eq!(config.get("s3.endpoint"), Some("http://localhost:11438/s3".to_string()));
|
||||
|
||||
assert_eq!(
|
||||
config.get("s3.endpoint"),
|
||||
Some("http://localhost:11438/s3".to_string())
|
||||
);
|
||||
|
||||
config.set("s3.require_auth", "true").unwrap();
|
||||
assert_eq!(config.s3.require_auth, true);
|
||||
|
||||
|
||||
config.set("s3.endpoint", "http://new.endpoint").unwrap();
|
||||
assert_eq!(config.s3.endpoint, "http://new.endpoint");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,16 +4,18 @@ use serde_json::Value;
|
||||
pub fn list_buckets_xml(buckets: &[String]) -> (HeaderMap, String) {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("Content-Type", "application/xml".parse().unwrap());
|
||||
|
||||
|
||||
let bucket_entries = buckets
|
||||
.iter()
|
||||
.map(|b| format!(
|
||||
"<Bucket><Name>{}</Name><CreationDate>2026-05-27T00:00:00Z</CreationDate></Bucket>",
|
||||
b
|
||||
))
|
||||
.map(|b| {
|
||||
format!(
|
||||
"<Bucket><Name>{}</Name><CreationDate>2026-05-27T00:00:00Z</CreationDate></Bucket>",
|
||||
b
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n ");
|
||||
|
||||
|
||||
let xml = format!(
|
||||
"<?xml version=\"1.0\" encoding=\"UTF-8\"?>
|
||||
<ListAllMyBucketsResult xmlns=\"http://s3.amazonaws.com/doc/2006-03-01/\">
|
||||
@@ -27,22 +29,25 @@ pub fn list_buckets_xml(buckets: &[String]) -> (HeaderMap, String) {
|
||||
</ListAllMyBucketsResult>",
|
||||
bucket_entries
|
||||
);
|
||||
|
||||
|
||||
(headers, xml)
|
||||
}
|
||||
|
||||
pub fn list_objects_xml(bucket_name: &str, objects: &[Value]) -> (HeaderMap, String) {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("Content-Type", "application/xml".parse().unwrap());
|
||||
|
||||
|
||||
let object_entries = objects
|
||||
.iter()
|
||||
.map(|obj| {
|
||||
let key = obj.get("Key").and_then(|k| k.as_str()).unwrap_or("");
|
||||
let last_modified = obj.get("LastModified").and_then(|l| l.as_str()).unwrap_or("");
|
||||
let last_modified = obj
|
||||
.get("LastModified")
|
||||
.and_then(|l| l.as_str())
|
||||
.unwrap_or("");
|
||||
let etag = obj.get("ETag").and_then(|e| e.as_str()).unwrap_or("");
|
||||
let size = obj.get("Size").and_then(|s| s.as_i64()).unwrap_or(0);
|
||||
|
||||
|
||||
format!(
|
||||
"<Contents>
|
||||
<Key>{}</Key>
|
||||
@@ -55,7 +60,7 @@ pub fn list_objects_xml(bucket_name: &str, objects: &[Value]) -> (HeaderMap, Str
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n ");
|
||||
|
||||
|
||||
let xml = format!(
|
||||
"<?xml version=\"1.0\" encoding=\"UTF-8\"?>
|
||||
<ListBucketResult xmlns=\"http://s3.amazonaws.com/doc/2006-03-01/\">
|
||||
@@ -68,6 +73,6 @@ pub fn list_objects_xml(bucket_name: &str, objects: &[Value]) -> (HeaderMap, Str
|
||||
</ListBucketResult>",
|
||||
bucket_name, object_entries
|
||||
);
|
||||
|
||||
|
||||
(headers, xml)
|
||||
}
|
||||
|
||||
@@ -439,7 +439,7 @@ fn compute_hashes_parallel(
|
||||
|
||||
let mut p = processed.lock().unwrap();
|
||||
*p += 1;
|
||||
if *p % 100 == 0 {
|
||||
if (*p).is_multiple_of(100) {
|
||||
print!("\r Hashed {}/{} files...", *p, total);
|
||||
use std::io::Write;
|
||||
std::io::stdout().flush().ok();
|
||||
|
||||
@@ -12,20 +12,22 @@ fn get_test_provider() -> Arc<dyn DataProvider> {
|
||||
#[test]
|
||||
fn test_password_authentication_brute_force_prevention() {
|
||||
let provider = get_test_provider();
|
||||
|
||||
|
||||
assert!(provider.check_password("demo", "demo123").unwrap());
|
||||
assert!(!provider.check_password("demo", "wrongpassword").unwrap());
|
||||
assert!(!provider.check_password("demo", "").unwrap());
|
||||
assert!(!provider.check_password("__nonexistent__", "anypassword").unwrap());
|
||||
assert!(!provider
|
||||
.check_password("__nonexistent__", "anypassword")
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_publickey_authentication_security() {
|
||||
let provider = get_test_provider();
|
||||
|
||||
|
||||
let keys = provider.get_public_keys("demo").unwrap();
|
||||
assert!(keys.is_empty() || keys.len() >= 0);
|
||||
|
||||
|
||||
let keys = provider.get_public_keys("__nonexistent__").unwrap();
|
||||
assert!(keys.is_empty());
|
||||
}
|
||||
@@ -33,10 +35,10 @@ fn test_publickey_authentication_security() {
|
||||
#[test]
|
||||
fn test_user_status_check() {
|
||||
let provider = get_test_provider();
|
||||
|
||||
|
||||
let user = provider.get_user("demo").unwrap();
|
||||
assert!(user.is_some());
|
||||
|
||||
|
||||
let user = provider.get_user("demo").unwrap();
|
||||
if let Some(u) = user {
|
||||
assert_eq!(u.status, 1);
|
||||
@@ -46,16 +48,16 @@ fn test_user_status_check() {
|
||||
#[test]
|
||||
fn test_home_dir_security() {
|
||||
let provider = get_test_provider();
|
||||
|
||||
|
||||
let home = provider.get_home_dir("demo").unwrap();
|
||||
assert!(home.is_some());
|
||||
|
||||
|
||||
let home = provider.get_home_dir("__nonexistent__").unwrap();
|
||||
assert!(home.is_none());
|
||||
|
||||
|
||||
if let Some(home_path) = provider.get_home_dir("demo").unwrap() {
|
||||
assert!(!home_path.contains(".."));
|
||||
assert!(!home_path.starts_with("/etc"));
|
||||
assert!(!home_path.starts_with("/root"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ fn test_channel_window_size_limits() {
|
||||
#[test]
|
||||
fn test_channel_request_validation() {
|
||||
let valid_requests = ["exec", "shell", "subsystem", "env"];
|
||||
|
||||
|
||||
for request in valid_requests {
|
||||
assert!(!request.is_empty());
|
||||
}
|
||||
@@ -30,6 +30,6 @@ fn test_channel_data_integrity() {
|
||||
// Data should not exceed window size
|
||||
let window_size = 32768u32;
|
||||
let max_data = window_size;
|
||||
|
||||
|
||||
assert!(max_data <= window_size);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use crate::ssh_server::cipher::EncryptionContext;
|
||||
use crate::ssh_server::crypto::{SessionKeys, Curve25519Kex, Ed25519HostKey};
|
||||
use crate::ssh_server::crypto::{Curve25519Kex, Ed25519HostKey, SessionKeys};
|
||||
|
||||
#[test]
|
||||
fn test_aes_ctr_encryption_decryption_consistency() {
|
||||
let key = vec![0u8; 16];
|
||||
let iv = vec![0u8; 16];
|
||||
|
||||
|
||||
let mut ctx = EncryptionContext::from_session_keys(&SessionKeys {
|
||||
session_id: vec![0u8; 32],
|
||||
encryption_key_ctos: key.clone(),
|
||||
@@ -15,10 +15,10 @@ fn test_aes_ctr_encryption_decryption_consistency() {
|
||||
iv_ctos: iv.clone(),
|
||||
iv_stoc: iv.clone(),
|
||||
});
|
||||
|
||||
|
||||
let plaintext = b"Test message for encryption";
|
||||
let ciphertext = ctx.encrypt_packet(plaintext, &key, &iv).unwrap();
|
||||
|
||||
|
||||
let decrypted = ctx.decrypt_packet(&ciphertext, &key, &iv).unwrap();
|
||||
assert_eq!(plaintext.to_vec(), decrypted);
|
||||
}
|
||||
@@ -27,7 +27,7 @@ fn test_aes_ctr_encryption_decryption_consistency() {
|
||||
fn test_hmac_sha256_authentication() {
|
||||
let key = vec![0u8; 32];
|
||||
let data = b"Test data for HMAC";
|
||||
|
||||
|
||||
let ctx = EncryptionContext::from_session_keys(&SessionKeys {
|
||||
session_id: vec![0u8; 32],
|
||||
encryption_key_ctos: vec![0u8; 16],
|
||||
@@ -37,12 +37,12 @@ fn test_hmac_sha256_authentication() {
|
||||
iv_ctos: vec![0u8; 16],
|
||||
iv_stoc: vec![0u8; 16],
|
||||
});
|
||||
|
||||
|
||||
let mac = ctx.compute_mac(1, data, &key).unwrap();
|
||||
assert_eq!(mac.len(), 32);
|
||||
|
||||
|
||||
assert!(ctx.verify_mac(1, data, &mac, &key).unwrap());
|
||||
|
||||
|
||||
let wrong_mac = vec![0u8; 32];
|
||||
assert!(!ctx.verify_mac(1, data, &wrong_mac, &key).unwrap());
|
||||
}
|
||||
@@ -52,19 +52,19 @@ fn test_curve25519_key_exchange_security() {
|
||||
// Create client and server instances
|
||||
let mut client_kex = Curve25519Kex::new();
|
||||
let mut server_kex = Curve25519Kex::new();
|
||||
|
||||
|
||||
// Get public keys first (before computing shared secrets)
|
||||
let client_pub = client_kex.public_key().to_vec();
|
||||
let server_pub = server_kex.public_key().to_vec();
|
||||
|
||||
|
||||
assert_eq!(client_pub.len(), 32);
|
||||
assert_eq!(server_pub.len(), 32);
|
||||
|
||||
|
||||
// Compute shared secrets using the SAME instances
|
||||
// (this consumes the secret, so can only be done once)
|
||||
let client_secret = client_kex.compute_shared_secret(&server_pub).unwrap();
|
||||
let server_secret = server_kex.compute_shared_secret(&client_pub).unwrap();
|
||||
|
||||
|
||||
// Shared secrets should match (Diffie-Hellman property)
|
||||
assert_eq!(client_secret, server_secret);
|
||||
assert_eq!(client_secret.len(), 32);
|
||||
@@ -73,12 +73,12 @@ fn test_curve25519_key_exchange_security() {
|
||||
#[test]
|
||||
fn test_ed25519_signature_verification() {
|
||||
let host_key = Ed25519HostKey::load_or_generate("test_security_key").unwrap();
|
||||
|
||||
|
||||
let message = b"Test message for signature";
|
||||
let signature = host_key.sign(message).unwrap();
|
||||
|
||||
|
||||
assert_eq!(signature.len(), 64);
|
||||
|
||||
|
||||
// Ed25519HostKey has sign() but verify might need external library
|
||||
// For security test, we verify signature length and structure
|
||||
assert!(!signature.is_empty());
|
||||
@@ -89,9 +89,9 @@ fn test_encryption_key_derivation_uniqueness() {
|
||||
let key1 = vec![1u8; 16];
|
||||
let key2 = vec![2u8; 16];
|
||||
let iv = vec![0u8; 16];
|
||||
|
||||
|
||||
let plaintext = b"Same plaintext";
|
||||
|
||||
|
||||
let mut ctx1 = EncryptionContext::from_session_keys(&SessionKeys {
|
||||
session_id: vec![0u8; 32],
|
||||
encryption_key_ctos: key1.clone(),
|
||||
@@ -101,7 +101,7 @@ fn test_encryption_key_derivation_uniqueness() {
|
||||
iv_ctos: iv.clone(),
|
||||
iv_stoc: iv.clone(),
|
||||
});
|
||||
|
||||
|
||||
let mut ctx2 = EncryptionContext::from_session_keys(&SessionKeys {
|
||||
session_id: vec![0u8; 32],
|
||||
encryption_key_ctos: key2.clone(),
|
||||
@@ -111,9 +111,9 @@ fn test_encryption_key_derivation_uniqueness() {
|
||||
iv_ctos: iv.clone(),
|
||||
iv_stoc: iv.clone(),
|
||||
});
|
||||
|
||||
|
||||
let ciphertext1 = ctx1.encrypt_packet(plaintext, &key1, &iv).unwrap();
|
||||
let ciphertext2 = ctx2.encrypt_packet(plaintext, &key2, &iv).unwrap();
|
||||
|
||||
|
||||
assert_ne!(ciphertext1, ciphertext2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,17 +3,17 @@ use std::path::PathBuf;
|
||||
#[test]
|
||||
fn test_path_traversal_prevention() {
|
||||
let root = PathBuf::from("/tmp/test_root");
|
||||
|
||||
|
||||
// Test 1: Normal path should be within root
|
||||
let safe_path = PathBuf::from("safe/file.txt");
|
||||
let full_path = root.join(&safe_path);
|
||||
assert!(full_path.starts_with(&root));
|
||||
|
||||
|
||||
// Test 2: Path traversal attempt should still resolve within root
|
||||
// (after normalization, ../../etc/passwd from /tmp/test_root becomes /tmp/etc/passwd or /etc/passwd)
|
||||
let evil_path = PathBuf::from("../../etc/passwd");
|
||||
let full_path = root.join(&evil_path);
|
||||
|
||||
|
||||
// The key security check: the resolved path should NOT be /etc/passwd
|
||||
// If Path::join normalizes it to /etc/passwd, that's a path traversal vulnerability
|
||||
// We check that the joined path either:
|
||||
@@ -34,7 +34,7 @@ fn test_path_traversal_prevention() {
|
||||
#[test]
|
||||
fn test_absolute_path_prevention() {
|
||||
let root = PathBuf::from("/tmp/test_root");
|
||||
|
||||
|
||||
let abs_path = PathBuf::from("/etc/passwd");
|
||||
assert!(!abs_path.starts_with(&root));
|
||||
}
|
||||
@@ -42,10 +42,10 @@ fn test_absolute_path_prevention() {
|
||||
#[test]
|
||||
fn test_directory_escape_prevention() {
|
||||
let root = PathBuf::from("/tmp/test_root");
|
||||
|
||||
|
||||
let parent_path = PathBuf::from("subdir/../..");
|
||||
let full_path = root.join(&parent_path);
|
||||
|
||||
|
||||
// Path should not escape root
|
||||
if full_path.canonicalize().is_ok() {
|
||||
let canonical = full_path.canonicalize().unwrap();
|
||||
@@ -56,11 +56,11 @@ fn test_directory_escape_prevention() {
|
||||
#[test]
|
||||
fn test_file_write_boundary_check() {
|
||||
let root = PathBuf::from("/tmp/test_root");
|
||||
|
||||
|
||||
let safe_file = PathBuf::from("safe.txt");
|
||||
let full_path = root.join(&safe_file);
|
||||
assert!(full_path.starts_with(&root));
|
||||
|
||||
|
||||
let outside_file = PathBuf::from("/tmp/outside.txt");
|
||||
assert!(!outside_file.starts_with(&root));
|
||||
}
|
||||
@@ -68,8 +68,8 @@ fn test_file_write_boundary_check() {
|
||||
#[test]
|
||||
fn test_hidden_file_access() {
|
||||
let root = PathBuf::from("/tmp/test_root");
|
||||
|
||||
|
||||
let hidden_path = PathBuf::from(".hidden");
|
||||
let full_path = root.join(&hidden_path);
|
||||
assert!(full_path.starts_with(&root));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
mod auth_security;
|
||||
mod channel_security;
|
||||
mod crypto_security;
|
||||
mod file_access_security;
|
||||
mod channel_security;
|
||||
|
||||
pub use auth_security::*;
|
||||
pub use channel_security::*;
|
||||
pub use crypto_security::*;
|
||||
pub use file_access_security::*;
|
||||
pub use channel_security::*;
|
||||
@@ -1,22 +1,23 @@
|
||||
use anyhow::Context;
|
||||
use axum::{
|
||||
extract::DefaultBodyLimit,
|
||||
extract::{Path, Query, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{Html, IntoResponse, Json},
|
||||
routing::{delete, get, patch, post, put},
|
||||
Router,
|
||||
extract::DefaultBodyLimit,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::str::FromStr;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::archive::{
|
||||
ArchiveConfig, ArchiveFormat, ArchiveProcessor, FormatDetector, ProcessorRegistry,
|
||||
};
|
||||
use crate::audio;
|
||||
use crate::auth::{AuthState, LoginRequest};
|
||||
use crate::provider::sqlite::SqliteProvider;
|
||||
use crate::render;
|
||||
use crate::download;
|
||||
use crate::archive::{self, ArchiveFormat, ArchiveProcessor, FormatDetector, ArchiveConfig, ProcessorRegistry};
|
||||
use filetree::{self, FileTree};
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -60,7 +61,7 @@ pub async fn run(port: u16, file: Option<String>) -> anyhow::Result<()> {
|
||||
db_dir: "data/users".to_string(),
|
||||
auth: AuthState::with_provider(Box::new(
|
||||
SqliteProvider::new("data/auth.sqlite")
|
||||
.map_err(|e| anyhow::anyhow!("Failed to init SqliteProvider: {}", e))?
|
||||
.map_err(|e| anyhow::anyhow!("Failed to init SqliteProvider: {}", e))?,
|
||||
)),
|
||||
auth_db_path: "data/auth.sqlite".to_string(),
|
||||
s3_keys: Arc::new(Mutex::new(load_s3_keys())),
|
||||
@@ -578,7 +579,7 @@ async fn search_tree(
|
||||
ORDER BY sort_order ASC, created_at ASC",
|
||||
)?;
|
||||
|
||||
let nodes: Vec<filetree::node::FileNode> = stmt
|
||||
let _nodes: Vec<filetree::node::FileNode> = stmt
|
||||
.query_map([&search_pattern], |row| {
|
||||
let children_json: String = row.get(6)?;
|
||||
let children: Vec<String> =
|
||||
@@ -607,7 +608,7 @@ async fn search_tree(
|
||||
.filter_map(|r| r.ok())
|
||||
.collect();
|
||||
|
||||
let tree = filetree::FileTree {
|
||||
let tree = filetree::FileTree {
|
||||
user_id: user_id.clone(),
|
||||
tree_type: "untitled folder".to_string(),
|
||||
nodes: vec![],
|
||||
@@ -914,69 +915,78 @@ fn extract_and_register_archive(
|
||||
user_id: &str,
|
||||
original_filename: &str,
|
||||
) -> anyhow::Result<(u64, u64, String)> {
|
||||
use std::path::PathBuf;
|
||||
use sha2::{Sha256, Digest};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
|
||||
// Initialize archive system
|
||||
let config = ArchiveConfig::default();
|
||||
let mut registry = ProcessorRegistry::new(config);
|
||||
registry.initialize()?;
|
||||
|
||||
|
||||
// Detect format
|
||||
let detector = FormatDetector::new();
|
||||
let format = detector.detect(archive_path)?;
|
||||
|
||||
eprintln!("[archive] Detected format: {} for file: {}", format, archive_path.display());
|
||||
|
||||
|
||||
eprintln!(
|
||||
"[archive] Detected format: {} for file: {}",
|
||||
format,
|
||||
archive_path.display()
|
||||
);
|
||||
|
||||
// Get processor
|
||||
let processor = registry.get_processor_mut(archive_path)?;
|
||||
|
||||
|
||||
// Create extraction directory
|
||||
let base_name = original_filename
|
||||
.rsplit_once('.')
|
||||
.map(|(name, _)| name)
|
||||
.unwrap_or(original_filename);
|
||||
|
||||
let extraction_dir = archive_path.parent()
|
||||
|
||||
let extraction_dir = archive_path
|
||||
.parent()
|
||||
.unwrap_or(std::path::Path::new("."))
|
||||
.join(format!("{}_extracted", base_name));
|
||||
|
||||
|
||||
std::fs::create_dir_all(&extraction_dir)?;
|
||||
|
||||
|
||||
// Open and extract
|
||||
let metadata = processor.open(archive_path)?;
|
||||
|
||||
eprintln!("[archive] Archive metadata: {} files, {} bytes",
|
||||
metadata.total_files, metadata.total_size);
|
||||
|
||||
|
||||
eprintln!(
|
||||
"[archive] Archive metadata: {} files, {} bytes",
|
||||
metadata.total_files, metadata.total_size
|
||||
);
|
||||
|
||||
let result = processor.extract_all(&extraction_dir)?;
|
||||
|
||||
eprintln!("[archive] Extracted {} files ({} bytes)",
|
||||
result.success_files, result.total_bytes);
|
||||
|
||||
|
||||
eprintln!(
|
||||
"[archive] Extracted {} files ({} bytes)",
|
||||
result.success_files, result.total_bytes
|
||||
);
|
||||
|
||||
// Register extracted files to database
|
||||
let conn = FileTree::init_user_db(user_id)?;
|
||||
|
||||
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
|
||||
|
||||
// Get MAC address for UUID generation
|
||||
let mac_output = std::process::Command::new("ifconfig")
|
||||
.arg("en0")
|
||||
.output()
|
||||
.map(|o| String::from_utf8_lossy(&o.stdout).to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
|
||||
let mac = mac_output
|
||||
.lines()
|
||||
.find(|l| l.contains("ether"))
|
||||
.and_then(|l| l.split_whitespace().nth(1))
|
||||
.unwrap_or("00:00:00:00:00:00");
|
||||
|
||||
|
||||
let mut registered_count = 0u64;
|
||||
|
||||
|
||||
// Recursively scan extracted directory
|
||||
fn scan_directory(
|
||||
dir: &std::path::Path,
|
||||
@@ -986,11 +996,11 @@ fn extract_and_register_archive(
|
||||
now: i64,
|
||||
) -> anyhow::Result<u64> {
|
||||
let mut count = 0u64;
|
||||
|
||||
|
||||
for entry in std::fs::read_dir(dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
|
||||
if path.is_dir() {
|
||||
count += scan_directory(&path, conn, user_id, mac, now)?;
|
||||
} else if path.is_file() {
|
||||
@@ -998,16 +1008,15 @@ fn extract_and_register_archive(
|
||||
let file_data = std::fs::read(&path)?;
|
||||
let file_hash = format!("{:x}", Sha256::digest(&file_data));
|
||||
let file_size = file_data.len() as i64;
|
||||
|
||||
let filename = path.file_name()
|
||||
|
||||
let filename = path
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
let file_path_str = path.to_str()
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
|
||||
let file_path_str = path.to_str().unwrap_or("unknown").to_string();
|
||||
|
||||
// Generate file UUID
|
||||
let mtime = std::fs::metadata(&path)
|
||||
.ok()
|
||||
@@ -1015,48 +1024,55 @@ fn extract_and_register_archive(
|
||||
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
|
||||
.map(|d| d.as_millis() as u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
|
||||
let input = format!("{}|{}|{}|{}", file_path_str, filename, mac, mtime);
|
||||
let hash = Sha256::digest(input.as_bytes());
|
||||
let hex = format!("{:x}", hash);
|
||||
let file_uuid = hex[0..32].to_string();
|
||||
|
||||
|
||||
// Register file (no sha256 in file_registry)
|
||||
conn.execute(
|
||||
"INSERT INTO file_registry (file_uuid, original_name, file_size, file_type, registered_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5)",
|
||||
rusqlite::params![&file_uuid, &filename, file_size, "", now],
|
||||
)?;
|
||||
|
||||
|
||||
// Add file location
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO file_locations (file_uuid, location, added_at)
|
||||
VALUES (?1, ?2, ?3)",
|
||||
rusqlite::params![&file_uuid, &file_path_str, now],
|
||||
)?;
|
||||
|
||||
|
||||
// Add file node
|
||||
let uuid_str = uuid::Uuid::new_v4().to_string().replace('-', "");
|
||||
let node_id = format!("node-{}", &uuid_str[0..8]);
|
||||
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO file_nodes (node_id, label, file_uuid, sha256, node_type, file_size, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, 'file', ?5, ?6, ?7)",
|
||||
rusqlite::params![&node_id, &filename, &file_uuid, &file_hash, file_size, now, now],
|
||||
)?;
|
||||
|
||||
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
|
||||
registered_count = scan_directory(&extraction_dir, &conn, user_id, mac, now)?;
|
||||
|
||||
eprintln!("[archive] Registered {} extracted files to database", registered_count);
|
||||
|
||||
Ok((result.success_files, result.total_bytes, extraction_dir.to_str().unwrap_or("unknown").to_string()))
|
||||
|
||||
eprintln!(
|
||||
"[archive] Registered {} extracted files to database",
|
||||
registered_count
|
||||
);
|
||||
|
||||
Ok((
|
||||
result.success_files,
|
||||
result.total_bytes,
|
||||
extraction_dir.to_str().unwrap_or("unknown").to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn upload_file(
|
||||
@@ -1147,23 +1163,23 @@ async fn upload_file(
|
||||
// Auto-extract archive files
|
||||
let file_path_buf = std::path::PathBuf::from(&file_path);
|
||||
let detector = FormatDetector::new();
|
||||
|
||||
|
||||
if let Ok(format) = detector.detect(&file_path_buf) {
|
||||
if format != ArchiveFormat::Unknown {
|
||||
eprintln!("[upload] Detected archive format: {}, extracting...", format);
|
||||
|
||||
eprintln!(
|
||||
"[upload] Detected archive format: {}, extracting...",
|
||||
format
|
||||
);
|
||||
|
||||
let user_id_clone = user_id.clone();
|
||||
let filename_clone = filename.clone();
|
||||
|
||||
|
||||
// Extract in blocking thread
|
||||
let extraction_result = tokio::task::spawn_blocking(move || {
|
||||
extract_and_register_archive(
|
||||
&file_path_buf,
|
||||
&user_id_clone,
|
||||
&filename_clone,
|
||||
)
|
||||
}).await;
|
||||
|
||||
extract_and_register_archive(&file_path_buf, &user_id_clone, &filename_clone)
|
||||
})
|
||||
.await;
|
||||
|
||||
match extraction_result {
|
||||
Ok(Ok((count, bytes, extract_dir))) => {
|
||||
extracted_info = Some((count, bytes, extract_dir));
|
||||
@@ -1208,13 +1224,13 @@ async fn upload_file(
|
||||
let hex = format!("{:x}", hash);
|
||||
let file_uuid = hex[0..32].to_string();
|
||||
|
||||
// Save to database (user-specific SQLite)
|
||||
// Save to database (user-specific SQLite)
|
||||
let file_uuid_clone = file_uuid.clone();
|
||||
let file_hash_clone = file_hash.clone();
|
||||
let filename_clone = filename.clone();
|
||||
let file_path_clone = file_path.clone();
|
||||
let user_id_clone = user_id.clone();
|
||||
|
||||
|
||||
let db_result = tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||
let conn = filetree::FileTree::init_user_db(&user_id_clone)?;
|
||||
|
||||
@@ -1281,7 +1297,7 @@ async fn upload_file(
|
||||
"sha256": file_hash,
|
||||
"size": file_size,
|
||||
});
|
||||
|
||||
|
||||
if let Some((count, bytes, extract_dir)) = extracted_info {
|
||||
response["extracted"] = serde_json::json!({
|
||||
"count": count,
|
||||
@@ -1289,12 +1305,8 @@ async fn upload_file(
|
||||
"directory": extract_dir,
|
||||
});
|
||||
}
|
||||
|
||||
(
|
||||
StatusCode::CREATED,
|
||||
Json(response),
|
||||
)
|
||||
.into_response()
|
||||
|
||||
(StatusCode::CREATED, Json(response)).into_response()
|
||||
}
|
||||
|
||||
async fn upload_unlimited(
|
||||
@@ -1798,7 +1810,7 @@ async fn logout_handler(State(state): State<AppState>, headers: HeaderMap) -> im
|
||||
let auth_header = headers
|
||||
.get("Authorization")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| crate::auth::parse_auth_header(h));
|
||||
.and_then(crate::auth::parse_auth_header);
|
||||
|
||||
match auth_header {
|
||||
Some(token) => {
|
||||
@@ -1824,7 +1836,7 @@ async fn verify_handler(State(state): State<AppState>, headers: HeaderMap) -> im
|
||||
let auth_header = headers
|
||||
.get("Authorization")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| crate::auth::parse_auth_header(h));
|
||||
.and_then(crate::auth::parse_auth_header);
|
||||
|
||||
match auth_header {
|
||||
Some(token) => match state.auth.verify_token(&token) {
|
||||
@@ -1857,7 +1869,7 @@ fn verify_auth(state: &AppState, headers: &HeaderMap) -> Result<String, StatusCo
|
||||
let auth_header = headers
|
||||
.get("Authorization")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| crate::auth::parse_auth_header(h));
|
||||
.and_then(crate::auth::parse_auth_header);
|
||||
|
||||
match auth_header {
|
||||
Some(token) => match state.auth.verify_token(&token) {
|
||||
@@ -2039,7 +2051,7 @@ async fn edit_config_handler(Query(params): Query<EditConfigQuery>) -> impl Into
|
||||
match crate::config::MarkBaseConfig::load(config_path) {
|
||||
Ok(mut config) => {
|
||||
let old_value = config.get(¶ms.key).unwrap_or_default();
|
||||
|
||||
|
||||
match config.set(¶ms.key, ¶ms.value) {
|
||||
Ok(_) => match config.validate() {
|
||||
Ok(_) => match config.save(config_path) {
|
||||
@@ -2056,7 +2068,7 @@ async fn edit_config_handler(Query(params): Query<EditConfigQuery>) -> impl Into
|
||||
) {
|
||||
log::warn!("Failed to write audit log: {}", e);
|
||||
}
|
||||
|
||||
|
||||
(StatusCode::OK, Json(serde_json::json!({"ok": true}))).into_response()
|
||||
}
|
||||
Err(e) => (
|
||||
@@ -2133,7 +2145,7 @@ async fn edit_s3_config_handler(Query(params): Query<EditConfigQuery>) -> impl I
|
||||
match crate::s3_config::S3Config::load_default() {
|
||||
Ok(mut config) => {
|
||||
let old_value = config.get(¶ms.key).unwrap_or_default();
|
||||
|
||||
|
||||
match config.set(¶ms.key, ¶ms.value) {
|
||||
Ok(_) => match config.validate() {
|
||||
Ok(_) => match config.save("config/s3.toml") {
|
||||
@@ -2150,7 +2162,7 @@ async fn edit_s3_config_handler(Query(params): Query<EditConfigQuery>) -> impl I
|
||||
) {
|
||||
log::warn!("Failed to write audit log: {}", e);
|
||||
}
|
||||
|
||||
|
||||
(StatusCode::OK, Json(serde_json::json!({"ok": true}))).into_response()
|
||||
}
|
||||
Err(e) => (
|
||||
@@ -2343,7 +2355,7 @@ async fn audit_handler() -> Json<serde_json::Value> {
|
||||
// Category View API handlers (Phase 1: 双视图管理)
|
||||
|
||||
async fn get_all_categories_handler() -> impl IntoResponse {
|
||||
let base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||
let _base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||
match crate::category_view::get_all_categories() {
|
||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||
Err(e) => (
|
||||
@@ -2354,10 +2366,8 @@ async fn get_all_categories_handler() -> impl IntoResponse {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_category_detail_handler(
|
||||
Path(category_name): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
let base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||
async fn get_category_detail_handler(Path(category_name): Path<String>) -> impl IntoResponse {
|
||||
let _base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||
match crate::category_view::get_category_detail(&category_name) {
|
||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||
Err(e) => (
|
||||
@@ -2369,7 +2379,7 @@ async fn get_category_detail_handler(
|
||||
}
|
||||
|
||||
async fn get_all_series_handler() -> impl IntoResponse {
|
||||
let base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||
let _base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||
match crate::category_view::get_all_series() {
|
||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||
Err(e) => (
|
||||
@@ -2380,10 +2390,8 @@ async fn get_all_series_handler() -> impl IntoResponse {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_series_detail_handler(
|
||||
Path(series_name): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
let base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||
async fn get_series_detail_handler(Path(series_name): Path<String>) -> impl IntoResponse {
|
||||
let _base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||
match crate::category_view::get_series_detail(&series_name) {
|
||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||
Err(e) => (
|
||||
@@ -2400,10 +2408,8 @@ struct SearchQuery {
|
||||
view: String,
|
||||
}
|
||||
|
||||
async fn search_files_handler(
|
||||
Query(query): Query<SearchQuery>,
|
||||
) -> impl IntoResponse {
|
||||
let base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||
async fn search_files_handler(Query(query): Query<SearchQuery>) -> impl IntoResponse {
|
||||
let _base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||
match crate::category_view::search_files(&query.q, &query.view) {
|
||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||
Err(e) => (
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
||||
use std::io::Write;
|
||||
use anyhow::{Result, anyhow};
|
||||
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||
use anyhow::{anyhow, Result};
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use log::{info, warn, debug};
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
use log::{debug, info, warn};
|
||||
use std::io::Write;
|
||||
|
||||
use ed25519_dalek::{VerifyingKey, Signature};
|
||||
use ed25519_dalek::{Signature, VerifyingKey};
|
||||
|
||||
use crate::provider::{DataProvider, ProviderError};
|
||||
|
||||
@@ -27,7 +27,11 @@ impl AuthHandler {
|
||||
}
|
||||
|
||||
/// 处理SSH_MSG_USERAUTH_REQUEST(参考OpenSSH auth2.c: userauth_request())
|
||||
pub fn handle_userauth_request(&mut self, packet: &SshPacket, session_id: &[u8]) -> Result<AuthResult> {
|
||||
pub fn handle_userauth_request(
|
||||
&mut self,
|
||||
packet: &SshPacket,
|
||||
session_id: &[u8],
|
||||
) -> Result<AuthResult> {
|
||||
info!("Processing SSH_MSG_USERAUTH_REQUEST");
|
||||
|
||||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
|
||||
@@ -41,7 +45,10 @@ impl AuthHandler {
|
||||
let service = read_ssh_string(&mut cursor)?;
|
||||
let method = read_ssh_string(&mut cursor)?;
|
||||
|
||||
info!("Auth request: user={}, service={}, method={}", user, service, method);
|
||||
info!(
|
||||
"Auth request: user={}, service={}, method={}",
|
||||
user, service, method
|
||||
);
|
||||
|
||||
if service != "ssh-connection" {
|
||||
warn!("Unsupported service: {}", service);
|
||||
@@ -62,18 +69,28 @@ impl AuthHandler {
|
||||
}
|
||||
|
||||
/// 处理password认证(参考OpenSSH auth-passwd.c)
|
||||
fn handle_password_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result<AuthResult> {
|
||||
fn handle_password_auth(
|
||||
&mut self,
|
||||
cursor: &mut std::io::Cursor<&[u8]>,
|
||||
user: &str,
|
||||
) -> Result<AuthResult> {
|
||||
info!("Handling password auth for user: {}", user);
|
||||
|
||||
let change_password = cursor.read_u8()? != 0;
|
||||
if change_password {
|
||||
warn!("Password change not supported");
|
||||
return Ok(AuthResult::Failure("Password change not supported".to_string()));
|
||||
return Ok(AuthResult::Failure(
|
||||
"Password change not supported".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let password = read_ssh_string(cursor)?;
|
||||
|
||||
debug!("Password auth attempt: user={}, password length={}", user, password.len());
|
||||
debug!(
|
||||
"Password auth attempt: user={}, password length={}",
|
||||
user,
|
||||
password.len()
|
||||
);
|
||||
|
||||
match self.provider.check_password(user, &password) {
|
||||
Ok(true) => {
|
||||
@@ -88,9 +105,7 @@ impl AuthHandler {
|
||||
warn!("User not found: {}", msg);
|
||||
Ok(AuthResult::Failure("password,publickey".to_string()))
|
||||
}
|
||||
Err(e) => {
|
||||
Err(anyhow!("Password auth error: {}", e))
|
||||
}
|
||||
Err(e) => Err(anyhow!("Password auth error: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,7 +160,12 @@ impl AuthHandler {
|
||||
let algorithm = read_ssh_string(cursor)?;
|
||||
let public_key_blob = read_ssh_string_bytes(cursor)?;
|
||||
|
||||
info!("Publickey auth: algorithm={}, blob_len={}, is_signed={}", algorithm, public_key_blob.len(), is_signed);
|
||||
info!(
|
||||
"Publickey auth: algorithm={}, blob_len={}, is_signed={}",
|
||||
algorithm,
|
||||
public_key_blob.len(),
|
||||
is_signed
|
||||
);
|
||||
|
||||
if !self.is_key_authorized(user, &algorithm, &public_key_blob)? {
|
||||
warn!("Public key not authorized for user: {}", user);
|
||||
@@ -160,14 +180,26 @@ impl AuthHandler {
|
||||
|
||||
let signature_blob = read_ssh_string_bytes(cursor)?;
|
||||
|
||||
self.verify_signature(&algorithm, &public_key_blob, &signature_blob, user, service, session_id)?;
|
||||
self.verify_signature(
|
||||
&algorithm,
|
||||
&public_key_blob,
|
||||
&signature_blob,
|
||||
user,
|
||||
service,
|
||||
session_id,
|
||||
)?;
|
||||
|
||||
info!("Publickey auth successful for user: {}", user);
|
||||
Ok(AuthResult::Success)
|
||||
}
|
||||
|
||||
/// 检查public key是否在授权列表中(数据库优先,fallback到filesystem)
|
||||
fn is_key_authorized(&self, user: &str, algorithm: &str, public_key_blob: &[u8]) -> Result<bool> {
|
||||
fn is_key_authorized(
|
||||
&self,
|
||||
user: &str,
|
||||
algorithm: &str,
|
||||
public_key_blob: &[u8],
|
||||
) -> Result<bool> {
|
||||
// 1. 先检查数据库
|
||||
match self.provider.get_public_keys(user) {
|
||||
Ok(keys) => {
|
||||
@@ -187,10 +219,12 @@ impl AuthHandler {
|
||||
Err(_) => match std::fs::read_to_string("data/authorized_keys") {
|
||||
Ok(c) => c,
|
||||
Err(_) => return Ok(false),
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
Ok(content.lines().any(|line| public_key_matches_line(line, algorithm, public_key_blob)))
|
||||
Ok(content
|
||||
.lines()
|
||||
.any(|line| public_key_matches_line(line, algorithm, public_key_blob)))
|
||||
}
|
||||
|
||||
/// 验证Ed25519签名(RFC 4252 §7)
|
||||
@@ -246,7 +280,8 @@ impl AuthHandler {
|
||||
signed_data.write_all(public_key_blob)?;
|
||||
|
||||
// 验证签名
|
||||
verifying_key.verify_strict(&signed_data, &signature)
|
||||
verifying_key
|
||||
.verify_strict(&signed_data, &signature)
|
||||
.map_err(|e| anyhow!("Ed25519 signature verification failed: {}", e))
|
||||
}
|
||||
}
|
||||
@@ -270,10 +305,10 @@ fn parse_ed25519_verifying_key(public_key_blob: &[u8]) -> Result<VerifyingKey> {
|
||||
if key_bytes.len() != 32 {
|
||||
return Err(anyhow!("Invalid Ed25519 key length: {}", key_bytes.len()));
|
||||
}
|
||||
let key_array: [u8; 32] = key_bytes.try_into()
|
||||
let key_array: [u8; 32] = key_bytes
|
||||
.try_into()
|
||||
.map_err(|_| anyhow!("Invalid Ed25519 key data"))?;
|
||||
VerifyingKey::from_bytes(&key_array)
|
||||
.map_err(|e| anyhow!("Invalid Ed25519 key: {}", e))
|
||||
VerifyingKey::from_bytes(&key_array).map_err(|e| anyhow!("Invalid Ed25519 key: {}", e))
|
||||
}
|
||||
|
||||
/// 解析Ed25519签名blob(SSH格式 -> Signature)
|
||||
@@ -285,9 +320,13 @@ fn parse_ed25519_signature(signature_blob: &[u8]) -> Result<Signature> {
|
||||
}
|
||||
let sig_bytes = read_ssh_string_bytes(&mut cursor)?;
|
||||
if sig_bytes.len() != 64 {
|
||||
return Err(anyhow!("Invalid Ed25519 signature length: {}", sig_bytes.len()));
|
||||
return Err(anyhow!(
|
||||
"Invalid Ed25519 signature length: {}",
|
||||
sig_bytes.len()
|
||||
));
|
||||
}
|
||||
let sig_array: [u8; 64] = sig_bytes.try_into()
|
||||
let sig_array: [u8; 64] = sig_bytes
|
||||
.try_into()
|
||||
.map_err(|_| anyhow!("Invalid Ed25519 signature data"))?;
|
||||
Ok(Signature::from_bytes(&sig_array))
|
||||
}
|
||||
@@ -305,7 +344,9 @@ fn public_key_matches_line(line: &str, algorithm: &str, public_key_blob: &[u8])
|
||||
if parts[0] != algorithm {
|
||||
return false;
|
||||
}
|
||||
base64_decode(parts[1]).map(|decoded| decoded == public_key_blob).unwrap_or(false)
|
||||
base64_decode(parts[1])
|
||||
.map(|decoded| decoded == public_key_blob)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
|
||||
@@ -323,7 +364,8 @@ fn read_ssh_string_bytes<R: std::io::Read>(reader: &mut R) -> Result<Vec<u8>> {
|
||||
}
|
||||
|
||||
fn base64_decode(input: &str) -> Result<Vec<u8>> {
|
||||
general_purpose::STANDARD.decode(input)
|
||||
general_purpose::STANDARD
|
||||
.decode(input)
|
||||
.map_err(|e| anyhow!("Base64 decode error: {}", e))
|
||||
}
|
||||
|
||||
@@ -335,7 +377,10 @@ mod tests {
|
||||
#[test]
|
||||
fn test_userauth_success_packet() {
|
||||
let packet = AuthHandler::build_userauth_success().unwrap();
|
||||
assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_SUCCESS as u8);
|
||||
assert_eq!(
|
||||
packet.payload[0],
|
||||
PacketType::SSH_MSG_USERAUTH_SUCCESS as u8
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -343,6 +388,9 @@ mod tests {
|
||||
let methods = vec!["password".to_string(), "publickey".to_string()];
|
||||
let packet = AuthHandler::build_userauth_failure(&methods, false).unwrap();
|
||||
|
||||
assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_FAILURE as u8);
|
||||
assert_eq!(
|
||||
packet.payload[0],
|
||||
PacketType::SSH_MSG_USERAUTH_FAILURE as u8
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,33 +1,33 @@
|
||||
// SSH加密通道实现(Phase 4)
|
||||
// 参考OpenSSH cipher.c, mac.c
|
||||
|
||||
use aes::Aes128; // 改为AES-128(协商算法是aes128-ctr)
|
||||
use super::crypto::SessionKeys;
|
||||
use aes::Aes128; // 改为AES-128(协商算法是aes128-ctr)
|
||||
use anyhow::{anyhow, Result};
|
||||
use byteorder::{BigEndian, WriteBytesExt};
|
||||
use cipher::{KeyIvInit, StreamCipher};
|
||||
use ctr::Ctr128BE;
|
||||
use hmac::{Hmac, Mac};
|
||||
use log::info;
|
||||
use sha2::Sha256;
|
||||
use cipher::{KeyIvInit, StreamCipher};
|
||||
use std::io::Write;
|
||||
use anyhow::{Result, anyhow};
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use log::{info, debug, warn};
|
||||
use super::crypto::SessionKeys;
|
||||
|
||||
type Aes128Ctr = Ctr128BE<Aes128>; // AES-128-CTR(16字节密钥)
|
||||
type Aes128Ctr = Ctr128BE<Aes128>; // AES-128-CTR(16字节密钥)
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// SSH加密通道管理器(参考OpenSSH struct sshcipher_ctx)
|
||||
pub struct EncryptionContext {
|
||||
pub session_id: Vec<u8>, // session identifier (exchange hash)
|
||||
pub encryption_key_ctos: Vec<u8>, // 客户端→服务器加密密钥
|
||||
pub encryption_key_stoc: Vec<u8>, // 服务器→客户端加密密钥
|
||||
pub mac_key_ctos: Vec<u8>, // 客户端→服务器MAC密钥
|
||||
pub mac_key_stoc: Vec<u8>, // 服务器→客户端MAC密钥
|
||||
pub iv_ctos: Vec<u8>, // 客户端→服务器IV
|
||||
pub iv_stoc: Vec<u8>, // 服务器→客户端IV
|
||||
pub sequence_number_ctos: u32, // 客户端→服务器序列号
|
||||
pub sequence_number_stoc: u32, // 服务器→客户端序列号
|
||||
pub cipher_ctos: Option<Aes128Ctr>, // 客户端→服务器cipher实例(持久化)
|
||||
pub cipher_stoc: Option<Aes128Ctr>, // 服务器→客户端cipher实例(持久化)
|
||||
pub encryption_key_ctos: Vec<u8>, // 客户端→服务器加密密钥
|
||||
pub encryption_key_stoc: Vec<u8>, // 服务器→客户端加密密钥
|
||||
pub mac_key_ctos: Vec<u8>, // 客户端→服务器MAC密钥
|
||||
pub mac_key_stoc: Vec<u8>, // 服务器→客户端MAC密钥
|
||||
pub iv_ctos: Vec<u8>, // 客户端→服务器IV
|
||||
pub iv_stoc: Vec<u8>, // 服务器→客户端IV
|
||||
pub sequence_number_ctos: u32, // 客户端→服务器序列号
|
||||
pub sequence_number_stoc: u32, // 服务器→客户端序列号
|
||||
pub cipher_ctos: Option<Aes128Ctr>, // 客户端→服务器cipher实例(持久化)
|
||||
pub cipher_stoc: Option<Aes128Ctr>, // 服务器→客户端cipher实例(持久化)
|
||||
}
|
||||
|
||||
impl Default for EncryptionContext {
|
||||
@@ -53,27 +53,33 @@ impl EncryptionContext {
|
||||
/// OpenSSH cipher.c: cipher初始化后状态持久化,counter跨packet递增
|
||||
pub fn from_session_keys(keys: &SessionKeys) -> Self {
|
||||
info!("Initializing ciphers with session keys:");
|
||||
info!(" encryption_key_ctos (16 bytes): {:?}", &keys.encryption_key_ctos[..16]);
|
||||
info!(
|
||||
" encryption_key_ctos (16 bytes): {:?}",
|
||||
&keys.encryption_key_ctos[..16]
|
||||
);
|
||||
info!(" iv_ctos (16 bytes): {:?}", &keys.iv_ctos[..16]);
|
||||
info!(" encryption_key_stoc (16 bytes): {:?}", &keys.encryption_key_stoc[..16]);
|
||||
info!(
|
||||
" encryption_key_stoc (16 bytes): {:?}",
|
||||
&keys.encryption_key_stoc[..16]
|
||||
);
|
||||
info!(" iv_stoc (16 bytes): {:?}", &keys.iv_stoc[..16]);
|
||||
|
||||
|
||||
// 初始化客户端→服务器cipher(用于解密client packets)
|
||||
let key_ctos_array = <[u8; 16]>::try_from(&keys.encryption_key_ctos[..16])
|
||||
.expect("encryption_key_ctos must be 16 bytes");
|
||||
let iv_ctos_array = <[u8; 16]>::try_from(&keys.iv_ctos[..16])
|
||||
.expect("iv_ctos must be 16 bytes");
|
||||
let iv_ctos_array =
|
||||
<[u8; 16]>::try_from(&keys.iv_ctos[..16]).expect("iv_ctos must be 16 bytes");
|
||||
let cipher_ctos = Aes128Ctr::new(&key_ctos_array.into(), &iv_ctos_array.into());
|
||||
|
||||
|
||||
// 初始化服务器→客户端cipher(用于加密server packets)
|
||||
let key_stoc_array = <[u8; 16]>::try_from(&keys.encryption_key_stoc[..16])
|
||||
.expect("encryption_key_stoc must be 16 bytes");
|
||||
let iv_stoc_array = <[u8; 16]>::try_from(&keys.iv_stoc[..16])
|
||||
.expect("iv_stoc must be 16 bytes");
|
||||
let iv_stoc_array =
|
||||
<[u8; 16]>::try_from(&keys.iv_stoc[..16]).expect("iv_stoc must be 16 bytes");
|
||||
let cipher_stoc = Aes128Ctr::new(&key_stoc_array.into(), &iv_stoc_array.into());
|
||||
|
||||
|
||||
info!("Ciphers initialized successfully");
|
||||
|
||||
|
||||
Self {
|
||||
session_id: keys.session_id.clone(),
|
||||
encryption_key_ctos: keys.encryption_key_ctos.clone(),
|
||||
@@ -84,26 +90,26 @@ impl EncryptionContext {
|
||||
iv_stoc: keys.iv_stoc.clone(),
|
||||
sequence_number_ctos: 0,
|
||||
sequence_number_stoc: 0,
|
||||
cipher_ctos: Some(cipher_ctos), // 持久化cipher实例
|
||||
cipher_stoc: Some(cipher_stoc), // 持久化cipher实例
|
||||
cipher_ctos: Some(cipher_ctos), // 持久化cipher实例
|
||||
cipher_stoc: Some(cipher_stoc), // 持久化cipher实例
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// RFC 4344: Compute AES-CTR IV for a specific packet
|
||||
/// IV = nonce(8 bytes from derived IV) + sequence_number(8 bytes)
|
||||
fn compute_ctr_iv(nonce: &[u8], sequence_number: u32) -> Vec<u8> {
|
||||
let mut iv = Vec::with_capacity(16);
|
||||
|
||||
|
||||
// Nonce: first 8 bytes of derived IV (constant)
|
||||
iv.extend_from_slice(&nonce[..8]);
|
||||
|
||||
|
||||
// Counter: sequence number as 8-byte big-endian
|
||||
iv.extend_from_slice(&sequence_number.to_be_bytes());
|
||||
iv.extend_from_slice(&[0u8; 4]); // Upper 4 bytes = 0
|
||||
|
||||
|
||||
iv
|
||||
}
|
||||
|
||||
|
||||
/// 加密packet(参考OpenSSH cipher.c: cipher_encrypt())
|
||||
pub fn encrypt_packet(
|
||||
&mut self,
|
||||
@@ -113,17 +119,17 @@ impl EncryptionContext {
|
||||
) -> Result<Vec<u8>> {
|
||||
let key_array = <[u8; 16]>::try_from(encryption_key)?;
|
||||
let iv_array = <[u8; 16]>::try_from(iv)?;
|
||||
|
||||
|
||||
let mut cipher = Aes128Ctr::new(&key_array.into(), &iv_array.into());
|
||||
|
||||
|
||||
let mut ciphertext = plaintext.to_vec();
|
||||
cipher.apply_keystream(&mut ciphertext);
|
||||
|
||||
|
||||
self.sequence_number_stoc += 1;
|
||||
|
||||
|
||||
Ok(ciphertext)
|
||||
}
|
||||
|
||||
|
||||
/// 解密packet(参考OpenSSH cipher.c: cipher_decrypt())
|
||||
pub fn decrypt_packet(
|
||||
&mut self,
|
||||
@@ -133,17 +139,17 @@ impl EncryptionContext {
|
||||
) -> Result<Vec<u8>> {
|
||||
let key_array = <[u8; 16]>::try_from(encryption_key)?;
|
||||
let iv_array = <[u8; 16]>::try_from(iv)?;
|
||||
|
||||
|
||||
let mut cipher = Aes128Ctr::new(&key_array.into(), &iv_array.into());
|
||||
|
||||
|
||||
let mut plaintext = ciphertext.to_vec();
|
||||
cipher.apply_keystream(&mut plaintext);
|
||||
|
||||
|
||||
self.sequence_number_ctos += 1;
|
||||
|
||||
|
||||
Ok(plaintext)
|
||||
}
|
||||
|
||||
|
||||
/// 计算MAC(参考OpenSSH mac.c: mac_compute())
|
||||
pub fn compute_mac(
|
||||
&self,
|
||||
@@ -152,17 +158,17 @@ impl EncryptionContext {
|
||||
mac_key: &[u8],
|
||||
) -> Result<Vec<u8>> {
|
||||
// HMAC-SHA256 MAC计算(参考OpenSSH mac.c)
|
||||
|
||||
|
||||
let mut mac = HmacSha256::new_from_slice(mac_key)?;
|
||||
|
||||
|
||||
// OpenSSH MAC格式:sequence_number + data
|
||||
mac.update(&sequence_number.to_be_bytes());
|
||||
mac.update(data);
|
||||
|
||||
|
||||
let result = mac.finalize();
|
||||
Ok(result.into_bytes().to_vec())
|
||||
}
|
||||
|
||||
|
||||
/// 验证MAC(参考OpenSSH mac.c: mac_check())
|
||||
pub fn verify_mac(
|
||||
&self,
|
||||
@@ -172,14 +178,14 @@ impl EncryptionContext {
|
||||
mac_key: &[u8],
|
||||
) -> Result<bool> {
|
||||
// HMAC验证(参考OpenSSH mac.c)
|
||||
|
||||
|
||||
let computed_mac = self.compute_mac(sequence_number, data, mac_key)?;
|
||||
|
||||
|
||||
// 防止时间攻击(使用常量时间比较)
|
||||
if computed_mac.len() != expected_mac.len() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
|
||||
// 简化实现:直接比较(实际应使用常量时间比较)
|
||||
Ok(computed_mac == expected_mac)
|
||||
}
|
||||
@@ -187,11 +193,11 @@ impl EncryptionContext {
|
||||
|
||||
/// SSH加密packet封装(参考OpenSSH packet.c: ssh_packet_write_poll())
|
||||
pub struct EncryptedPacket {
|
||||
pub packet_length: u32, // 加密后packet长度
|
||||
pub padding_length: u8, // padding长度(加密后)
|
||||
pub payload: Vec<u8>, // payload(加密后)
|
||||
pub padding: Vec<u8>, // padding(加密后)
|
||||
pub mac: Vec<u8>, // MAC(32字节,HMAC-SHA256)
|
||||
pub packet_length: u32, // 加密后packet长度
|
||||
pub padding_length: u8, // padding长度(加密后)
|
||||
pub payload: Vec<u8>, // payload(加密后)
|
||||
pub padding: Vec<u8>, // padding(加密后)
|
||||
pub mac: Vec<u8>, // MAC(32字节,HMAC-SHA256)
|
||||
}
|
||||
|
||||
impl EncryptedPacket {
|
||||
@@ -204,82 +210,88 @@ impl EncryptedPacket {
|
||||
) -> Result<Self> {
|
||||
let block_size = 16;
|
||||
let min_padding = 4;
|
||||
|
||||
|
||||
let payload_length = plaintext_payload.len();
|
||||
|
||||
|
||||
// RFC 4253: entire plaintext packet (including 4-byte packet_length field) must be multiple of block_size
|
||||
// plaintext_packet = packet_length_field(4) + padding_length(1) + payload + padding
|
||||
// So: (4 + 1 + payload_length + padding_length) % 16 == 0
|
||||
|
||||
let base_size = 4 + 1 + payload_length; // without padding
|
||||
|
||||
let base_size = 4 + 1 + payload_length; // without padding
|
||||
let padding_needed = (block_size - (base_size % block_size)) % block_size;
|
||||
|
||||
|
||||
// Ensure padding >= min_padding (RFC 4253 requirement)
|
||||
let padding_length: u8 = if padding_needed < min_padding {
|
||||
(padding_needed + block_size) as u8 // Add one more block to meet minimum
|
||||
(padding_needed + block_size) as u8 // Add one more block to meet minimum
|
||||
} else {
|
||||
padding_needed as u8
|
||||
};
|
||||
|
||||
|
||||
// packet_length = padding_length(1) + payload + padding
|
||||
let packet_length = 1 + payload_length + padding_length as usize;
|
||||
|
||||
info!("Creating AES-CTR encrypted packet: payload_len={}, padding_len={}, packet_len={}",
|
||||
payload_length, padding_length, packet_length);
|
||||
|
||||
|
||||
info!(
|
||||
"Creating AES-CTR encrypted packet: payload_len={}, padding_len={}, packet_len={}",
|
||||
payload_length, padding_length, packet_length
|
||||
);
|
||||
|
||||
// 构建plaintext packet(packet_length + padding_length + payload + padding)
|
||||
let mut plaintext_packet = Vec::new();
|
||||
plaintext_packet.write_u32::<BigEndian>(packet_length as u32)?; // plaintext packet_length
|
||||
plaintext_packet.write_u8(padding_length)?; // plaintext padding_length
|
||||
plaintext_packet.write_all(plaintext_payload)?; // plaintext payload
|
||||
|
||||
plaintext_packet.write_u32::<BigEndian>(packet_length as u32)?; // plaintext packet_length
|
||||
plaintext_packet.write_u8(padding_length)?; // plaintext padding_length
|
||||
plaintext_packet.write_all(plaintext_payload)?; // plaintext payload
|
||||
|
||||
let mut random_padding = vec![0u8; padding_length as usize];
|
||||
use rand::RngCore;
|
||||
rand::thread_rng().fill_bytes(&mut random_padding);
|
||||
plaintext_packet.write_all(&random_padding)?; // plaintext padding
|
||||
|
||||
plaintext_packet.write_all(&random_padding)?; // plaintext padding
|
||||
|
||||
info!("Plaintext packet size: {} bytes", plaintext_packet.len());
|
||||
|
||||
|
||||
// MtE模式:先計算MAC over plaintext,再加密
|
||||
let sequence_number = if is_server_to_client {
|
||||
encryption_ctx.sequence_number_stoc
|
||||
} else {
|
||||
encryption_ctx.sequence_number_ctos
|
||||
};
|
||||
|
||||
|
||||
let mac_key = if is_server_to_client {
|
||||
&encryption_ctx.mac_key_stoc
|
||||
} else {
|
||||
&encryption_ctx.mac_key_ctos
|
||||
};
|
||||
|
||||
|
||||
info!("MAC calculation (MtE mode) over plaintext packet:");
|
||||
info!(" sequence_number: {}", sequence_number);
|
||||
info!(" mac_key length: {}", mac_key.len());
|
||||
info!(" plaintext_packet length: {}", plaintext_packet.len());
|
||||
|
||||
|
||||
// MAC計算:HMAC(sequence_number || plaintext_packet)
|
||||
let mac = encryption_ctx.compute_mac(sequence_number, &plaintext_packet, mac_key)?;
|
||||
|
||||
|
||||
// 然後加密plaintext packet(AES-CTR加密整個packet)
|
||||
let cipher = if is_server_to_client {
|
||||
encryption_ctx.cipher_stoc.as_mut()
|
||||
encryption_ctx
|
||||
.cipher_stoc
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("cipher_stoc not initialized"))?
|
||||
} else {
|
||||
encryption_ctx.cipher_ctos.as_mut()
|
||||
encryption_ctx
|
||||
.cipher_ctos
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("cipher_ctos not initialized"))?
|
||||
};
|
||||
|
||||
|
||||
let mut encrypted_packet = plaintext_packet;
|
||||
cipher.apply_keystream(&mut encrypted_packet);
|
||||
|
||||
|
||||
// 更新sequence number
|
||||
if is_server_to_client {
|
||||
encryption_ctx.sequence_number_stoc += 1;
|
||||
} else {
|
||||
encryption_ctx.sequence_number_ctos += 1;
|
||||
}
|
||||
|
||||
|
||||
Ok(Self {
|
||||
packet_length: packet_length as u32,
|
||||
padding_length,
|
||||
@@ -288,24 +300,27 @@ impl EncryptedPacket {
|
||||
mac,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// 写入加密packet(参考OpenSSH cipher.c)
|
||||
/// AES-CTR模式:写入完整加密packet + MAC
|
||||
pub fn write<W: std::io::Write>(&self, stream: &mut W) -> Result<()> {
|
||||
info!("Writing AES-CTR encrypted packet: total_encrypted_len={}, mac_len={}",
|
||||
self.payload.len(), self.mac.len());
|
||||
|
||||
info!(
|
||||
"Writing AES-CTR encrypted packet: total_encrypted_len={}, mac_len={}",
|
||||
self.payload.len(),
|
||||
self.mac.len()
|
||||
);
|
||||
|
||||
// AES-CTR: 整个packet已加密(包括packet_length),直接写入
|
||||
stream.write_all(&self.payload)?;
|
||||
info!("Wrote encrypted packet ({} bytes)", self.payload.len());
|
||||
|
||||
|
||||
// 写入MAC
|
||||
stream.write_all(&self.mac)?;
|
||||
info!("Wrote MAC ({} bytes)", self.mac.len());
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 读取加密packet(参考OpenSSH packet.c ssh_packet_read_poll2)
|
||||
/// OpenSSH packet.c: AES-CTR先解密第一个块,再提取packet_length
|
||||
/// aadlen = 0 (没有EtM或authenticated encryption), packet_length被加密
|
||||
@@ -315,32 +330,42 @@ impl EncryptedPacket {
|
||||
is_client_to_server: bool,
|
||||
) -> Result<Self> {
|
||||
use std::io::Read;
|
||||
|
||||
|
||||
info!("Reading AES-CTR encrypted packet (packet_length encrypted)");
|
||||
|
||||
|
||||
// 1. 读取第一个加密块(16字节,包含加密的packet_length)
|
||||
let mut first_block_encrypted = [0u8; 16];
|
||||
stream.read_exact(&mut first_block_encrypted)?;
|
||||
|
||||
info!("Read first encrypted block (16 bytes): {:?}", &first_block_encrypted);
|
||||
|
||||
|
||||
info!(
|
||||
"Read first encrypted block (16 bytes): {:?}",
|
||||
&first_block_encrypted
|
||||
);
|
||||
|
||||
// 2. 获取持久化cipher实例(counter已递增)
|
||||
let cipher = if is_client_to_server {
|
||||
encryption_ctx.cipher_ctos.as_mut()
|
||||
encryption_ctx
|
||||
.cipher_ctos
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("cipher_ctos not initialized"))?
|
||||
} else {
|
||||
encryption_ctx.cipher_stoc.as_mut()
|
||||
encryption_ctx
|
||||
.cipher_stoc
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("cipher_stoc not initialized"))?
|
||||
};
|
||||
|
||||
info!("Using cipher for decryption (is_client_to_server={})", is_client_to_server);
|
||||
|
||||
|
||||
info!(
|
||||
"Using cipher for decryption (is_client_to_server={})",
|
||||
is_client_to_server
|
||||
);
|
||||
|
||||
// 3. 解密第一个块(counter自动递增)
|
||||
let mut first_block_decrypted = first_block_encrypted;
|
||||
cipher.apply_keystream(&mut first_block_decrypted);
|
||||
|
||||
|
||||
info!("Decrypted first block: {:?}", &first_block_decrypted);
|
||||
|
||||
|
||||
// 3. 从解密后的数据中提取packet_length(前4字节)和padding_length(第5字节)
|
||||
let packet_length = u32::from_be_bytes([
|
||||
first_block_decrypted[0],
|
||||
@@ -349,67 +374,73 @@ impl EncryptedPacket {
|
||||
first_block_decrypted[3],
|
||||
]);
|
||||
let padding_length = first_block_decrypted[4];
|
||||
|
||||
info!("Decrypted packet_length={}, padding_length={}", packet_length, padding_length);
|
||||
|
||||
|
||||
info!(
|
||||
"Decrypted packet_length={}, padding_length={}",
|
||||
packet_length, padding_length
|
||||
);
|
||||
|
||||
// 4. 合理性检查
|
||||
if packet_length > 35000 {
|
||||
info!("packet_length raw bytes: {:?}", &first_block_decrypted[..4]);
|
||||
return Err(anyhow!("Invalid packet_length: {}", packet_length));
|
||||
}
|
||||
|
||||
|
||||
// 3. 计算剩余加密数据长度
|
||||
// packet_length = padding_length(1) + payload + padding
|
||||
// 总加密数据 = packet_length(4) + packet_length = packet_length + 4
|
||||
// 已读取16字节,剩余 = packet_length + 4 - 16
|
||||
let total_encrypted_size = packet_length as usize + 4; // packet_length field + content
|
||||
let total_encrypted_size = packet_length as usize + 4; // packet_length field + content
|
||||
let remaining_encrypted_size = total_encrypted_size - 16;
|
||||
|
||||
info!("Total encrypted size: {}, remaining: {}", total_encrypted_size, remaining_encrypted_size);
|
||||
|
||||
|
||||
info!(
|
||||
"Total encrypted size: {}, remaining: {}",
|
||||
total_encrypted_size, remaining_encrypted_size
|
||||
);
|
||||
|
||||
// 4. 读取剩余加密数据
|
||||
let mut remaining_encrypted = vec![0u8; remaining_encrypted_size];
|
||||
stream.read_exact(&mut remaining_encrypted)?;
|
||||
|
||||
|
||||
// 5. 继续解密(使用同一个cipher)
|
||||
cipher.apply_keystream(&mut remaining_encrypted);
|
||||
|
||||
|
||||
info!("Remaining decrypted data: {:?}", &remaining_encrypted);
|
||||
|
||||
|
||||
// 6. 提取payload和padding
|
||||
// payload长度 = packet_length - padding_length - 1
|
||||
let payload_length = packet_length as usize - padding_length as usize - 1;
|
||||
info!("Calculated payload_length: {}", payload_length);
|
||||
|
||||
|
||||
// 从第一块提取payload_part1(5-16字节,11字节)
|
||||
let payload_part1_len = std::cmp::min(payload_length, 11);
|
||||
let payload_part1 = &first_block_decrypted[5..5 + payload_part1_len];
|
||||
|
||||
|
||||
// 从剩余数据提取payload_part2
|
||||
let payload_part2_len = payload_length - payload_part1_len;
|
||||
let payload_part2 = &remaining_encrypted[..payload_part2_len];
|
||||
|
||||
|
||||
// 合并payload
|
||||
let mut payload = Vec::new();
|
||||
payload.extend_from_slice(payload_part1);
|
||||
payload.extend_from_slice(payload_part2);
|
||||
|
||||
|
||||
// 提取padding(从remaining_encrypted的末尾)
|
||||
let padding = remaining_encrypted[payload_part2_len..].to_vec();
|
||||
|
||||
|
||||
// 9. 读取MAC
|
||||
info!("Reading MAC (32 bytes)...");
|
||||
let mut mac = vec![0u8; 32];
|
||||
stream.read_exact(&mut mac)?;
|
||||
info!("MAC read successfully");
|
||||
|
||||
|
||||
// 10. 更新sequence number
|
||||
if is_client_to_server {
|
||||
encryption_ctx.sequence_number_ctos += 1;
|
||||
} else {
|
||||
encryption_ctx.sequence_number_stoc += 1;
|
||||
}
|
||||
|
||||
|
||||
Ok(Self {
|
||||
packet_length,
|
||||
padding_length,
|
||||
@@ -418,7 +449,7 @@ impl EncryptedPacket {
|
||||
mac,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// 获取payload内容
|
||||
pub fn payload(&self) -> &[u8] {
|
||||
&self.payload
|
||||
@@ -428,13 +459,13 @@ impl EncryptedPacket {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_aes256_ctr_encryption() {
|
||||
let key = vec![0u8; 16]; // AES-128 key (16 bytes)
|
||||
let key = vec![0u8; 16]; // AES-128 key (16 bytes)
|
||||
let iv = vec![0u8; 16];
|
||||
let plaintext = b"Hello World";
|
||||
|
||||
|
||||
let mut ctx = EncryptionContext::from_session_keys(&SessionKeys {
|
||||
session_id: vec![0u8; 32],
|
||||
encryption_key_ctos: key.clone(),
|
||||
@@ -444,18 +475,18 @@ mod tests {
|
||||
iv_ctos: iv.clone(),
|
||||
iv_stoc: iv.clone(),
|
||||
});
|
||||
|
||||
|
||||
let ciphertext = ctx.encrypt_packet(plaintext, &key, &iv).unwrap();
|
||||
let decrypted = ctx.decrypt_packet(&ciphertext, &key, &iv).unwrap();
|
||||
|
||||
|
||||
assert_eq!(plaintext.to_vec(), decrypted);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_hmac_sha256() {
|
||||
let key = vec![0u8; 32];
|
||||
let data = b"test data";
|
||||
|
||||
|
||||
let ctx = EncryptionContext::from_session_keys(&SessionKeys {
|
||||
session_id: vec![0u8; 32],
|
||||
encryption_key_ctos: vec![0u8; 32],
|
||||
@@ -465,10 +496,10 @@ mod tests {
|
||||
iv_ctos: vec![0u8; 16],
|
||||
iv_stoc: vec![0u8; 16],
|
||||
});
|
||||
|
||||
|
||||
let mac = ctx.compute_mac(1, data, &key).unwrap();
|
||||
assert_eq!(mac.len(), 32); // HMAC-SHA256 = 32字节
|
||||
|
||||
assert_eq!(mac.len(), 32); // HMAC-SHA256 = 32字节
|
||||
|
||||
// 验证MAC
|
||||
assert!(ctx.verify_mac(1, data, &mac, &key).unwrap());
|
||||
}
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
// SSH加密模块(Phase 3:密钥交换)
|
||||
// 参考OpenSSH curve25519.c, kex.c
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use x25519_dalek::{EphemeralSecret, PublicKey, SharedSecret};
|
||||
use ed25519_dalek::{SigningKey, VerifyingKey, Signature, Signer};
|
||||
use sha2::{Sha256, Digest};
|
||||
use log::{info, debug};
|
||||
use anyhow::{anyhow, Result};
|
||||
use ed25519_dalek::{Signer, SigningKey};
|
||||
use log::info;
|
||||
use rand::rngs::OsRng;
|
||||
use sha2::{Digest, Sha256};
|
||||
use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||
|
||||
/// Curve25519密钥交换处理器(参考OpenSSH curve25519.c)
|
||||
pub struct Curve25519Kex {
|
||||
secret: Option<EphemeralSecret>, // 使用Option包装(一次性使用类型)
|
||||
secret: Option<EphemeralSecret>, // 使用Option包装(一次性使用类型)
|
||||
public: PublicKey,
|
||||
}
|
||||
|
||||
@@ -21,34 +21,37 @@ impl Curve25519Kex {
|
||||
// x25519-dalek 2.0标准API:使用random_from_rng
|
||||
let secret = EphemeralSecret::random_from_rng(OsRng);
|
||||
let public = PublicKey::from(&secret);
|
||||
|
||||
Self { secret: Some(secret), public } // Some包装
|
||||
|
||||
Self {
|
||||
secret: Some(secret),
|
||||
public,
|
||||
} // Some包装
|
||||
}
|
||||
|
||||
|
||||
/// 获取公钥(用于SSH_MSG_KEX_ECDH_INIT)
|
||||
pub fn public_key(&self) -> &[u8] {
|
||||
self.public.as_bytes()
|
||||
}
|
||||
|
||||
|
||||
/// 计算共享密钥(参考OpenSSH curve25519_shared_secret())
|
||||
/// 使用&mut self(消耗模式,符合OpenSSH设计)
|
||||
pub fn compute_shared_secret(&mut self, client_public: &[u8]) -> Result<[u8; 32]> {
|
||||
if client_public.len() != 32 {
|
||||
return Err(anyhow!("Invalid client public key length"));
|
||||
}
|
||||
|
||||
|
||||
info!("=== X25519 Shared Secret Calculation ===");
|
||||
info!("Client public key input: {:?}", client_public);
|
||||
info!("Server public key: {:?}", self.public.as_bytes());
|
||||
|
||||
|
||||
// 参考OpenSSH:curve25519共享密钥计算
|
||||
let client_public_key = PublicKey::from(<[u8; 32]>::try_from(client_public)?);
|
||||
|
||||
|
||||
// 使用take()取出secret(Rust标准模式)
|
||||
if let Some(secret) = self.secret.take() {
|
||||
let shared_secret = secret.diffie_hellman(&client_public_key);
|
||||
info!("Computed shared secret: {:?}", shared_secret.as_bytes());
|
||||
Ok(shared_secret.as_bytes().clone())
|
||||
Ok(*shared_secret.as_bytes())
|
||||
} else {
|
||||
Err(anyhow!("Secret already used"))
|
||||
}
|
||||
@@ -71,47 +74,85 @@ impl SessionKeys {
|
||||
/// RFC 4253 Section 7.2: Key = HASH(K || H || X || session_id)
|
||||
pub fn derive(
|
||||
shared_secret: &[u8],
|
||||
exchange_hash: &[u8], // H参数(exchange hash)
|
||||
server_public_key: &[u8],
|
||||
client_public_key: &[u8],
|
||||
server_host_key: &[u8],
|
||||
exchange_hash: &[u8], // H参数(exchange hash)
|
||||
_server_public_key: &[u8],
|
||||
_client_public_key: &[u8],
|
||||
_server_host_key: &[u8],
|
||||
) -> Result<Self> {
|
||||
// RFC 4253: session_id = H (第一次exchange hash)
|
||||
let session_id = exchange_hash.to_vec();
|
||||
|
||||
|
||||
info!("SessionKeys::derive() starting");
|
||||
info!(" shared_secret full (32 bytes): {:?}", shared_secret);
|
||||
|
||||
|
||||
// RFC 8731 Section 3.1: X25519 output is little-endian
|
||||
// OpenSSH sshbuf_put_bignum2_bytes() uses bytes DIRECTLY (no reversal)
|
||||
// Treats little-endian bytes as big-endian mpint (logical reinterpret)
|
||||
info!(" Using shared_secret directly (little-endian bytes as big-endian mpint)");
|
||||
info!(" shared_secret[0] = {} (>=0x80? {})", shared_secret[0], shared_secret[0] >= 0x80);
|
||||
info!(
|
||||
" shared_secret[0] = {} (>=0x80? {})",
|
||||
shared_secret[0],
|
||||
shared_secret[0] >= 0x80
|
||||
);
|
||||
info!(" exchange_hash full (32 bytes): {:?}", exchange_hash);
|
||||
info!(" session_id full (32 bytes): {:?}", session_id);
|
||||
|
||||
|
||||
// RFC 4253密钥派生公式:HASH(K || H || X || session_id)
|
||||
// K is shared_secret encoded as mpint (using little-endian bytes directly)
|
||||
let shared_secret_mpint = Self::encode_mpint(shared_secret);
|
||||
|
||||
info!(" shared_secret_mpint ({} bytes): {:?}", shared_secret_mpint.len(), &shared_secret_mpint[..std::cmp::min(12, shared_secret_mpint.len())]);
|
||||
|
||||
let encryption_key_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'C', &session_id)?;
|
||||
let encryption_key_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'D', &session_id)?;
|
||||
let mac_key_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'E', &session_id)?;
|
||||
let mac_key_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'F', &session_id)?;
|
||||
|
||||
let iv_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'A', &session_id)?;
|
||||
let iv_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'B', &session_id)?;
|
||||
|
||||
|
||||
info!(
|
||||
" shared_secret_mpint ({} bytes): {:?}",
|
||||
shared_secret_mpint.len(),
|
||||
&shared_secret_mpint[..std::cmp::min(12, shared_secret_mpint.len())]
|
||||
);
|
||||
|
||||
let encryption_key_ctos =
|
||||
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'C', &session_id)?;
|
||||
let encryption_key_stoc =
|
||||
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'D', &session_id)?;
|
||||
let mac_key_ctos =
|
||||
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'E', &session_id)?;
|
||||
let mac_key_stoc =
|
||||
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'F', &session_id)?;
|
||||
|
||||
let iv_ctos =
|
||||
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'A', &session_id)?;
|
||||
let iv_stoc =
|
||||
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'B', &session_id)?;
|
||||
|
||||
info!("Derived keys summary:");
|
||||
info!(" encryption_key_ctos ({} bytes): {:?}", encryption_key_ctos.len(), &encryption_key_ctos[..std::cmp::min(16, encryption_key_ctos.len())]);
|
||||
info!(" encryption_key_stoc ({} bytes): {:?}", encryption_key_stoc.len(), &encryption_key_stoc[..std::cmp::min(16, encryption_key_stoc.len())]);
|
||||
info!(" iv_ctos ({} bytes): {:?}", iv_ctos.len(), &iv_ctos[..std::cmp::min(16, iv_ctos.len())]);
|
||||
info!(" iv_stoc ({} bytes): {:?}", iv_stoc.len(), &iv_stoc[..std::cmp::min(16, iv_stoc.len())]);
|
||||
info!(" mac_key_ctos ({} bytes): {:?}", mac_key_ctos.len(), &mac_key_ctos[..std::cmp::min(16, mac_key_ctos.len())]);
|
||||
info!(" mac_key_stoc ({} bytes): {:?}", mac_key_stoc.len(), &mac_key_stoc[..std::cmp::min(16, mac_key_stoc.len())]);
|
||||
|
||||
info!(
|
||||
" encryption_key_ctos ({} bytes): {:?}",
|
||||
encryption_key_ctos.len(),
|
||||
&encryption_key_ctos[..std::cmp::min(16, encryption_key_ctos.len())]
|
||||
);
|
||||
info!(
|
||||
" encryption_key_stoc ({} bytes): {:?}",
|
||||
encryption_key_stoc.len(),
|
||||
&encryption_key_stoc[..std::cmp::min(16, encryption_key_stoc.len())]
|
||||
);
|
||||
info!(
|
||||
" iv_ctos ({} bytes): {:?}",
|
||||
iv_ctos.len(),
|
||||
&iv_ctos[..std::cmp::min(16, iv_ctos.len())]
|
||||
);
|
||||
info!(
|
||||
" iv_stoc ({} bytes): {:?}",
|
||||
iv_stoc.len(),
|
||||
&iv_stoc[..std::cmp::min(16, iv_stoc.len())]
|
||||
);
|
||||
info!(
|
||||
" mac_key_ctos ({} bytes): {:?}",
|
||||
mac_key_ctos.len(),
|
||||
&mac_key_ctos[..std::cmp::min(16, mac_key_ctos.len())]
|
||||
);
|
||||
info!(
|
||||
" mac_key_stoc ({} bytes): {:?}",
|
||||
mac_key_stoc.len(),
|
||||
&mac_key_stoc[..std::cmp::min(16, mac_key_stoc.len())]
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
session_id,
|
||||
encryption_key_ctos,
|
||||
@@ -122,65 +163,73 @@ impl SessionKeys {
|
||||
iv_stoc,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// RFC 4253密钥派生函数
|
||||
/// 公式:Key = HASH(K || H || X || session_id)
|
||||
fn derive_key_rfc4253(K_mpint: &[u8], H: &[u8], X: char, session_id: &[u8]) -> Result<Vec<u8>> {
|
||||
let mut hasher = Sha256::new();
|
||||
|
||||
|
||||
info!("Deriving key for X='{}'", X);
|
||||
info!(" K_mpint ({} bytes): {:?}", K_mpint.len(), &K_mpint[..std::cmp::min(8, K_mpint.len())]);
|
||||
info!(
|
||||
" K_mpint ({} bytes): {:?}",
|
||||
K_mpint.len(),
|
||||
&K_mpint[..std::cmp::min(8, K_mpint.len())]
|
||||
);
|
||||
info!(" H ({} bytes): {:?}", H.len(), &H[..8]);
|
||||
info!(" session_id ({} bytes): {:?}", session_id.len(), &session_id[..8]);
|
||||
|
||||
info!(
|
||||
" session_id ({} bytes): {:?}",
|
||||
session_id.len(),
|
||||
&session_id[..8]
|
||||
);
|
||||
|
||||
// RFC 4253: HASH(K || H || X || session_id)
|
||||
hasher.update(K_mpint); // K (shared secret in mpint format)
|
||||
hasher.update(H); // H (exchange hash)
|
||||
hasher.update(&[X as u8]); // X (single character)
|
||||
hasher.update(K_mpint); // K (shared secret in mpint format)
|
||||
hasher.update(H); // H (exchange hash)
|
||||
hasher.update([X as u8]); // X (single character)
|
||||
hasher.update(session_id); // session_id
|
||||
|
||||
|
||||
let full_hash = hasher.finalize();
|
||||
|
||||
|
||||
info!(" Derived key (first 8 bytes): {:?}", &full_hash[..8]);
|
||||
|
||||
|
||||
// 根據key類型返回不同長度:
|
||||
// AES-128-CTR key/IV: 16 bytes
|
||||
// HMAC-SHA256 key: 32 bytes
|
||||
match X {
|
||||
'A' | 'B' | 'C' | 'D' => Ok(full_hash[..16].to_vec()), // IV or encryption key
|
||||
'A' | 'B' | 'C' | 'D' => Ok(full_hash[..16].to_vec()), // IV or encryption key
|
||||
'E' | 'F' => Ok(full_hash.to_vec()), // MAC key (full 32 bytes)
|
||||
_ => Ok(full_hash[..16].to_vec()), // default
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// SSH mpint编码(参考RFC 4253 Section 5)
|
||||
/// Curve25519 shared secret特殊处理
|
||||
fn encode_mpint(bytes: &[u8]) -> Vec<u8> {
|
||||
// RFC 4253: mpint = uint32(length) + data
|
||||
// 去掉前导零,如果最高位>=0x80前面加0
|
||||
|
||||
|
||||
// 去掉前导零字节(但不去掉最后一个字节即使它是0)
|
||||
let mut start = 0;
|
||||
while start < bytes.len() - 1 && bytes[start] == 0 {
|
||||
start += 1;
|
||||
}
|
||||
|
||||
|
||||
let data_without_leading_zeros = &bytes[start..];
|
||||
|
||||
|
||||
// 构建mpint数据
|
||||
let mut mpint_data = Vec::new();
|
||||
|
||||
|
||||
// 如果最高位>=0x80,前面加0字节(避免负数)
|
||||
if data_without_leading_zeros[0] >= 0x80 {
|
||||
mpint_data.push(0);
|
||||
}
|
||||
mpint_data.extend_from_slice(data_without_leading_zeros);
|
||||
|
||||
|
||||
// 最终格式:uint32长度 + mpint数据
|
||||
let mut result = Vec::new();
|
||||
result.extend_from_slice(&(mpint_data.len() as u32).to_be_bytes());
|
||||
result.extend_from_slice(&mpint_data);
|
||||
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
@@ -192,45 +241,45 @@ pub struct Ed25519HostKey {
|
||||
|
||||
impl Ed25519HostKey {
|
||||
/// 加载或生成主机密钥(参考OpenSSH hostfile.c)
|
||||
pub fn load_or_generate(key_path: &str) -> Result<Self> {
|
||||
pub fn load_or_generate(_key_path: &str) -> Result<Self> {
|
||||
// 简化实现:生成临时密钥(实际应从文件加载)
|
||||
// 参考OpenSSH ssh-keygen
|
||||
|
||||
|
||||
let signing_key = SigningKey::generate(&mut OsRng);
|
||||
|
||||
|
||||
Ok(Self { signing_key })
|
||||
}
|
||||
|
||||
|
||||
/// 获取公钥(用于SSH_MSG_KEX_ECDH_REPLY)
|
||||
pub fn public_key_bytes(&self) -> Vec<u8> {
|
||||
// SSH Ed25519公钥格式(参考OpenSSH sshkey.c)
|
||||
let verifying_key = self.signing_key.verifying_key();
|
||||
|
||||
|
||||
// SSH格式:ssh-ed25519 + 公钥bytes
|
||||
// 简化:仅返回公钥bytes(32字节)
|
||||
verifying_key.as_bytes().to_vec()
|
||||
}
|
||||
|
||||
|
||||
/// 签名(参考OpenSSH sshkey.c: sshkey_sign())
|
||||
pub fn sign(&self, data: &[u8]) -> Result<Vec<u8>> {
|
||||
// OpenSSH Ed25519签名
|
||||
let signature = self.signing_key.sign(data);
|
||||
|
||||
|
||||
// SSH签名格式(参考OpenSSH ssh-sign.c)
|
||||
// 简化:仅返回签名bytes(64字节)
|
||||
Ok(signature.to_bytes().to_vec())
|
||||
}
|
||||
|
||||
|
||||
/// 获取完整SSH公钥格式(参考OpenSSH sshkey.c)
|
||||
pub fn ssh_public_key(&self) -> String {
|
||||
let public_bytes = self.public_key_bytes();
|
||||
|
||||
|
||||
// SSH公钥格式:ssh-ed25519 <base64-encoded-public-key>
|
||||
// 参考OpenSSH ssh-keygen -y
|
||||
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
let encoded = general_purpose::STANDARD.encode(&public_bytes);
|
||||
|
||||
|
||||
format!("ssh-ed25519 {}", encoded)
|
||||
}
|
||||
}
|
||||
@@ -238,40 +287,44 @@ impl Ed25519HostKey {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_curve25519_key_generation() {
|
||||
let kex = Curve25519Kex::new();
|
||||
assert_eq!(kex.public_key().len(), 32);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_curve25519_shared_secret() {
|
||||
let mut client_kex = Curve25519Kex::new();
|
||||
let mut server_kex = Curve25519Kex::new();
|
||||
|
||||
|
||||
// 客户端计算共享密钥
|
||||
let client_secret = client_kex.compute_shared_secret(server_kex.public_key()).unwrap();
|
||||
|
||||
let client_secret = client_kex
|
||||
.compute_shared_secret(server_kex.public_key())
|
||||
.unwrap();
|
||||
|
||||
// 服务器计算共享密钥
|
||||
let server_secret = server_kex.compute_shared_secret(client_kex.public_key()).unwrap();
|
||||
|
||||
let server_secret = server_kex
|
||||
.compute_shared_secret(client_kex.public_key())
|
||||
.unwrap();
|
||||
|
||||
// 应该相同(Curve25519特性)
|
||||
assert_eq!(client_secret, server_secret);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_ed25519_host_key() {
|
||||
let host_key = Ed25519HostKey::load_or_generate("test_key").unwrap();
|
||||
assert_eq!(host_key.public_key_bytes().len(), 32);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_ed25519_signature() {
|
||||
let host_key = Ed25519HostKey::load_or_generate("test_key").unwrap();
|
||||
let data = b"test data";
|
||||
|
||||
|
||||
let signature = host_key.sign(data).unwrap();
|
||||
assert_eq!(signature.len(), 64); // Ed25519签名64字节
|
||||
assert_eq!(signature.len(), 64); // Ed25519签名64字节
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
// SSH端口转发数据传输(Phase 13.5)
|
||||
// 参考OpenSSH channels.c: channel_handle_data()
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, warn, debug};
|
||||
use std::net::{TcpStream};
|
||||
use std::io::{Read, Write};
|
||||
use std::thread;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use anyhow::{anyhow, Result};
|
||||
use byteorder::{BigEndian, WriteBytesExt};
|
||||
use log::{debug, info, warn};
|
||||
use std::io::{Read, Write};
|
||||
use std::net::TcpStream;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
|
||||
/// 数据转发器(Phase 13.5:双向数据传输)
|
||||
pub struct DataForwarder {
|
||||
@@ -25,29 +25,40 @@ impl DataForwarder {
|
||||
max_packet_size,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 启动双向数据转发(Phase 13.5:SSH channel ↔ TCP socket)
|
||||
pub fn start_bidirectional_forwarding(
|
||||
&mut self,
|
||||
ssh_stream: TcpStream, // SSH client连接(加密通道)
|
||||
target_stream: TcpStream, // 目标服务连接(TCP socket)
|
||||
ssh_stream: TcpStream, // SSH client连接(加密通道)
|
||||
target_stream: TcpStream, // 目标服务连接(TCP socket)
|
||||
) -> Result<()> {
|
||||
info!("Starting bidirectional data forwarding for channel {}", self.channel_id);
|
||||
|
||||
info!(
|
||||
"Starting bidirectional data forwarding for channel {}",
|
||||
self.channel_id
|
||||
);
|
||||
|
||||
// Phase 13.5: SSH channel → Target socket(SSH client数据 → 本地服务)
|
||||
let ssh_to_target = self.start_ssh_to_target_forwarding(ssh_stream.try_clone()?, target_stream.try_clone()?);
|
||||
|
||||
let ssh_to_target = self
|
||||
.start_ssh_to_target_forwarding(ssh_stream.try_clone()?, target_stream.try_clone()?);
|
||||
|
||||
// Phase 13.5: Target socket → SSH channel(本地服务数据 → SSH client)
|
||||
let target_to_ssh = self.start_target_to_ssh_forwarding(target_stream, ssh_stream);
|
||||
|
||||
|
||||
// Phase 13.5: 等待两个转发线程完成
|
||||
ssh_to_target.join().map_err(|e| anyhow!("SSH to target thread error: {:?}", e))?;
|
||||
target_to_ssh.join().map_err(|e| anyhow!("Target to SSH thread error: {:?}", e))?;
|
||||
|
||||
info!("Bidirectional data forwarding completed for channel {}", self.channel_id);
|
||||
ssh_to_target
|
||||
.join()
|
||||
.map_err(|e| anyhow!("SSH to target thread error: {:?}", e))?;
|
||||
target_to_ssh
|
||||
.join()
|
||||
.map_err(|e| anyhow!("Target to SSH thread error: {:?}", e))?;
|
||||
|
||||
info!(
|
||||
"Bidirectional data forwarding completed for channel {}",
|
||||
self.channel_id
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// SSH channel → Target socket转发(Phase 13.5)
|
||||
fn start_ssh_to_target_forwarding(
|
||||
&self,
|
||||
@@ -57,18 +68,21 @@ impl DataForwarder {
|
||||
let channel_id = self.channel_id;
|
||||
let window_size = self.window_size.clone();
|
||||
let max_packet_size = self.max_packet_size;
|
||||
|
||||
|
||||
thread::spawn(move || {
|
||||
info!("SSH to target forwarding thread started for channel {}", channel_id);
|
||||
|
||||
info!(
|
||||
"SSH to target forwarding thread started for channel {}",
|
||||
channel_id
|
||||
);
|
||||
|
||||
let mut buffer = vec![0u8; max_packet_size as usize];
|
||||
|
||||
|
||||
loop {
|
||||
// Phase 13.5: 从SSH channel读取数据
|
||||
let n = match ssh_stream.read(&mut buffer) {
|
||||
Ok(0) => {
|
||||
info!("SSH channel EOF for channel {}", channel_id);
|
||||
break; // EOF
|
||||
break; // EOF
|
||||
}
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
@@ -76,45 +90,61 @@ impl DataForwarder {
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Phase 13.5: 检查window size
|
||||
{
|
||||
let window = window_size.lock().unwrap();
|
||||
if *window < n as u32 {
|
||||
warn!("Window size insufficient for channel {}: need {}, have {}",
|
||||
channel_id, n, *window);
|
||||
warn!(
|
||||
"Window size insufficient for channel {}: need {}, have {}",
|
||||
channel_id, n, *window
|
||||
);
|
||||
// Phase 13.5: 理论上应该等待SSH_MSG_CHANNEL_WINDOW_ADJUST
|
||||
// 简化实现:继续发送(可能会违反RFC 4254)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Phase 13.5: 写入目标socket
|
||||
if let Err(e) = target_stream.write_all(&buffer[..n]) {
|
||||
warn!("Target socket write error for channel {}: {}", channel_id, e);
|
||||
warn!(
|
||||
"Target socket write error for channel {}: {}",
|
||||
channel_id, e
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Phase 13.5: Flush确保数据发送
|
||||
if let Err(e) = target_stream.flush() {
|
||||
warn!("Target socket flush error for channel {}: {}", channel_id, e);
|
||||
warn!(
|
||||
"Target socket flush error for channel {}: {}",
|
||||
channel_id, e
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Phase 13.5: 消耗window size
|
||||
{
|
||||
let mut window = window_size.lock().unwrap();
|
||||
*window -= n as u32;
|
||||
debug!("Window size consumed for channel {}: {} bytes, remaining {}",
|
||||
channel_id, n, *window);
|
||||
debug!(
|
||||
"Window size consumed for channel {}: {} bytes, remaining {}",
|
||||
channel_id, n, *window
|
||||
);
|
||||
}
|
||||
|
||||
info!("Forwarded {} bytes from SSH to target for channel {}", n, channel_id);
|
||||
|
||||
info!(
|
||||
"Forwarded {} bytes from SSH to target for channel {}",
|
||||
n, channel_id
|
||||
);
|
||||
}
|
||||
|
||||
info!("SSH to target forwarding thread stopped for channel {}", channel_id);
|
||||
|
||||
info!(
|
||||
"SSH to target forwarding thread stopped for channel {}",
|
||||
channel_id
|
||||
);
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// Target socket → SSH channel转发(Phase 13.5)
|
||||
fn start_target_to_ssh_forwarding(
|
||||
&self,
|
||||
@@ -122,18 +152,21 @@ impl DataForwarder {
|
||||
mut ssh_stream: TcpStream,
|
||||
) -> thread::JoinHandle<()> {
|
||||
let channel_id = self.channel_id;
|
||||
|
||||
|
||||
thread::spawn(move || {
|
||||
info!("Target to SSH forwarding thread started for channel {}", channel_id);
|
||||
|
||||
let mut buffer = vec![0u8; 8192]; // 8KB buffer
|
||||
|
||||
info!(
|
||||
"Target to SSH forwarding thread started for channel {}",
|
||||
channel_id
|
||||
);
|
||||
|
||||
let mut buffer = vec![0u8; 8192]; // 8KB buffer
|
||||
|
||||
loop {
|
||||
// Phase 13.5: 从目标socket读取数据
|
||||
let n = match target_stream.read(&mut buffer) {
|
||||
Ok(0) => {
|
||||
info!("Target socket EOF for channel {}", channel_id);
|
||||
break; // EOF
|
||||
break; // EOF
|
||||
}
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
@@ -141,43 +174,51 @@ impl DataForwarder {
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Phase 13.5: 构建SSH_MSG_CHANNEL_DATA packet
|
||||
// 注意:实际实现需要通过EncryptedPacket加密
|
||||
// 这里简化实现,直接写入SSH stream(测试用)
|
||||
|
||||
|
||||
// Phase 13.5: 写入SSH channel
|
||||
if let Err(e) = ssh_stream.write_all(&buffer[..n]) {
|
||||
warn!("SSH channel write error for channel {}: {}", channel_id, e);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Phase 13.5: Flush确保数据发送
|
||||
if let Err(e) = ssh_stream.flush() {
|
||||
warn!("SSH channel flush error for channel {}: {}", channel_id, e);
|
||||
break;
|
||||
}
|
||||
|
||||
info!("Forwarded {} bytes from target to SSH for channel {}", n, channel_id);
|
||||
|
||||
info!(
|
||||
"Forwarded {} bytes from target to SSH for channel {}",
|
||||
n, channel_id
|
||||
);
|
||||
}
|
||||
|
||||
info!("Target to SSH forwarding thread stopped for channel {}", channel_id);
|
||||
|
||||
info!(
|
||||
"Target to SSH forwarding thread stopped for channel {}",
|
||||
channel_id
|
||||
);
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// 获取当前window size(Phase 13.5)
|
||||
pub fn get_window_size(&self) -> u32 {
|
||||
*self.window_size.lock().unwrap()
|
||||
}
|
||||
|
||||
|
||||
/// 增加window size(Phase 13.5:SSH_MSG_CHANNEL_WINDOW_ADJUST)
|
||||
pub fn adjust_window_size(&self, bytes_to_add: u32) {
|
||||
let mut window = self.window_size.lock().unwrap();
|
||||
*window += bytes_to_add;
|
||||
info!("Window size adjusted for channel {}: added {} bytes, total {}",
|
||||
self.channel_id, bytes_to_add, *window);
|
||||
info!(
|
||||
"Window size adjusted for channel {}: added {} bytes, total {}",
|
||||
self.channel_id, bytes_to_add, *window
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
/// 检查window size是否足够(Phase 13.5)
|
||||
pub fn check_window_available(&self, data_size: u32) -> bool {
|
||||
let window = self.window_size.lock().unwrap();
|
||||
@@ -188,64 +229,64 @@ impl DataForwarder {
|
||||
/// SSH_MSG_CHANNEL_DATA构建(Phase 13.5)
|
||||
pub fn build_channel_data_packet(channel_id: u32, data: &[u8]) -> Result<Vec<u8>> {
|
||||
let mut packet = Vec::new();
|
||||
|
||||
|
||||
// Packet type: SSH_MSG_CHANNEL_DATA (type 94)
|
||||
packet.write_u8(94)?;
|
||||
|
||||
|
||||
// Recipient channel ID
|
||||
packet.write_u32::<BigEndian>(channel_id)?;
|
||||
|
||||
|
||||
// Data length (SSH string)
|
||||
packet.write_u32::<BigEndian>(data.len() as u32)?;
|
||||
|
||||
|
||||
// Data content
|
||||
packet.write_all(data)?;
|
||||
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
|
||||
/// SSH_MSG_CHANNEL_WINDOW_ADJUST构建(Phase 13.5)
|
||||
pub fn build_window_adjust_packet(channel_id: u32, bytes_to_add: u32) -> Result<Vec<u8>> {
|
||||
let mut packet = Vec::new();
|
||||
|
||||
|
||||
// Packet type: SSH_MSG_CHANNEL_WINDOW_ADJUST (type 93)
|
||||
packet.write_u8(93)?;
|
||||
|
||||
|
||||
// Recipient channel ID
|
||||
packet.write_u32::<BigEndian>(channel_id)?;
|
||||
|
||||
|
||||
// Bytes to add
|
||||
packet.write_u32::<BigEndian>(bytes_to_add)?;
|
||||
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_data_forwarder_creation() {
|
||||
let forwarder = DataForwarder::new(1, 2097152, 32768);
|
||||
assert_eq!(forwarder.channel_id, 1);
|
||||
assert_eq!(forwarder.get_window_size(), 2097152);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_window_size_adjustment() {
|
||||
let forwarder = DataForwarder::new(1, 2097152, 32768);
|
||||
|
||||
|
||||
// 消耗window size
|
||||
forwarder.adjust_window_size(1000);
|
||||
assert_eq!(forwarder.get_window_size(), 2097152 + 1000);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_build_channel_data_packet() {
|
||||
let data = b"Hello, SSH!";
|
||||
let packet = build_channel_data_packet(1, data).unwrap();
|
||||
|
||||
assert_eq!(packet[0], 94); // SSH_MSG_CHANNEL_DATA
|
||||
// 验证packet结构
|
||||
|
||||
assert_eq!(packet[0], 94); // SSH_MSG_CHANNEL_DATA
|
||||
// 验证packet结构
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,42 +1,42 @@
|
||||
// SSH密钥交换算法协商实现(Phase 2)
|
||||
// 参考OpenSSH kex.c: kex_send_kexinit(), kex_choose_conf()
|
||||
|
||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
||||
use anyhow::{Result, anyhow};
|
||||
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||
use anyhow::{anyhow, Result};
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use log::{info, debug};
|
||||
use log::{debug, info};
|
||||
use std::io::{Read, Write};
|
||||
|
||||
/// SSH算法类型(参考OpenSSH PROTOCOL定义)
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum AlgorithmType {
|
||||
KEX_ALGS = 0, // 密钥交换算法
|
||||
KEX_ALGS = 0, // 密钥交换算法
|
||||
SERVER_HOST_KEY_ALGS = 1, // 服务器主机密钥算法
|
||||
ENC_ALGS_CTOS = 2, // 客户端到服务器加密算法
|
||||
ENC_ALGS_STOC = 3, // 服务器到客户端加密算法
|
||||
MAC_ALGS_CTOS = 4, // 客户端到服务器MAC算法
|
||||
MAC_ALGS_STOC = 5, // 服务器到客户端MAC算法
|
||||
COMP_ALGS_CTOS = 6, // 客户端到服务器压缩算法
|
||||
COMP_ALGS_STOC = 7, // 服务器到客户端压缩算法
|
||||
LANGS_CTOS = 8, // 客户端到服务器语言
|
||||
LANGS_STOC = 9, // 服务器到客户端语言
|
||||
ENC_ALGS_CTOS = 2, // 客户端到服务器加密算法
|
||||
ENC_ALGS_STOC = 3, // 服务器到客户端加密算法
|
||||
MAC_ALGS_CTOS = 4, // 客户端到服务器MAC算法
|
||||
MAC_ALGS_STOC = 5, // 服务器到客户端MAC算法
|
||||
COMP_ALGS_CTOS = 6, // 客户端到服务器压缩算法
|
||||
COMP_ALGS_STOC = 7, // 服务器到客户端压缩算法
|
||||
LANGS_CTOS = 8, // 客户端到服务器语言
|
||||
LANGS_STOC = 9, // 服务器到客户端语言
|
||||
}
|
||||
|
||||
/// SSH算法提议(参考OpenSSH kex.h: struct kex)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KexProposal {
|
||||
pub kex_algorithms: String, // 密钥交换算法列表
|
||||
pub server_host_key_algorithms: String, // 主机密钥算法列表
|
||||
pub encryption_algorithms_ctos: String, // 加密算法(客户端→服务器)
|
||||
pub encryption_algorithms_stoc: String, // 加密算法(服务器→客户端)
|
||||
pub mac_algorithms_ctos: String, // MAC算法(客户端→服务器)
|
||||
pub mac_algorithms_stoc: String, // MAC算法(服务器→客户端)
|
||||
pub kex_algorithms: String, // 密钥交换算法列表
|
||||
pub server_host_key_algorithms: String, // 主机密钥算法列表
|
||||
pub encryption_algorithms_ctos: String, // 加密算法(客户端→服务器)
|
||||
pub encryption_algorithms_stoc: String, // 加密算法(服务器→客户端)
|
||||
pub mac_algorithms_ctos: String, // MAC算法(客户端→服务器)
|
||||
pub mac_algorithms_stoc: String, // MAC算法(服务器→客户端)
|
||||
pub compression_algorithms_ctos: String, // 压缩算法(客户端→服务器)
|
||||
pub compression_algorithms_stoc: String, // 压缩算法(服务器→客户端)
|
||||
pub languages_ctos: String, // 语言(客户端→服务器)
|
||||
pub languages_stoc: String, // 语言(服务器→客户端)
|
||||
pub first_kex_packet_follows: bool, // 是否立即发送第一个KEX packet
|
||||
pub reserved: u32, // 保留字段(0)
|
||||
pub languages_ctos: String, // 语言(客户端→服务器)
|
||||
pub languages_stoc: String, // 语言(服务器→客户端)
|
||||
pub first_kex_packet_follows: bool, // 是否立即发送第一个KEX packet
|
||||
pub reserved: u32, // 保留字段(0)
|
||||
}
|
||||
|
||||
impl KexProposal {
|
||||
@@ -46,31 +46,31 @@ impl KexProposal {
|
||||
Self {
|
||||
// 密钥交换算法:优先Curve25519(推荐) + strict KEX extension
|
||||
kex_algorithms: "curve25519-sha256,curve25519-sha256@libssh.org,diffie-hellman-group14-sha256,ext-info-s,kex-strict-s-v00@openssh.com".to_string(),
|
||||
|
||||
|
||||
// 主机密钥算法:优先Ed25519
|
||||
server_host_key_algorithms: "ssh-ed25519,rsa-sha2-256,rsa-sha2-512".to_string(),
|
||||
|
||||
|
||||
// 加密算法:AES-256-CTR(推荐)
|
||||
encryption_algorithms_ctos: "aes256-ctr,aes128-ctr".to_string(),
|
||||
encryption_algorithms_stoc: "aes256-ctr,aes128-ctr".to_string(),
|
||||
|
||||
|
||||
// MAC算法:HMAC-SHA256
|
||||
mac_algorithms_ctos: "hmac-sha2-256,hmac-sha2-512".to_string(),
|
||||
mac_algorithms_stoc: "hmac-sha2-256,hmac-sha2-512".to_string(),
|
||||
|
||||
|
||||
// 压缩算法:none优先
|
||||
compression_algorithms_ctos: "none,zlib".to_string(),
|
||||
compression_algorithms_stoc: "none,zlib".to_string(),
|
||||
|
||||
|
||||
// 语言:空
|
||||
languages_ctos: "".to_string(),
|
||||
languages_stoc: "".to_string(),
|
||||
|
||||
|
||||
first_kex_packet_follows: false,
|
||||
reserved: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 创建客户端默认提议(用于测试)
|
||||
pub fn client_default() -> Self {
|
||||
Self {
|
||||
@@ -88,20 +88,20 @@ impl KexProposal {
|
||||
reserved: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 序列化到SSH_MSG_KEXINIT packet(参考OpenSSH kex_send_kexinit())
|
||||
pub fn to_kexinit_packet(&self) -> Result<SshPacket> {
|
||||
let mut payload = Vec::new();
|
||||
|
||||
|
||||
// Packet type
|
||||
payload.write_u8(PacketType::SSH_MSG_KEXINIT as u8)?;
|
||||
|
||||
|
||||
// Cookie(16字节随机数,OpenSSH要求)
|
||||
let mut cookie = [0u8; 16];
|
||||
use rand::Rng;
|
||||
rand::thread_rng().fill(&mut cookie);
|
||||
payload.write_all(&cookie)?;
|
||||
|
||||
|
||||
// 10个算法列表(SSH string格式:length + data)
|
||||
write_ssh_string(&mut payload, &self.kex_algorithms)?;
|
||||
write_ssh_string(&mut payload, &self.server_host_key_algorithms)?;
|
||||
@@ -113,29 +113,29 @@ impl KexProposal {
|
||||
write_ssh_string(&mut payload, &self.compression_algorithms_stoc)?;
|
||||
write_ssh_string(&mut payload, &self.languages_ctos)?;
|
||||
write_ssh_string(&mut payload, &self.languages_stoc)?;
|
||||
|
||||
|
||||
// first_kex_packet_follows(boolean)
|
||||
payload.write_u8(if self.first_kex_packet_follows { 1 } else { 0 })?;
|
||||
|
||||
|
||||
// reserved(u32)
|
||||
payload.write_u32::<BigEndian>(self.reserved)?;
|
||||
|
||||
|
||||
Ok(SshPacket::new(payload))
|
||||
}
|
||||
|
||||
|
||||
/// 从SSH_MSG_KEXINIT packet解析(参考OpenSSH kex_input_kexinit())
|
||||
pub fn from_kexinit_packet(packet: &SshPacket) -> Result<Self> {
|
||||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准)
|
||||
|
||||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准)
|
||||
|
||||
// Packet type
|
||||
let packet_type = cursor.read_u8()?;
|
||||
if packet_type != PacketType::SSH_MSG_KEXINIT as u8 {
|
||||
return Err(anyhow!("Invalid packet type for KEXINIT"));
|
||||
}
|
||||
|
||||
|
||||
// Cookie(16字节,忽略)
|
||||
cursor.read_exact(&mut [0u8; 16])?;
|
||||
|
||||
|
||||
// 10个算法列表
|
||||
let kex_algorithms = read_ssh_string(&mut cursor)?;
|
||||
let server_host_key_algorithms = read_ssh_string(&mut cursor)?;
|
||||
@@ -147,13 +147,13 @@ impl KexProposal {
|
||||
let compression_algorithms_stoc = read_ssh_string(&mut cursor)?;
|
||||
let languages_ctos = read_ssh_string(&mut cursor)?;
|
||||
let languages_stoc = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
// first_kex_packet_follows
|
||||
let first_kex_packet_follows = cursor.read_u8()? != 0;
|
||||
|
||||
|
||||
// reserved
|
||||
let reserved = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
kex_algorithms,
|
||||
server_host_key_algorithms,
|
||||
@@ -174,14 +174,14 @@ impl KexProposal {
|
||||
/// SSH算法协商结果(参考OpenSSH struct kex)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KexResult {
|
||||
pub kex_algorithm: String, // 选定的密钥交换算法
|
||||
pub host_key_algorithm: String, // 选定的主机密钥算法
|
||||
pub encryption_ctos: String, // 选定的加密算法(客户端→服务器)
|
||||
pub encryption_stoc: String, // 选定的加密算法(服务器→客户端)
|
||||
pub mac_ctos: String, // 选定的MAC算法(客户端→服务器)
|
||||
pub mac_stoc: String, // 选定的MAC算法(服务器→客户端)
|
||||
pub compression_ctos: String, // 选定的压缩算法(客户端→服务器)
|
||||
pub compression_stoc: String, // 选定的压缩算法(服务器→客户端)
|
||||
pub kex_algorithm: String, // 选定的密钥交换算法
|
||||
pub host_key_algorithm: String, // 选定的主机密钥算法
|
||||
pub encryption_ctos: String, // 选定的加密算法(客户端→服务器)
|
||||
pub encryption_stoc: String, // 选定的加密算法(服务器→客户端)
|
||||
pub mac_ctos: String, // 选定的MAC算法(客户端→服务器)
|
||||
pub mac_stoc: String, // 选定的MAC算法(服务器→客户端)
|
||||
pub compression_ctos: String, // 选定的压缩算法(客户端→服务器)
|
||||
pub compression_stoc: String, // 选定的压缩算法(服务器→客户端)
|
||||
}
|
||||
|
||||
/// 算法匹配逻辑(参考OpenSSH kex_choose_conf())
|
||||
@@ -189,28 +189,43 @@ impl KexResult {
|
||||
/// 从服务器和客户端提议中选择算法(参考OpenSSH kex_choose_conf())
|
||||
pub fn choose_algorithms(server: &KexProposal, client: &KexProposal) -> Result<Self> {
|
||||
info!("Starting algorithm negotiation");
|
||||
|
||||
|
||||
// 算法匹配:优先客户端偏好(OpenSSH逻辑)
|
||||
// 参考OpenSSH:客户端列出的算法顺序为偏好顺序
|
||||
|
||||
|
||||
// 密钥交换算法匹配
|
||||
let kex_algorithm = match_algorithm(&client.kex_algorithms, &server.kex_algorithms)?;
|
||||
|
||||
|
||||
// 主机密钥算法匹配
|
||||
let host_key_algorithm = match_algorithm(&client.server_host_key_algorithms, &server.server_host_key_algorithms)?;
|
||||
|
||||
let host_key_algorithm = match_algorithm(
|
||||
&client.server_host_key_algorithms,
|
||||
&server.server_host_key_algorithms,
|
||||
)?;
|
||||
|
||||
// 加密算法匹配
|
||||
let encryption_ctos = match_algorithm(&client.encryption_algorithms_ctos, &server.encryption_algorithms_ctos)?;
|
||||
let encryption_stoc = match_algorithm(&client.encryption_algorithms_stoc, &server.encryption_algorithms_stoc)?;
|
||||
|
||||
let encryption_ctos = match_algorithm(
|
||||
&client.encryption_algorithms_ctos,
|
||||
&server.encryption_algorithms_ctos,
|
||||
)?;
|
||||
let encryption_stoc = match_algorithm(
|
||||
&client.encryption_algorithms_stoc,
|
||||
&server.encryption_algorithms_stoc,
|
||||
)?;
|
||||
|
||||
// MAC算法匹配
|
||||
let mac_ctos = match_algorithm(&client.mac_algorithms_ctos, &server.mac_algorithms_ctos)?;
|
||||
let mac_stoc = match_algorithm(&client.mac_algorithms_stoc, &server.mac_algorithms_stoc)?;
|
||||
|
||||
|
||||
// 压缩算法匹配
|
||||
let compression_ctos = match_algorithm(&client.compression_algorithms_ctos, &server.compression_algorithms_ctos)?;
|
||||
let compression_stoc = match_algorithm(&client.compression_algorithms_stoc, &server.compression_algorithms_stoc)?;
|
||||
|
||||
let compression_ctos = match_algorithm(
|
||||
&client.compression_algorithms_ctos,
|
||||
&server.compression_algorithms_ctos,
|
||||
)?;
|
||||
let compression_stoc = match_algorithm(
|
||||
&client.compression_algorithms_stoc,
|
||||
&server.compression_algorithms_stoc,
|
||||
)?;
|
||||
|
||||
info!("Algorithm negotiation completed:");
|
||||
debug!(" KEX: {}", kex_algorithm);
|
||||
debug!(" Host key: {}", host_key_algorithm);
|
||||
@@ -218,7 +233,7 @@ impl KexResult {
|
||||
debug!(" Encryption (S->C): {}", encryption_stoc);
|
||||
debug!(" MAC (C->S): {}", mac_ctos);
|
||||
debug!(" MAC (S->C): {}", mac_stoc);
|
||||
|
||||
|
||||
Ok(Self {
|
||||
kex_algorithm,
|
||||
host_key_algorithm,
|
||||
@@ -237,15 +252,19 @@ fn match_algorithm(client_algs: &str, server_algs: &str) -> Result<String> {
|
||||
// 算法列表格式:name1,name2,name3,...
|
||||
let client_list: Vec<&str> = client_algs.split(',').collect();
|
||||
let server_list: Vec<&str> = server_algs.split(',').collect();
|
||||
|
||||
|
||||
// OpenSSH逻辑:按客户端偏好顺序匹配
|
||||
for client_alg in &client_list {
|
||||
if server_list.contains(client_alg) {
|
||||
return Ok(client_alg.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow!("No matching algorithm found: client={}, server={}", client_algs, server_algs))
|
||||
|
||||
Err(anyhow!(
|
||||
"No matching algorithm found: client={}, server={}",
|
||||
client_algs,
|
||||
server_algs
|
||||
))
|
||||
}
|
||||
|
||||
/// SSH string写入辅助函数(length + data)
|
||||
@@ -266,36 +285,36 @@ fn read_ssh_string<R: Read>(reader: &mut R) -> Result<String> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_kex_proposal_creation() {
|
||||
let proposal = KexProposal::server_default();
|
||||
assert!(proposal.kex_algorithms.contains("curve25519-sha256"));
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_kex_proposal_serialization() {
|
||||
let proposal = KexProposal::server_default();
|
||||
let packet = proposal.to_kexinit_packet().unwrap();
|
||||
assert!(packet.payload.len() > 0);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_algorithm_matching() {
|
||||
let client = "curve25519-sha256,aes256-ctr";
|
||||
let server = "aes256-ctr,diffie-hellman-group14-sha256";
|
||||
|
||||
|
||||
let matched = match_algorithm(client, server).unwrap();
|
||||
assert_eq!(matched, "aes256-ctr"); // 按客户端顺序匹配
|
||||
assert_eq!(matched, "aes256-ctr"); // 按客户端顺序匹配
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_kex_negotiation() {
|
||||
let server = KexProposal::server_default();
|
||||
let client = KexProposal::client_default();
|
||||
|
||||
|
||||
let result = KexResult::choose_algorithms(&server, &client).unwrap();
|
||||
assert_eq!(result.kex_algorithm, "curve25519-sha256"); // 优先Curve25519
|
||||
assert_eq!(result.encryption_ctos, "aes256-ctr"); // AES-256-CTR
|
||||
assert_eq!(result.kex_algorithm, "curve25519-sha256"); // 优先Curve25519
|
||||
assert_eq!(result.encryption_ctos, "aes256-ctr"); // AES-256-CTR
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
// SSH密钥交换完整流程(Phase 3剩余)
|
||||
// 参考OpenSSH kex.c: complete implementation
|
||||
|
||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
||||
use crate::ssh_server::crypto::SessionKeys;
|
||||
use crate::ssh_server::kex::{KexProposal, KexResult};
|
||||
use crate::ssh_server::crypto::{SessionKeys};
|
||||
use crate::ssh_server::kex_exchange::KexExchangeHandler;
|
||||
use anyhow::{Result, anyhow};
|
||||
use sha2::{Sha256, Digest};
|
||||
use byteorder::{BigEndian, WriteBytesExt};
|
||||
use log::{info, debug};
|
||||
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||
use anyhow::{anyhow, Result};
|
||||
use log::info;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
/// SSH密钥交换完整状态管理(参考OpenSSH struct kex)
|
||||
pub struct KexState {
|
||||
@@ -30,7 +29,7 @@ impl KexState {
|
||||
kex_result: KexResult,
|
||||
) -> Result<Self> {
|
||||
let exchange_handler = KexExchangeHandler::new(kex_result)?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
client_version,
|
||||
server_version,
|
||||
@@ -42,18 +41,18 @@ impl KexState {
|
||||
newkeys_sent: false,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// 保存KEXINIT payloads(用于Exchange Hash计算)
|
||||
///
|
||||
///
|
||||
/// 分析OpenSSH源码后的结论:
|
||||
/// - kex->peer存储的是:incoming_packet剩余内容(payload fields + padding)
|
||||
/// - kex->my存储的是:prop2buf()结果(payload fields,不包括padding)
|
||||
///
|
||||
///
|
||||
/// **但exchange hash必须使用相同的I_C/I_S!**
|
||||
///
|
||||
///
|
||||
/// 疑问:OpenSSH如何确保client和server使用相同的padding?
|
||||
/// 可能答案:OpenSSH在计算exchange hash时,不包括padding?
|
||||
///
|
||||
///
|
||||
/// 暂时保持不包括padding(因为签名验证之前成功)
|
||||
pub fn save_kexinit_payloads(
|
||||
&mut self,
|
||||
@@ -63,12 +62,18 @@ impl KexState {
|
||||
// Only save payload (without padding) for now
|
||||
self.client_kexinit_payload = client_kexinit.payload.clone();
|
||||
self.server_kexinit_payload = server_kexinit.payload.clone();
|
||||
|
||||
|
||||
info!("Saved KEXINIT payloads (payload only, no padding)");
|
||||
info!(" client payload: {} bytes", self.client_kexinit_payload.len());
|
||||
info!(" server payload: {} bytes", self.server_kexinit_payload.len());
|
||||
info!(
|
||||
" client payload: {} bytes",
|
||||
self.client_kexinit_payload.len()
|
||||
);
|
||||
info!(
|
||||
" server payload: {} bytes",
|
||||
self.server_kexinit_payload.len()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
/// 计算Exchange Hash(参考OpenSSH kex.c: kex_hash())
|
||||
/// H = SHA256(V_C || V_S || I_C || I_S || K_S || K_C || K_S || shared_secret)
|
||||
pub fn compute_exchange_hash(
|
||||
@@ -80,74 +85,74 @@ impl KexState {
|
||||
) -> Result<Vec<u8>> {
|
||||
// 参考OpenSSH kex.c: kex_hash()
|
||||
let mut hasher = Sha256::new();
|
||||
|
||||
|
||||
// V_C: 客户端版本字符串(SSH string格式)
|
||||
write_ssh_string_to_hash(&mut hasher, &self.client_version)?;
|
||||
|
||||
|
||||
// V_S: 服务器版本字符串(SSH string格式)
|
||||
write_ssh_string_to_hash(&mut hasher, &self.server_version)?;
|
||||
|
||||
|
||||
// OpenSSH kexgex.c: "kexinit messages: fake header: len+SSH2_MSG_KEXINIT"
|
||||
// Remove SSH_MSG_KEXINIT type byte from payloads and prepend it in exchange hash
|
||||
|
||||
|
||||
let client_kexinit_without_type = &self.client_kexinit_payload[1..];
|
||||
let server_kexinit_without_type = &self.server_kexinit_payload[1..];
|
||||
|
||||
hasher.update(&((client_kexinit_without_type.len() + 1) as u32).to_be_bytes());
|
||||
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
|
||||
|
||||
hasher.update(((client_kexinit_without_type.len() + 1) as u32).to_be_bytes());
|
||||
hasher.update([20]); // SSH_MSG_KEXINIT type byte
|
||||
hasher.update(client_kexinit_without_type);
|
||||
|
||||
hasher.update(&((server_kexinit_without_type.len() + 1) as u32).to_be_bytes());
|
||||
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
|
||||
|
||||
hasher.update(((server_kexinit_without_type.len() + 1) as u32).to_be_bytes());
|
||||
hasher.update([20]); // SSH_MSG_KEXINIT type byte
|
||||
hasher.update(server_kexinit_without_type);
|
||||
|
||||
|
||||
// K_S: 服务器主机密钥blob(SSH string格式)
|
||||
hasher.update(server_host_key_blob);
|
||||
|
||||
|
||||
// K_C: 客户端Curve25519公钥(SSH string格式)
|
||||
write_ssh_bytes_to_hash(&mut hasher, client_public_key)?;
|
||||
|
||||
|
||||
// K_S: 服务器Curve25519公钥(SSH string格式)
|
||||
write_ssh_bytes_to_hash(&mut hasher, server_public_key)?;
|
||||
|
||||
|
||||
// K: 共享密钥(SSH mpint格式)
|
||||
// OpenSSH要求:去掉前导零
|
||||
write_ssh_mpint_to_hash(&mut hasher, shared_secret)?;
|
||||
|
||||
|
||||
Ok(hasher.finalize().to_vec())
|
||||
}
|
||||
|
||||
|
||||
/// 处理SSH_MSG_NEWKEYS(参考OpenSSH kex.c: kex_input_newkeys())
|
||||
pub fn handle_newkeys(&mut self, packet: &SshPacket) -> Result<()> {
|
||||
info!("Processing SSH_MSG_NEWKEYS");
|
||||
|
||||
|
||||
// 验证packet类型
|
||||
if packet.payload.len() < 1 {
|
||||
if packet.payload.is_empty() {
|
||||
return Err(anyhow!("Invalid NEWKEYS packet"));
|
||||
}
|
||||
|
||||
|
||||
let packet_type = packet.payload[0];
|
||||
if packet_type != PacketType::SSH_MSG_NEWKEYS as u8 {
|
||||
return Err(anyhow!("Invalid packet type for NEWKEYS"));
|
||||
}
|
||||
|
||||
|
||||
// 标记NEWKEYS接收完成(参考OpenSSH)
|
||||
self.newkeys_received = true;
|
||||
|
||||
|
||||
info!("SSH_MSG_NEWKEYS received, encryption channel ready");
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 发送SSH_MSG_NEWKEYS(参考OpenSSH kex.c: kex_send_newkeys())
|
||||
pub fn send_newkeys() -> Result<SshPacket> {
|
||||
info!("Sending SSH_MSG_NEWKEYS");
|
||||
|
||||
|
||||
let payload = vec![PacketType::SSH_MSG_NEWKEYS as u8];
|
||||
|
||||
|
||||
Ok(SshPacket::new(payload))
|
||||
}
|
||||
|
||||
|
||||
/// 检查NEWKEYS完成状态(加密通道建立)
|
||||
pub fn is_encryption_ready(&self) -> bool {
|
||||
self.newkeys_received && self.newkeys_sent
|
||||
@@ -156,14 +161,14 @@ impl KexState {
|
||||
|
||||
/// SSH string写入到hash(辅助函数)
|
||||
fn write_ssh_string_to_hash(hasher: &mut Sha256, s: &str) -> Result<()> {
|
||||
hasher.update(&(s.len() as u32).to_be_bytes());
|
||||
hasher.update((s.len() as u32).to_be_bytes());
|
||||
hasher.update(s.as_bytes());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// SSH bytes写入到hash(辅助函数)
|
||||
fn write_ssh_bytes_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
|
||||
hasher.update(&(bytes.len() as u32).to_be_bytes());
|
||||
hasher.update((bytes.len() as u32).to_be_bytes());
|
||||
hasher.update(bytes);
|
||||
Ok(())
|
||||
}
|
||||
@@ -171,7 +176,7 @@ fn write_ssh_bytes_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
|
||||
/// SSH mpint写入到hash(参考OpenSSH sshbuf_put_mpint())
|
||||
fn write_ssh_mpint_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
|
||||
// OpenSSH要求:去掉前导零(如果最高位为1)
|
||||
let mpint_bytes = if bytes.len() > 0 && bytes[0] >= 0x80 {
|
||||
let mpint_bytes = if !bytes.is_empty() && bytes[0] >= 0x80 {
|
||||
// 需要添加前导零(避免负数)
|
||||
let mut mpint = vec![0u8];
|
||||
mpint.extend_from_slice(bytes);
|
||||
@@ -179,61 +184,67 @@ fn write_ssh_mpint_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
|
||||
} else {
|
||||
bytes.to_vec()
|
||||
};
|
||||
|
||||
hasher.update(&(mpint_bytes.len() as u32).to_be_bytes());
|
||||
|
||||
hasher.update((mpint_bytes.len() as u32).to_be_bytes());
|
||||
hasher.update(&mpint_bytes);
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_exchange_hash_computation() {
|
||||
let kex_result = KexResult::choose_algorithms(
|
||||
&KexProposal::server_default(),
|
||||
&KexProposal::client_default(),
|
||||
).unwrap();
|
||||
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut state = KexState::new(
|
||||
"SSH-2.0-OpenSSH_10.2".to_string(),
|
||||
"SSH-2.0-MarkBaseSSH_1.0".to_string(),
|
||||
kex_result,
|
||||
).unwrap();
|
||||
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Set minimal KEXINIT payloads (need at least 1 byte for packet type)
|
||||
state.client_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
|
||||
state.server_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
|
||||
|
||||
state.client_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
|
||||
state.server_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
|
||||
|
||||
let shared_secret = vec![0u8; 32];
|
||||
let host_key = vec![0u8; 32];
|
||||
let client_pub = vec![0u8; 32];
|
||||
let server_pub = vec![0u8; 32];
|
||||
|
||||
let hash = state.compute_exchange_hash(&shared_secret, &host_key, &client_pub, &server_pub).unwrap();
|
||||
|
||||
assert_eq!(hash.len(), 32); // SHA256输出32字节
|
||||
|
||||
let hash = state
|
||||
.compute_exchange_hash(&shared_secret, &host_key, &client_pub, &server_pub)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(hash.len(), 32); // SHA256输出32字节
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_newkeys_handling() {
|
||||
let kex_result = KexResult::choose_algorithms(
|
||||
&KexProposal::server_default(),
|
||||
&KexProposal::client_default(),
|
||||
).unwrap();
|
||||
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut state = KexState::new(
|
||||
"SSH-2.0-OpenSSH_10.2".to_string(),
|
||||
"SSH-2.0-MarkBaseSSH_1.0".to_string(),
|
||||
kex_result,
|
||||
).unwrap();
|
||||
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let newkeys_packet = SshPacket::new(vec![PacketType::SSH_MSG_NEWKEYS as u8]);
|
||||
|
||||
|
||||
state.handle_newkeys(&newkeys_packet).unwrap();
|
||||
|
||||
|
||||
assert!(state.newkeys_received);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
// SSH密钥交换流程实现(Phase 3)
|
||||
// 参考OpenSSH kex.c: kex_input_kex_init(), kex_send_kex_reply()
|
||||
|
||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
||||
use crate::ssh_server::kex::{KexResult};
|
||||
use crate::ssh_server::crypto::{Curve25519Kex, SessionKeys, Ed25519HostKey};
|
||||
use anyhow::{Result, anyhow};
|
||||
use crate::ssh_server::crypto::{Curve25519Kex, Ed25519HostKey, SessionKeys};
|
||||
use crate::ssh_server::kex::KexResult;
|
||||
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||
use anyhow::{anyhow, Result};
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use log::{info, debug};
|
||||
use log::info;
|
||||
use sha2::Digest;
|
||||
use std::io::{Read, Write};
|
||||
use sha2::{Sha256, Digest};
|
||||
|
||||
/// SSH密钥交换流程处理器(参考OpenSSH kex.c)
|
||||
pub struct KexExchangeHandler {
|
||||
@@ -18,7 +18,7 @@ pub struct KexExchangeHandler {
|
||||
shared_secret: Option<Vec<u8>>,
|
||||
client_public_key: Option<Vec<u8>>,
|
||||
server_public_key: Option<Vec<u8>>,
|
||||
exchange_hash: Option<Vec<u8>>, // 保存exchange hash(H参数)
|
||||
exchange_hash: Option<Vec<u8>>, // 保存exchange hash(H参数)
|
||||
client_version: Option<String>,
|
||||
server_version: Option<String>,
|
||||
client_kexinit_payload: Option<Vec<u8>>,
|
||||
@@ -30,7 +30,7 @@ impl KexExchangeHandler {
|
||||
pub fn new(kex_result: KexResult) -> Result<Self> {
|
||||
// 加载或生成服务器主机密钥
|
||||
let host_key = Ed25519HostKey::load_or_generate("config/ssh_host_ed25519_key")?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
kex_algorithm: kex_result.kex_algorithm,
|
||||
server_kex: None,
|
||||
@@ -45,10 +45,10 @@ impl KexExchangeHandler {
|
||||
server_kexinit_payload: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// 处理SSH_MSG_KEXDH_INIT(Curve25519密钥交换)(参考OpenSSH kex.c: kex_input_kex_init())
|
||||
|
||||
/// 处理SSH_MSG_KEXDH_INIT(Curve25519密钥交换)(参考OpenSSH kex.c: kex_input_kex_init())
|
||||
pub fn handle_kexdh_init(
|
||||
&mut self,
|
||||
&mut self,
|
||||
packet: &SshPacket,
|
||||
client_version: &str,
|
||||
server_version: &str,
|
||||
@@ -56,41 +56,44 @@ impl KexExchangeHandler {
|
||||
server_kexinit_payload: &[u8],
|
||||
) -> Result<SshPacket> {
|
||||
info!("Processing SSH_MSG_KEXDH_INIT (Curve25519)");
|
||||
|
||||
|
||||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
|
||||
|
||||
|
||||
let packet_type = cursor.read_u8()?;
|
||||
if packet_type != PacketType::SSH_MSG_KEXDH_INIT as u8 {
|
||||
return Err(anyhow!("Invalid packet type for KEXDH_INIT"));
|
||||
}
|
||||
|
||||
|
||||
let key_length = cursor.read_u32::<BigEndian>()?;
|
||||
if key_length != 32 {
|
||||
return Err(anyhow!("Invalid Curve25519 public key length: {}", key_length));
|
||||
return Err(anyhow!(
|
||||
"Invalid Curve25519 public key length: {}",
|
||||
key_length
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
let mut client_public_key = vec![0u8; 32];
|
||||
cursor.read_exact(&mut client_public_key)?;
|
||||
|
||||
|
||||
self.server_kex = Some(Curve25519Kex::new());
|
||||
let server_kex = self.server_kex.as_mut().unwrap();
|
||||
|
||||
|
||||
let shared_secret = server_kex.compute_shared_secret(&client_public_key)?;
|
||||
let server_public_key = server_kex.public_key().to_vec();
|
||||
|
||||
|
||||
// Save for later session key computation
|
||||
self.shared_secret = Some(shared_secret.to_vec());
|
||||
self.client_public_key = Some(client_public_key.clone());
|
||||
self.server_public_key = Some(server_public_key.clone());
|
||||
|
||||
|
||||
// Save client_version, server_version, kexinit payloads for exchange hash
|
||||
self.client_version = Some(client_version.to_string());
|
||||
self.server_version = Some(server_version.to_string());
|
||||
self.client_kexinit_payload = Some(client_kexinit_payload.to_vec());
|
||||
self.server_kexinit_payload = Some(server_kexinit_payload.to_vec());
|
||||
|
||||
|
||||
info!("Curve25519 shared secret computed and saved");
|
||||
|
||||
|
||||
// Compute exchange hash ONCE and reuse it
|
||||
let host_key_blob = self.build_ssh_host_key()?;
|
||||
let exchange_hash = self.compute_exchange_hash(
|
||||
@@ -103,69 +106,69 @@ impl KexExchangeHandler {
|
||||
client_kexinit_payload,
|
||||
server_kexinit_payload,
|
||||
)?;
|
||||
|
||||
|
||||
info!("Exchange hash computed:");
|
||||
info!(" shared_secret[0] = {} (>=0x80? {})", shared_secret[0], shared_secret[0] >= 0x80);
|
||||
info!(
|
||||
" shared_secret[0] = {} (>=0x80? {})",
|
||||
shared_secret[0],
|
||||
shared_secret[0] >= 0x80
|
||||
);
|
||||
info!(" exchange_hash full (32 bytes): {:?}", exchange_hash);
|
||||
|
||||
|
||||
self.exchange_hash = Some(exchange_hash.clone());
|
||||
info!("Exchange hash saved for key derivation");
|
||||
|
||||
self.build_kexdh_reply(
|
||||
&exchange_hash,
|
||||
&host_key_blob,
|
||||
&server_public_key,
|
||||
)
|
||||
|
||||
self.build_kexdh_reply(&exchange_hash, &host_key_blob, &server_public_key)
|
||||
}
|
||||
|
||||
|
||||
/// 构建SSH_MSG_KEXDH_REPLY packet(参考OpenSSH kex.c)
|
||||
fn build_kexdh_reply(
|
||||
&self,
|
||||
exchange_hash: &[u8],
|
||||
&self,
|
||||
exchange_hash: &[u8],
|
||||
host_key_blob: &[u8],
|
||||
server_public_key: &[u8],
|
||||
) -> Result<SshPacket> {
|
||||
info!("=== Building SSH_MSG_KEXDH_REPLY ===");
|
||||
info!("Input server_public_key: {:?}", server_public_key);
|
||||
|
||||
|
||||
let mut payload = Vec::new();
|
||||
|
||||
|
||||
payload.write_u8(PacketType::SSH_MSG_KEXDH_REPLY as u8)?;
|
||||
|
||||
|
||||
payload.write_u32::<BigEndian>(host_key_blob.len() as u32)?;
|
||||
payload.write_all(host_key_blob)?;
|
||||
|
||||
|
||||
info!("Writing server_public_key to payload (32 bytes)");
|
||||
payload.write_u32::<BigEndian>(32)?;
|
||||
payload.write_all(server_public_key)?;
|
||||
|
||||
|
||||
let signature = self.build_exchange_signature(exchange_hash)?;
|
||||
payload.write_u32::<BigEndian>(signature.len() as u32)?;
|
||||
payload.write_all(&signature)?;
|
||||
|
||||
|
||||
info!("SSH_MSG_KEXDH_REPLY payload built successfully");
|
||||
Ok(SshPacket::new(payload))
|
||||
}
|
||||
|
||||
|
||||
/// 构建SSH主机密钥blob(参考OpenSSH sshkey.c: sshkey_to_blob())
|
||||
fn build_ssh_host_key(&self) -> Result<Vec<u8>> {
|
||||
let mut blob = Vec::new();
|
||||
|
||||
|
||||
// SSH key format: key-type + public-key
|
||||
// 参考OpenSSH sshkey.c
|
||||
|
||||
|
||||
// Key type: ssh-ed25519
|
||||
blob.write_u32::<BigEndian>(11)?; // "ssh-ed25519".len()
|
||||
blob.write_u32::<BigEndian>(11)?; // "ssh-ed25519".len()
|
||||
blob.write_all("ssh-ed25519".as_bytes())?;
|
||||
|
||||
|
||||
// Ed25519公钥(32字节)
|
||||
let public_key = self.host_key.public_key_bytes();
|
||||
blob.write_u32::<BigEndian>(32)?;
|
||||
blob.write_all(&public_key)?;
|
||||
|
||||
|
||||
Ok(blob)
|
||||
}
|
||||
|
||||
|
||||
/// 计算Exchange Hash(参考OpenSSH kex.c: kex_hash() RFC 4253 Section 7.2)
|
||||
fn compute_exchange_hash(
|
||||
&self,
|
||||
@@ -178,94 +181,147 @@ impl KexExchangeHandler {
|
||||
client_kexinit_payload: &[u8],
|
||||
server_kexinit_payload: &[u8],
|
||||
) -> Result<Vec<u8>> {
|
||||
use sha2::{Sha256, Digest};
|
||||
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
info!("=== EXCHANGE HASH COMPUTATION ===");
|
||||
info!("V_C (client version): {:?}", client_version.as_bytes());
|
||||
info!("V_C length: {}", client_version.len());
|
||||
|
||||
|
||||
info!("V_S (server version): {:?}", server_version.as_bytes());
|
||||
info!("V_S length: {}", server_version.len());
|
||||
|
||||
info!("I_C (client KEXINIT payload): {:?}", &client_kexinit_payload[..std::cmp::min(50, client_kexinit_payload.len())]);
|
||||
|
||||
info!(
|
||||
"I_C (client KEXINIT payload): {:?}",
|
||||
&client_kexinit_payload[..std::cmp::min(50, client_kexinit_payload.len())]
|
||||
);
|
||||
info!("I_C length: {}", client_kexinit_payload.len());
|
||||
info!("I_C[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)", client_kexinit_payload[0]);
|
||||
|
||||
info!("I_S (server KEXINIT payload): {:?}", &server_kexinit_payload[..std::cmp::min(50, server_kexinit_payload.len())]);
|
||||
info!(
|
||||
"I_C[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)",
|
||||
client_kexinit_payload[0]
|
||||
);
|
||||
|
||||
info!(
|
||||
"I_S (server KEXINIT payload): {:?}",
|
||||
&server_kexinit_payload[..std::cmp::min(50, server_kexinit_payload.len())]
|
||||
);
|
||||
info!("I_S length: {}", server_kexinit_payload.len());
|
||||
info!("I_S[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)", server_kexinit_payload[0]);
|
||||
|
||||
info!("K_S (host key blob): {:?}", &host_key_blob[..std::cmp::min(30, host_key_blob.len())]);
|
||||
info!(
|
||||
"I_S[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)",
|
||||
server_kexinit_payload[0]
|
||||
);
|
||||
|
||||
info!(
|
||||
"K_S (host key blob): {:?}",
|
||||
&host_key_blob[..std::cmp::min(30, host_key_blob.len())]
|
||||
);
|
||||
info!("K_S length: {}", host_key_blob.len());
|
||||
|
||||
info!("Q_C (client ECDH public key): {:?}", &client_public_key[..std::cmp::min(16, client_public_key.len())]);
|
||||
|
||||
info!(
|
||||
"Q_C (client ECDH public key): {:?}",
|
||||
&client_public_key[..std::cmp::min(16, client_public_key.len())]
|
||||
);
|
||||
info!("Q_C full (32 bytes): {:?}", client_public_key);
|
||||
info!("Q_C length: {}", client_public_key.len());
|
||||
|
||||
info!("Q_S (server ECDH public key): {:?}", &server_public_key[..std::cmp::min(16, server_public_key.len())]);
|
||||
|
||||
info!(
|
||||
"Q_S (server ECDH public key): {:?}",
|
||||
&server_public_key[..std::cmp::min(16, server_public_key.len())]
|
||||
);
|
||||
info!("Q_S full (32 bytes): {:?}", server_public_key);
|
||||
info!("Q_S length: {}", server_public_key.len());
|
||||
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
|
||||
|
||||
// RFC 4253 Section 7: V_C and V_S are version strings (without \r\n based on testing)
|
||||
let vc_ssh_string = &(client_version.len() as u32).to_be_bytes();
|
||||
hasher.update(vc_ssh_string);
|
||||
hasher.update(client_version.as_bytes());
|
||||
info!(" Exchange hash component V_C: len={} bytes=[{:?}] data=[{:?}]", 4+client_version.len(), vc_ssh_string, client_version.as_bytes());
|
||||
|
||||
info!(
|
||||
" Exchange hash component V_C: len={} bytes=[{:?}] data=[{:?}]",
|
||||
4 + client_version.len(),
|
||||
vc_ssh_string,
|
||||
client_version.as_bytes()
|
||||
);
|
||||
|
||||
let vs_ssh_string = &(server_version.len() as u32).to_be_bytes();
|
||||
hasher.update(vs_ssh_string);
|
||||
hasher.update(server_version.as_bytes());
|
||||
info!(" Exchange hash component V_S: len={} bytes=[{:?}] data=[{:?}]", 4+server_version.len(), vs_ssh_string, server_version.as_bytes());
|
||||
|
||||
info!(
|
||||
" Exchange hash component V_S: len={} bytes=[{:?}] data=[{:?}]",
|
||||
4 + server_version.len(),
|
||||
vs_ssh_string,
|
||||
server_version.as_bytes()
|
||||
);
|
||||
|
||||
// OpenSSH kexgex.c: "kexinit messages: fake header: len+SSH2_MSG_KEXINIT"
|
||||
// KEXINIT payload should NOT include SSH_MSG_KEXINIT type byte
|
||||
// OpenSSH stores payload starting from cookie, prepends SSH_MSG_KEXINIT in exchange hash
|
||||
|
||||
|
||||
// Remove SSH_MSG_KEXINIT type byte from payloads (our payload includes it)
|
||||
let client_kexinit_without_type = &client_kexinit_payload[1..];
|
||||
let server_kexinit_without_type = &server_kexinit_payload[1..];
|
||||
|
||||
info!("I_C (client KEXINIT without type byte): {} bytes (first byte should be cookie)", client_kexinit_without_type.len());
|
||||
info!("I_S (server KEXINIT without type byte): {} bytes", server_kexinit_without_type.len());
|
||||
|
||||
|
||||
info!(
|
||||
"I_C (client KEXINIT without type byte): {} bytes (first byte should be cookie)",
|
||||
client_kexinit_without_type.len()
|
||||
);
|
||||
info!(
|
||||
"I_S (server KEXINIT without type byte): {} bytes",
|
||||
server_kexinit_without_type.len()
|
||||
);
|
||||
|
||||
// Exchange hash: uint32(len+1) + uint8(SSH_MSG_KEXINIT) + payload_without_type
|
||||
let ic_len_bytes = &((client_kexinit_without_type.len() + 1) as u32).to_be_bytes();
|
||||
hasher.update(ic_len_bytes);
|
||||
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
|
||||
hasher.update([20]); // SSH_MSG_KEXINIT type byte
|
||||
hasher.update(client_kexinit_without_type);
|
||||
info!(" Exchange hash component I_C: len={} bytes=[{:?}] type=[20] payload_len={} (first 8 bytes=[{:?}])", 4+1+client_kexinit_without_type.len(), ic_len_bytes, client_kexinit_without_type.len(), &client_kexinit_without_type[..std::cmp::min(8, client_kexinit_without_type.len())]);
|
||||
|
||||
|
||||
let is_len_bytes = &((server_kexinit_without_type.len() + 1) as u32).to_be_bytes();
|
||||
hasher.update(is_len_bytes);
|
||||
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
|
||||
hasher.update([20]); // SSH_MSG_KEXINIT type byte
|
||||
hasher.update(server_kexinit_without_type);
|
||||
info!(" Exchange hash component I_S: len={} bytes=[{:?}] type=[20] payload_len={} (first 8 bytes=[{:?}])", 4+1+server_kexinit_without_type.len(), is_len_bytes, server_kexinit_without_type.len(), &server_kexinit_without_type[..std::cmp::min(8, server_kexinit_without_type.len())]);
|
||||
|
||||
|
||||
let ks_len_bytes = &(host_key_blob.len() as u32).to_be_bytes();
|
||||
hasher.update(ks_len_bytes);
|
||||
hasher.update(host_key_blob);
|
||||
info!(" Exchange hash component K_S: len={} bytes=[{:?}] blob_len={} (full=[{:?}])", 4+host_key_blob.len(), ks_len_bytes, host_key_blob.len(), host_key_blob);
|
||||
|
||||
info!(
|
||||
" Exchange hash component K_S: len={} bytes=[{:?}] blob_len={} (full=[{:?}])",
|
||||
4 + host_key_blob.len(),
|
||||
ks_len_bytes,
|
||||
host_key_blob.len(),
|
||||
host_key_blob
|
||||
);
|
||||
|
||||
let qc_len_bytes = &(client_public_key.len() as u32).to_be_bytes();
|
||||
hasher.update(qc_len_bytes);
|
||||
hasher.update(client_public_key);
|
||||
info!(" Exchange hash component Q_C: len={} bytes=[{:?}] key=[{:?}]", 4+client_public_key.len(), qc_len_bytes, client_public_key);
|
||||
|
||||
info!(
|
||||
" Exchange hash component Q_C: len={} bytes=[{:?}] key=[{:?}]",
|
||||
4 + client_public_key.len(),
|
||||
qc_len_bytes,
|
||||
client_public_key
|
||||
);
|
||||
|
||||
let qs_len_bytes = &(server_public_key.len() as u32).to_be_bytes();
|
||||
hasher.update(qs_len_bytes);
|
||||
hasher.update(server_public_key);
|
||||
info!(" Exchange hash component Q_S: len={} bytes=[{:?}] key=[{:?}]", 4+server_public_key.len(), qs_len_bytes, server_public_key);
|
||||
|
||||
info!(
|
||||
" Exchange hash component Q_S: len={} bytes=[{:?}] key=[{:?}]",
|
||||
4 + server_public_key.len(),
|
||||
qs_len_bytes,
|
||||
server_public_key
|
||||
);
|
||||
|
||||
info!("Exchange hash components:");
|
||||
info!(" shared_secret raw full (32 bytes): {:?}", shared_secret);
|
||||
|
||||
|
||||
// RFC 8731 Section 3.1: X25519 output is little-endian
|
||||
// OpenSSH sshbuf_put_bignum2_bytes() uses bytes DIRECTLY (no reversal)
|
||||
// Treats little-endian bytes as big-endian mpint (logical reinterpret)
|
||||
info!(" Using shared_secret directly (little-endian bytes as big-endian mpint)");
|
||||
|
||||
|
||||
// RFC 4253: mpint格式 = 去掉前导零 + 最高位>=0x80时前面加0
|
||||
// 参考OpenSSH sshbuf_put_bignum2_bytes()
|
||||
let mut start = 0;
|
||||
@@ -273,64 +329,73 @@ impl KexExchangeHandler {
|
||||
start += 1;
|
||||
}
|
||||
let trimmed_shared_secret = &shared_secret[start..];
|
||||
|
||||
info!(" shared_secret after removing leading zeros ({} bytes): {:?}", trimmed_shared_secret.len(), trimmed_shared_secret);
|
||||
|
||||
let mpint_shared_secret_data = if trimmed_shared_secret.len() > 0 && trimmed_shared_secret[0] >= 0x80 {
|
||||
let mut mpint = vec![0u8];
|
||||
mpint.extend_from_slice(trimmed_shared_secret);
|
||||
info!(" trimmed_shared_secret[0] >= 0x80, prepending 0 byte");
|
||||
mpint
|
||||
} else {
|
||||
trimmed_shared_secret.to_vec()
|
||||
};
|
||||
|
||||
info!(" mpint_shared_secret_data ({} bytes): {:?}", mpint_shared_secret_data.len(), &mpint_shared_secret_data[..std::cmp::min(8, mpint_shared_secret_data.len())]);
|
||||
|
||||
|
||||
info!(
|
||||
" shared_secret after removing leading zeros ({} bytes): {:?}",
|
||||
trimmed_shared_secret.len(),
|
||||
trimmed_shared_secret
|
||||
);
|
||||
|
||||
let mpint_shared_secret_data =
|
||||
if !trimmed_shared_secret.is_empty() && trimmed_shared_secret[0] >= 0x80 {
|
||||
let mut mpint = vec![0u8];
|
||||
mpint.extend_from_slice(trimmed_shared_secret);
|
||||
info!(" trimmed_shared_secret[0] >= 0x80, prepending 0 byte");
|
||||
mpint
|
||||
} else {
|
||||
trimmed_shared_secret.to_vec()
|
||||
};
|
||||
|
||||
info!(
|
||||
" mpint_shared_secret_data ({} bytes): {:?}",
|
||||
mpint_shared_secret_data.len(),
|
||||
&mpint_shared_secret_data[..std::cmp::min(8, mpint_shared_secret_data.len())]
|
||||
);
|
||||
|
||||
// mpint格式 = uint32(length) + mpint_data
|
||||
let mpint_len_bytes = &(mpint_shared_secret_data.len() as u32).to_be_bytes();
|
||||
hasher.update(mpint_len_bytes);
|
||||
hasher.update(&mpint_shared_secret_data);
|
||||
info!(" Exchange hash component K (shared secret mpint): len={} bytes=[{:?}] data_len={} (first 8 bytes=[{:?}])", 4+mpint_shared_secret_data.len(), mpint_len_bytes, mpint_shared_secret_data.len(), &mpint_shared_secret_data[..std::cmp::min(8, mpint_shared_secret_data.len())]);
|
||||
|
||||
|
||||
Ok(hasher.finalize().to_vec())
|
||||
}
|
||||
|
||||
|
||||
/// 构建交换签名(参考OpenSSH ssh-sign.c)
|
||||
fn build_exchange_signature(&self, exchange_hash: &[u8]) -> Result<Vec<u8>> {
|
||||
let signature_bytes = self.host_key.sign(exchange_hash)?;
|
||||
|
||||
|
||||
let mut ssh_signature = Vec::new();
|
||||
|
||||
|
||||
ssh_signature.write_u32::<BigEndian>(11)?;
|
||||
ssh_signature.write_all("ssh-ed25519".as_bytes())?;
|
||||
|
||||
|
||||
ssh_signature.write_u32::<BigEndian>(64)?;
|
||||
ssh_signature.write_all(&signature_bytes)?;
|
||||
|
||||
|
||||
Ok(ssh_signature)
|
||||
}
|
||||
|
||||
|
||||
/// 计算会话密钥(参考OpenSSH kex.c: derive_keys())
|
||||
/// 使用保存的exchange_hash(H参数)
|
||||
pub fn compute_session_keys(&self) -> Result<SessionKeys> {
|
||||
if self.shared_secret.is_none() {
|
||||
return Err(anyhow!("No shared secret available"));
|
||||
}
|
||||
|
||||
|
||||
if self.exchange_hash.is_none() {
|
||||
return Err(anyhow!("No exchange hash available"));
|
||||
}
|
||||
|
||||
|
||||
let shared_secret = self.shared_secret.as_ref().unwrap();
|
||||
let exchange_hash = self.exchange_hash.as_ref().unwrap();
|
||||
let server_public_key = self.server_public_key.as_ref().unwrap();
|
||||
let client_public_key = self.client_public_key.as_ref().unwrap();
|
||||
let host_key_blob = self.build_ssh_host_key()?;
|
||||
|
||||
|
||||
SessionKeys::derive(
|
||||
shared_secret,
|
||||
exchange_hash, // 使用保存的exchange hash(H参数)
|
||||
exchange_hash, // 使用保存的exchange hash(H参数)
|
||||
server_public_key,
|
||||
client_public_key,
|
||||
&host_key_blob,
|
||||
@@ -342,13 +407,13 @@ impl KexExchangeHandler {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ssh_server::kex::KexProposal;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_kex_exchange_handler_creation() {
|
||||
let server_proposal = KexProposal::server_default();
|
||||
let client_proposal = KexProposal::client_default();
|
||||
let kex_result = KexResult::choose_algorithms(&server_proposal, &client_proposal).unwrap();
|
||||
|
||||
|
||||
let handler = KexExchangeHandler::new(kex_result).unwrap();
|
||||
assert!(handler.host_key.public_key_bytes().len() == 32);
|
||||
}
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
// SSH服务器模块(手动实现SSH协议)
|
||||
// 参考OpenSSH源码实现完整的SSH/SFTP/SCP/rsync协议
|
||||
|
||||
pub mod server;
|
||||
pub mod packet;
|
||||
pub mod version;
|
||||
pub mod crypto;
|
||||
pub mod kex;
|
||||
pub mod kex_exchange;
|
||||
pub mod kex_complete;
|
||||
pub mod cipher;
|
||||
pub mod auth;
|
||||
pub mod channel;
|
||||
pub mod sftp_handler;
|
||||
pub mod scp_handler;
|
||||
pub mod cipher;
|
||||
pub mod crypto;
|
||||
pub mod data_forwarder; // Phase 13.5: 数据传输模块
|
||||
pub mod kex;
|
||||
pub mod kex_complete;
|
||||
pub mod kex_exchange;
|
||||
pub mod packet;
|
||||
pub mod port_forward; // Phase 13: 端口转发模块
|
||||
pub mod port_forward_listener; // Phase 13.4: 监听线程模块
|
||||
pub mod rsync_handler;
|
||||
pub mod sshbuf; // Phase 15: SSH Buffer 零拷贝管理(参考OpenSSH sshbuf.c)
|
||||
pub mod port_forward; // Phase 13: 端口转发模块
|
||||
pub mod ssh_security_config; // Phase 13.1: 企业级安全配置
|
||||
pub mod port_forward_listener; // Phase 13.4: 监听线程模块
|
||||
pub mod data_forwarder; // Phase 13.5: 数据传输模块
|
||||
pub mod window_manager; // Phase 13.6-13.7: Window size + Channel生命周期
|
||||
pub mod scp_handler;
|
||||
pub mod server;
|
||||
pub mod sftp_handler;
|
||||
pub mod ssh_security_config; // Phase 13.1: 企业级安全配置
|
||||
pub mod sshbuf; // Phase 15: SSH Buffer 零拷贝管理(参考OpenSSH sshbuf.c)
|
||||
pub mod version;
|
||||
pub mod window_manager; // Phase 13.6-13.7: Window size + Channel生命周期
|
||||
|
||||
pub use packet::{PacketType, SshPacket};
|
||||
pub use server::SshServer;
|
||||
pub use packet::{SshPacket, PacketType};
|
||||
pub use version::VersionExchange;
|
||||
pub use ssh_security_config::SshSecurityConfig; // Phase 13.1: 导出安全配置
|
||||
pub use sshbuf::SshBuf; // Phase 15: 导出 SSH Buffer
|
||||
pub use ssh_security_config::SshSecurityConfig; // Phase 13.1: 导出安全配置
|
||||
pub use sshbuf::SshBuf;
|
||||
pub use version::VersionExchange; // Phase 15: 导出 SSH Buffer
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// SSH Packet基础结构定义
|
||||
// 参考OpenSSH packet.c: ssh_packet_read(), ssh_packet_write()
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{anyhow, Result};
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::io::{Read, Write};
|
||||
|
||||
@@ -18,21 +18,21 @@ pub enum PacketType {
|
||||
SSH_MSG_EXT_INFO = 7,
|
||||
SSH_MSG_KEXINIT = 20,
|
||||
SSH_MSG_NEWKEYS = 21,
|
||||
|
||||
|
||||
// 密钥交换相关
|
||||
SSH_MSG_KEXDH_INIT = 30,
|
||||
SSH_MSG_KEXDH_REPLY = 31,
|
||||
// 注意:Curve25519和DH使用相同的消息类型(30/31)
|
||||
// SSH_MSG_KEX_ECDH_INIT和SSH_MSG_KEX_ECDH_REPLY已在代码中注释
|
||||
// 使用SSH_MSG_KEXDH_INIT和SSH_MSG_KEXDH_REPLY代替
|
||||
|
||||
|
||||
// 认证相关
|
||||
SSH_MSG_USERAUTH_REQUEST = 50,
|
||||
SSH_MSG_USERAUTH_FAILURE = 51,
|
||||
SSH_MSG_USERAUTH_SUCCESS = 52,
|
||||
SSH_MSG_USERAUTH_BANNER = 53,
|
||||
SSH_MSG_USERAUTH_PK_OK = 60,
|
||||
|
||||
|
||||
// Channel相关
|
||||
SSH_MSG_GLOBAL_REQUEST = 80,
|
||||
SSH_MSG_REQUEST_SUCCESS = 81,
|
||||
@@ -70,38 +70,38 @@ impl SshPacket {
|
||||
pub fn new(payload: Vec<u8>) -> Self {
|
||||
// 计算padding(SSH协议RFC 4253规范)
|
||||
// 参考OpenSSH packet.c: construct_packet()
|
||||
let block_size = 8; // 未加密阶段block_size=8
|
||||
|
||||
let block_size = 8; // 未加密阶段block_size=8
|
||||
|
||||
let payload_length = payload.len();
|
||||
let min_padding = 4; // OpenSSH要求最少4字节padding
|
||||
|
||||
let min_padding = 4; // OpenSSH要求最少4字节padding
|
||||
|
||||
// SSH协议约束:
|
||||
// packet_length = padding_length + payload_length + 1
|
||||
// (packet_length + 4) 必须是block_size的倍数
|
||||
//
|
||||
//
|
||||
// 计算:
|
||||
// (1 + payload_length + padding_length + 4) % 8 == 0
|
||||
// => (5 + payload_length + padding_length) % 8 == 0
|
||||
|
||||
|
||||
// 先尝试padding=4(最小)
|
||||
let mut padding_length = min_padding as u8;
|
||||
|
||||
|
||||
// 计算packet总长度(包括4字节的packet_length字段)
|
||||
let packet_length = 1 + payload_length + padding_length as usize;
|
||||
let total_length = packet_length + 4; // 加上packet_length字段本身的4字节
|
||||
|
||||
let total_length = packet_length + 4; // 加上packet_length字段本身的4字节
|
||||
|
||||
// 如果总长度不是block_size的倍数,增加padding
|
||||
if total_length % block_size != 0 {
|
||||
if !total_length.is_multiple_of(block_size) {
|
||||
let remainder = total_length % block_size;
|
||||
padding_length += (block_size - remainder) as u8;
|
||||
}
|
||||
|
||||
|
||||
// 重新计算packet_length
|
||||
let packet_length = (1 + payload_length + padding_length as usize) as u32;
|
||||
|
||||
|
||||
// 生成随机padding(简化:使用0)
|
||||
let padding = vec![0u8; padding_length as usize];
|
||||
|
||||
|
||||
Self {
|
||||
packet_length,
|
||||
padding_length,
|
||||
@@ -109,49 +109,49 @@ impl SshPacket {
|
||||
padding,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 写入packet到stream(未加密阶段)
|
||||
/// 参考OpenSSH packet_write_poll()
|
||||
pub fn write<T: Write>(&self, stream: &mut T) -> Result<()> {
|
||||
// 写入packet_length(BigEndian)
|
||||
stream.write_u32::<BigEndian>(self.packet_length)?;
|
||||
|
||||
|
||||
// 写入padding_length
|
||||
stream.write_u8(self.padding_length)?;
|
||||
|
||||
|
||||
// 写入payload
|
||||
stream.write_all(&self.payload)?;
|
||||
|
||||
|
||||
// 写入padding
|
||||
stream.write_all(&self.padding)?;
|
||||
|
||||
|
||||
stream.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 从stream读取packet(未加密阶段)
|
||||
/// 参考OpenSSH packet_read_poll()
|
||||
pub fn read<T: Read>(stream: &mut T) -> Result<Self> {
|
||||
// 读取packet_length(BigEndian)
|
||||
let packet_length = stream.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
// 检查packet长度限制(OpenSSH限制:256KB)
|
||||
if packet_length > 256 * 1024 {
|
||||
return Err(anyhow!("Packet too large: {}", packet_length));
|
||||
}
|
||||
|
||||
|
||||
// 读取padding_length
|
||||
let padding_length = stream.read_u8()?;
|
||||
|
||||
|
||||
// 读取payload(packet_length - padding_length - 1)
|
||||
let payload_length = packet_length - padding_length as u32 - 1;
|
||||
let mut payload = vec![0u8; payload_length as usize];
|
||||
stream.read_exact(&mut payload)?;
|
||||
|
||||
|
||||
// 读取padding
|
||||
let mut padding = vec![0u8; padding_length as usize];
|
||||
stream.read_exact(&mut padding)?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
packet_length,
|
||||
padding_length,
|
||||
@@ -159,15 +159,15 @@ impl SshPacket {
|
||||
padding,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// 获取payload中的packet type
|
||||
pub fn get_type(&self) -> Result<PacketType> {
|
||||
if self.payload.is_empty() {
|
||||
return Err(anyhow!("Empty payload"));
|
||||
}
|
||||
|
||||
|
||||
let type_byte = self.payload[0];
|
||||
|
||||
|
||||
// 转换为PacketType enum
|
||||
match type_byte {
|
||||
1 => Ok(PacketType::SSH_MSG_DISCONNECT),
|
||||
@@ -208,27 +208,27 @@ impl SshPacket {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Cursor;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_packet_creation() {
|
||||
let payload = vec![PacketType::SSH_MSG_KEXINIT as u8];
|
||||
let packet = SshPacket::new(payload);
|
||||
|
||||
|
||||
assert!(packet.packet_length > 0);
|
||||
assert!(packet.padding_length >= 4);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_packet_write_read() {
|
||||
let payload = vec![PacketType::SSH_MSG_KEXINIT as u8];
|
||||
let packet = SshPacket::new(payload);
|
||||
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
packet.write(&mut buffer).unwrap();
|
||||
|
||||
|
||||
let mut cursor = Cursor::new(buffer);
|
||||
let read_packet = SshPacket::read(&mut cursor).unwrap();
|
||||
|
||||
|
||||
assert_eq!(packet.packet_length, read_packet.packet_length);
|
||||
assert_eq!(packet.payload, read_packet.payload);
|
||||
}
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
// SSH端口转发协议实现(Phase 13)
|
||||
// 参考OpenSSH channels.c和RFC 4254
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, warn, debug};
|
||||
use std::net::{TcpListener, TcpStream, SocketAddr};
|
||||
use std::io::{Read, Write};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
use crate::ssh_server::ssh_security_config::SshSecurityConfig;
|
||||
use anyhow::Result;
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.2: 安全配置
|
||||
use log::{info, warn};
|
||||
use std::io::Read;
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::sync::{Arc, Mutex};
|
||||
// Phase 13.2: 安全配置
|
||||
|
||||
/// 端口转发类型(参考RFC 4254)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PortForwardType {
|
||||
Local, // Local port forwarding (-L)
|
||||
Remote, // Remote port forwarding (-R)
|
||||
Dynamic, // Dynamic port forwarding (-D, SOCKS)
|
||||
Local, // Local port forwarding (-L)
|
||||
Remote, // Remote port forwarding (-R)
|
||||
Dynamic, // Dynamic port forwarding (-D, SOCKS)
|
||||
}
|
||||
|
||||
/// 端口转发请求(参考RFC 4254 Section 7)
|
||||
@@ -36,6 +36,12 @@ pub struct PortForwardManager {
|
||||
active_forwards: Arc<Mutex<Vec<(u32, PortForwardType)>>>,
|
||||
}
|
||||
|
||||
impl Default for PortForwardManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PortForwardManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
@@ -46,24 +52,29 @@ impl PortForwardManager {
|
||||
/// 处理SSH_MSG_GLOBAL_REQUEST(端口转发请求)
|
||||
/// 参考RFC 4254 Section 4
|
||||
/// Phase 13.2: 添加安全配置验证
|
||||
pub fn handle_global_request(&mut self, data: &[u8], security_config: &SshSecurityConfig) -> Result<(bool, Option<Vec<u8>>)> {
|
||||
pub fn handle_global_request(
|
||||
&mut self,
|
||||
data: &[u8],
|
||||
security_config: &SshSecurityConfig,
|
||||
) -> Result<(bool, Option<Vec<u8>>)> {
|
||||
info!("Processing SSH_MSG_GLOBAL_REQUEST for port forwarding");
|
||||
|
||||
|
||||
let mut cursor = std::io::Cursor::new(data);
|
||||
cursor.set_position(1); // Skip packet type
|
||||
|
||||
cursor.set_position(1); // Skip packet type
|
||||
|
||||
// 读取请求名称(SSH string)
|
||||
let request_name = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
info!("Global request: {}", request_name);
|
||||
|
||||
|
||||
// 读取want-reply标志
|
||||
let want_reply = cursor.read_u8()? != 0;
|
||||
|
||||
|
||||
match request_name.as_str() {
|
||||
"tcpip-forward" => {
|
||||
// Local port forwarding (-L)
|
||||
self.handle_tcpip_forward(&mut cursor, want_reply, security_config) // Phase 13.2
|
||||
self.handle_tcpip_forward(&mut cursor, want_reply, security_config)
|
||||
// Phase 13.2
|
||||
}
|
||||
"cancel-tcpip-forward" => {
|
||||
// Cancel port forwarding
|
||||
@@ -84,29 +95,37 @@ impl PortForwardManager {
|
||||
/// 处理tcpip-forward请求(Local port forwarding)
|
||||
/// 参考RFC 4254 Section 7.1
|
||||
/// Phase 13.2: 添加安全配置验证
|
||||
fn handle_tcpip_forward(&mut self, cursor: &mut std::io::Cursor<&[u8]>, want_reply: bool, security_config: &SshSecurityConfig) -> Result<(bool, Option<Vec<u8>>)> {
|
||||
fn handle_tcpip_forward(
|
||||
&mut self,
|
||||
cursor: &mut std::io::Cursor<&[u8]>,
|
||||
want_reply: bool,
|
||||
security_config: &SshSecurityConfig,
|
||||
) -> Result<(bool, Option<Vec<u8>>)> {
|
||||
// 读取bind address(SSH string)
|
||||
let bind_address = read_ssh_string(cursor)?;
|
||||
|
||||
|
||||
// 读取bind port
|
||||
let bind_port = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
info!("tcpip-forward request: bind_address={}, bind_port={}", bind_address, bind_port);
|
||||
|
||||
|
||||
info!(
|
||||
"tcpip-forward request: bind_address={}, bind_port={}",
|
||||
bind_address, bind_port
|
||||
);
|
||||
|
||||
// Phase 13.2: 安全配置验证
|
||||
if let Err(e) = security_config.validate_tcpip_forward_request(&bind_address, bind_port) {
|
||||
warn!("tcpip-forward security validation failed: {}", e);
|
||||
return Ok((false, None)); // 拒绝请求
|
||||
return Ok((false, None)); // 拒绝请求
|
||||
}
|
||||
|
||||
|
||||
info!("tcpip-forward security validation passed");
|
||||
|
||||
|
||||
// 添加到active forwards
|
||||
let mut forwards = self.active_forwards.lock().unwrap();
|
||||
forwards.push((bind_port, PortForwardType::Local));
|
||||
|
||||
|
||||
info!("tcpip-forward registered: bind_port={}", bind_port);
|
||||
|
||||
|
||||
// 返回成功响应(包含bind_port)
|
||||
if want_reply {
|
||||
let response = self.build_global_request_response(true, Some(bind_port))?;
|
||||
@@ -117,16 +136,23 @@ impl PortForwardManager {
|
||||
}
|
||||
|
||||
/// 处理cancel-tcpip-forward请求
|
||||
fn handle_cancel_tcpip_forward(&mut self, cursor: &mut std::io::Cursor<&[u8]>, want_reply: bool) -> Result<(bool, Option<Vec<u8>>)> {
|
||||
fn handle_cancel_tcpip_forward(
|
||||
&mut self,
|
||||
cursor: &mut std::io::Cursor<&[u8]>,
|
||||
want_reply: bool,
|
||||
) -> Result<(bool, Option<Vec<u8>>)> {
|
||||
let bind_address = read_ssh_string(cursor)?;
|
||||
let bind_port = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
info!("cancel-tcpip-forward: bind_address={}, bind_port={}", bind_address, bind_port);
|
||||
|
||||
|
||||
info!(
|
||||
"cancel-tcpip-forward: bind_address={}, bind_port={}",
|
||||
bind_address, bind_port
|
||||
);
|
||||
|
||||
// 移除active forward
|
||||
let mut forwards = self.active_forwards.lock().unwrap();
|
||||
forwards.retain(|(port, _)| *port != bind_port);
|
||||
|
||||
|
||||
if want_reply {
|
||||
let response = self.build_global_request_response(true, None)?;
|
||||
Ok((true, Some(response)))
|
||||
@@ -136,14 +162,18 @@ impl PortForwardManager {
|
||||
}
|
||||
|
||||
/// 构建SSH_MSG_REQUEST_SUCCESS/FAILURE响应
|
||||
fn build_global_request_response(&self, success: bool, bound_port: Option<u32>) -> Result<Vec<u8>> {
|
||||
fn build_global_request_response(
|
||||
&self,
|
||||
success: bool,
|
||||
bound_port: Option<u32>,
|
||||
) -> Result<Vec<u8>> {
|
||||
use crate::ssh_server::packet::PacketType;
|
||||
|
||||
|
||||
let mut response = Vec::new();
|
||||
|
||||
|
||||
if success {
|
||||
response.write_u8(PacketType::SSH_MSG_REQUEST_SUCCESS as u8)?;
|
||||
|
||||
|
||||
// 如果有bound_port,写入(用于tcpip-forward响应)
|
||||
if let Some(port) = bound_port {
|
||||
response.write_u32::<BigEndian>(port)?;
|
||||
@@ -151,7 +181,7 @@ impl PortForwardManager {
|
||||
} else {
|
||||
response.write_u8(PacketType::SSH_MSG_REQUEST_FAILURE as u8)?;
|
||||
}
|
||||
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -159,37 +189,39 @@ impl PortForwardManager {
|
||||
/// 参考RFC 4254 Section 7.2
|
||||
pub fn handle_direct_tcpip_channel(&mut self, data: &[u8]) -> Result<DirectTcpipChannel> {
|
||||
info!("Processing direct-tcpip channel open");
|
||||
|
||||
|
||||
let mut cursor = std::io::Cursor::new(data);
|
||||
cursor.set_position(1); // Skip packet type
|
||||
|
||||
cursor.set_position(1); // Skip packet type
|
||||
|
||||
// 读取channel type(已知道是"direct-tcpip",跳过)
|
||||
let _channel_type = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
// 读取sender_channel
|
||||
let sender_channel = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
// 读取initial window size
|
||||
let initial_window_size = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
// 读取maximum packet size
|
||||
let max_packet_size = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
// 读取host to connect(SSH string)
|
||||
let host_to_connect = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
// 读取port to connect
|
||||
let port_to_connect = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
// 读取originator address(SSH string)
|
||||
let originator_address = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
// 读取originator port
|
||||
let originator_port = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
info!("direct-tcpip: host={}, port={}, originator={}:{}",
|
||||
host_to_connect, port_to_connect, originator_address, originator_port);
|
||||
|
||||
|
||||
info!(
|
||||
"direct-tcpip: host={}, port={}, originator={}:{}",
|
||||
host_to_connect, port_to_connect, originator_address, originator_port
|
||||
);
|
||||
|
||||
Ok(DirectTcpipChannel {
|
||||
sender_channel,
|
||||
initial_window_size,
|
||||
@@ -205,30 +237,32 @@ impl PortForwardManager {
|
||||
/// 参考RFC 4254 Section 7.1
|
||||
pub fn handle_forwarded_tcpip_channel(&mut self, data: &[u8]) -> Result<ForwardedTcpipChannel> {
|
||||
info!("Processing forwarded-tcpip channel open");
|
||||
|
||||
|
||||
let mut cursor = std::io::Cursor::new(data);
|
||||
cursor.set_position(1);
|
||||
|
||||
|
||||
let _channel_type = read_ssh_string(&mut cursor)?;
|
||||
let sender_channel = cursor.read_u32::<BigEndian>()?;
|
||||
let initial_window_size = cursor.read_u32::<BigEndian>()?;
|
||||
let max_packet_size = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
// 读取bind address(SSH string)
|
||||
let bind_address = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
// 读取bind port
|
||||
let bind_port = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
// 读取originator address(SSH string)
|
||||
let originator_address = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
// 读取originator port
|
||||
let originator_port = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
info!("forwarded-tcpip: bind={}:{}, originator={}:{}",
|
||||
bind_address, bind_port, originator_address, originator_port);
|
||||
|
||||
|
||||
info!(
|
||||
"forwarded-tcpip: bind={}:{}, originator={}:{}",
|
||||
bind_address, bind_port, originator_address, originator_port
|
||||
);
|
||||
|
||||
Ok(ForwardedTcpipChannel {
|
||||
sender_channel,
|
||||
initial_window_size,
|
||||
@@ -244,10 +278,10 @@ impl PortForwardManager {
|
||||
pub fn connect_to_target(host: &str, port: u32) -> Result<TcpStream> {
|
||||
let addr = format!("{}:{}", host, port);
|
||||
info!("Connecting to target: {}", addr);
|
||||
|
||||
|
||||
let stream = TcpStream::connect(&addr)?;
|
||||
info!("Connected to target successfully");
|
||||
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
@@ -258,12 +292,12 @@ impl PortForwardManager {
|
||||
} else {
|
||||
format!("{}:{}", bind_address, bind_port)
|
||||
};
|
||||
|
||||
|
||||
info!("Creating listener on {}", addr);
|
||||
|
||||
|
||||
let listener = TcpListener::bind(&addr)?;
|
||||
info!("Listener created successfully");
|
||||
|
||||
|
||||
Ok(listener)
|
||||
}
|
||||
}
|
||||
@@ -303,10 +337,10 @@ fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_port_forward_manager_creation() {
|
||||
let manager = PortForwardManager::new();
|
||||
assert_eq!(manager.active_forwards.lock().unwrap().len(), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
// SSH端口转发监听线程(Phase 13.4)
|
||||
// 参考OpenSSH channels.c: channel_forward_listener
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, warn, debug, error};
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::thread;
|
||||
use std::sync::{Arc, Mutex, mpsc};
|
||||
use std::io::{Read, Write};
|
||||
use byteorder::{BigEndian, WriteBytesExt};
|
||||
use crate::ssh_server::packet::PacketType;
|
||||
use crate::ssh_server::ssh_security_config::SshSecurityConfig;
|
||||
use anyhow::Result;
|
||||
use log::{error, info, warn};
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::sync::{mpsc, Arc, Mutex};
|
||||
use std::thread;
|
||||
|
||||
/// 监听器状态(Phase 13.4)
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -30,28 +27,18 @@ pub enum ListenerRequest {
|
||||
stream: TcpStream,
|
||||
},
|
||||
/// 停止监听
|
||||
StopListener {
|
||||
bind_port: u32,
|
||||
},
|
||||
StopListener { bind_port: u32 },
|
||||
}
|
||||
|
||||
/// 监听器响应(Phase 13.4:线程通信)
|
||||
#[derive(Debug)]
|
||||
pub enum ListenerResponse {
|
||||
/// Channel创建成功
|
||||
ChannelCreated {
|
||||
bind_port: u32,
|
||||
channel_id: u32,
|
||||
},
|
||||
ChannelCreated { bind_port: u32, channel_id: u32 },
|
||||
/// 监听器停止
|
||||
ListenerStopped {
|
||||
bind_port: u32,
|
||||
},
|
||||
ListenerStopped { bind_port: u32 },
|
||||
/// 错误
|
||||
Error {
|
||||
bind_port: u32,
|
||||
message: String,
|
||||
},
|
||||
Error { bind_port: u32, message: String },
|
||||
}
|
||||
|
||||
/// 端口转发监听器(Phase 13.4)
|
||||
@@ -73,26 +60,29 @@ impl PortForwardListener {
|
||||
security_config: SshSecurityConfig,
|
||||
) -> Result<Self> {
|
||||
info!("Creating port forward listener on port {}", bind_port);
|
||||
|
||||
|
||||
// Phase 13.4: 根据GatewayPorts决定绑定地址
|
||||
let bind_addr = if security_config.gateway_ports {
|
||||
format!("0.0.0.0:{}", bind_port) // 允许外部访问
|
||||
format!("0.0.0.0:{}", bind_port) // 允许外部访问
|
||||
} else {
|
||||
format!("127.0.0.1:{}", bind_port) // 只允许本地访问
|
||||
format!("127.0.0.1:{}", bind_port) // 只允许本地访问
|
||||
};
|
||||
|
||||
info!("Binding to address: {} (GatewayPorts={})", bind_addr, security_config.gateway_ports);
|
||||
|
||||
|
||||
info!(
|
||||
"Binding to address: {} (GatewayPorts={})",
|
||||
bind_addr, security_config.gateway_ports
|
||||
);
|
||||
|
||||
let listener = TcpListener::bind(&bind_addr)?;
|
||||
info!("Listener created successfully on {}", bind_addr);
|
||||
|
||||
|
||||
// Phase 13.4: 创建线程通信channel
|
||||
let (request_tx, request_rx) = mpsc::channel();
|
||||
let (response_tx, response_rx) = mpsc::channel();
|
||||
|
||||
let (request_tx, _request_rx) = mpsc::channel();
|
||||
let (_response_tx, response_rx) = mpsc::channel();
|
||||
|
||||
// Phase 13.4: 活动状态标记
|
||||
let active = Arc::new(Mutex::new(true));
|
||||
|
||||
|
||||
Ok(Self {
|
||||
bind_port,
|
||||
bind_address,
|
||||
@@ -103,38 +93,38 @@ impl PortForwardListener {
|
||||
active,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// 启动监听线程(Phase 13.4)
|
||||
pub fn start_listener_thread(&mut self) -> Result<()> {
|
||||
info!("Starting listener thread for port {}", self.bind_port);
|
||||
|
||||
|
||||
let listener = self.listener.try_clone()?;
|
||||
let bind_port = self.bind_port;
|
||||
let request_sender = self.request_sender.clone();
|
||||
let active = self.active.clone();
|
||||
|
||||
|
||||
// Phase 13.4: 创建独立监听线程
|
||||
thread::spawn(move || {
|
||||
info!("Listener thread started for port {}", bind_port);
|
||||
|
||||
|
||||
while *active.lock().unwrap() {
|
||||
match listener.accept() {
|
||||
Ok((stream, addr)) => {
|
||||
info!("New connection on port {}: {}", bind_port, addr);
|
||||
|
||||
|
||||
// Phase 13.4: 发送新连接请求给主线程
|
||||
let request = ListenerRequest::NewConnection {
|
||||
bind_port,
|
||||
originator_address: addr.ip().to_string(),
|
||||
originator_port: addr.port() as u32, // Phase 13.4: u16转u32
|
||||
originator_port: addr.port() as u32, // Phase 13.4: u16转u32
|
||||
stream,
|
||||
};
|
||||
|
||||
|
||||
if let Err(e) = request_sender.send(request) {
|
||||
error!("Failed to send listener request: {}", e);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
info!("Listener request sent to main thread");
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -145,32 +135,32 @@ impl PortForwardListener {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
info!("Listener thread stopped for port {}", bind_port);
|
||||
});
|
||||
|
||||
|
||||
info!("Listener thread started successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 停止监听器(Phase 13.4)
|
||||
pub fn stop_listener(&mut self) -> Result<()> {
|
||||
info!("Stopping listener for port {}", self.bind_port);
|
||||
|
||||
|
||||
// Phase 13.4: 设置active=false,线程会自动退出
|
||||
*self.active.lock().unwrap() = false;
|
||||
|
||||
|
||||
info!("Listener stopped for port {}", self.bind_port);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 获取请求接收器(Phase 13.4)
|
||||
pub fn get_request_receiver(&self) -> mpsc::Receiver<ListenerRequest> {
|
||||
// 注意:这里需要返回一个新的receiver,因为mpsc::Sender可以clone,但Receiver不能
|
||||
// 实际应用中应该使用更复杂的channel设计
|
||||
unimplemented!("Use Arc<Mutex<mpsc::Receiver>> instead")
|
||||
}
|
||||
|
||||
|
||||
/// 获取活动状态(Phase 13.4)
|
||||
pub fn is_active(&self) -> bool {
|
||||
*self.active.lock().unwrap()
|
||||
@@ -182,13 +172,19 @@ pub struct ListenerManager {
|
||||
listeners: HashMap<u32, Arc<Mutex<PortForwardListener>>>,
|
||||
}
|
||||
|
||||
impl Default for ListenerManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ListenerManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
listeners: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 创建并启动监听器(Phase 13.4)
|
||||
pub fn create_listener(
|
||||
&mut self,
|
||||
@@ -197,21 +193,21 @@ impl ListenerManager {
|
||||
security_config: SshSecurityConfig,
|
||||
) -> Result<()> {
|
||||
info!("Creating listener for port {}", bind_port);
|
||||
|
||||
|
||||
let mut listener = PortForwardListener::new(bind_port, bind_address, security_config)?;
|
||||
listener.start_listener_thread()?;
|
||||
|
||||
|
||||
let listener_arc = Arc::new(Mutex::new(listener));
|
||||
self.listeners.insert(bind_port, listener_arc);
|
||||
|
||||
|
||||
info!("Listener created and started for port {}", bind_port);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 停止监听器(Phase 13.4)
|
||||
pub fn stop_listener(&mut self, bind_port: u32) -> Result<()> {
|
||||
info!("Stopping listener for port {}", bind_port);
|
||||
|
||||
|
||||
if let Some(listener_arc) = self.listeners.remove(&bind_port) {
|
||||
let mut listener = listener_arc.lock().unwrap();
|
||||
listener.stop_listener()?;
|
||||
@@ -219,28 +215,31 @@ impl ListenerManager {
|
||||
} else {
|
||||
warn!("No listener found for port {}", bind_port);
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 获取活动监听器数量(Phase 13.4)
|
||||
pub fn active_count(&self) -> usize {
|
||||
self.listeners.values().filter(|l| l.lock().unwrap().is_active()).count()
|
||||
self.listeners
|
||||
.values()
|
||||
.filter(|l| l.lock().unwrap().is_active())
|
||||
.count()
|
||||
}
|
||||
}
|
||||
|
||||
use std::collections::HashMap; // Phase 13.4: HashMap for listener management
|
||||
use std::collections::HashMap; // Phase 13.4: HashMap for listener management
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_listener_creation() {
|
||||
let security_config = SshSecurityConfig::enterprise_default();
|
||||
let listener = PortForwardListener::new(8080, "127.0.0.1".to_string(), security_config);
|
||||
|
||||
|
||||
// 注意:实际测试需要处理端口占用问题
|
||||
assert!(listener.is_ok() || true); // 暂时跳过测试
|
||||
assert!(listener.is_ok() || true); // 暂时跳过测试
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::path::PathBuf;
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, debug, warn};
|
||||
use crate::vfs::{VfsBackend, VfsFile, VfsError};
|
||||
use crate::vfs::open_flags::OpenFlags;
|
||||
use crate::vfs::{VfsBackend, VfsFile};
|
||||
use anyhow::{anyhow, Result};
|
||||
use log::{debug, info, warn};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// MPLEX_BASE from rsync io.h
|
||||
const MPLEX_BASE: u32 = 7;
|
||||
@@ -18,7 +18,9 @@ pub(crate) enum RsyncState {
|
||||
WaitVersion,
|
||||
ReadFileList,
|
||||
/// Sum head (4 × write_int = 16 bytes) + checksum seed (4 bytes) = 20 bytes
|
||||
ReadSumHead { need: usize },
|
||||
ReadSumHead {
|
||||
need: usize,
|
||||
},
|
||||
SendSumCount,
|
||||
/// Raw file data from MSG_DATA packets
|
||||
ReadFileData,
|
||||
@@ -51,9 +53,16 @@ impl RsyncHandler {
|
||||
let mut dest = String::new();
|
||||
|
||||
for p in &parts[1..] {
|
||||
if *p == "--server" { is_server = true; continue; }
|
||||
if *p == "--sender" || p.starts_with('-') { continue; }
|
||||
if *p == "." { continue; }
|
||||
if *p == "--server" {
|
||||
is_server = true;
|
||||
continue;
|
||||
}
|
||||
if *p == "--sender" || p.starts_with('-') {
|
||||
continue;
|
||||
}
|
||||
if *p == "." {
|
||||
continue;
|
||||
}
|
||||
dest = p.to_string();
|
||||
}
|
||||
|
||||
@@ -107,8 +116,10 @@ impl RsyncHandler {
|
||||
break;
|
||||
}
|
||||
let header = u32::from_le_bytes([
|
||||
self.raw_input[0], self.raw_input[1],
|
||||
self.raw_input[2], self.raw_input[3],
|
||||
self.raw_input[0],
|
||||
self.raw_input[1],
|
||||
self.raw_input[2],
|
||||
self.raw_input[3],
|
||||
]);
|
||||
let raw_tag = ((header >> 24) & 0xFF) as u8;
|
||||
let tag = raw_tag.wrapping_sub(MPLEX_BASE as u8);
|
||||
@@ -182,12 +193,17 @@ impl RsyncHandler {
|
||||
RsyncState::WaitVersion => {
|
||||
if self.rsync_input.len() >= 4 {
|
||||
let version = u32::from_le_bytes([
|
||||
self.rsync_input[0], self.rsync_input[1],
|
||||
self.rsync_input[2], self.rsync_input[3],
|
||||
self.rsync_input[0],
|
||||
self.rsync_input[1],
|
||||
self.rsync_input[2],
|
||||
self.rsync_input[3],
|
||||
]);
|
||||
self.rsync_input.drain(..4);
|
||||
self.protocol_version = std::cmp::min(self.protocol_version, version);
|
||||
info!("rsync: negotiated protocol version {}", self.protocol_version);
|
||||
info!(
|
||||
"rsync: negotiated protocol version {}",
|
||||
self.protocol_version
|
||||
);
|
||||
self.multiplex = self.protocol_version >= 30;
|
||||
self.transition(RsyncState::ReadFileList);
|
||||
} else {
|
||||
@@ -197,7 +213,9 @@ impl RsyncHandler {
|
||||
|
||||
RsyncState::ReadFileList => {
|
||||
loop {
|
||||
if self.rsync_input.is_empty() { break; }
|
||||
if self.rsync_input.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
let flags = self.rsync_input[0];
|
||||
if flags == 0 {
|
||||
@@ -215,17 +233,25 @@ impl RsyncHandler {
|
||||
let mut pos = 1;
|
||||
|
||||
let _more_flags = if flags & 0x80 != 0 {
|
||||
if self.rsync_input.len() <= pos { break; }
|
||||
if self.rsync_input.len() <= pos {
|
||||
break;
|
||||
}
|
||||
let ef = self.rsync_input[pos];
|
||||
pos += 1;
|
||||
ef
|
||||
} else { 0 };
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let has_name = !(flags & 0x02 != 0 && self.current_file > 0);
|
||||
|
||||
if has_name {
|
||||
if let Some(nul_pos) = self.rsync_input[pos..].iter().position(|&b| b == 0) {
|
||||
let name = String::from_utf8_lossy(&self.rsync_input[pos..pos + nul_pos]).to_string();
|
||||
if let Some(nul_pos) =
|
||||
self.rsync_input[pos..].iter().position(|&b| b == 0)
|
||||
{
|
||||
let name =
|
||||
String::from_utf8_lossy(&self.rsync_input[pos..pos + nul_pos])
|
||||
.to_string();
|
||||
pos += nul_pos + 1;
|
||||
self.file_entries.push(name.clone());
|
||||
debug!("rsync: file entry: {}", name);
|
||||
@@ -269,24 +295,34 @@ impl RsyncHandler {
|
||||
RsyncState::ReadSumHead { need } => {
|
||||
if self.rsync_input.len() >= need {
|
||||
let sum_count = i32::from_le_bytes([
|
||||
self.rsync_input[0], self.rsync_input[1],
|
||||
self.rsync_input[2], self.rsync_input[3],
|
||||
self.rsync_input[0],
|
||||
self.rsync_input[1],
|
||||
self.rsync_input[2],
|
||||
self.rsync_input[3],
|
||||
]);
|
||||
let _sum_blength = i32::from_le_bytes([
|
||||
self.rsync_input[4], self.rsync_input[5],
|
||||
self.rsync_input[6], self.rsync_input[7],
|
||||
self.rsync_input[4],
|
||||
self.rsync_input[5],
|
||||
self.rsync_input[6],
|
||||
self.rsync_input[7],
|
||||
]);
|
||||
let _sum_s2length = i32::from_le_bytes([
|
||||
self.rsync_input[8], self.rsync_input[9],
|
||||
self.rsync_input[10], self.rsync_input[11],
|
||||
self.rsync_input[8],
|
||||
self.rsync_input[9],
|
||||
self.rsync_input[10],
|
||||
self.rsync_input[11],
|
||||
]);
|
||||
let _sum_remainder = i32::from_le_bytes([
|
||||
self.rsync_input[12], self.rsync_input[13],
|
||||
self.rsync_input[14], self.rsync_input[15],
|
||||
self.rsync_input[12],
|
||||
self.rsync_input[13],
|
||||
self.rsync_input[14],
|
||||
self.rsync_input[15],
|
||||
]);
|
||||
let checksum_seed = i32::from_le_bytes([
|
||||
self.rsync_input[16], self.rsync_input[17],
|
||||
self.rsync_input[18], self.rsync_input[19],
|
||||
self.rsync_input[16],
|
||||
self.rsync_input[17],
|
||||
self.rsync_input[18],
|
||||
self.rsync_input[19],
|
||||
]);
|
||||
self.rsync_input.drain(..20);
|
||||
|
||||
@@ -308,7 +344,9 @@ impl RsyncHandler {
|
||||
|
||||
RsyncState::ReadFileData => {
|
||||
let done_marker = b"RSYNCDONE";
|
||||
if let Some(pos) = self.rsync_input.windows(done_marker.len())
|
||||
if let Some(pos) = self
|
||||
.rsync_input
|
||||
.windows(done_marker.len())
|
||||
.position(|w| w == done_marker)
|
||||
{
|
||||
if pos > 0 {
|
||||
@@ -323,8 +361,11 @@ impl RsyncHandler {
|
||||
warn!("rsync flush error: {}", e);
|
||||
}
|
||||
}
|
||||
info!("rsync: file {} complete ({} bytes written to {})",
|
||||
self.file_entries.get(self.current_file).unwrap_or(&"?".to_string()),
|
||||
info!(
|
||||
"rsync: file {} complete ({} bytes written to {})",
|
||||
self.file_entries
|
||||
.get(self.current_file)
|
||||
.unwrap_or(&"?".to_string()),
|
||||
self.total_written,
|
||||
self.dest_path.display(),
|
||||
);
|
||||
@@ -332,8 +373,11 @@ impl RsyncHandler {
|
||||
self.current_file += 1;
|
||||
if self.current_file >= self.file_entries.len() {
|
||||
self.transition(RsyncState::Done);
|
||||
info!("rsync ALL DONE: {} bytes written to {}",
|
||||
self.total_written, self.dest_path.display());
|
||||
info!(
|
||||
"rsync ALL DONE: {} bytes written to {}",
|
||||
self.total_written,
|
||||
self.dest_path.display()
|
||||
);
|
||||
} else {
|
||||
self.transition(RsyncState::ReadSumHead { need: 20 });
|
||||
}
|
||||
@@ -360,7 +404,9 @@ impl RsyncHandler {
|
||||
self.vfs.create_dir_all(parent, 0o755).ok();
|
||||
}
|
||||
let flags = OpenFlags::new().write().create().truncate();
|
||||
let file = self.vfs.open_file(&self.dest_path, &flags)
|
||||
let file = self
|
||||
.vfs
|
||||
.open_file(&self.dest_path, &flags)
|
||||
.map_err(|e| anyhow!("open error: {}", e))?;
|
||||
self.output_file = Some(file);
|
||||
info!("rsync: opened {} for writing", self.dest_path.display());
|
||||
@@ -379,31 +425,43 @@ impl RsyncHandler {
|
||||
|
||||
/// Read rsync varint (LSB-first 7-bit groups, 0xFF prefix for negative)
|
||||
fn read_varint(buf: &[u8]) -> Option<(i32, usize)> {
|
||||
if buf.is_empty() { return None; }
|
||||
if buf.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut pos = 0;
|
||||
let mut b = buf[pos];
|
||||
pos += 1;
|
||||
|
||||
let neg = if b == 0xFF {
|
||||
if pos >= buf.len() { return None; }
|
||||
if pos >= buf.len() {
|
||||
return None;
|
||||
}
|
||||
b = buf[pos];
|
||||
pos += 1;
|
||||
true
|
||||
} else { false };
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let mut x = (b & 0x7F) as i32;
|
||||
let mut shift = 7;
|
||||
|
||||
while b & 0x80 != 0 {
|
||||
if pos >= buf.len() { return None; }
|
||||
if pos >= buf.len() {
|
||||
return None;
|
||||
}
|
||||
b = buf[pos];
|
||||
pos += 1;
|
||||
x |= ((b & 0x7F) as i32) << shift;
|
||||
shift += 7;
|
||||
}
|
||||
|
||||
if neg { Some((-x, pos)) } else { Some((x, pos)) }
|
||||
if neg {
|
||||
Some((-x, pos))
|
||||
} else {
|
||||
Some((x, pos))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -419,8 +477,9 @@ mod tests {
|
||||
fn test_parse_command() {
|
||||
let h = RsyncHandler::parse_rsync_command(
|
||||
"rsync --server -g -l -o -p -D -r -t -v --dirs . /tmp/upload.bin",
|
||||
make_vfs()
|
||||
).unwrap();
|
||||
make_vfs(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(h.dest_path, PathBuf::from("/tmp/upload.bin"));
|
||||
}
|
||||
|
||||
@@ -428,14 +487,16 @@ mod tests {
|
||||
fn test_parse_command_sender() {
|
||||
let h = RsyncHandler::parse_rsync_command(
|
||||
"rsync --server --sender -vlogDtprz . /home/user/file.txt",
|
||||
make_vfs()
|
||||
).unwrap();
|
||||
make_vfs(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(h.dest_path, PathBuf::from("/home/user/file.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_version_exchange() {
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap();
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs())
|
||||
.unwrap();
|
||||
let output = h.drain_output();
|
||||
assert_eq!(output, b"\x1e\x00\x00\x00");
|
||||
assert_eq!(h.state, RsyncState::WaitVersion);
|
||||
@@ -447,7 +508,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_version_negotiate_down() {
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap();
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs())
|
||||
.unwrap();
|
||||
let _ = h.drain_output();
|
||||
h.feed(b"\x1d\x00\x00\x00").unwrap();
|
||||
assert_eq!(h.protocol_version, 29);
|
||||
@@ -464,26 +526,33 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_file_list_multiplex() {
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap();
|
||||
let mut h =
|
||||
RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs())
|
||||
.unwrap();
|
||||
let _ = h.drain_output();
|
||||
h.feed(b"\x1e\x00\x00\x00").unwrap();
|
||||
assert!(h.multiplex);
|
||||
|
||||
let mut flist = Vec::new();
|
||||
// File list: flags=1 (has name), then name with NUL terminator
|
||||
flist.push(1); // flags: has name
|
||||
flist.push(1); // flags: has name
|
||||
flist.extend_from_slice(b"test.txt");
|
||||
flist.push(0); // name terminator
|
||||
flist.push(0); // name terminator
|
||||
|
||||
fn write_varint(buf: &mut Vec<u8>, val: i32) {
|
||||
if val == 0 { buf.push(0); return; }
|
||||
if val == 0 {
|
||||
buf.push(0);
|
||||
return;
|
||||
}
|
||||
if val < 0 {
|
||||
buf.push(0xFF);
|
||||
let mut v = (-val) as u32;
|
||||
while v > 0 {
|
||||
let mut byte = (v & 0x7F) as u8;
|
||||
v >>= 7;
|
||||
if v > 0 { byte |= 0x80; }
|
||||
if v > 0 {
|
||||
byte |= 0x80;
|
||||
}
|
||||
buf.push(byte);
|
||||
}
|
||||
} else {
|
||||
@@ -491,7 +560,9 @@ mod tests {
|
||||
while v > 0 {
|
||||
let mut byte = (v & 0x7F) as u8;
|
||||
v >>= 7;
|
||||
if v > 0 { byte |= 0x80; }
|
||||
if v > 0 {
|
||||
byte |= 0x80;
|
||||
}
|
||||
buf.push(byte);
|
||||
}
|
||||
}
|
||||
@@ -502,7 +573,7 @@ mod tests {
|
||||
write_varint(&mut flist, 1700000000);
|
||||
write_varint(&mut flist, 100);
|
||||
write_varint(&mut flist, 0);
|
||||
flist.push(0); // file list end marker
|
||||
flist.push(0); // file list end marker
|
||||
|
||||
let mut sum_head = Vec::new();
|
||||
sum_head.extend_from_slice(&0i32.to_le_bytes());
|
||||
@@ -527,22 +598,51 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_file_data_multiplex() {
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap();
|
||||
let mut h =
|
||||
RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs())
|
||||
.unwrap();
|
||||
let _ = h.drain_output();
|
||||
h.feed(b"\x1e\x00\x00\x00").unwrap();
|
||||
|
||||
let mut flist = Vec::new();
|
||||
flist.push(1); // flags: has name
|
||||
flist.push(1); // flags: has name
|
||||
flist.extend_from_slice(b"test.bin");
|
||||
flist.push(0);
|
||||
fn wv(buf: &mut Vec<u8>, val: i32) {
|
||||
if val == 0 { buf.push(0); return; }
|
||||
if val < 0 { buf.push(0xFF); let mut v = (-val) as u32; while v > 0 { let mut byte = (v & 0x7F) as u8; v >>= 7; if v > 0 { byte |= 0x80; } buf.push(byte); } }
|
||||
else { let mut v = val as u32; while v > 0 { let mut byte = (v & 0x7F) as u8; v >>= 7; if v > 0 { byte |= 0x80; } buf.push(byte); } }
|
||||
if val == 0 {
|
||||
buf.push(0);
|
||||
return;
|
||||
}
|
||||
if val < 0 {
|
||||
buf.push(0xFF);
|
||||
let mut v = (-val) as u32;
|
||||
while v > 0 {
|
||||
let mut byte = (v & 0x7F) as u8;
|
||||
v >>= 7;
|
||||
if v > 0 {
|
||||
byte |= 0x80;
|
||||
}
|
||||
buf.push(byte);
|
||||
}
|
||||
} else {
|
||||
let mut v = val as u32;
|
||||
while v > 0 {
|
||||
let mut byte = (v & 0x7F) as u8;
|
||||
v >>= 7;
|
||||
if v > 0 {
|
||||
byte |= 0x80;
|
||||
}
|
||||
buf.push(byte);
|
||||
}
|
||||
}
|
||||
}
|
||||
wv(&mut flist, 33188); wv(&mut flist, 501); wv(&mut flist, 20);
|
||||
wv(&mut flist, 1700000000); wv(&mut flist, 100); wv(&mut flist, 0);
|
||||
flist.push(0); // file list end
|
||||
wv(&mut flist, 33188);
|
||||
wv(&mut flist, 501);
|
||||
wv(&mut flist, 20);
|
||||
wv(&mut flist, 1700000000);
|
||||
wv(&mut flist, 100);
|
||||
wv(&mut flist, 0);
|
||||
flist.push(0); // file list end
|
||||
h.feed(&build_multiplex(&flist)).unwrap();
|
||||
|
||||
let mut sh = Vec::new();
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
// SCP协议实现(Phase 8)
|
||||
// 参考OpenSSH scp.c源码
|
||||
|
||||
use crate::vfs::{VfsBackend, VfsFile, VfsError, VfsStat};
|
||||
use crate::vfs::open_flags::OpenFlags;
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, warn, debug};
|
||||
use crate::vfs::{VfsBackend, VfsFile, VfsStat};
|
||||
use anyhow::{anyhow, Result};
|
||||
use log::{debug, info, warn};
|
||||
use std::io::{BufRead, Read, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::io::{Read, Write, BufRead};
|
||||
use std::time::SystemTime;
|
||||
|
||||
/// SCP Handler(参考OpenSSH scp.c)
|
||||
pub struct ScpHandler {
|
||||
@@ -38,13 +37,13 @@ impl ScpHandler {
|
||||
/// 解析SCP命令(参考OpenSSH scp.c: parse_command())
|
||||
pub fn parse_scp_command(command: &str, vfs: Box<dyn VfsBackend>) -> Result<Self> {
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
|
||||
|
||||
if parts.len() < 2 || parts[0] != "scp" {
|
||||
return Err(anyhow!("Invalid SCP command: {}", command));
|
||||
}
|
||||
|
||||
let mut handler = ScpHandler::new(PathBuf::from("/tmp"), vfs);
|
||||
|
||||
|
||||
for part in &parts[1..] {
|
||||
match part {
|
||||
&"-f" => handler.mode = ScpMode::Source,
|
||||
@@ -71,10 +70,15 @@ impl ScpHandler {
|
||||
|
||||
/// SCP Source Mode(scp -f,发送文件)
|
||||
fn handle_source_mode(&self, channel: &mut dyn ReadWrite) -> Result<()> {
|
||||
info!("SCP source mode: sending files from {}", self.root_dir.display());
|
||||
info!(
|
||||
"SCP source mode: sending files from {}",
|
||||
self.root_dir.display()
|
||||
);
|
||||
|
||||
let full_path = self.resolve_path(&self.root_dir.to_string_lossy())?;
|
||||
let stat = self.vfs.stat(&full_path)
|
||||
let stat = self
|
||||
.vfs
|
||||
.stat(&full_path)
|
||||
.map_err(|e| anyhow!("stat error: {}", e))?;
|
||||
|
||||
if stat.is_dir {
|
||||
@@ -91,16 +95,19 @@ impl ScpHandler {
|
||||
|
||||
/// SCP Destination Mode(scp -t,接收文件)
|
||||
fn handle_destination_mode(&mut self, channel: &mut dyn ReadWrite) -> Result<()> {
|
||||
info!("SCP destination mode: receiving files to {}", self.root_dir.display());
|
||||
info!(
|
||||
"SCP destination mode: receiving files to {}",
|
||||
self.root_dir.display()
|
||||
);
|
||||
|
||||
channel.write_all(&[0])?;
|
||||
channel.flush()?;
|
||||
|
||||
|
||||
let mut buffer = String::new();
|
||||
|
||||
|
||||
loop {
|
||||
buffer.clear();
|
||||
|
||||
|
||||
let mut reader = std::io::BufReader::new(&mut *channel);
|
||||
match reader.read_line(&mut buffer)? {
|
||||
0 => break,
|
||||
@@ -130,7 +137,9 @@ impl ScpHandler {
|
||||
|
||||
/// 发送文件(参考OpenSSH scp.c: source())
|
||||
fn send_file(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
|
||||
let stat = self.vfs.stat(path)
|
||||
let stat = self
|
||||
.vfs
|
||||
.stat(path)
|
||||
.map_err(|e| anyhow!("stat error: {}", e))?;
|
||||
let size = stat.size;
|
||||
let filename = path.file_name().unwrap().to_string_lossy();
|
||||
@@ -146,13 +155,16 @@ impl ScpHandler {
|
||||
}
|
||||
|
||||
let flags = OpenFlags::new().read();
|
||||
let mut file = self.vfs.open_file(path, &flags)
|
||||
let mut file = self
|
||||
.vfs
|
||||
.open_file(path, &flags)
|
||||
.map_err(|e| anyhow!("open error: {}", e))?;
|
||||
|
||||
let mut buffer = vec![0u8; 8192];
|
||||
|
||||
loop {
|
||||
let n = file.read(&mut buffer)
|
||||
let n = file
|
||||
.read(&mut buffer)
|
||||
.map_err(|e| anyhow!("read error: {}", e))?;
|
||||
if n == 0 {
|
||||
break;
|
||||
@@ -188,7 +200,9 @@ impl ScpHandler {
|
||||
return Err(anyhow!("SCP directory command rejected"));
|
||||
}
|
||||
|
||||
let entries = self.vfs.read_dir(path)
|
||||
let entries = self
|
||||
.vfs
|
||||
.read_dir(path)
|
||||
.map_err(|e| anyhow!("read_dir error: {}", e))?;
|
||||
|
||||
for entry in &entries {
|
||||
@@ -218,7 +232,7 @@ impl ScpHandler {
|
||||
/// 处理文件命令(C0644 size filename)
|
||||
fn handle_file_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
|
||||
|
||||
if parts.len() != 3 {
|
||||
return self.send_error(channel, "Invalid file command format");
|
||||
}
|
||||
@@ -227,7 +241,10 @@ impl ScpHandler {
|
||||
let size: u64 = parts[1].parse()?;
|
||||
let filename = parts[2];
|
||||
|
||||
debug!("SCP receive file: mode={}, size={}, name={}", mode_str, size, filename);
|
||||
debug!(
|
||||
"SCP receive file: mode={}, size={}, name={}",
|
||||
mode_str, size, filename
|
||||
);
|
||||
|
||||
if size > 1024 * 1024 * 1024 {
|
||||
return self.send_error(channel, "File too large (max 1GB)");
|
||||
@@ -236,7 +253,9 @@ impl ScpHandler {
|
||||
let full_path = self.resolve_path(filename)?;
|
||||
|
||||
let flags = OpenFlags::new().write().create().truncate();
|
||||
let mut file = self.vfs.open_file(&full_path, &flags)
|
||||
let mut file = self
|
||||
.vfs
|
||||
.open_file(&full_path, &flags)
|
||||
.map_err(|e| anyhow!("open error: {}", e))?;
|
||||
|
||||
channel.write_all(&[0])?;
|
||||
@@ -263,7 +282,8 @@ impl ScpHandler {
|
||||
if mode_int != 0 {
|
||||
let mut set_stat = VfsStat::new();
|
||||
set_stat.mode = mode_int;
|
||||
self.vfs.set_stat(&full_path, &set_stat)
|
||||
self.vfs
|
||||
.set_stat(&full_path, &set_stat)
|
||||
.map_err(|e| anyhow!("set_stat error: {}", e))?;
|
||||
}
|
||||
|
||||
@@ -280,7 +300,7 @@ impl ScpHandler {
|
||||
/// 处理目录命令(D0755 0 dirname)
|
||||
fn handle_directory_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
|
||||
|
||||
if parts.len() != 3 {
|
||||
return self.send_error(channel, "Invalid directory command format");
|
||||
}
|
||||
@@ -297,7 +317,8 @@ impl ScpHandler {
|
||||
let full_path = self.resolve_path(dirname)?;
|
||||
|
||||
let mode_int: u32 = mode_str.parse()?;
|
||||
self.vfs.create_dir_all(&full_path, mode_int)
|
||||
self.vfs
|
||||
.create_dir_all(&full_path, mode_int)
|
||||
.map_err(|e| anyhow!("create_dir_all error: {}", e))?;
|
||||
|
||||
channel.write_all(&[0])?;
|
||||
@@ -326,7 +347,7 @@ impl ScpHandler {
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
|
||||
|
||||
if parts.len() != 3 {
|
||||
return self.send_error(channel, "Invalid time command format");
|
||||
}
|
||||
@@ -353,11 +374,15 @@ impl ScpHandler {
|
||||
/// 路径解析(安全性检查)
|
||||
fn resolve_path(&self, path: &str) -> Result<PathBuf> {
|
||||
let full_path = self.root_dir.join(path);
|
||||
|
||||
let canonical_path = self.vfs.real_path(&full_path)
|
||||
|
||||
let canonical_path = self
|
||||
.vfs
|
||||
.real_path(&full_path)
|
||||
.map_err(|e| anyhow!("Path resolution error: {}", e))?;
|
||||
|
||||
let root_canonical = self.vfs.real_path(&self.root_dir)
|
||||
let root_canonical = self
|
||||
.vfs
|
||||
.real_path(&self.root_dir)
|
||||
.map_err(|e| anyhow!("Root path resolution error: {}", e))?;
|
||||
|
||||
if !canonical_path.starts_with(&root_canonical) {
|
||||
@@ -383,20 +408,23 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_scp_command_parse() {
|
||||
let handler = ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
let handler =
|
||||
ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
assert_eq!(handler.mode, ScpMode::Destination);
|
||||
assert_eq!(handler.root_dir, PathBuf::from("/tmp"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scp_recursive_parse() {
|
||||
let handler = ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
let handler =
|
||||
ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
assert!(handler.recursive);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scp_source_parse() {
|
||||
let handler = ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
let handler =
|
||||
ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
assert_eq!(handler.mode, ScpMode::Source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,32 +1,32 @@
|
||||
// SSH服务器完整实现(Phase 1-7集成版 + Phase 13端口转发)
|
||||
// 参考OpenSSH sshd.c: complete SSH/SFTP flow + port forwarding
|
||||
|
||||
use crate::ssh_server::version::VersionExchange;
|
||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
||||
use crate::ssh_server::kex::{KexResult, KexProposal};
|
||||
use crate::ssh_server::kex_complete::{KexState};
|
||||
use crate::ssh_server::auth::{AuthHandler, AuthResult};
|
||||
use crate::provider::sqlite::SqliteProvider;
|
||||
use crate::provider::pg::PgProvider;
|
||||
use crate::provider::sqlite::SqliteProvider;
|
||||
use crate::provider::DataProvider;
|
||||
use crate::ssh_server::channel::{ChannelManager};
|
||||
use crate::ssh_server::cipher::{EncryptionContext, EncryptedPacket};
|
||||
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1
|
||||
use crate::ssh_server::port_forward::PortForwardManager; // Phase 13
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, warn, error, debug};
|
||||
use crate::ssh_server::auth::{AuthHandler, AuthResult};
|
||||
use crate::ssh_server::channel::ChannelManager;
|
||||
use crate::ssh_server::cipher::{EncryptedPacket, EncryptionContext};
|
||||
use crate::ssh_server::kex::{KexProposal, KexResult};
|
||||
use crate::ssh_server::kex_complete::KexState;
|
||||
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||
use crate::ssh_server::port_forward::PortForwardManager; // Phase 13
|
||||
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1
|
||||
use crate::ssh_server::version::VersionExchange;
|
||||
use anyhow::{anyhow, Result};
|
||||
use log::{error, info, warn};
|
||||
use std::io::{Read, Write};
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::path::PathBuf;
|
||||
use std::thread;
|
||||
use std::io::{Read, Write};
|
||||
use std::sync::{Arc, Mutex}; // Phase 13: 端口转发线程同步
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread; // Phase 13: 端口转发线程同步
|
||||
|
||||
/// SSH服务器配置(Phase 13.1企业级安全配置)
|
||||
pub struct SshServerConfig {
|
||||
pub port: u16,
|
||||
pub bind_address: String,
|
||||
pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置
|
||||
pub pg_conn: Option<String>, // PostgreSQL连接字符串(SFTPGo兼容认证)
|
||||
pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置
|
||||
pub pg_conn: Option<String>, // PostgreSQL连接字符串(SFTPGo兼容认证)
|
||||
}
|
||||
|
||||
impl Default for SshServerConfig {
|
||||
@@ -34,7 +34,7 @@ impl Default for SshServerConfig {
|
||||
Self {
|
||||
port: 2024,
|
||||
bind_address: "127.0.0.1".to_string(),
|
||||
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1
|
||||
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1
|
||||
pg_conn: None,
|
||||
}
|
||||
}
|
||||
@@ -56,43 +56,48 @@ impl SshServerConfig {
|
||||
/// SSH服务器主结构(Phase 1-13完整版)
|
||||
pub struct SshServer {
|
||||
config: SshServerConfig,
|
||||
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1: 共享安全配置
|
||||
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1: 共享安全配置
|
||||
}
|
||||
|
||||
impl SshServer {
|
||||
pub fn new(config: SshServerConfig) -> Self {
|
||||
let security_config = Arc::new(Mutex::new(config.security_config.clone())); // Phase 13.1: 先clone
|
||||
let security_config = Arc::new(Mutex::new(config.security_config.clone())); // Phase 13.1: 先clone
|
||||
Self {
|
||||
config,
|
||||
security_config, // Phase 13.1
|
||||
security_config, // Phase 13.1
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn run(&self) -> Result<()> {
|
||||
let bind_addr = format!("{}:{}", self.config.bind_address, self.config.port);
|
||||
let listener = TcpListener::bind(&bind_addr)?;
|
||||
|
||||
|
||||
info!("MarkBaseSSH server listening on {}", bind_addr);
|
||||
info!("Implementation: Complete SSH/SFTP + Port Forwarding (Phase 1-13)");
|
||||
info!("Security config: GatewayPorts={}, PermitOpen={:?}, MaxSessions={}",
|
||||
info!(
|
||||
"Security config: GatewayPorts={}, PermitOpen={:?}, MaxSessions={}",
|
||||
self.config.security_config.gateway_ports,
|
||||
self.config.security_config.permit_open,
|
||||
self.config.security_config.max_sessions);
|
||||
|
||||
let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置
|
||||
self.config.security_config.max_sessions
|
||||
);
|
||||
|
||||
let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置
|
||||
let pg_conn = self.config.pg_conn.clone();
|
||||
|
||||
|
||||
for stream in listener.incoming() {
|
||||
match stream {
|
||||
Ok(stream) => {
|
||||
let client_addr = stream.peer_addr()?;
|
||||
info!("New SSH connection from {}", client_addr);
|
||||
|
||||
let security_config_clone = security_config.clone(); // Phase 13.1
|
||||
|
||||
let security_config_clone = security_config.clone(); // Phase 13.1
|
||||
let pg_conn_clone = pg_conn.clone();
|
||||
|
||||
|
||||
thread::spawn(move || {
|
||||
if let Err(e) = handle_connection_complete(stream, security_config_clone, pg_conn_clone) { // Phase 13.1
|
||||
if let Err(e) =
|
||||
handle_connection_complete(stream, security_config_clone, pg_conn_clone)
|
||||
{
|
||||
// Phase 13.1
|
||||
error!("Connection error: {}", e);
|
||||
}
|
||||
});
|
||||
@@ -102,90 +107,127 @@ impl SshServer {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// 处理完整SSH连接(Phase 1-13完整流程)
|
||||
fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshSecurityConfig>>, pg_conn: Option<String>) -> Result<()> {
|
||||
fn handle_connection_complete(
|
||||
stream: TcpStream,
|
||||
security_config: Arc<Mutex<SshSecurityConfig>>,
|
||||
pg_conn: Option<String>,
|
||||
) -> Result<()> {
|
||||
info!("Handling client connection (Phase 1-13 complete flow with port forwarding)");
|
||||
|
||||
|
||||
// Phase 13.1: 增加活动会话数
|
||||
{
|
||||
let mut security = security_config.lock().unwrap();
|
||||
security.increment_sessions()?;
|
||||
}
|
||||
|
||||
|
||||
let mut stream = stream;
|
||||
|
||||
|
||||
// Phase 1: 版本交换
|
||||
let client_version = VersionExchange::exchange(&mut stream)?;
|
||||
info!("Version exchange: client={}, server=SSH-2.0-MarkBaseSSH_1.0", client_version);
|
||||
|
||||
info!(
|
||||
"Version exchange: client={}, server=SSH-2.0-MarkBaseSSH_1.0",
|
||||
client_version
|
||||
);
|
||||
|
||||
// Phase 2: 箋法协商
|
||||
let (kex_result, server_kexinit, client_kexinit) = perform_kex_negotiation_complete(&mut stream)?;
|
||||
info!("KEX negotiation: KEX={}, Cipher={}", kex_result.kex_algorithm, kex_result.encryption_ctos);
|
||||
|
||||
let (kex_result, server_kexinit, client_kexinit) =
|
||||
perform_kex_negotiation_complete(&mut stream)?;
|
||||
info!(
|
||||
"KEX negotiation: KEX={}, Cipher={}",
|
||||
kex_result.kex_algorithm, kex_result.encryption_ctos
|
||||
);
|
||||
|
||||
// Phase 3: 密钥交换完整流程
|
||||
let mut encryption_ctx = perform_complete_kex_exchange(&mut stream, client_version.clone(), kex_result, server_kexinit, client_kexinit)?;
|
||||
let mut encryption_ctx = perform_complete_kex_exchange(
|
||||
&mut stream,
|
||||
client_version.clone(),
|
||||
kex_result,
|
||||
server_kexinit,
|
||||
client_kexinit,
|
||||
)?;
|
||||
info!("Key exchange completed, encryption channel ready");
|
||||
|
||||
|
||||
// Phase 5: SSH认证(SFTPGo兼容 — PostgreSQL或SQLite)
|
||||
let provider: Box<dyn DataProvider> = if let Some(ref conn_str) = pg_conn {
|
||||
info!("Using PostgreSQL auth provider (SFTPGo-compatible): {}", conn_str);
|
||||
Box::new(PgProvider::new(conn_str)
|
||||
.map_err(|e| anyhow!("Failed to init PgProvider: {}", e))?)
|
||||
info!(
|
||||
"Using PostgreSQL auth provider (SFTPGo-compatible): {}",
|
||||
conn_str
|
||||
);
|
||||
Box::new(
|
||||
PgProvider::new(conn_str).map_err(|e| anyhow!("Failed to init PgProvider: {}", e))?,
|
||||
)
|
||||
} else {
|
||||
info!("Using SQLite auth provider");
|
||||
Box::new(SqliteProvider::new("data/auth.sqlite")
|
||||
.map_err(|e| anyhow!("Failed to init SqliteProvider: {}", e))?)
|
||||
Box::new(
|
||||
SqliteProvider::new("data/auth.sqlite")
|
||||
.map_err(|e| anyhow!("Failed to init SqliteProvider: {}", e))?,
|
||||
)
|
||||
};
|
||||
let mut auth_handler = AuthHandler::new(provider);
|
||||
let auth_user = perform_ssh_auth(&mut stream, &mut auth_handler, &mut encryption_ctx)?;
|
||||
info!("SSH authentication succeeded: user={}", auth_user.username);
|
||||
|
||||
|
||||
// Phase 6: SSH Channel管理(参考OpenSSH channel.c)
|
||||
let mut channel_manager = ChannelManager::new(auth_user.home_dir.clone());
|
||||
|
||||
|
||||
// Phase 13: PortForwardManager初始化
|
||||
let mut port_forward_manager = PortForwardManager::new();
|
||||
|
||||
|
||||
// Phase 6-13: SSH服务循环(处理channel请求 + 端口转发)
|
||||
let security_config_clone = security_config.clone(); // Phase 13.1: clone for service loop
|
||||
handle_ssh_service_loop(&mut stream, &mut channel_manager, &mut encryption_ctx, &mut port_forward_manager, security_config_clone)?;
|
||||
|
||||
let security_config_clone = security_config.clone(); // Phase 13.1: clone for service loop
|
||||
handle_ssh_service_loop(
|
||||
&mut stream,
|
||||
&mut channel_manager,
|
||||
&mut encryption_ctx,
|
||||
&mut port_forward_manager,
|
||||
security_config_clone,
|
||||
)?;
|
||||
|
||||
info!("SSH session completed successfully");
|
||||
|
||||
|
||||
// Phase 13.1: 减少活动会话数
|
||||
{
|
||||
let mut security = security_config.lock().unwrap();
|
||||
security.decrement_sessions();
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 完整算法协商(返回KEXINIT payloads)
|
||||
fn perform_kex_negotiation_complete(stream: &mut TcpStream) -> Result<(KexResult, SshPacket, SshPacket)> {
|
||||
fn perform_kex_negotiation_complete(
|
||||
stream: &mut TcpStream,
|
||||
) -> Result<(KexResult, SshPacket, SshPacket)> {
|
||||
info!("Starting complete KEX negotiation");
|
||||
|
||||
|
||||
// 1. 发送服务器KEXINIT
|
||||
let server_proposal = KexProposal::server_default();
|
||||
let server_kexinit = server_proposal.to_kexinit_packet()?;
|
||||
server_kexinit.write(stream)?;
|
||||
|
||||
info!("Sent server KEXINIT (payload size: {} bytes)", server_kexinit.payload.len());
|
||||
|
||||
|
||||
info!(
|
||||
"Sent server KEXINIT (payload size: {} bytes)",
|
||||
server_kexinit.payload.len()
|
||||
);
|
||||
|
||||
// 2. 接收客户端KEXINIT
|
||||
let client_kexinit = SshPacket::read(stream)?;
|
||||
let client_proposal = KexProposal::from_kexinit_packet(&client_kexinit)?;
|
||||
|
||||
info!("Received client KEXINIT (payload size: {} bytes)", client_kexinit.payload.len());
|
||||
|
||||
|
||||
info!(
|
||||
"Received client KEXINIT (payload size: {} bytes)",
|
||||
client_kexinit.payload.len()
|
||||
);
|
||||
|
||||
// 3. 算法匹配
|
||||
let kex_result = KexResult::choose_algorithms(&server_proposal, &client_proposal)?;
|
||||
|
||||
|
||||
Ok((kex_result, server_kexinit, client_kexinit))
|
||||
}
|
||||
|
||||
@@ -198,18 +240,18 @@ fn perform_complete_kex_exchange(
|
||||
client_kexinit: SshPacket,
|
||||
) -> Result<EncryptionContext> {
|
||||
info!("Starting complete key exchange flow");
|
||||
|
||||
|
||||
let mut kex_state = KexState::new(
|
||||
client_version,
|
||||
"SSH-2.0-MarkBaseSSH_1.0".to_string(),
|
||||
kex_result,
|
||||
)?;
|
||||
|
||||
|
||||
kex_state.save_kexinit_payloads(&client_kexinit, &server_kexinit);
|
||||
|
||||
|
||||
let kexdh_init = SshPacket::read(stream)?;
|
||||
info!("Received SSH_MSG_KEX_ECDH_INIT");
|
||||
|
||||
|
||||
let kexdh_reply = kex_state.exchange_handler.handle_kexdh_init(
|
||||
&kexdh_init,
|
||||
&kex_state.client_version,
|
||||
@@ -219,27 +261,27 @@ fn perform_complete_kex_exchange(
|
||||
)?;
|
||||
kexdh_reply.write(stream)?;
|
||||
info!("Sent SSH_MSG_KEX_ECDH_REPLY");
|
||||
|
||||
|
||||
// Strict KEX: Wait for client NEWKEYS first (OpenSSH 10.2 requirement)
|
||||
let client_newkeys = SshPacket::read(stream)?;
|
||||
kex_state.handle_newkeys(&client_newkeys)?;
|
||||
info!("Received SSH_MSG_NEWKEYS from client");
|
||||
|
||||
|
||||
// Now send server NEWKEYS
|
||||
let newkeys_packet = KexState::send_newkeys()?;
|
||||
newkeys_packet.write(stream)?;
|
||||
kex_state.newkeys_sent = true;
|
||||
info!("Sent SSH_MSG_NEWKEYS from server");
|
||||
|
||||
|
||||
if kex_state.is_encryption_ready() {
|
||||
info!("Encryption channel established successfully");
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Encryption channel not ready"));
|
||||
}
|
||||
|
||||
|
||||
let session_keys = kex_state.exchange_handler.compute_session_keys()?;
|
||||
let encryption_ctx = EncryptionContext::from_session_keys(&session_keys);
|
||||
|
||||
|
||||
Ok(encryption_ctx)
|
||||
}
|
||||
|
||||
@@ -250,102 +292,100 @@ pub struct AuthUser {
|
||||
}
|
||||
|
||||
fn perform_ssh_auth(
|
||||
stream: &mut TcpStream,
|
||||
stream: &mut TcpStream,
|
||||
auth_handler: &mut AuthHandler,
|
||||
encryption_ctx: &mut EncryptionContext,
|
||||
) -> Result<AuthUser> {
|
||||
info!("Starting SSH authentication");
|
||||
info!("Encryption context: key_ctos_len={}, key_stoc_len={}, iv_ctos_len={}, iv_stoc_len={}",
|
||||
info!(
|
||||
"Encryption context: key_ctos_len={}, key_stoc_len={}, iv_ctos_len={}, iv_stoc_len={}",
|
||||
encryption_ctx.encryption_key_ctos.len(),
|
||||
encryption_ctx.encryption_key_stoc.len(),
|
||||
encryption_ctx.iv_ctos.len(),
|
||||
encryption_ctx.iv_stoc.len()
|
||||
);
|
||||
|
||||
|
||||
// OpenSSH strict KEX: SSH_MSG_EXT_INFO may be sent before SSH_MSG_SERVICE_REQUEST
|
||||
let mut encrypted_request = EncryptedPacket::read(stream, encryption_ctx, true)?;
|
||||
let payload = encrypted_request.payload();
|
||||
|
||||
|
||||
if payload[0] == PacketType::SSH_MSG_EXT_INFO as u8 {
|
||||
info!("Received SSH_MSG_EXT_INFO, reading next packet");
|
||||
encrypted_request = EncryptedPacket::read(stream, encryption_ctx, true)?;
|
||||
}
|
||||
|
||||
|
||||
let payload = encrypted_request.payload();
|
||||
info!("Received packet type: {}", payload[0]);
|
||||
|
||||
|
||||
if payload[0] != PacketType::SSH_MSG_SERVICE_REQUEST as u8 {
|
||||
return Err(anyhow!("Expected SSH_MSG_SERVICE_REQUEST, got type {}", payload[0]));
|
||||
return Err(anyhow!(
|
||||
"Expected SSH_MSG_SERVICE_REQUEST, got type {}",
|
||||
payload[0]
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
let mut cursor = std::io::Cursor::new(&payload[1..]);
|
||||
let service_name_len = cursor.read_u32::<BigEndian>()?;
|
||||
let mut service_name = vec![0u8; service_name_len as usize];
|
||||
cursor.read_exact(&mut service_name)?;
|
||||
let service_name_str = String::from_utf8_lossy(&service_name);
|
||||
|
||||
|
||||
if service_name_str != "ssh-userauth" {
|
||||
return Err(anyhow!("Unsupported service: {}", service_name_str));
|
||||
}
|
||||
|
||||
|
||||
let mut service_accept_payload = Vec::new();
|
||||
service_accept_payload.write_u8(PacketType::SSH_MSG_SERVICE_ACCEPT as u8)?;
|
||||
service_accept_payload.write_u32::<BigEndian>(12)?; // "ssh-userauth" length is 12, not 14!
|
||||
service_accept_payload.write_u32::<BigEndian>(12)?; // "ssh-userauth" length is 12, not 14!
|
||||
service_accept_payload.write_all("ssh-userauth".as_bytes())?;
|
||||
|
||||
let encrypted_accept = EncryptedPacket::new(
|
||||
&service_accept_payload,
|
||||
encryption_ctx,
|
||||
true,
|
||||
)?;
|
||||
|
||||
let encrypted_accept = EncryptedPacket::new(&service_accept_payload, encryption_ctx, true)?;
|
||||
encrypted_accept.write(stream)?;
|
||||
info!("Sent encrypted SSH_MSG_SERVICE_ACCEPT");
|
||||
|
||||
|
||||
let session_id = encryption_ctx.session_id.clone();
|
||||
|
||||
|
||||
loop {
|
||||
let auth_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; // Reading from client, use cipher_ctos
|
||||
let auth_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; // Reading from client, use cipher_ctos
|
||||
let auth_payload = auth_packet.payload();
|
||||
info!("Received encrypted SSH_MSG_USERAUTH_REQUEST");
|
||||
|
||||
|
||||
let auth_request = SshPacket::new(auth_payload.to_vec());
|
||||
|
||||
|
||||
match auth_handler.handle_userauth_request(&auth_request, &session_id)? {
|
||||
AuthResult::Success => {
|
||||
let success_payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8];
|
||||
let encrypted_success = EncryptedPacket::new(
|
||||
&success_payload,
|
||||
encryption_ctx,
|
||||
true,
|
||||
)?;
|
||||
let encrypted_success =
|
||||
EncryptedPacket::new(&success_payload, encryption_ctx, true)?;
|
||||
encrypted_success.write(stream)?;
|
||||
info!("Sent encrypted SSH_MSG_USERAUTH_SUCCESS");
|
||||
|
||||
|
||||
// Extract username from auth request
|
||||
let user = extract_username_from_auth_request(&auth_request)
|
||||
.unwrap_or_else(|_| "unknown".to_string());
|
||||
let home_dir = auth_handler.get_home_dir(&user)
|
||||
let home_dir = auth_handler
|
||||
.get_home_dir(&user)
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from("/Users/accusys/markbase"));
|
||||
info!("Auth success: user={}, home_dir={:?}", user, home_dir);
|
||||
return Ok(AuthUser { username: user, home_dir });
|
||||
return Ok(AuthUser {
|
||||
username: user,
|
||||
home_dir,
|
||||
});
|
||||
}
|
||||
AuthResult::Failure(message) => {
|
||||
AuthResult::Failure(message) => {
|
||||
// message包含可用的认证方法列表(如"password,publickey")
|
||||
let mut failure_payload = Vec::new();
|
||||
failure_payload.write_u8(PacketType::SSH_MSG_USERAUTH_FAILURE as u8)?;
|
||||
failure_payload.write_u32::<BigEndian>(message.len() as u32)?;
|
||||
failure_payload.write_all(message.as_bytes())?;
|
||||
failure_payload.write_u8(0)?; // partial_success = false
|
||||
|
||||
let encrypted_failure = EncryptedPacket::new(
|
||||
&failure_payload,
|
||||
encryption_ctx,
|
||||
true,
|
||||
)?;
|
||||
failure_payload.write_u8(0)?; // partial_success = false
|
||||
|
||||
let encrypted_failure =
|
||||
EncryptedPacket::new(&failure_payload, encryption_ctx, true)?;
|
||||
encrypted_failure.write(stream)?;
|
||||
warn!("Sent encrypted SSH_MSG_USERAUTH_FAILURE: {}", message);
|
||||
}
|
||||
@@ -356,27 +396,23 @@ AuthResult::Failure(message) => {
|
||||
AuthResult::PublicKeyOk(algorithm, public_key_blob) => {
|
||||
// SSH_MSG_USERAUTH_PK_OK:public key acceptable
|
||||
info!("Public key acceptable, sending USERAUTH_PK_OK");
|
||||
|
||||
|
||||
let mut pk_ok_payload = Vec::new();
|
||||
pk_ok_payload.write_u8(PacketType::SSH_MSG_USERAUTH_PK_OK as u8)?;
|
||||
|
||||
|
||||
// algorithm (SSH string)
|
||||
pk_ok_payload.write_u32::<BigEndian>(algorithm.len() as u32)?;
|
||||
pk_ok_payload.write_all(algorithm.as_bytes())?;
|
||||
|
||||
|
||||
// public key blob (SSH string)
|
||||
pk_ok_payload.write_u32::<BigEndian>(public_key_blob.len() as u32)?;
|
||||
pk_ok_payload.write_all(&public_key_blob)?;
|
||||
|
||||
let encrypted_pk_ok = EncryptedPacket::new(
|
||||
&pk_ok_payload,
|
||||
encryption_ctx,
|
||||
true,
|
||||
)?;
|
||||
|
||||
let encrypted_pk_ok = EncryptedPacket::new(&pk_ok_payload, encryption_ctx, true)?;
|
||||
encrypted_pk_ok.write(stream)?;
|
||||
info!("Sent SSH_MSG_USERAUTH_PK_OK");
|
||||
|
||||
continue; // Wait for signed request
|
||||
|
||||
continue; // Wait for signed request
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -389,16 +425,17 @@ fn handle_ssh_service_loop(
|
||||
stream: &mut TcpStream,
|
||||
channel_manager: &mut ChannelManager,
|
||||
encryption_ctx: &mut EncryptionContext,
|
||||
port_forward_manager: &mut PortForwardManager, // Phase 13
|
||||
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1
|
||||
port_forward_manager: &mut PortForwardManager, // Phase 13
|
||||
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1
|
||||
) -> Result<()> {
|
||||
info!("Starting SSH service loop (Phase 14.2: unified poll + child status)");
|
||||
|
||||
|
||||
loop {
|
||||
// ⭐⭐⭐⭐⭐ Phase 14.2: 统一poll + child状态检测
|
||||
// 返回三元组:(stdout_packets, client_has_data, child_exited)
|
||||
let (stdout_packets, client_has_data, child_exited) = channel_manager.poll_exec_stdout_and_client(stream)?;
|
||||
|
||||
let (stdout_packets, client_has_data, child_exited) =
|
||||
channel_manager.poll_exec_stdout_and_client(stream)?;
|
||||
|
||||
// 1. 发送stdout/stderr数据(如果有)
|
||||
if let Some(packets) = stdout_packets {
|
||||
for packet in packets {
|
||||
@@ -407,93 +444,100 @@ fn handle_ssh_service_loop(
|
||||
info!("Sent stdout/stderr data (Phase 14.2)");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 2. 处理child exited(发送EOF + CLOSE)
|
||||
if child_exited {
|
||||
info!("Child process exited, sending SSH_MSG_CHANNEL_EOF + CLOSE");
|
||||
|
||||
|
||||
// ⭐⭐⭐⭐⭐ Phase 14.2: 使用ChannelManager.handle_child_exited()
|
||||
let exit_packets = channel_manager.handle_child_exited()?;
|
||||
for packet in exit_packets {
|
||||
let encrypted_packet = EncryptedPacket::new(&packet.payload, encryption_ctx, true)?;
|
||||
encrypted_packet.write(stream)?;
|
||||
}
|
||||
|
||||
|
||||
// 继续处理client数据(可能还有其他请求)
|
||||
}
|
||||
|
||||
|
||||
// 3. 处理client数据(如果有)
|
||||
if !client_has_data {
|
||||
// client没有数据,继续下一轮循环
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
// client有数据,读取并处理
|
||||
let encrypted_packet = EncryptedPacket::read(stream, encryption_ctx, true)?;
|
||||
let packet = SshPacket::new(encrypted_packet.payload().to_vec());
|
||||
|
||||
|
||||
match packet.payload.first() {
|
||||
// Phase 13: SSH_MSG_GLOBAL_REQUEST处理(端口转发)
|
||||
Some(&pt) if pt == PacketType::SSH_MSG_GLOBAL_REQUEST as u8 => {
|
||||
info!("Received SSH_MSG_GLOBAL_REQUEST (port forwarding)");
|
||||
|
||||
|
||||
// Phase 13.1: 安全配置验证
|
||||
let security = security_config.lock().unwrap();
|
||||
if !security.allow_tcp_forwarding {
|
||||
warn!("TCP forwarding disabled by security config");
|
||||
let failure_packet = vec![PacketType::SSH_MSG_REQUEST_FAILURE as u8];
|
||||
let encrypted_failure = EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
|
||||
let encrypted_failure =
|
||||
EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
|
||||
encrypted_failure.write(stream)?;
|
||||
info!("Sent SSH_MSG_REQUEST_FAILURE (TCP forwarding disabled)");
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
// Phase 13.2: 调用PortForwardManager处理(传递security_config)
|
||||
let (success, response) = port_forward_manager.handle_global_request(&packet.payload, &security)?;
|
||||
drop(security); // 释放锁
|
||||
|
||||
let (success, response) =
|
||||
port_forward_manager.handle_global_request(&packet.payload, &security)?;
|
||||
drop(security); // 释放锁
|
||||
|
||||
if success {
|
||||
if let Some(response_data) = response {
|
||||
let encrypted_response = EncryptedPacket::new(&response_data, encryption_ctx, true)?;
|
||||
let encrypted_response =
|
||||
EncryptedPacket::new(&response_data, encryption_ctx, true)?;
|
||||
encrypted_response.write(stream)?;
|
||||
info!("Sent SSH_MSG_REQUEST_SUCCESS (tcpip-forward accepted)");
|
||||
} else {
|
||||
// 无响应数据时,发送简单的SUCCESS
|
||||
let success_packet = vec![PacketType::SSH_MSG_REQUEST_SUCCESS as u8];
|
||||
let encrypted_success = EncryptedPacket::new(&success_packet, encryption_ctx, true)?;
|
||||
let encrypted_success =
|
||||
EncryptedPacket::new(&success_packet, encryption_ctx, true)?;
|
||||
encrypted_success.write(stream)?;
|
||||
info!("Sent SSH_MSG_REQUEST_SUCCESS");
|
||||
}
|
||||
} else {
|
||||
let failure_packet = vec![PacketType::SSH_MSG_REQUEST_FAILURE as u8];
|
||||
let encrypted_failure = EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
|
||||
let encrypted_failure =
|
||||
EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
|
||||
encrypted_failure.write(stream)?;
|
||||
info!("Sent SSH_MSG_REQUEST_FAILURE (tcpip-forward rejected)");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_OPEN as u8 => {
|
||||
info!("Received SSH_MSG_CHANNEL_OPEN");
|
||||
|
||||
|
||||
// Phase 13.3: 获取security_config并传递给handle_channel_open
|
||||
let security = security_config.lock().unwrap();
|
||||
let response = channel_manager.handle_channel_open(&packet, Some(&security))?;
|
||||
drop(security); // 释放锁
|
||||
|
||||
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
drop(security); // 释放锁
|
||||
|
||||
let encrypted_response =
|
||||
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
encrypted_response.write(stream)?;
|
||||
info!("Sent SSH_MSG_CHANNEL_OPEN_CONFIRMATION");
|
||||
}
|
||||
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_REQUEST as u8 => {
|
||||
info!("Received SSH_MSG_CHANNEL_REQUEST");
|
||||
if let Some(response) = channel_manager.handle_channel_request(&packet)? {
|
||||
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
let encrypted_response =
|
||||
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
encrypted_response.write(stream)?;
|
||||
|
||||
|
||||
// ⭐⭐⭐⭐⭐ Phase 14.5修复:区分普通命令和交互式进程
|
||||
// 检查是否有 exec_process(交互式进程如 rsync)
|
||||
let has_exec_process = channel_manager.has_exec_process();
|
||||
|
||||
|
||||
if has_exec_process {
|
||||
info!("⭐⭐⭐⭐⭐ [INTERACTIVE_PROCESS] Detected exec_process (rsync/SCP), skipping immediate EOF");
|
||||
// 对于交互式进程,只发送 SUCCESS,等待 poll 循环处理数据流
|
||||
@@ -503,23 +547,37 @@ fn handle_ssh_service_loop(
|
||||
if let Some(channel_id) = channel_manager.get_channel_with_output() {
|
||||
if let Some(output) = channel_manager.get_channel_output(channel_id) {
|
||||
// 发送命令输出(SSH_MSG_CHANNEL_DATA)
|
||||
let data_packet = channel_manager.build_channel_data(channel_id, &output)?;
|
||||
let encrypted_data = EncryptedPacket::new(&data_packet.payload, encryption_ctx, true)?;
|
||||
let data_packet =
|
||||
channel_manager.build_channel_data(channel_id, &output)?;
|
||||
let encrypted_data = EncryptedPacket::new(
|
||||
&data_packet.payload,
|
||||
encryption_ctx,
|
||||
true,
|
||||
)?;
|
||||
encrypted_data.write(stream)?;
|
||||
info!("Sent command output ({} bytes)", output.len());
|
||||
|
||||
|
||||
// 发送SSH_MSG_CHANNEL_EOF
|
||||
let eof_packet = channel_manager.build_channel_eof(channel_id)?;
|
||||
let encrypted_eof = EncryptedPacket::new(&eof_packet.payload, encryption_ctx, true)?;
|
||||
let encrypted_eof = EncryptedPacket::new(
|
||||
&eof_packet.payload,
|
||||
encryption_ctx,
|
||||
true,
|
||||
)?;
|
||||
encrypted_eof.write(stream)?;
|
||||
info!("Sent SSH_MSG_CHANNEL_EOF");
|
||||
|
||||
|
||||
// 发送SSH_MSG_CHANNEL_CLOSE
|
||||
let close_packet = channel_manager.build_channel_close(channel_id)?;
|
||||
let encrypted_close = EncryptedPacket::new(&close_packet.payload, encryption_ctx, true)?;
|
||||
let close_packet =
|
||||
channel_manager.build_channel_close(channel_id)?;
|
||||
let encrypted_close = EncryptedPacket::new(
|
||||
&close_packet.payload,
|
||||
encryption_ctx,
|
||||
true,
|
||||
)?;
|
||||
encrypted_close.write(stream)?;
|
||||
info!("Sent SSH_MSG_CHANNEL_CLOSE");
|
||||
|
||||
|
||||
// 移除channel
|
||||
channel_manager.remove_channel(channel_id);
|
||||
}
|
||||
@@ -531,22 +589,28 @@ fn handle_ssh_service_loop(
|
||||
info!("Received SSH_MSG_CHANNEL_DATA");
|
||||
if let Some(response) = channel_manager.handle_channel_data(&packet)? {
|
||||
// Phase 7: SFTP响应通过CHANNEL_DATA返回
|
||||
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
let encrypted_response =
|
||||
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
encrypted_response.write(stream)?;
|
||||
info!("Sent SSH_MSG_CHANNEL_DATA (SFTP response)");
|
||||
}
|
||||
|
||||
|
||||
// ⭐⭐⭐⭐⭐ Phase 15.1: Drain pending packets (e.g. WINDOW_ADJUST + delayed SFTP response)
|
||||
while let Some(pending) = channel_manager.pending_packets.pop_front() {
|
||||
let encrypted_pending = EncryptedPacket::new(&pending.payload, encryption_ctx, true)?;
|
||||
let encrypted_pending =
|
||||
EncryptedPacket::new(&pending.payload, encryption_ctx, true)?;
|
||||
encrypted_pending.write(stream)?;
|
||||
info!("Sent pending packet (type {})", pending.payload.first().unwrap_or(&0));
|
||||
info!(
|
||||
"Sent pending packet (type {})",
|
||||
pending.payload.first().unwrap_or(&0)
|
||||
);
|
||||
}
|
||||
}
|
||||
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_CLOSE as u8 => {
|
||||
info!("Received SSH_MSG_CHANNEL_CLOSE");
|
||||
if let Some(response) = channel_manager.handle_channel_close(&packet)? {
|
||||
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
let encrypted_response =
|
||||
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
encrypted_response.write(stream)?;
|
||||
}
|
||||
break;
|
||||
@@ -565,8 +629,10 @@ fn handle_ssh_service_loop(
|
||||
let payload = &packet.payload;
|
||||
if payload.len() >= 9 {
|
||||
// Format: uint32 recipient_channel || uint32 bytes_to_add
|
||||
let recipient_channel = u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]);
|
||||
let bytes_to_add = u32::from_be_bytes([payload[5], payload[6], payload[7], payload[8]]);
|
||||
let recipient_channel =
|
||||
u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]);
|
||||
let bytes_to_add =
|
||||
u32::from_be_bytes([payload[5], payload[6], payload[7], payload[8]]);
|
||||
channel_manager.adjust_remote_window(recipient_channel, bytes_to_add);
|
||||
}
|
||||
}
|
||||
@@ -575,12 +641,14 @@ fn handle_ssh_service_loop(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 从SSH_MSG_USERAUTH_REQUEST payload中提取用户名
|
||||
fn extract_username_from_auth_request(packet: &crate::ssh_server::packet::SshPacket) -> Result<String> {
|
||||
fn extract_username_from_auth_request(
|
||||
packet: &crate::ssh_server::packet::SshPacket,
|
||||
) -> Result<String> {
|
||||
let payload = &packet.payload;
|
||||
if payload.len() < 5 {
|
||||
return Err(anyhow!("Auth request too short"));
|
||||
@@ -598,10 +666,10 @@ pub fn run_ssh_server(port: Option<u16>, pg_conn: Option<&str>) -> Result<()> {
|
||||
let config = SshServerConfig {
|
||||
port: port.unwrap_or(2024),
|
||||
bind_address: "127.0.0.1".to_string(),
|
||||
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置
|
||||
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置
|
||||
pg_conn: pg_conn.map(|s| s.to_string()),
|
||||
};
|
||||
|
||||
|
||||
let server = SshServer::new(config);
|
||||
server.run()
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
||||
// SSH企业级安全配置(Phase 13.1)
|
||||
// 参考OpenSSH sshd_config安全配置
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{anyhow, Result};
|
||||
use log::{info, warn};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
@@ -14,25 +14,25 @@ pub struct SshSecurityConfig {
|
||||
/// false: 只绑定127.0.0.1(安全)
|
||||
/// true: 允许绑定0.0.0.0(危险)
|
||||
pub gateway_ports: bool,
|
||||
|
||||
|
||||
/// PermitOpen白名单
|
||||
/// ["localhost:3000", "localhost:4000", "localhost:*"]
|
||||
/// 空数组表示允许所有目标(不安全)
|
||||
pub permit_open: Vec<String>,
|
||||
|
||||
|
||||
/// AllowTcpForwarding配置
|
||||
/// true: 允许TCP转发
|
||||
/// false: 禁止所有TCP转发
|
||||
pub allow_tcp_forwarding: bool,
|
||||
|
||||
|
||||
/// MaxSessions限制
|
||||
/// 最大会话数,防止资源耗尽
|
||||
pub max_sessions: u32,
|
||||
|
||||
|
||||
/// ConnectTimeout超时(秒)
|
||||
/// 连接超时设置,防止悬挂连接
|
||||
pub connect_timeout: u64,
|
||||
|
||||
|
||||
/// 活动会话数(运行时状态)
|
||||
pub active_sessions: u32,
|
||||
}
|
||||
@@ -42,110 +42,125 @@ impl SshSecurityConfig {
|
||||
/// 参考:OpenSSH企业级生产环境配置
|
||||
pub fn enterprise_default() -> Self {
|
||||
Self {
|
||||
gateway_ports: false, // 安全:只绑定127.0.0.1
|
||||
permit_open: vec!["localhost:*".to_string()], // 限制转发目标(白名单)
|
||||
allow_tcp_forwarding: true, // 允许TCP转发
|
||||
max_sessions: 10, // 最多10个会话
|
||||
connect_timeout: 30, // 30秒超时
|
||||
active_sessions: 0, // 运行时状态
|
||||
gateway_ports: false, // 安全:只绑定127.0.0.1
|
||||
permit_open: vec!["localhost:*".to_string()], // 限制转发目标(白名单)
|
||||
allow_tcp_forwarding: true, // 允许TCP转发
|
||||
max_sessions: 10, // 最多10个会话
|
||||
connect_timeout: 30, // 30秒超时
|
||||
active_sessions: 0, // 运行时状态
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 开发环境默认配置(宽松)
|
||||
pub fn development_default() -> Self {
|
||||
Self {
|
||||
gateway_ports: true, // 开发:允许0.0.0.0
|
||||
permit_open: vec![], // 开发:允许所有目标
|
||||
gateway_ports: true, // 开发:允许0.0.0.0
|
||||
permit_open: vec![], // 开发:允许所有目标
|
||||
allow_tcp_forwarding: true,
|
||||
max_sessions: 20, // 开发:更多会话
|
||||
connect_timeout: 60, // 开发:更长超时
|
||||
max_sessions: 20, // 开发:更多会话
|
||||
connect_timeout: 60, // 开发:更长超时
|
||||
active_sessions: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 从JSON配置文件加载
|
||||
pub fn load_from_file(path: &str) -> Result<Self> {
|
||||
if !Path::new(path).exists() {
|
||||
info!("SSH security config file not found, using enterprise default");
|
||||
return Ok(Self::enterprise_default());
|
||||
}
|
||||
|
||||
|
||||
let config_str = fs::read_to_string(path)?;
|
||||
let config: serde_json::Value = serde_json::from_str(&config_str)?;
|
||||
|
||||
let security = config.get("ssh_server")
|
||||
|
||||
let security = config
|
||||
.get("ssh_server")
|
||||
.and_then(|s| s.get("security"))
|
||||
.ok_or_else(|| anyhow!("Invalid config structure"))?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
gateway_ports: security.get("gateway_ports")
|
||||
gateway_ports: security
|
||||
.get("gateway_ports")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false),
|
||||
permit_open: security.get("permit_open")
|
||||
permit_open: security
|
||||
.get("permit_open")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_else(|| vec!["localhost:*".to_string()]),
|
||||
allow_tcp_forwarding: security.get("allow_tcp_forwarding")
|
||||
allow_tcp_forwarding: security
|
||||
.get("allow_tcp_forwarding")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(true),
|
||||
max_sessions: security.get("max_sessions")
|
||||
max_sessions: security
|
||||
.get("max_sessions")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|v| v as u32)
|
||||
.unwrap_or(10),
|
||||
connect_timeout: security.get("connect_timeout")
|
||||
connect_timeout: security
|
||||
.get("connect_timeout")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(30),
|
||||
active_sessions: 0,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// 验证tcpip-forward请求(安全检查)
|
||||
/// 参考OpenSSH auth2.c: ssh_forwarding_check()
|
||||
pub fn validate_tcpip_forward_request(
|
||||
&self,
|
||||
bind_address: &str,
|
||||
bind_port: u32,
|
||||
) -> Result<()> {
|
||||
info!("Validating tcpip-forward request: bind_address={}, bind_port={}", bind_address, bind_port);
|
||||
|
||||
pub fn validate_tcpip_forward_request(&self, bind_address: &str, bind_port: u32) -> Result<()> {
|
||||
info!(
|
||||
"Validating tcpip-forward request: bind_address={}, bind_port={}",
|
||||
bind_address, bind_port
|
||||
);
|
||||
|
||||
// 1. AllowTcpForwarding检查
|
||||
if !self.allow_tcp_forwarding {
|
||||
warn!("TCP forwarding disabled by security config");
|
||||
return Err(anyhow!("TCP forwarding disabled by AllowTcpForwarding=no"));
|
||||
}
|
||||
|
||||
|
||||
// 2. GatewayPorts检查
|
||||
if !self.gateway_ports {
|
||||
// 只允许绑定到127.0.0.1或localhost
|
||||
if bind_address != "127.0.0.1" && bind_address != "localhost" && bind_address != "" {
|
||||
warn!("GatewayPorts disabled, bind_address {} not allowed", bind_address);
|
||||
if bind_address != "127.0.0.1" && bind_address != "localhost" && !bind_address.is_empty() {
|
||||
warn!(
|
||||
"GatewayPorts disabled, bind_address {} not allowed",
|
||||
bind_address
|
||||
);
|
||||
return Err(anyhow!("GatewayPorts=no, only 127.0.0.1 allowed"));
|
||||
}
|
||||
info!("GatewayPorts check passed: bind_address={}", bind_address);
|
||||
}
|
||||
|
||||
|
||||
// 3. MaxSessions检查
|
||||
if self.active_sessions >= self.max_sessions {
|
||||
warn!("Max sessions limit reached: {} >= {}", self.active_sessions, self.max_sessions);
|
||||
warn!(
|
||||
"Max sessions limit reached: {} >= {}",
|
||||
self.active_sessions, self.max_sessions
|
||||
);
|
||||
return Err(anyhow!("Max sessions limit reached: {}", self.max_sessions));
|
||||
}
|
||||
|
||||
|
||||
// 4. 特权端口检查(防止<1024)
|
||||
if bind_port < 1024 {
|
||||
warn!("Cannot bind to privileged port: {}", bind_port);
|
||||
return Err(anyhow!("Cannot bind to privileged port < 1024"));
|
||||
}
|
||||
|
||||
|
||||
// 5. 端口范围检查(防止过大端口)
|
||||
if bind_port > 65535 {
|
||||
warn!("Invalid port number: {}", bind_port);
|
||||
return Err(anyhow!("Invalid port number > 65535"));
|
||||
}
|
||||
|
||||
|
||||
info!("tcpip-forward request validated successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 验证direct-tcpip channel请求(安全检查)
|
||||
/// 参考OpenSSH channels.c: channel_connect_direct_tcpip()
|
||||
pub fn validate_direct_tcpip_channel(
|
||||
@@ -153,14 +168,17 @@ impl SshSecurityConfig {
|
||||
host_to_connect: &str,
|
||||
port_to_connect: u32,
|
||||
) -> Result<()> {
|
||||
info!("Validating direct-tcpip channel: host={}, port={}", host_to_connect, port_to_connect);
|
||||
|
||||
info!(
|
||||
"Validating direct-tcpip channel: host={}, port={}",
|
||||
host_to_connect, port_to_connect
|
||||
);
|
||||
|
||||
// 1. AllowTcpForwarding检查
|
||||
if !self.allow_tcp_forwarding {
|
||||
warn!("TCP forwarding disabled by security config");
|
||||
return Err(anyhow!("TCP forwarding disabled by AllowTcpForwarding=no"));
|
||||
}
|
||||
|
||||
|
||||
// 2. PermitOpen白名单检查
|
||||
if !self.permit_open.is_empty() {
|
||||
let target = format!("{}:{}", host_to_connect, port_to_connect);
|
||||
@@ -173,28 +191,34 @@ impl SshSecurityConfig {
|
||||
target == *pattern
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
if !allowed {
|
||||
warn!("Target {}:{} not in PermitOpen whitelist", host_to_connect, port_to_connect);
|
||||
return Err(anyhow!("Target {}:{} not in PermitOpen whitelist",
|
||||
host_to_connect, port_to_connect));
|
||||
warn!(
|
||||
"Target {}:{} not in PermitOpen whitelist",
|
||||
host_to_connect, port_to_connect
|
||||
);
|
||||
return Err(anyhow!(
|
||||
"Target {}:{} not in PermitOpen whitelist",
|
||||
host_to_connect,
|
||||
port_to_connect
|
||||
));
|
||||
}
|
||||
info!("PermitOpen check passed: target={}", target);
|
||||
} else {
|
||||
// permit_open为空,允许所有目标(不安全,仅用于开发)
|
||||
info!("PermitOpen whitelist empty, allowing all targets (development mode)");
|
||||
}
|
||||
|
||||
|
||||
// 3. 端口范围检查
|
||||
if port_to_connect < 1 || port_to_connect > 65535 {
|
||||
if !(1..=65535).contains(&port_to_connect) {
|
||||
warn!("Invalid port number: {}", port_to_connect);
|
||||
return Err(anyhow!("Invalid port number: {}", port_to_connect));
|
||||
}
|
||||
|
||||
|
||||
info!("direct-tcpip channel validated successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 增加活动会话数
|
||||
pub fn increment_sessions(&mut self) -> Result<()> {
|
||||
if self.active_sessions >= self.max_sessions {
|
||||
@@ -204,7 +228,7 @@ impl SshSecurityConfig {
|
||||
info!("Active sessions: {}", self.active_sessions);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 减少活动会话数
|
||||
pub fn decrement_sessions(&mut self) {
|
||||
if self.active_sessions > 0 {
|
||||
@@ -217,56 +241,76 @@ impl SshSecurityConfig {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_enterprise_default_config() {
|
||||
let config = SshSecurityConfig::enterprise_default();
|
||||
|
||||
|
||||
assert_eq!(config.gateway_ports, false);
|
||||
assert_eq!(config.permit_open, vec!["localhost:*".to_string()]);
|
||||
assert_eq!(config.allow_tcp_forwarding, true);
|
||||
assert_eq!(config.max_sessions, 10);
|
||||
assert_eq!(config.connect_timeout, 30);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_validate_tcpip_forward_request() {
|
||||
let config = SshSecurityConfig::enterprise_default();
|
||||
|
||||
|
||||
// 正常请求应该通过
|
||||
assert!(config.validate_tcpip_forward_request("127.0.0.1", 8080).is_ok());
|
||||
assert!(config.validate_tcpip_forward_request("localhost", 8080).is_ok());
|
||||
|
||||
assert!(config
|
||||
.validate_tcpip_forward_request("127.0.0.1", 8080)
|
||||
.is_ok());
|
||||
assert!(config
|
||||
.validate_tcpip_forward_request("localhost", 8080)
|
||||
.is_ok());
|
||||
|
||||
// GatewayPorts=false时,0.0.0.0应该被拒绝
|
||||
assert!(config.validate_tcpip_forward_request("0.0.0.0", 8080).is_err());
|
||||
|
||||
assert!(config
|
||||
.validate_tcpip_forward_request("0.0.0.0", 8080)
|
||||
.is_err());
|
||||
|
||||
// 特权端口应该被拒绝
|
||||
assert!(config.validate_tcpip_forward_request("127.0.0.1", 80).is_err());
|
||||
assert!(config
|
||||
.validate_tcpip_forward_request("127.0.0.1", 80)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_validate_direct_tcpip_channel() {
|
||||
let config = SshSecurityConfig::enterprise_default();
|
||||
|
||||
|
||||
// localhost:*应该通过(通配符匹配)
|
||||
assert!(config.validate_direct_tcpip_channel("localhost", 3000).is_ok());
|
||||
assert!(config.validate_direct_tcpip_channel("localhost", 4000).is_ok());
|
||||
|
||||
assert!(config
|
||||
.validate_direct_tcpip_channel("localhost", 3000)
|
||||
.is_ok());
|
||||
assert!(config
|
||||
.validate_direct_tcpip_channel("localhost", 4000)
|
||||
.is_ok());
|
||||
|
||||
// 其他host应该被拒绝
|
||||
assert!(config.validate_direct_tcpip_channel("192.168.1.100", 3000).is_err());
|
||||
assert!(config.validate_direct_tcpip_channel("example.com", 80).is_err());
|
||||
assert!(config
|
||||
.validate_direct_tcpip_channel("192.168.1.100", 3000)
|
||||
.is_err());
|
||||
assert!(config
|
||||
.validate_direct_tcpip_channel("example.com", 80)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_development_default_config() {
|
||||
let config = SshSecurityConfig::development_default();
|
||||
|
||||
|
||||
assert_eq!(config.gateway_ports, true);
|
||||
assert_eq!(config.permit_open.len(), 0); // 空数组表示允许所有
|
||||
assert_eq!(config.max_sessions, 20);
|
||||
|
||||
|
||||
// 开发配置应该允许所有请求
|
||||
assert!(config.validate_tcpip_forward_request("0.0.0.0", 8080).is_ok());
|
||||
assert!(config.validate_direct_tcpip_channel("example.com", 80).is_ok());
|
||||
assert!(config
|
||||
.validate_tcpip_forward_request("0.0.0.0", 8080)
|
||||
.is_ok());
|
||||
assert!(config
|
||||
.validate_direct_tcpip_channel("example.com", 80)
|
||||
.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
// SSH Buffer 零拷贝实现(参考 OpenSSH sshbuf.c)
|
||||
// 提供高效的 buffer 管理,消除临时 buffer
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{anyhow, Result};
|
||||
use std::io::{Read, Write};
|
||||
|
||||
/// SSH Buffer(参考 OpenSSH struct sshbuf)
|
||||
///
|
||||
///
|
||||
/// OpenSSH 实现:
|
||||
/// ```c
|
||||
/// struct sshbuf {
|
||||
@@ -16,10 +16,10 @@ use std::io::{Read, Write};
|
||||
/// };
|
||||
/// ```
|
||||
pub struct SshBuf {
|
||||
data: Vec<u8>, // Data buffer (对应 OpenSSH buf->d)
|
||||
off: usize, // Offset (对应 OpenSSH buf->off)
|
||||
size: usize, // Size (对应 OpenSSH buf->size)
|
||||
max_size: usize, // Maximum size (对应 OpenSSH buf->max_size)
|
||||
data: Vec<u8>, // Data buffer (对应 OpenSSH buf->d)
|
||||
off: usize, // Offset (对应 OpenSSH buf->off)
|
||||
size: usize, // Size (对应 OpenSSH buf->size)
|
||||
max_size: usize, // Maximum size (对应 OpenSSH buf->max_size)
|
||||
}
|
||||
|
||||
impl SshBuf {
|
||||
@@ -32,7 +32,7 @@ impl SshBuf {
|
||||
max_size: 128 * 1024 * 1024, // 128MB (OpenSSH SSHBUF_SIZE_MAX)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 创建指定大小的 SSH Buffer
|
||||
pub fn with_capacity(capacity: usize) -> Self {
|
||||
Self {
|
||||
@@ -42,7 +42,7 @@ impl SshBuf {
|
||||
max_size: 128 * 1024 * 1024,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 设置最大大小
|
||||
pub fn set_max_size(&mut self, max_size: usize) -> Result<()> {
|
||||
if max_size > 128 * 1024 * 1024 {
|
||||
@@ -51,47 +51,47 @@ impl SshBuf {
|
||||
self.max_size = max_size;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 获取 buffer 长度(对应 OpenSSH sshbuf_len)
|
||||
///
|
||||
///
|
||||
/// OpenSSH: `sshbuf_len = buf->size - buf->off`
|
||||
pub fn len(&self) -> usize {
|
||||
self.size - self.off
|
||||
}
|
||||
|
||||
|
||||
/// 检查 buffer 是否为空
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
|
||||
/// 获取可用空间(对应 OpenSSH sshbuf_avail)
|
||||
///
|
||||
///
|
||||
/// OpenSSH: `sshbuf_avail = buf->max_size - buf->size`
|
||||
pub fn avail(&self) -> usize {
|
||||
self.max_size - self.size
|
||||
}
|
||||
|
||||
|
||||
/// 获取可变指针(对应 OpenSSH sshbuf_mutable_ptr)
|
||||
///
|
||||
///
|
||||
/// OpenSSH 实现:
|
||||
/// ```c
|
||||
/// u_char *sshbuf_mutable_ptr(const struct sshbuf *buf) {
|
||||
/// return buf->d + buf->off;
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
///
|
||||
/// Rust 实现:返回 `&mut [u8]` slice(零拷贝)
|
||||
pub fn mutable_ptr(&mut self) -> &mut [u8] {
|
||||
&mut self.data[self.off..self.size]
|
||||
}
|
||||
|
||||
|
||||
/// 获取不可变指针(对应 OpenSSH sshbuf_ptr)
|
||||
pub fn ptr(&self) -> &[u8] {
|
||||
&self.data[self.off..self.size]
|
||||
}
|
||||
|
||||
|
||||
/// 预分配空间(对应 OpenSSH sshbuf_reserve)
|
||||
///
|
||||
///
|
||||
/// OpenSSH 实现:
|
||||
/// ```c
|
||||
/// int sshbuf_reserve(struct sshbuf *buf, size_t len, u_char **dpp) {
|
||||
@@ -104,31 +104,31 @@ impl SshBuf {
|
||||
/// return 0;
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
///
|
||||
/// Rust 实现:返回 `&mut [u8]` slice(零拷贝,可直接 write)
|
||||
pub fn reserve(&mut self, len: usize) -> Result<&mut [u8]> {
|
||||
if len > self.avail() {
|
||||
return Err(anyhow!("no buffer space (avail={})", self.avail()));
|
||||
}
|
||||
|
||||
|
||||
// 预分配空间
|
||||
let current_size = self.size;
|
||||
let new_size = current_size + len;
|
||||
|
||||
|
||||
// 确保 Vec 有足够容量
|
||||
if new_size > self.data.len() {
|
||||
self.data.resize(new_size, 0);
|
||||
}
|
||||
|
||||
|
||||
// 更新 size
|
||||
self.size = new_size;
|
||||
|
||||
|
||||
// 返回新空间的 slice(零拷贝)
|
||||
Ok(&mut self.data[current_size..new_size])
|
||||
}
|
||||
|
||||
|
||||
/// 消费数据(对应 OpenSSH sshbuf_consume)
|
||||
///
|
||||
///
|
||||
/// OpenSSH 实现:
|
||||
/// ```c
|
||||
/// int sshbuf_consume(struct sshbuf *buf, size_t len) {
|
||||
@@ -140,29 +140,33 @@ impl SshBuf {
|
||||
/// return 0;
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
///
|
||||
/// Rust 实现:移动偏移量(零拷贝,不实际删除数据)
|
||||
pub fn consume(&mut self, len: usize) -> Result<()> {
|
||||
if len > self.len() {
|
||||
return Err(anyhow!("message incomplete (len={}, consume={})", self.len(), len));
|
||||
return Err(anyhow!(
|
||||
"message incomplete (len={}, consume={})",
|
||||
self.len(),
|
||||
len
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
self.off += len;
|
||||
|
||||
|
||||
// 如果 buffer 空,重置
|
||||
if self.off == self.size {
|
||||
self.off = 0;
|
||||
self.size = 0;
|
||||
|
||||
|
||||
// OpenSSH: pack buffer(移除已消费的数据)
|
||||
// Rust: 我们保留 Vec,但重置指针
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 从末尾消费数据(对应 OpenSSH sshbuf_consume_end)
|
||||
///
|
||||
///
|
||||
/// OpenSSH 实现:
|
||||
/// ```c
|
||||
/// int sshbuf_consume_end(struct sshbuf *buf, size_t len) {
|
||||
@@ -174,13 +178,13 @@ impl SshBuf {
|
||||
if len > self.len() {
|
||||
return Err(anyhow!("message incomplete"));
|
||||
}
|
||||
|
||||
|
||||
self.size -= len;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 直接从 fd read 到 buffer(对应 OpenSSH sshbuf_read)
|
||||
///
|
||||
///
|
||||
/// OpenSSH 实现:
|
||||
/// ```c
|
||||
/// int sshbuf_read(int fd, struct sshbuf *buf, size_t maxlen, size_t *rlen) {
|
||||
@@ -195,71 +199,75 @@ impl SshBuf {
|
||||
/// return 0;
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
///
|
||||
/// Rust 实现:零拷贝,直接 read 到 buffer
|
||||
pub fn read_from<R: Read>(&mut self, reader: &mut R, maxlen: usize) -> Result<usize> {
|
||||
// 1. reserve 空间
|
||||
let space = self.reserve(maxlen)?;
|
||||
|
||||
|
||||
// 2. 直接 read 到 buffer(零拷贝)
|
||||
let n = reader.read(space)?;
|
||||
|
||||
|
||||
// 3. 调整大小(移除未使用的空间)
|
||||
if maxlen > n {
|
||||
self.consume_end(maxlen - n)?;
|
||||
}
|
||||
|
||||
|
||||
Ok(n)
|
||||
}
|
||||
|
||||
|
||||
/// 直接从 buffer write 到 fd(对应 OpenSSH channel_handle_wfd)
|
||||
///
|
||||
///
|
||||
/// OpenSSH 实现:
|
||||
/// ```c
|
||||
/// buf = sshbuf_mutable_ptr(c->output); // 获取指针
|
||||
/// len = write(c->wfd, buf, dlen); // 直接 write
|
||||
/// sshbuf_consume(c->output, len); // 消费已写入的数据
|
||||
/// ```
|
||||
///
|
||||
///
|
||||
/// Rust 实现:零拷贝,直接 write 从 buffer
|
||||
pub fn write_to<W: Write>(&mut self, writer: &mut W) -> Result<usize> {
|
||||
if self.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
|
||||
// 1. 获取数据指针(零拷贝)
|
||||
let data = self.ptr();
|
||||
|
||||
|
||||
// 2. 直接 write(零拷贝)
|
||||
let n = writer.write(data)?;
|
||||
|
||||
|
||||
// 3. 消费已写入的数据(零拷贝,只移动偏移)
|
||||
self.consume(n)?;
|
||||
|
||||
|
||||
Ok(n)
|
||||
}
|
||||
|
||||
|
||||
/// 添加数据(对应 OpenSSH sshbuf_put)
|
||||
///
|
||||
///
|
||||
/// 用于不需要零拷贝的场景
|
||||
pub fn put(&mut self, data: &[u8]) -> Result<()> {
|
||||
let space = self.reserve(data.len())?;
|
||||
space.copy_from_slice(data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 清空 buffer
|
||||
pub fn reset(&mut self) {
|
||||
self.off = 0;
|
||||
self.size = 0;
|
||||
// OpenSSH: 保留 Vec,只重置指针
|
||||
}
|
||||
|
||||
|
||||
/// Debug: 打印 buffer 状态
|
||||
pub fn debug_info(&self) -> String {
|
||||
format!(
|
||||
"SshBuf: off={}, size={}, len={}, alloc={}, max_size={}",
|
||||
self.off, self.size, self.len(), self.data.len(), self.max_size
|
||||
self.off,
|
||||
self.size,
|
||||
self.len(),
|
||||
self.data.len(),
|
||||
self.max_size
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -274,11 +282,11 @@ impl Default for SshBuf {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Cursor;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sshbuf_basic() {
|
||||
let mut buf = SshBuf::new();
|
||||
|
||||
|
||||
// Test reserve - write into reserved space
|
||||
{
|
||||
let space = buf.reserve(10).unwrap();
|
||||
@@ -286,57 +294,57 @@ mod tests {
|
||||
space[0] = 1;
|
||||
space[1] = 2;
|
||||
} // space dropped, buf accessible
|
||||
|
||||
|
||||
// Verify buffer length after reserve
|
||||
assert_eq!(buf.len(), 10);
|
||||
let ptr = buf.mutable_ptr();
|
||||
assert_eq!(ptr[0], 1);
|
||||
assert_eq!(ptr[1], 2);
|
||||
|
||||
|
||||
// Test consume
|
||||
buf.consume(2).unwrap();
|
||||
assert_eq!(buf.len(), 8);
|
||||
assert_eq!(buf.off, 2);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sshbuf_zero_copy_read() {
|
||||
let mut buf = SshBuf::with_capacity(100);
|
||||
let mut reader = Cursor::new("hello world");
|
||||
|
||||
|
||||
// 零拷贝 read
|
||||
let n = buf.read_from(&mut reader, 20).unwrap();
|
||||
assert_eq!(n, 11); // "hello world" length
|
||||
assert_eq!(buf.len(), 11);
|
||||
|
||||
|
||||
// 检查数据
|
||||
let data = buf.ptr();
|
||||
assert_eq!(data, "hello world".as_bytes());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sshbuf_zero_copy_write() {
|
||||
let mut buf = SshBuf::new();
|
||||
buf.put("hello world".as_bytes()).unwrap();
|
||||
|
||||
|
||||
let mut writer = Vec::new();
|
||||
|
||||
|
||||
// 零拷贝 write
|
||||
let n = buf.write_to(&mut writer).unwrap();
|
||||
assert_eq!(n, 11);
|
||||
assert_eq!(buf.len(), 0); // 已消费
|
||||
|
||||
|
||||
// 检查数据
|
||||
assert_eq!(writer, "hello world".as_bytes());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sshbuf_max_size() {
|
||||
let mut buf = SshBuf::new();
|
||||
buf.set_max_size(1000).unwrap();
|
||||
|
||||
|
||||
// 尝试 reserve 超过 max_size
|
||||
let result = buf.reserve(2000);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
// 参考OpenSSH sshd.c: ssh_exchange_identification()
|
||||
|
||||
use anyhow::Result;
|
||||
use log::{debug, info};
|
||||
use std::io::{Read, Write};
|
||||
use log::{info, debug};
|
||||
|
||||
/// SSH版本字符串
|
||||
pub const SSH_VERSION: &str = "SSH-2.0-MarkBaseSSH_1.0";
|
||||
@@ -15,93 +15,96 @@ impl VersionExchange {
|
||||
/// 执行版本交换(服务器端)
|
||||
pub fn exchange<T: Read + Write>(stream: &mut T) -> Result<String> {
|
||||
info!("Starting SSH version exchange");
|
||||
|
||||
|
||||
// 1. 发送服务器版本
|
||||
Self::send_version(stream)?;
|
||||
|
||||
|
||||
// 2. 接收客户端版本
|
||||
let client_version = Self::receive_version(stream)?;
|
||||
|
||||
info!("Version exchange completed: server={}, client={}", SSH_VERSION, client_version);
|
||||
|
||||
info!(
|
||||
"Version exchange completed: server={}, client={}",
|
||||
SSH_VERSION, client_version
|
||||
);
|
||||
Ok(client_version)
|
||||
}
|
||||
|
||||
|
||||
/// 发送服务器版本(参考OpenSSH ssh_exchange_identification)
|
||||
fn send_version<T: Write>(stream: &mut T) -> Result<()> {
|
||||
let version_line = format!("{}\r\n", SSH_VERSION);
|
||||
stream.write_all(version_line.as_bytes())?;
|
||||
stream.flush()?;
|
||||
|
||||
|
||||
debug!("Sent version: {}", SSH_VERSION);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 接收客户端版本(参考OpenSSH ssh_exchange_identification)
|
||||
fn receive_version<T: Read>(stream: &mut T) -> Result<String> {
|
||||
let mut buffer = Vec::new();
|
||||
let mut byte = [0u8; 1];
|
||||
|
||||
|
||||
// 读取直到遇到'\n'(参考OpenSSH实现)
|
||||
loop {
|
||||
stream.read_exact(&mut byte)?;
|
||||
|
||||
|
||||
// OpenSSH兼容性处理:跳过前导空行和调试信息
|
||||
if buffer.is_empty() && byte[0] == '\n' as u8 {
|
||||
continue; // 跳过空行
|
||||
if buffer.is_empty() && byte[0] == b'\n' {
|
||||
continue; // 跳过空行
|
||||
}
|
||||
|
||||
|
||||
// 调试信息行(以'#'开头),跳过
|
||||
if buffer.is_empty() && byte[0] == '#' as u8 {
|
||||
if buffer.is_empty() && byte[0] == b'#' {
|
||||
// 读取整行调试信息
|
||||
while byte[0] != '\n' as u8 {
|
||||
while byte[0] != b'\n' {
|
||||
stream.read_exact(&mut byte)?;
|
||||
}
|
||||
buffer.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
buffer.push(byte[0]);
|
||||
|
||||
|
||||
// 遇到'\n'结束
|
||||
if byte[0] == '\n' as u8 {
|
||||
if byte[0] == b'\n' {
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// 缓冲区溢出保护(OpenSSH限制:255字节)
|
||||
if buffer.len() > 255 {
|
||||
return Err(anyhow::anyhow!("Version string too long"));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 解析版本字符串
|
||||
let version_line = String::from_utf8(buffer)?;
|
||||
let version = version_line.trim().trim_matches('\r');
|
||||
|
||||
|
||||
// 验证版本格式(SSH-2.0-*)
|
||||
if !version.starts_with("SSH-2.0-") {
|
||||
return Err(anyhow::anyhow!("Invalid SSH version: {}", version));
|
||||
}
|
||||
|
||||
|
||||
debug!("Received version: {}", version);
|
||||
Ok(version.to_string())
|
||||
}
|
||||
|
||||
|
||||
/// 解析客户端版本信息(兼容性检查)
|
||||
pub fn parse_client_version(version: &str) -> Result<ClientVersionInfo> {
|
||||
// 格式:SSH-protoversion-softwareversion SP comments
|
||||
let parts: Vec<&str> = version.split_whitespace().collect();
|
||||
|
||||
|
||||
let main_part = parts.first().map_or(version, |v| v);
|
||||
let dash_parts: Vec<&str> = main_part.split('-').collect();
|
||||
|
||||
|
||||
if dash_parts.len() < 3 {
|
||||
return Err(anyhow::anyhow!("Invalid version format: {}", version));
|
||||
}
|
||||
|
||||
|
||||
let proto_version = dash_parts.get(1).map_or("2.0", |v| v);
|
||||
let software_version = dash_parts.get(2).map_or("unknown", |v| v);
|
||||
let comments = parts.get(1).map(|s| s.to_string());
|
||||
|
||||
|
||||
Ok(ClientVersionInfo {
|
||||
proto_version: proto_version.to_string(),
|
||||
software_version: software_version.to_string(),
|
||||
@@ -120,12 +123,12 @@ pub struct ClientVersionInfo {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_version_format() {
|
||||
assert!(SSH_VERSION.starts_with("SSH-2.0-"));
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_parse_client_version() {
|
||||
let version = "SSH-2.0-OpenSSH_10.2";
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
// SSH Window Size管理(Phase 13.6)
|
||||
// 参考RFC 4254 Section 5.2: Window Size Adjustment
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, warn, debug};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use byteorder::{BigEndian, WriteBytesExt};
|
||||
use crate::ssh_server::packet::PacketType;
|
||||
use anyhow::{anyhow, Result};
|
||||
use byteorder::{BigEndian, WriteBytesExt};
|
||||
use log::{info, warn};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
/// Window Size管理器(Phase 13.6)
|
||||
pub struct WindowManager {
|
||||
initial_window_size: u32, // RFC 4254: 2MB默认
|
||||
initial_window_size: u32, // RFC 4254: 2MB默认
|
||||
current_window_size: Arc<Mutex<u32>>,
|
||||
max_packet_size: u32, // RFC 4254: 32KB默认
|
||||
consumed_bytes: Arc<Mutex<u32>>, // 已消耗bytes统计
|
||||
max_packet_size: u32, // RFC 4254: 32KB默认
|
||||
consumed_bytes: Arc<Mutex<u32>>, // 已消耗bytes统计
|
||||
}
|
||||
|
||||
impl WindowManager {
|
||||
@@ -25,89 +25,103 @@ impl WindowManager {
|
||||
consumed_bytes: Arc::new(Mutex::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// RFC 4254默认window size(2MB)
|
||||
pub fn rfc_default() -> Self {
|
||||
Self::new(2097152, 32768) // 2MB window, 32KB packet
|
||||
Self::new(2097152, 32768) // 2MB window, 32KB packet
|
||||
}
|
||||
|
||||
|
||||
/// 检查window size是否足够(Phase 13.6)
|
||||
pub fn check_window_available(&self, data_size: u32) -> bool {
|
||||
let window = self.current_window_size.lock().unwrap();
|
||||
let available = *window >= data_size;
|
||||
|
||||
|
||||
if !available {
|
||||
warn!("Window size insufficient: need {}, have {}", data_size, *window);
|
||||
warn!(
|
||||
"Window size insufficient: need {}, have {}",
|
||||
data_size, *window
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
available
|
||||
}
|
||||
|
||||
|
||||
/// 消耗window size(Phase 13.6:发送数据后)
|
||||
pub fn consume_window(&self, data_size: u32) -> Result<()> {
|
||||
let mut window = self.current_window_size.lock().unwrap();
|
||||
|
||||
|
||||
if *window < data_size {
|
||||
return Err(anyhow!("Window size insufficient: need {}, have {}", data_size, *window));
|
||||
return Err(anyhow!(
|
||||
"Window size insufficient: need {}, have {}",
|
||||
data_size,
|
||||
*window
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
*window -= data_size;
|
||||
|
||||
|
||||
// 统计已消耗bytes
|
||||
let mut consumed = self.consumed_bytes.lock().unwrap();
|
||||
*consumed += data_size;
|
||||
|
||||
info!("Window size consumed: {} bytes, remaining {}, total consumed {}",
|
||||
data_size, *window, *consumed);
|
||||
|
||||
|
||||
info!(
|
||||
"Window size consumed: {} bytes, remaining {}, total consumed {}",
|
||||
data_size, *window, *consumed
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 调整window size(Phase 13.6:收到SSH_MSG_CHANNEL_WINDOW_ADJUST)
|
||||
pub fn adjust_window(&self, bytes_to_add: u32) {
|
||||
let mut window = self.current_window_size.lock().unwrap();
|
||||
*window += bytes_to_add;
|
||||
|
||||
info!("Window size adjusted: added {} bytes, total {}", bytes_to_add, *window);
|
||||
|
||||
info!(
|
||||
"Window size adjusted: added {} bytes, total {}",
|
||||
bytes_to_add, *window
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
/// 构建SSH_MSG_CHANNEL_WINDOW_ADJUST packet(Phase 13.6)
|
||||
pub fn build_window_adjust_packet(channel_id: u32, bytes_to_add: u32) -> Result<Vec<u8>> {
|
||||
let mut packet = Vec::new();
|
||||
|
||||
|
||||
// Packet type: SSH_MSG_CHANNEL_WINDOW_ADJUST (type 93)
|
||||
packet.write_u8(PacketType::SSH_MSG_CHANNEL_WINDOW_ADJUST as u8)?;
|
||||
|
||||
|
||||
// Recipient channel ID
|
||||
packet.write_u32::<BigEndian>(channel_id)?;
|
||||
|
||||
|
||||
// Bytes to add
|
||||
packet.write_u32::<BigEndian>(bytes_to_add)?;
|
||||
|
||||
info!("Built SSH_MSG_CHANNEL_WINDOW_ADJUST for channel {}: +{} bytes",
|
||||
channel_id, bytes_to_add);
|
||||
|
||||
|
||||
info!(
|
||||
"Built SSH_MSG_CHANNEL_WINDOW_ADJUST for channel {}: +{} bytes",
|
||||
channel_id, bytes_to_add
|
||||
);
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
|
||||
|
||||
/// 获取当前window size(Phase 13.6)
|
||||
pub fn get_current_window(&self) -> u32 {
|
||||
*self.current_window_size.lock().unwrap()
|
||||
}
|
||||
|
||||
|
||||
/// 获取已消耗bytes(Phase 13.6)
|
||||
pub fn get_consumed_bytes(&self) -> u32 {
|
||||
*self.consumed_bytes.lock().unwrap()
|
||||
}
|
||||
|
||||
|
||||
/// 重置window size(Phase 13.6:channel重置)
|
||||
pub fn reset_window(&self) {
|
||||
let mut window = self.current_window_size.lock().unwrap();
|
||||
*window = self.initial_window_size;
|
||||
|
||||
|
||||
let mut consumed = self.consumed_bytes.lock().unwrap();
|
||||
*consumed = 0;
|
||||
|
||||
|
||||
info!("Window size reset to initial: {}", self.initial_window_size);
|
||||
}
|
||||
}
|
||||
@@ -128,63 +142,63 @@ impl ChannelLifecycle {
|
||||
close_received: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 构建SSH_MSG_CHANNEL_EOF packet(Phase 13.7)
|
||||
pub fn build_eof_packet(channel_id: u32) -> Result<Vec<u8>> {
|
||||
let mut packet = Vec::new();
|
||||
|
||||
|
||||
// Packet type: SSH_MSG_CHANNEL_EOF (type 96)
|
||||
packet.write_u8(PacketType::SSH_MSG_CHANNEL_EOF as u8)?;
|
||||
|
||||
|
||||
// Recipient channel ID
|
||||
packet.write_u32::<BigEndian>(channel_id)?;
|
||||
|
||||
|
||||
info!("Built SSH_MSG_CHANNEL_EOF for channel {}", channel_id);
|
||||
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
|
||||
|
||||
/// 构建SSH_MSG_CHANNEL_CLOSE packet(Phase 13.7)
|
||||
pub fn build_close_packet(channel_id: u32) -> Result<Vec<u8>> {
|
||||
let mut packet = Vec::new();
|
||||
|
||||
|
||||
// Packet type: SSH_MSG_CHANNEL_CLOSE (type 97)
|
||||
packet.write_u8(PacketType::SSH_MSG_CHANNEL_CLOSE as u8)?;
|
||||
|
||||
|
||||
// Recipient channel ID
|
||||
packet.write_u32::<BigEndian>(channel_id)?;
|
||||
|
||||
|
||||
info!("Built SSH_MSG_CHANNEL_CLOSE for channel {}", channel_id);
|
||||
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
|
||||
|
||||
/// 标记EOF已发送(Phase 13.7)
|
||||
pub fn mark_eof_sent(&mut self) {
|
||||
self.eof_sent = true;
|
||||
info!("Channel {} EOF marked as sent", self.channel_id);
|
||||
}
|
||||
|
||||
|
||||
/// 标记CLOSE已接收(Phase 13.7)
|
||||
pub fn mark_close_received(&mut self) {
|
||||
self.close_received = true;
|
||||
info!("Channel {} CLOSE marked as received", self.channel_id);
|
||||
}
|
||||
|
||||
|
||||
/// 检查是否可以清理channel(Phase 13.7)
|
||||
pub fn can_cleanup(&self) -> bool {
|
||||
self.eof_sent && self.close_received
|
||||
}
|
||||
|
||||
|
||||
/// 清理channel资源(Phase 13.7)
|
||||
pub fn cleanup_channel(&self) -> Result<()> {
|
||||
info!("Cleaning up channel {} resources", self.channel_id);
|
||||
|
||||
|
||||
// Phase 13.7: 实际清理逻辑需要在ChannelManager中实现
|
||||
// - 移除channel记录
|
||||
// - 关闭TCP连接
|
||||
// - 清理监听器(如果是forwarded-tcpip)
|
||||
|
||||
|
||||
info!("Channel {} cleanup completed", self.channel_id);
|
||||
Ok(())
|
||||
}
|
||||
@@ -193,42 +207,42 @@ impl ChannelLifecycle {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_window_manager_creation() {
|
||||
let manager = WindowManager::rfc_default();
|
||||
assert_eq!(manager.get_current_window(), 2097152);
|
||||
assert_eq!(manager.max_packet_size, 32768);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_window_consumption() {
|
||||
let manager = WindowManager::rfc_default();
|
||||
|
||||
|
||||
// 消耗1000 bytes
|
||||
manager.consume_window(1000).unwrap();
|
||||
assert_eq!(manager.get_current_window(), 2097152 - 1000);
|
||||
assert_eq!(manager.get_consumed_bytes(), 1000);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_window_adjustment() {
|
||||
let manager = WindowManager::rfc_default();
|
||||
|
||||
|
||||
// 消耗1000 bytes
|
||||
manager.consume_window(1000).unwrap();
|
||||
|
||||
|
||||
// 调整500 bytes
|
||||
manager.adjust_window(500);
|
||||
assert_eq!(manager.get_current_window(), 2097152 - 1000 + 500);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_build_eof_packet() {
|
||||
let packet = ChannelLifecycle::build_eof_packet(1).unwrap();
|
||||
assert_eq!(packet[0], PacketType::SSH_MSG_CHANNEL_EOF as u8);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_build_close_packet() {
|
||||
let packet = ChannelLifecycle::build_close_packet(1).unwrap();
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
use super::util;
|
||||
use super::open_flags::OpenFlags;
|
||||
use super::util;
|
||||
use super::{VfsBackend, VfsDirEntry, VfsError, VfsFile, VfsStat};
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{Read, Seek, SeekFrom, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::os::unix::fs::{MetadataExt, PermissionsExt};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// 本地文件系统实现(直接包装 std::fs,不做路径解析)
|
||||
/// 路径解析由上层(SftpHandler)负责
|
||||
pub struct LocalFs;
|
||||
|
||||
impl Default for LocalFs {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl LocalFs {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
@@ -26,7 +32,9 @@ impl VfsFile for LocalFile {
|
||||
}
|
||||
|
||||
fn write(&mut self, buf: &[u8]) -> Result<usize, VfsError> {
|
||||
self.file.write(buf).map_err(|e| VfsError::Io(e.to_string()))
|
||||
self.file
|
||||
.write(buf)
|
||||
.map_err(|e| VfsError::Io(e.to_string()))
|
||||
}
|
||||
|
||||
fn seek(&mut self, pos: SeekFrom) -> Result<u64, VfsError> {
|
||||
@@ -38,12 +46,17 @@ impl VfsFile for LocalFile {
|
||||
}
|
||||
|
||||
fn stat(&mut self) -> Result<VfsStat, VfsError> {
|
||||
let meta = self.file.metadata().map_err(|e| VfsError::Io(e.to_string()))?;
|
||||
let meta = self
|
||||
.file
|
||||
.metadata()
|
||||
.map_err(|e| VfsError::Io(e.to_string()))?;
|
||||
Ok(util::stat_from_metadata(&meta, false))
|
||||
}
|
||||
|
||||
fn set_len(&mut self, size: u64) -> Result<(), VfsError> {
|
||||
self.file.set_len(size).map_err(|e| VfsError::Io(e.to_string()))
|
||||
self.file
|
||||
.set_len(size)
|
||||
.map_err(|e| VfsError::Io(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,8 +99,7 @@ impl VfsBackend for LocalFs {
|
||||
if flags.create && !flags.exclusive {
|
||||
if let Ok(meta) = file.metadata() {
|
||||
if flags.mode != 0 && meta.permissions().mode() != flags.mode {
|
||||
fs::set_permissions(path, std::fs::Permissions::from_mode(flags.mode))
|
||||
.ok();
|
||||
fs::set_permissions(path, std::fs::Permissions::from_mode(flags.mode)).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -157,10 +169,12 @@ impl VfsBackend for LocalFs {
|
||||
stat.atime.duration_since(std::time::UNIX_EPOCH).ok(),
|
||||
stat.mtime.duration_since(std::time::UNIX_EPOCH).ok(),
|
||||
) {
|
||||
filetime::set_file_times(path,
|
||||
filetime::set_file_times(
|
||||
path,
|
||||
filetime::FileTime::from_unix_time(atime.as_secs() as i64, 0),
|
||||
filetime::FileTime::from_unix_time(mtime.as_secs() as i64, 0),
|
||||
).map_err(|e| util::map_io_error(path, e))?;
|
||||
)
|
||||
.map_err(|e| util::map_io_error(path, e))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -174,8 +188,7 @@ impl VfsBackend for LocalFs {
|
||||
fn create_symlink(&self, target: &Path, link: &Path) -> Result<(), VfsError> {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
std::os::unix::fs::symlink(target, link)
|
||||
.map_err(|e| util::map_io_error(link, e))?;
|
||||
std::os::unix::fs::symlink(target, link).map_err(|e| util::map_io_error(link, e))?;
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
@@ -188,7 +201,9 @@ impl VfsBackend for LocalFs {
|
||||
}
|
||||
|
||||
fn real_path(&self, path: &Path) -> Result<PathBuf, VfsError> {
|
||||
let canonical = path.canonicalize().map_err(|e| util::map_io_error(path, e))?;
|
||||
let canonical = path
|
||||
.canonicalize()
|
||||
.map_err(|e| util::map_io_error(path, e))?;
|
||||
Ok(canonical)
|
||||
}
|
||||
|
||||
@@ -204,7 +219,9 @@ impl VfsBackend for LocalFs {
|
||||
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
return Err(VfsError::Unsupported("hard_link not supported on non-Unix systems".to_string()));
|
||||
return Err(VfsError::Unsupported(
|
||||
"hard_link not supported on non-Unix systems".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
pub mod open_flags;
|
||||
pub mod local_fs;
|
||||
pub mod open_flags;
|
||||
pub mod s3_fs;
|
||||
pub mod util;
|
||||
|
||||
@@ -120,7 +120,11 @@ pub trait VfsBackend: Send {
|
||||
fn read_dir(&self, path: &Path) -> Result<Vec<VfsDirEntry>, VfsError>;
|
||||
|
||||
/// 打开文件(读/写)
|
||||
fn open_file(&self, path: &Path, flags: &open_flags::OpenFlags) -> Result<Box<dyn VfsFile>, VfsError>;
|
||||
fn open_file(
|
||||
&self,
|
||||
path: &Path,
|
||||
flags: &open_flags::OpenFlags,
|
||||
) -> Result<Box<dyn VfsFile>, VfsError>;
|
||||
|
||||
/// 获取文件/目录元数据
|
||||
fn stat(&self, path: &Path) -> Result<VfsStat, VfsError>;
|
||||
|
||||
@@ -56,7 +56,10 @@ impl S3Vfs {
|
||||
|
||||
let credentials = Credentials::new(access_key, secret_key);
|
||||
|
||||
Ok(Self { bucket, credentials })
|
||||
Ok(Self {
|
||||
bucket,
|
||||
credentials,
|
||||
})
|
||||
}
|
||||
|
||||
fn path_to_key(path: &Path) -> String {
|
||||
@@ -118,7 +121,10 @@ impl S3Vfs {
|
||||
.map_err(|e| VfsError::Io(format!("S3 PUT failed: {}", e)))?;
|
||||
|
||||
if resp.status() != 200 {
|
||||
return Err(VfsError::Io(format!("PutObject returned {}", resp.status())));
|
||||
return Err(VfsError::Io(format!(
|
||||
"PutObject returned {}",
|
||||
resp.status()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -149,15 +155,15 @@ impl S3Vfs {
|
||||
.map_err(|e| VfsError::Io(format!("S3 CopyObject failed: {}", e)))?;
|
||||
|
||||
if resp.status() != 200 {
|
||||
return Err(VfsError::Io(format!("CopyObject returned {}", resp.status())));
|
||||
return Err(VfsError::Io(format!(
|
||||
"CopyObject returned {}",
|
||||
resp.status()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn list_objects(
|
||||
&self,
|
||||
prefix: &str,
|
||||
) -> Result<actions::ListObjectsV2Response, VfsError> {
|
||||
fn list_objects(&self, prefix: &str) -> Result<actions::ListObjectsV2Response, VfsError> {
|
||||
let mut action = actions::ListObjectsV2::new(&self.bucket, Some(&self.credentials));
|
||||
if !prefix.is_empty() {
|
||||
action.with_prefix(prefix);
|
||||
@@ -181,9 +187,8 @@ impl S3Vfs {
|
||||
.read_to_string(&mut body)
|
||||
.map_err(|e| VfsError::Io(format!("Failed to read S3 list response: {}", e)))?;
|
||||
|
||||
actions::ListObjectsV2::parse_response(&body).map_err(|e| {
|
||||
VfsError::Io(format!("Failed to parse S3 list response XML: {}", e))
|
||||
})
|
||||
actions::ListObjectsV2::parse_response(&body)
|
||||
.map_err(|e| VfsError::Io(format!("Failed to parse S3 list response XML: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -409,7 +414,9 @@ impl VfsBackend for S3Vfs {
|
||||
|
||||
impl VfsFile for S3VfsFile {
|
||||
fn read(&mut self, buf: &mut [u8]) -> Result<usize, VfsError> {
|
||||
let to_read = buf.len().min((self.size.saturating_sub(self.position)) as usize);
|
||||
let to_read = buf
|
||||
.len()
|
||||
.min((self.size.saturating_sub(self.position)) as usize);
|
||||
if to_read == 0 {
|
||||
return Ok(0);
|
||||
}
|
||||
@@ -443,7 +450,7 @@ impl VfsFile for S3VfsFile {
|
||||
self.position = sz.saturating_add(offset as u64);
|
||||
} else {
|
||||
let abs = offset.unsigned_abs();
|
||||
self.position = if abs <= sz { sz - abs } else { 0 };
|
||||
self.position = sz.saturating_sub(abs);
|
||||
}
|
||||
}
|
||||
std::io::SeekFrom::Current(offset) => {
|
||||
@@ -451,11 +458,7 @@ impl VfsFile for S3VfsFile {
|
||||
self.position = self.position.saturating_add(offset as u64);
|
||||
} else {
|
||||
let abs = offset.unsigned_abs();
|
||||
self.position = if abs <= self.position {
|
||||
self.position - abs
|
||||
} else {
|
||||
0
|
||||
};
|
||||
self.position = self.position.saturating_sub(abs);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -549,7 +552,10 @@ impl S3VfsLike {
|
||||
.map_err(|e| VfsError::Io(format!("S3 PUT failed: {}", e)))?;
|
||||
|
||||
if resp.status() != 200 {
|
||||
return Err(VfsError::Io(format!("PutObject returned {}", resp.status())));
|
||||
return Err(VfsError::Io(format!(
|
||||
"PutObject returned {}",
|
||||
resp.status()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -612,10 +618,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_path_to_key() {
|
||||
assert_eq!(
|
||||
S3Vfs::path_to_key(Path::new("/foo/bar.txt")),
|
||||
"foo/bar.txt"
|
||||
);
|
||||
assert_eq!(S3Vfs::path_to_key(Path::new("/foo/bar.txt")), "foo/bar.txt");
|
||||
assert_eq!(S3Vfs::path_to_key(Path::new("/")), "");
|
||||
assert_eq!(
|
||||
S3Vfs::path_to_key(Path::new("relative/path")),
|
||||
|
||||
@@ -7,7 +7,9 @@ use std::path::Path;
|
||||
pub fn map_io_error(path: &Path, e: std::io::Error) -> VfsError {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::NotFound => VfsError::NotFound(path.display().to_string()),
|
||||
std::io::ErrorKind::PermissionDenied => VfsError::PermissionDenied(path.display().to_string()),
|
||||
std::io::ErrorKind::PermissionDenied => {
|
||||
VfsError::PermissionDenied(path.display().to_string())
|
||||
}
|
||||
std::io::ErrorKind::AlreadyExists => VfsError::AlreadyExists(path.display().to_string()),
|
||||
std::io::ErrorKind::DirectoryNotEmpty => VfsError::NotEmpty(path.display().to_string()),
|
||||
std::io::ErrorKind::NotADirectory => VfsError::NotADirectory(path.display().to_string()),
|
||||
@@ -65,13 +67,7 @@ pub fn build_long_name(stat: &VfsStat, name: &str) -> String {
|
||||
|
||||
format!(
|
||||
"{}{} {} {} {} {} {} {}",
|
||||
file_type, perms,
|
||||
link_count,
|
||||
stat.uid,
|
||||
stat.gid,
|
||||
size,
|
||||
mtime,
|
||||
name
|
||||
file_type, perms, link_count, stat.uid, stat.gid, size, mtime, name
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user