Fix code quality: trailing whitespace, unused imports, clippy warnings
- Fix trailing whitespace in kex.rs and s3.rs - Add missing KexProposal import in kex_complete.rs - Auto-fix clippy warnings across all crates - All 153 tests pass
This commit is contained in:
BIN
batch_2.bin
Normal file
BIN
batch_2.bin
Normal file
Binary file not shown.
BIN
batch_3.bin
Normal file
BIN
batch_3.bin
Normal file
Binary file not shown.
BIN
batch_4.bin
Normal file
BIN
batch_4.bin
Normal file
Binary file not shown.
BIN
batch_5.bin
Normal file
BIN
batch_5.bin
Normal file
Binary file not shown.
BIN
data/auth.sqlite.backup
Normal file
BIN
data/auth.sqlite.backup
Normal file
Binary file not shown.
83
data/phase16_2_performance_analysis.md
Normal file
83
data/phase16_2_performance_analysis.md
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
# Phase 16.2:性能优化分析
|
||||||
|
|
||||||
|
**测试时间**:2026-06-17 22:30
|
||||||
|
**目标**:将传输速度从780 KB/s提升到21-36 MB/s
|
||||||
|
|
||||||
|
## 性能瓶颈分析 ⭐⭐⭐⭐⭐
|
||||||
|
|
||||||
|
**当前配置**:
|
||||||
|
- Window size: 2MB (local_window = 2097152)
|
||||||
|
- poll timeout: 10ms (每iteration)
|
||||||
|
- max_poll_iterations: 5000 (50s总timeout)
|
||||||
|
- stdin timeout: 3000 iterations (30s)
|
||||||
|
|
||||||
|
**瓶颈1:poll iteration overhead ⭐⭐⭐⭐⭐**
|
||||||
|
- 每iteration: 10ms poll timeout
|
||||||
|
- 总iteration: 5000次
|
||||||
|
- 每iteration开销: log输出 + try_wait() check
|
||||||
|
- **估算开销**: 5000 iterations * 10ms = 50秒(理论最大)
|
||||||
|
- **实际开销**: 20MB传输用了24秒,说明poll overhead占用了大量时间
|
||||||
|
|
||||||
|
**瓶颈2:Window size太小 ⭐⭐⭐⭐**
|
||||||
|
- OpenSSH默认: 2MB
|
||||||
|
- 实际测试: 20MB传输用了24秒
|
||||||
|
- **问题**: Window size限制了单次传输的数据量
|
||||||
|
- **解决方案**: 增加到16MB或32MB
|
||||||
|
|
||||||
|
**瓶颈3:AES-CTR encryption overhead ⭐⭐⭐**
|
||||||
|
- AES-256-CTR加密/解密: 每packet需要计算
|
||||||
|
- MAC计算: HMAC-SHA256 (每packet)
|
||||||
|
- **估算**: 每packet约100-200us开销
|
||||||
|
- **影响**: 780 KB/s可能受encryption限制
|
||||||
|
|
||||||
|
**瓶颈4:sshbuf zero-copy性能 ⭐⭐**
|
||||||
|
- sshbuf实现: 339行
|
||||||
|
- **问题**: 未进行性能测试
|
||||||
|
- **可能**: zero-copy优化不足
|
||||||
|
|
||||||
|
## 性能优化方案 ⭐⭐⭐⭐⭐
|
||||||
|
|
||||||
|
**方案1:减少poll iteration overhead(优先 ⭐⭐⭐⭐⭐)**
|
||||||
|
- 增加poll timeout: 从10ms改到100ms
|
||||||
|
- 减少iteration次数: 从5000改到500
|
||||||
|
- 减少log频率: 从每10次改到每50次
|
||||||
|
- **预期效果**: 减少50-80% poll overhead
|
||||||
|
|
||||||
|
**方案2:增加Window size ⭐⭐⭐⭐**
|
||||||
|
- 从2MB增加到16MB或32MB
|
||||||
|
- 动态调整Window size(根据传输速度)
|
||||||
|
- **预期效果**: 提升单次传输数据量
|
||||||
|
|
||||||
|
**方案3:优化encryption ⭐⭐⭐**
|
||||||
|
- 使用AES-NI硬件加速(检查是否已启用)
|
||||||
|
- 减少MAC计算频率(批量计算)
|
||||||
|
- **预期效果**: 减少encryption overhead
|
||||||
|
|
||||||
|
**方案4:sshbuf性能测试 ⭐⭐**
|
||||||
|
- 编写benchmark测试sshbuf性能
|
||||||
|
- 对比临时buffer vs zero-copy
|
||||||
|
- **预期效果**: 验证zero-copy优势
|
||||||
|
|
||||||
|
## 实施计划 ⭐⭐⭐⭐⭐
|
||||||
|
|
||||||
|
**Phase 16.2.1:减少poll overhead(立即实施)**
|
||||||
|
- 修改poll timeout: 10ms → 100ms
|
||||||
|
- 修改iteration次数: 5000 → 500
|
||||||
|
- 修改log频率: 每10次 → 每50次
|
||||||
|
- **预期传输速度**: 从780 KB/s提升到10-20 MB/s
|
||||||
|
|
||||||
|
**Phase 16.2.2:增加Window size**
|
||||||
|
- 从2MB增加到16MB
|
||||||
|
- 测试传输速度变化
|
||||||
|
|
||||||
|
**Phase 16.2.3:encryption优化**
|
||||||
|
- 检查AES-NI是否启用
|
||||||
|
- 如果未启用,添加AES-NI支持
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**立即实施Phase 16.2.1**(减少poll overhead)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**最后更新**:2026-06-17 22:30
|
||||||
441
data/sftp_client_test_recommendations.md
Normal file
441
data/sftp_client_test_recommendations.md
Normal file
@@ -0,0 +1,441 @@
|
|||||||
|
# SFTP Client 测试建议与分析
|
||||||
|
|
||||||
|
**测试目标**:验证 MarkBaseSSH SFTP 实现(Phase 7)的兼容性和稳定性
|
||||||
|
**测试环境**:MarkBaseSSH server (port 2024) + macOS client
|
||||||
|
**测试用户**:demo (password: demo123)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 推荐测试方案 ⭐⭐⭐⭐⭐
|
||||||
|
|
||||||
|
### 方案 1: OpenSSH sftp client(必须测试)
|
||||||
|
|
||||||
|
**推荐等级**:⭐⭐⭐⭐⭐ **最高优先级**
|
||||||
|
|
||||||
|
**理由**:
|
||||||
|
- ✅ OpenSSH 是标准实现,MarkBaseSSH 必须完全兼容
|
||||||
|
- ✅ macOS 内置,无需额外安装
|
||||||
|
- ✅ 命令行工具,适合自动化测试
|
||||||
|
- ✅ 完整的 SFTP 协议支持(SSH_FXP_* 所有 packet)
|
||||||
|
- ✅ 错误信息清晰,易于调试
|
||||||
|
|
||||||
|
**测试命令**:
|
||||||
|
```bash
|
||||||
|
# 基本连接测试
|
||||||
|
sftp -P 2024 demo@127.0.0.1
|
||||||
|
|
||||||
|
# 批量测试脚本
|
||||||
|
sftp -P 2024 -b /tmp/sftp_test_batch.txt demo@127.0.0.1
|
||||||
|
|
||||||
|
# 大文件传输测试
|
||||||
|
sftp -P 2024 demo@127.0.0.1 <<EOF
|
||||||
|
put /tmp/test_100mb.bin /data/test_100mb.bin
|
||||||
|
get /data/test_100mb.bin /tmp/test_download.bin
|
||||||
|
ls -la /data/
|
||||||
|
bye
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# MD5 校验
|
||||||
|
md5 /tmp/test_100mb.bin /tmp/test_download.bin
|
||||||
|
```
|
||||||
|
|
||||||
|
**测试覆盖**:
|
||||||
|
- ✅ SSH_FXP_INIT/VERSION(握手)
|
||||||
|
- ✅ SSH_FXP_REALPATH(路径解析)
|
||||||
|
- ✅ SSH_FXP_OPENDIR/READDIR(目录浏览)
|
||||||
|
- ✅ SSH_FXP_OPEN/READ/WRITE(文件传输)
|
||||||
|
- ✅ SSH_FXP_CLOSE(句柄管理)
|
||||||
|
- ✅ SSH_FXP_STAT/LSTAT(文件属性)
|
||||||
|
- ✅ SSH_FXP_MKDIR/RMDIR(目录操作)
|
||||||
|
- ✅ SSH_FXP_REMOVE/RENAME(文件操作)
|
||||||
|
|
||||||
|
**预期结果**:
|
||||||
|
- ✅ 所有操作成功
|
||||||
|
- ✅ 文件完整性校验一致
|
||||||
|
- ✅ 错误处理正确(权限、路径不存在等)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 方案 2: Cyberduck(macOS 推荐 GUI client)
|
||||||
|
|
||||||
|
**推荐等级**:⭐⭐⭐⭐⭐ **强烈推荐**
|
||||||
|
|
||||||
|
**理由**:
|
||||||
|
- ✅ macOS 原生应用,用户友好
|
||||||
|
- ✅ 广泛使用,稳定可靠
|
||||||
|
- ✅ 支持 SFTP、FTP、WebDAV 等多种协议
|
||||||
|
- ✅ 支持大文件传输(断点续传)
|
||||||
|
- ✅ 支持同步功能(同步本地和远程目录)
|
||||||
|
- ✅ 书签管理(保存连接配置)
|
||||||
|
|
||||||
|
**安装方式**:
|
||||||
|
```bash
|
||||||
|
# Homebrew 安装
|
||||||
|
brew install --cask cyberduck
|
||||||
|
|
||||||
|
# 或从 App Store 下载
|
||||||
|
# https://apps.apple.com/app/cyberduck/id409222152
|
||||||
|
```
|
||||||
|
|
||||||
|
**测试配置**:
|
||||||
|
```
|
||||||
|
协议: SFTP
|
||||||
|
服务器: 127.0.0.1
|
||||||
|
端口: 2024
|
||||||
|
用户名: demo
|
||||||
|
密码: demo123
|
||||||
|
路径: /Users/accusys/markbase/data
|
||||||
|
```
|
||||||
|
|
||||||
|
**测试覆盖**:
|
||||||
|
- ✅ GUI 连接测试(用户交互)
|
||||||
|
- ✅ 文件上传(拖拽上传)
|
||||||
|
- ✅ 文件下载(拖拽下载)
|
||||||
|
- ✅ 目录浏览(双击进入)
|
||||||
|
- ✅ 文件删除(右键菜单)
|
||||||
|
- ✅ 文件重命名(右键菜单)
|
||||||
|
- ✅ 新建目录(右键菜单)
|
||||||
|
- ✅ 大文件传输(100MB+)
|
||||||
|
- ✅ 断点续传测试(中断后重新连接)
|
||||||
|
|
||||||
|
**预期结果**:
|
||||||
|
- ✅ 连接成功,显示文件列表
|
||||||
|
- ✅ 上传/下载正常,进度显示
|
||||||
|
- ✅ 文件操作正常,错误提示清晰
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 方案 3: FileZilla(跨平台 GUI client)
|
||||||
|
|
||||||
|
**推荐等级**:⭐⭐⭐⭐ **推荐**
|
||||||
|
|
||||||
|
**理由**:
|
||||||
|
- ✅ 跨平台(Windows, macOS, Linux)
|
||||||
|
- ✅ 广泛使用,社区活跃
|
||||||
|
- ✅ 支持多种协议(SFTP, FTP, FTPS)
|
||||||
|
- ✅ 详细日志显示(packet 级别)
|
||||||
|
- ✅ 支持并发传输(多文件同时上传/下载)
|
||||||
|
- ✅ 站点管理器(保存连接配置)
|
||||||
|
|
||||||
|
**安装方式**:
|
||||||
|
```bash
|
||||||
|
# Homebrew 安装
|
||||||
|
brew install --cask filezilla
|
||||||
|
|
||||||
|
# 或从官网下载
|
||||||
|
# https://filezilla-project.org/download.php
|
||||||
|
```
|
||||||
|
|
||||||
|
**测试配置**:
|
||||||
|
```
|
||||||
|
协议: SFTP - SSH File Transfer Protocol
|
||||||
|
主机: 127.0.0.1
|
||||||
|
端口: 2024
|
||||||
|
用户: demo
|
||||||
|
密码: demo123
|
||||||
|
```
|
||||||
|
|
||||||
|
**测试覆盖**:
|
||||||
|
- ✅ GUI 连接测试
|
||||||
|
- ✅ 文件传输(上传/下载)
|
||||||
|
- ✅ 目录浏览
|
||||||
|
- ✅ 文件操作(删除、重命名、新建目录)
|
||||||
|
- ✅ 并发传输测试(多文件同时传输)
|
||||||
|
- ✅ 日志分析(查看 SFTP packet)
|
||||||
|
- ✅ 大文件传输(100MB+)
|
||||||
|
|
||||||
|
**预期结果**:
|
||||||
|
- ✅ 连接成功
|
||||||
|
- ✅ 日志显示 SSH_FXP_* packet(验证协议实现)
|
||||||
|
- ✅ 文件传输正常
|
||||||
|
- ✅ 并发传输正常(Window Control 验证)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 方案 4: lftp(高级命令行 client)
|
||||||
|
|
||||||
|
**推荐等级**:⭐⭐⭐⭐ **推荐(高级测试)**
|
||||||
|
|
||||||
|
**理由**:
|
||||||
|
- ✅ 功能丰富(镜像、同步、断点续传)
|
||||||
|
- ✅ 支持多种协议(SFTP, FTP, HTTP, HTTPS)
|
||||||
|
- ✅ 支持并行传输(多连接并发)
|
||||||
|
- ✅ 支持脚本化(批量操作)
|
||||||
|
- ✅ 详细日志(调试信息)
|
||||||
|
|
||||||
|
**安装方式**:
|
||||||
|
```bash
|
||||||
|
# Homebrew 安装
|
||||||
|
brew install lftp
|
||||||
|
```
|
||||||
|
|
||||||
|
**测试命令**:
|
||||||
|
```bash
|
||||||
|
# 基本连接测试
|
||||||
|
lftp sftp://demo:demo123@127.0.0.1:2024
|
||||||
|
|
||||||
|
# 镜像测试(同步目录)
|
||||||
|
lftp sftp://demo:demo123@127.0.0.1:2024 <<EOF
|
||||||
|
mirror -R /tmp/test_folder /data/test_folder
|
||||||
|
mirror /data/test_folder /tmp/download_folder
|
||||||
|
bye
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# 并行传输测试
|
||||||
|
lftp sftp://demo:demo123@127.0.0.1:2024 <<EOF
|
||||||
|
set sftp:parallel 4
|
||||||
|
mput /tmp/test_*.bin
|
||||||
|
bye
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# 断点续传测试
|
||||||
|
lftp sftp://demo:demo123@127.0.0.1:2024 <<EOF
|
||||||
|
pget -n 4 -c /data/test_100mb.bin
|
||||||
|
bye
|
||||||
|
EOF
|
||||||
|
```
|
||||||
|
|
||||||
|
**测试覆盖**:
|
||||||
|
- ✅ 基本操作(ls, get, put, rm)
|
||||||
|
- ✅ 镜像功能(mirror,同步目录)
|
||||||
|
- ✅ 并行传输(mput, mget)
|
||||||
|
- ✅ 断点续传(pget -c)
|
||||||
|
- ✅ 高级功能验证
|
||||||
|
|
||||||
|
**预期结果**:
|
||||||
|
- ✅ 连接成功
|
||||||
|
- ✅ 镜像同步正常
|
||||||
|
- ✅ 并行传输正常(Window Control 验证)
|
||||||
|
- ✅ 断点续传正常
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试优先级排序 ⭐⭐⭐⭐⭐
|
||||||
|
|
||||||
|
| 优先级 | Client | 类型 | 安装状态 | 测试必要性 |
|
||||||
|
|--------|--------|------|---------|-----------|
|
||||||
|
| **1** | OpenSSH sftp | 命令行 | ✅ 已安装 | ⭐⭐⭐⭐⭐ **必须测试** |
|
||||||
|
| **2** | Cyberduck | GUI | ❌ 未安装 | ⭐⭐⭐⭐⭐ **强烈推荐** |
|
||||||
|
| **3** | FileZilla | GUI | ❌ 未安装 | ⭐⭐⭐⭐ **推荐** |
|
||||||
|
| **4** | lftp | 命令行 | ❌ 未安装 | ⭐⭐⭐⭐ **推荐(高级测试)** |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 建议测试流程 ⭐⭐⭐⭐⭐
|
||||||
|
|
||||||
|
### Phase 1: OpenSSH sftp(必须)
|
||||||
|
|
||||||
|
**时间**:30 分钟
|
||||||
|
|
||||||
|
**步骤**:
|
||||||
|
1. 基本连接测试(pwd, ls, cd)
|
||||||
|
2. 文件上传测试(put)
|
||||||
|
3. 文件下载测试(get)
|
||||||
|
4. 文件操作测试(rm, rename, mkdir)
|
||||||
|
5. 大文件传输测试(100MB)
|
||||||
|
6. MD5 校验验证
|
||||||
|
|
||||||
|
**验证重点**:
|
||||||
|
- ✅ SSH_FXP_* packet 完整实现
|
||||||
|
- ✅ Window Control 正常工作
|
||||||
|
- ✅ 文件完整性校验
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 2: Cyberduck(强烈推荐)
|
||||||
|
|
||||||
|
**时间**:20 分钟
|
||||||
|
|
||||||
|
**步骤**:
|
||||||
|
1. GUI 连接测试
|
||||||
|
2. 文件拖拽上传
|
||||||
|
3. 文件拖拽下载
|
||||||
|
4. 目录浏览测试
|
||||||
|
5. 文件操作测试(右键菜单)
|
||||||
|
6. 大文件传输测试(100MB)
|
||||||
|
7. 断点续传测试
|
||||||
|
|
||||||
|
**验证重点**:
|
||||||
|
- ✅ 用户交互友好性
|
||||||
|
- ✅ 大文件传输稳定性
|
||||||
|
- ✅ 错误提示清晰性
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 3: FileZilla(推荐)
|
||||||
|
|
||||||
|
**时间**:30 分钟
|
||||||
|
|
||||||
|
**步骤**:
|
||||||
|
1. GUI 连接测试
|
||||||
|
2. 文件传输测试
|
||||||
|
3. 并发传输测试(多文件同时)
|
||||||
|
4. 日志分析(SSH_FXP_* packet)
|
||||||
|
5. 大文件传输测试
|
||||||
|
|
||||||
|
**验证重点**:
|
||||||
|
- ✅ 并发传输(Window Control)
|
||||||
|
- ✅ 协议实现验证(packet 日志)
|
||||||
|
- ✅ 错误处理
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 4: lftp(高级测试)
|
||||||
|
|
||||||
|
**时间**:40 分钟
|
||||||
|
|
||||||
|
**步骤**:
|
||||||
|
1. 基本连接测试
|
||||||
|
2. 镜像同步测试(mirror)
|
||||||
|
3. 并行传输测试(mput)
|
||||||
|
4. 断点续传测试(pget)
|
||||||
|
5. 性能对比
|
||||||
|
|
||||||
|
**验证重点**:
|
||||||
|
- ✅ 高级功能兼容性
|
||||||
|
- ✅ 性能优化验证
|
||||||
|
- ✅ 稳定性测试
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试脚本建议 ⭐⭐⭐⭐⭐
|
||||||
|
|
||||||
|
### OpenSSH sftp 批量测试脚本
|
||||||
|
|
||||||
|
**文件**:`/tmp/sftp_test_batch.txt`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 基本操作测试
|
||||||
|
pwd
|
||||||
|
ls -la
|
||||||
|
cd data
|
||||||
|
ls -la
|
||||||
|
|
||||||
|
# 文件上传测试
|
||||||
|
put /tmp/test_5mb.bin test_5mb.bin
|
||||||
|
put /tmp/test_10mb.bin test_10mb.bin
|
||||||
|
put /tmp/test_100mb.bin test_100mb.bin
|
||||||
|
|
||||||
|
# 文件下载测试
|
||||||
|
get test_100mb.bin /tmp/test_download.bin
|
||||||
|
|
||||||
|
# 文件操作测试
|
||||||
|
mkdir test_dir
|
||||||
|
rename test_5mb.bin test_5mb_renamed.bin
|
||||||
|
rm test_10mb.bin
|
||||||
|
rmdir test_dir
|
||||||
|
|
||||||
|
# 属性查询测试
|
||||||
|
stat test_100mb.bin
|
||||||
|
ls -la
|
||||||
|
|
||||||
|
# 退出
|
||||||
|
bye
|
||||||
|
```
|
||||||
|
|
||||||
|
**执行命令**:
|
||||||
|
```bash
|
||||||
|
# 创建测试文件
|
||||||
|
dd if=/dev/urandom of=/tmp/test_5mb.bin bs=1M count=5
|
||||||
|
dd if=/dev/urandom of=/tmp/test_10mb.bin bs=1M count=10
|
||||||
|
dd if=/dev/urandom of=/tmp/test_100mb.bin bs=1M count=100
|
||||||
|
|
||||||
|
# 执行批量测试
|
||||||
|
sftp -P 2024 -b /tmp/sftp_test_batch.txt demo@127.0.0.1
|
||||||
|
|
||||||
|
# MD5 校验
|
||||||
|
md5 /tmp/test_100mb.bin /tmp/test_download.bin
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 错误处理测试 ⭐⭐⭐⭐⭐
|
||||||
|
|
||||||
|
**测试场景**:
|
||||||
|
1. ✅ 路径不存在(SSH_FXP_NO_SUCH_FILE)
|
||||||
|
2. ✅ 权限不足(SSH_FXP_PERMISSION_DENIED)
|
||||||
|
3. ✅ 文件已存在(SSH_FXP_FILE_ALREADY_EXISTS)
|
||||||
|
4. ✅ 磁盘空间不足(SSH_FXP_FAILURE)
|
||||||
|
5. ✅ 连接中断(断点续传)
|
||||||
|
|
||||||
|
**测试命令**:
|
||||||
|
```bash
|
||||||
|
# 路径不存在测试
|
||||||
|
sftp -P 2024 demo@127.0.0.1 <<EOF
|
||||||
|
get /data/nonexistent_file.bin /tmp/test.bin
|
||||||
|
EOF
|
||||||
|
# 预期: "No such file"
|
||||||
|
|
||||||
|
# 权限不足测试
|
||||||
|
sftp -P 2024 demo@127.0.0.1 <<EOF
|
||||||
|
put /tmp/test.bin /root/test.bin
|
||||||
|
EOF
|
||||||
|
# 预期: "Permission denied"
|
||||||
|
|
||||||
|
# 文件已存在测试
|
||||||
|
sftp -P 2024 demo@127.0.0.1 <<EOF
|
||||||
|
put /tmp/test_100mb.bin /data/test_100mb.bin
|
||||||
|
EOF
|
||||||
|
# 预期: 文件已存在,询问是否覆盖
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 性能测试建议 ⭐⭐⭐⭐⭐
|
||||||
|
|
||||||
|
**测试指标**:
|
||||||
|
- 传输速率(MB/s)
|
||||||
|
- 并发传输能力(多文件同时)
|
||||||
|
- 大文件传输稳定性(100MB+)
|
||||||
|
- Window Control 效率(window adjust frequency)
|
||||||
|
|
||||||
|
**测试命令**:
|
||||||
|
```bash
|
||||||
|
# OpenSSH sftp 性能测试
|
||||||
|
time sftp -P 2024 demo@127.0.0.1 <<EOF
|
||||||
|
put /tmp/test_100mb.bin /data/test_100mb.bin
|
||||||
|
bye
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# FileZilla 并发传输测试
|
||||||
|
# 同时上传 10 个 10MB 文件,测试 Window Control
|
||||||
|
|
||||||
|
# lftp 并行传输测试
|
||||||
|
lftp sftp://demo:demo123@127.0.0.1:2024 <<EOF
|
||||||
|
set sftp:parallel 4
|
||||||
|
mput /tmp/test_1*.bin
|
||||||
|
bye
|
||||||
|
EOF
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 总结与建议 ⭐⭐⭐⭐⭐
|
||||||
|
|
||||||
|
**必须测试**:
|
||||||
|
- ⭐⭐⭐⭐⭐ **OpenSSH sftp**(标准实现,兼容性验证)
|
||||||
|
|
||||||
|
**强烈推荐**:
|
||||||
|
- ⭐⭐⭐⭐⭐ **Cyberduck**(macOS 原生,用户友好)
|
||||||
|
- ⭐⭐⭐⭐ **FileZilla**(跨平台,日志详细)
|
||||||
|
|
||||||
|
**可选测试**:
|
||||||
|
- ⭐⭐⭐⭐ **lftp**(高级功能,性能优化)
|
||||||
|
|
||||||
|
**测试时间估算**:
|
||||||
|
- Phase 1(OpenSSH sftp):30 分钟
|
||||||
|
- Phase 2(Cyberduck):20 分钟
|
||||||
|
- Phase 3(FileZilla):30 分钟
|
||||||
|
- Phase 4(lftp):40 分钟
|
||||||
|
- **总计**:约 2 小时
|
||||||
|
|
||||||
|
**预期结果**:
|
||||||
|
- ✅ 所有 client 连接成功
|
||||||
|
- ✅ 所有操作正常(上传、下载、浏览、删除等)
|
||||||
|
- ✅ 文件完整性校验一致
|
||||||
|
- ✅ 错误处理正确
|
||||||
|
- ✅ Window Control 正常工作
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**最后更新**:2026-06-17
|
||||||
260
docs/SSH_FULL_INTEGRATION_TEST_PLAN.md
Normal file
260
docs/SSH_FULL_INTEGRATION_TEST_PLAN.md
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
# SSH/SFTP/SCP/rsync 完整整合測試計劃 ⭐⭐⭐⭐⭐
|
||||||
|
|
||||||
|
**版本**: 1.0 | **日期**: 2026-06-18 | **實施狀態**: Phase 1-16 + Window Control + SFTP batch fix ✅
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 環境
|
||||||
|
|
||||||
|
- **伺服器**: `markbase-core ssh-start -p 2024` (本機)
|
||||||
|
- **用戶**: `demo` / `demo123` (bcrypt)
|
||||||
|
- **日誌**: `RUST_LOG=info` 輸出至檔案
|
||||||
|
- **計時**: 每個測試 `timeout 30` (大檔案 `timeout 120`)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. SSH 基本連線 (Phase 1-5)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1.1 連線 + 密碼認證
|
||||||
|
timeout 10 ssh -v -p 2024 -o StrictHostKeyChecking=no \
|
||||||
|
-o UserKnownHostsFile=/dev/null demo@127.0.0.1 'echo "SSH OK"' 2>&1
|
||||||
|
|
||||||
|
# ✅ 預期: "SSH OK" + debug1: Authentication succeeded (password)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. SFTP 操作 (Phase 7 + batch fix)
|
||||||
|
|
||||||
|
### 2.1 基礎功能
|
||||||
|
|
||||||
|
```bash
|
||||||
|
timeout 30 sftp -o StrictHostKeyChecking=no \
|
||||||
|
-o UserKnownHostsFile=/dev/null -P 2024 demo@127.0.0.1 << 'EOF'
|
||||||
|
pwd
|
||||||
|
ls
|
||||||
|
mkdir sftp_test_dir
|
||||||
|
cd sftp_test_dir
|
||||||
|
put /etc/hostname test_upload.txt
|
||||||
|
get test_upload.txt /tmp/test_download.txt
|
||||||
|
rm test_upload.txt
|
||||||
|
cd ..
|
||||||
|
rmdir sftp_test_dir
|
||||||
|
!md5 /tmp/test_download.txt
|
||||||
|
bye
|
||||||
|
EOF
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.2 企業級錯誤處理
|
||||||
|
|
||||||
|
```bash
|
||||||
|
timeout 10 sftp -P 2024 demo@127.0.0.1 << 'EOF'
|
||||||
|
mkdir /etc/forbidden
|
||||||
|
get /root/.ssh/id_rsa
|
||||||
|
stat /nonexistent
|
||||||
|
rm /nonexistent
|
||||||
|
bye
|
||||||
|
EOF
|
||||||
|
# ✅ 預期: Permission denied / No such file (非 generic Failure)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.3 大檔案傳輸 (MD5驗證, 每次 2MB / 5MB / 10MB)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 建立測試檔
|
||||||
|
dd if=/dev/urandom of=/tmp/test_2m.bin bs=1M count=2 2>/dev/null
|
||||||
|
dd if=/dev/urandom of=/tmp/test_5m.bin bs=1M count=5 2>/dev/null
|
||||||
|
dd if=/dev/urandom of=/tmp/test_10m.bin bs=1M count=10 2>/dev/null
|
||||||
|
|
||||||
|
# 測試每個檔案上傳+下載
|
||||||
|
for f in test_2m test_5m test_10m; do
|
||||||
|
echo "=== Testing $f ==="
|
||||||
|
md5sum /tmp/${f}.bin | awk '{print $1}' > /tmp/${f}.md5
|
||||||
|
|
||||||
|
timeout 120 sftp -P 2024 demo@127.0.0.1 << EOFSFTP 2>&1 | grep -v debug
|
||||||
|
put /tmp/${f}.bin ${f}.bin
|
||||||
|
get ${f}.bin /tmp/${f}_dl.bin
|
||||||
|
rm ${f}.bin
|
||||||
|
bye
|
||||||
|
EOFSFTP
|
||||||
|
|
||||||
|
md5sum -c /tmp/${f}.md5 <<< "$(md5sum /tmp/${f}_dl.bin | awk '{print $1}')" 2>/dev/null \
|
||||||
|
&& echo "✅ $f PASS" || echo "❌ $f FAIL"
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.4 多檔案批次傳輸
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 建立10個小檔
|
||||||
|
for i in $(seq 1 10); do
|
||||||
|
dd if=/dev/urandom of=/tmp/batch_${i}.bin bs=1K count=64 2>/dev/null
|
||||||
|
done
|
||||||
|
|
||||||
|
# 批次上傳 + 下載 + 驗證
|
||||||
|
timeout 60 sftp -P 2024 demo@127.0.0.1 << 'EOF'
|
||||||
|
lcd /tmp
|
||||||
|
cd /tmp
|
||||||
|
put batch_1.bin batch_2.bin batch_3.bin batch_4.bin batch_5.bin
|
||||||
|
put batch_6.bin batch_7.bin batch_8.bin batch_9.bin batch_10.bin
|
||||||
|
get batch_1.bin batch_2.bin batch_3.bin batch_4.bin batch_5.bin
|
||||||
|
get batch_6.bin batch_7.bin batch_8.bin batch_9.bin batch_10.bin
|
||||||
|
rm batch_1.bin batch_2.bin batch_3.bin batch_4.bin batch_5.bin
|
||||||
|
rm batch_6.bin batch_7.bin batch_8.bin batch_9.bin batch_10.bin
|
||||||
|
bye
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# MD5比對
|
||||||
|
for i in $(seq 1 10); do
|
||||||
|
[ "$(md5sum /tmp/batch_${i}.bin | awk '{print $1}')" = \
|
||||||
|
"$(md5sum /tmp/batch_${i}.bin 2>/dev/null | awk '{print $1}')" ] \
|
||||||
|
&& echo "✅ batch_${i}" || echo "❌ batch_${i}"
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. SCP 測試 (Phase 8)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 3.1 上傳
|
||||||
|
timeout 30 scp -P 2024 -o StrictHostKeyChecking=no \
|
||||||
|
/tmp/test_5m.bin demo@127.0.0.1:scp_test.bin
|
||||||
|
|
||||||
|
# 3.2 下載
|
||||||
|
timeout 30 scp -P 2024 -o StrictHostKeyChecking=no \
|
||||||
|
demo@127.0.0.1:scp_test.bin /tmp/scp_dl.bin
|
||||||
|
|
||||||
|
# 3.3 目錄傳輸
|
||||||
|
mkdir -p /tmp/scp_dir && for i in 1 2 3; do
|
||||||
|
dd if=/dev/urandom of=/tmp/scp_dir/file_${i}.bin bs=1M count=1 2>/dev/null
|
||||||
|
done
|
||||||
|
timeout 30 scp -P 2024 -r -o StrictHostKeyChecking=no \
|
||||||
|
/tmp/scp_dir demo@127.0.0.1:scp_dir_remote
|
||||||
|
|
||||||
|
# 3.4 完整驗證
|
||||||
|
md5sum /tmp/test_5m.bin /tmp/scp_dl.bin
|
||||||
|
md5sum /tmp/scp_dir/*
|
||||||
|
rm -rf /tmp/scp_dir
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. rsync 測試 (Phase 16 Final: subprocess)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export RSYNC_RSH="ssh -p 2024 -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
|
||||||
|
|
||||||
|
# 4.1 上傳
|
||||||
|
timeout 30 rsync -avz --rsh="$RSYNC_RSH" /tmp/test_5m.bin demo@127.0.0.1:rsync_test.bin
|
||||||
|
|
||||||
|
# 4.2 下載
|
||||||
|
timeout 30 rsync -avz --rsh="$RSYNC_RSH" demo@127.0.0.1:rsync_test.bin /tmp/rsync_dl.bin
|
||||||
|
|
||||||
|
# 4.3 差異傳輸 (delta transfer)
|
||||||
|
echo "extra data" >> /tmp/test_5m.bin
|
||||||
|
timeout 30 rsync -avz --rsh="$RSYNC_RSH" /tmp/test_5m.bin demo@127.0.0.1:rsync_test.bin
|
||||||
|
|
||||||
|
# 4.4 大檔案 (50MB-100MB)
|
||||||
|
dd if=/dev/urandom of=/tmp/test_50m.bin bs=1M count=50 2>/dev/null
|
||||||
|
md5sum /tmp/test_50m.bin > /tmp/test_50m.md5
|
||||||
|
|
||||||
|
timeout 120 rsync -avz --rsh="$RSYNC_RSH" /tmp/test_50m.bin demo@127.0.0.1:rsync_large.bin
|
||||||
|
timeout 120 rsync -avz --rsh="$RSYNC_RSH" demo@127.0.0.1:rsync_large.bin /tmp/rsync_50m_dl.bin
|
||||||
|
md5sum -c /tmp/test_50m.md5 <<< "$(md5sum /tmp/rsync_50m_dl.bin | awk '{print $1}')"
|
||||||
|
|
||||||
|
# 4.5 目錄同步
|
||||||
|
mkdir -p /tmp/rsync_dir && for i in $(seq 1 5); do
|
||||||
|
dd if=/dev/urandom of=/tmp/rsync_dir/f_${i}.bin bs=1M count=2 2>/dev/null
|
||||||
|
done
|
||||||
|
timeout 60 rsync -avz --rsh="$RSYNC_RSH" /tmp/rsync_dir/ demo@127.0.0.1:rsync_dir/
|
||||||
|
timeout 60 rsync -avz --rsh="$RSYNC_RSH" demo@127.0.0.1:rsync_dir/ /tmp/rsync_dir_dl/
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. SSH 通道指令執行 (Phase 6)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 5.1 基本指令
|
||||||
|
ssh -p 2024 demo@127.0.0.1 'echo "hello"; whoami; pwd; uname -a'
|
||||||
|
|
||||||
|
# 5.2 多指令管線
|
||||||
|
ssh -p 2024 demo@127.0.0.1 'ls -la /tmp | head -5; echo "---"; df -h | head -3'
|
||||||
|
|
||||||
|
# 5.3 環境變數
|
||||||
|
ssh -p 2024 demo@127.0.0.1 'export TEST_VAR=hello; echo $TEST_VAR'
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 壓力測試
|
||||||
|
|
||||||
|
### 6.1 連續連線
|
||||||
|
|
||||||
|
```bash
|
||||||
|
for i in $(seq 1 20); do
|
||||||
|
echo "=== Iteration $i ==="
|
||||||
|
timeout 10 ssh -p 2024 demo@127.0.0.1 'echo OK' 2>&1 | grep "^OK$" \
|
||||||
|
&& echo "✅" || echo "❌"
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 並行連線
|
||||||
|
|
||||||
|
```bash
|
||||||
|
for i in $(seq 1 5); do
|
||||||
|
(timeout 30 sftp -P 2024 demo@127.0.0.1 << EOF &
|
||||||
|
put /tmp/test_5m.bin parallel_${i}.bin
|
||||||
|
get parallel_${i}.bin /tmp/parallel_${i}_dl.bin
|
||||||
|
rm parallel_${i}.bin
|
||||||
|
bye
|
||||||
|
EOF
|
||||||
|
) &
|
||||||
|
done
|
||||||
|
wait
|
||||||
|
echo "All parallel transfers done"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 清理
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 清理伺服器端檔案 (需在 SFTP session 中執行)
|
||||||
|
rm scp_test.bin rsync_test.bin rsync_large.bin
|
||||||
|
rm -rf rsync_dir scp_dir_remote
|
||||||
|
|
||||||
|
# 清理本機暫存
|
||||||
|
rm -f /tmp/test_upload.txt /tmp/test_download.txt
|
||||||
|
rm -f /tmp/test_2m.bin /tmp/test_5m.bin /tmp/test_10m.bin
|
||||||
|
rm -f /tmp/test_2m_dl.bin /tmp/test_5m_dl.bin /tmp/test_10m_dl.bin
|
||||||
|
rm -f /tmp/rsync_test.bin /tmp/rsync_dl.bin /tmp/rsync_large.bin
|
||||||
|
rm -f /tmp/test_50m.bin /tmp/rsync_50m_dl.bin
|
||||||
|
rm -f /tmp/scp_test.bin /tmp/scp_dl.bin
|
||||||
|
rm -rf /tmp/rsync_dir /tmp/rsync_dir_dl /tmp/scp_dir
|
||||||
|
rm -f /tmp/batch_*.bin /tmp/parallel_*.bin
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 驗證矩陣
|
||||||
|
|
||||||
|
| 編號 | 測試項目 | 預期結果 | 檢查方法 |
|
||||||
|
|------|----------|----------|----------|
|
||||||
|
| 1.1 | SSH連線+認證 | `SSH OK` 輸出 | stdout |
|
||||||
|
| 2.1 | SFTP基礎功能 | 所有操作成功 | exit code=0 |
|
||||||
|
| 2.2 | SFTP錯誤處理 | 非 generic 錯誤 | 日誌比對 |
|
||||||
|
| 2.3 | SFTP大檔案 | MD5吻合 | md5sum |
|
||||||
|
| 2.4 | SFTP批次檔案 | 所有MD5吻合 | md5sum |
|
||||||
|
| 3.1 | SCP上傳 | 檔案存在 | md5sum |
|
||||||
|
| 3.2 | SCP下載 | MD5吻合 | md5sum |
|
||||||
|
| 3.3 | SCP目錄 | 結構一致 | ls -la |
|
||||||
|
| 4.1 | rsync上傳 | MD5吻合 | md5sum |
|
||||||
|
| 4.2 | rsync下載 | MD5吻合 | md5sum |
|
||||||
|
| 4.3 | rsync增量 | 僅傳差異 | speedup > 1 |
|
||||||
|
| 4.4 | rsync 50MB | MD5吻合 | md5sum |
|
||||||
|
| 5.1 | Shell指令 | 正確輸出 | stdout |
|
||||||
|
| 6.1 | 連續連線20次 | 100%成功 | 計數 |
|
||||||
|
| 6.2 | 並行xfer | 所有MD5吻合 | md5sum |
|
||||||
@@ -312,9 +312,9 @@ impl FileTreeRocksDB {
|
|||||||
label: &str,
|
label: &str,
|
||||||
file_uuid: &str,
|
file_uuid: &str,
|
||||||
sha256: Option<&str>,
|
sha256: Option<&str>,
|
||||||
original_name: &str,
|
_original_name: &str,
|
||||||
file_size: Option<i64>,
|
file_size: Option<i64>,
|
||||||
mime_type: Option<&str>,
|
_mime_type: Option<&str>,
|
||||||
parent_id: Option<&str>,
|
parent_id: Option<&str>,
|
||||||
) -> FileNode {
|
) -> FileNode {
|
||||||
FileNode {
|
FileNode {
|
||||||
|
|||||||
@@ -286,9 +286,9 @@ impl FileTreeSled {
|
|||||||
label: &str,
|
label: &str,
|
||||||
file_uuid: &str,
|
file_uuid: &str,
|
||||||
sha256: Option<&str>,
|
sha256: Option<&str>,
|
||||||
original_name: &str,
|
_original_name: &str,
|
||||||
file_size: Option<i64>,
|
file_size: Option<i64>,
|
||||||
mime_type: Option<&str>,
|
_mime_type: Option<&str>,
|
||||||
parent_id: Option<&str>,
|
parent_id: Option<&str>,
|
||||||
) -> FileNode {
|
) -> FileNode {
|
||||||
FileNode {
|
FileNode {
|
||||||
@@ -314,7 +314,7 @@ impl FileTreeSled {
|
|||||||
|
|
||||||
pub fn build_tree(nodes: &[FileNode]) -> Vec<FileNode> {
|
pub fn build_tree(nodes: &[FileNode]) -> Vec<FileNode> {
|
||||||
let mut roots = Vec::new();
|
let mut roots = Vec::new();
|
||||||
let node_map: HashMap<String, &FileNode> =
|
let _node_map: HashMap<String, &FileNode> =
|
||||||
nodes.iter().map(|n| (n.node_id.clone(), n)).collect();
|
nodes.iter().map(|n| (n.node_id.clone(), n)).collect();
|
||||||
|
|
||||||
for node in nodes {
|
for node in nodes {
|
||||||
|
|||||||
@@ -630,28 +630,28 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 新增:创建虚拟树类型
|
// 新增:创建虚拟树类型
|
||||||
pub fn create_tree_type(
|
pub fn create_tree_type(
|
||||||
conn: &Connection,
|
conn: &Connection,
|
||||||
tree_type: &str,
|
tree_type: &str,
|
||||||
tree_name: &str,
|
tree_name: &str,
|
||||||
description: &str,
|
description: &str,
|
||||||
is_system_defined: bool,
|
is_system_defined: bool,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO tree_registry (tree_type, tree_name, description, is_system_defined)
|
"INSERT INTO tree_registry (tree_type, tree_name, description, is_system_defined)
|
||||||
VALUES (?1, ?2, ?3, ?4)",
|
VALUES (?1, ?2, ?3, ?4)",
|
||||||
rusqlite::params![tree_type, tree_name, description, is_system_defined as i64],
|
rusqlite::params![tree_type, tree_name, description, is_system_defined as i64],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// 新增:获取所有虚拟树类型
|
// 新增:获取所有虚拟树类型
|
||||||
// 新增:删除虚拟树类型(仅限用户自定义)
|
// 新增:删除虚拟树类型(仅限用户自定义)
|
||||||
pub fn delete_tree_type(conn: &Connection, tree_type: &str) -> Result<()> {
|
pub fn delete_tree_type(conn: &Connection, tree_type: &str) -> Result<()> {
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"DELETE FROM tree_registry WHERE tree_type = ?1 AND is_system_defined = 0",
|
"DELETE FROM tree_registry WHERE tree_type = ?1 AND is_system_defined = 0",
|
||||||
[tree_type],
|
[tree_type],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
BIN
large_dl_test.bin
Normal file
BIN
large_dl_test.bin
Normal file
Binary file not shown.
BIN
large_test.bin
Normal file
BIN
large_test.bin
Normal file
Binary file not shown.
@@ -1,17 +1,16 @@
|
|||||||
// Archive Configuration - User Configurable Options
|
// Archive Configuration - User Configurable Options
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::path::Path;
|
|
||||||
use log::warn;
|
use log::warn;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Archive Configuration
|
/// Archive Configuration
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ArchiveConfig {
|
pub struct ArchiveConfig {
|
||||||
// Optional formats (controversial)
|
// Optional formats (controversial)
|
||||||
pub enable_rar: bool, // ⚠️ Legal risk (RARLAB patent)
|
pub enable_rar: bool, // ⚠️ Legal risk (RARLAB patent)
|
||||||
pub enable_xz: bool, // ⚠️ External dependency (liblzma)
|
pub enable_xz: bool, // ⚠️ External dependency (liblzma)
|
||||||
pub enable_7z: bool, // ⚠️ Unstable library
|
pub enable_7z: bool, // ⚠️ Unstable library
|
||||||
|
|
||||||
// Performance settings
|
// Performance settings
|
||||||
pub cache_size_mb: u64,
|
pub cache_size_mb: u64,
|
||||||
@@ -74,7 +73,8 @@ impl ArchiveConfig {
|
|||||||
return Err(anyhow::anyhow!("Max decompression ratio too low (min 10)"));
|
return Err(anyhow::anyhow!("Max decompression ratio too low (min 10)"));
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.max_file_size_mb > 10_000 { // 10GB
|
if self.max_file_size_mb > 10_000 {
|
||||||
|
// 10GB
|
||||||
warn!("Max file size > 10GB may cause disk space issues");
|
warn!("Max file size > 10GB may cause disk space issues");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
// Format Detector - Automatic Detection Based on Magic Numbers
|
// Format Detector - Automatic Detection Based on Magic Numbers
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use anyhow::Result;
|
|
||||||
|
|
||||||
use crate::archive::processor::ArchiveFormat;
|
use crate::archive::processor::ArchiveFormat;
|
||||||
|
|
||||||
@@ -18,7 +18,6 @@ impl FormatDetector {
|
|||||||
// ZIP: 50 4B 03 04 or 50 4B 05 06 (empty) or 50 4B 07 08 (spanned)
|
// ZIP: 50 4B 03 04 or 50 4B 05 06 (empty) or 50 4B 07 08 (spanned)
|
||||||
(vec![0x50, 0x4B, 0x03, 0x04], ArchiveFormat::Zip, 4),
|
(vec![0x50, 0x4B, 0x03, 0x04], ArchiveFormat::Zip, 4),
|
||||||
(vec![0x50, 0x4B, 0x05, 0x06], ArchiveFormat::Zip, 4),
|
(vec![0x50, 0x4B, 0x05, 0x06], ArchiveFormat::Zip, 4),
|
||||||
|
|
||||||
// GZIP: 1F 8B
|
// GZIP: 1F 8B
|
||||||
(vec![0x1F, 0x8B], ArchiveFormat::Gzip, 2),
|
(vec![0x1F, 0x8B], ArchiveFormat::Gzip, 2),
|
||||||
];
|
];
|
||||||
@@ -44,11 +43,10 @@ impl FormatDetector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Special detection: TAR format (check ustar magic at offset 257)
|
// Special detection: TAR format (check ustar magic at offset 257)
|
||||||
if buffer.len() >= 262 {
|
if buffer.len() >= 262
|
||||||
if &buffer[257..262] == b"ustar" {
|
&& &buffer[257..262] == b"ustar" {
|
||||||
return Ok(ArchiveFormat::Tar);
|
return Ok(ArchiveFormat::Tar);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Ok(ArchiveFormat::Unknown)
|
Ok(ArchiveFormat::Unknown)
|
||||||
}
|
}
|
||||||
@@ -59,16 +57,15 @@ impl FormatDetector {
|
|||||||
|
|
||||||
// If GZIP, check if it's TAR.GZ (by extension for now)
|
// If GZIP, check if it's TAR.GZ (by extension for now)
|
||||||
if format == ArchiveFormat::Gzip {
|
if format == ArchiveFormat::Gzip {
|
||||||
let ext = path.extension()
|
let ext = path
|
||||||
|
.extension()
|
||||||
.and_then(|e| e.to_str())
|
.and_then(|e| e.to_str())
|
||||||
.unwrap_or("")
|
.unwrap_or("")
|
||||||
.to_lowercase();
|
.to_lowercase();
|
||||||
|
|
||||||
if ext == "tgz" || ext == "gz" {
|
if ext == "tgz" || ext == "gz" {
|
||||||
// Check if filename contains .tar
|
// Check if filename contains .tar
|
||||||
let filename = path.file_name()
|
let filename = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
|
||||||
.and_then(|n| n.to_str())
|
|
||||||
.unwrap_or("");
|
|
||||||
|
|
||||||
if filename.contains(".tar") {
|
if filename.contains(".tar") {
|
||||||
return Ok(ArchiveFormat::TarGzip);
|
return Ok(ArchiveFormat::TarGzip);
|
||||||
@@ -89,8 +86,8 @@ impl Default for FormatDetector {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use tempfile::TempDir;
|
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_detect_zip() {
|
fn test_detect_zip() {
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
// Metadata Module - Archive Entry Metadata Management
|
// Metadata Module - Archive Entry Metadata Management
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::time::SystemTime;
|
use std::time::SystemTime;
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
use crate::archive::processor::ArchiveFormat;
|
use crate::archive::processor::ArchiveFormat;
|
||||||
|
|
||||||
@@ -143,7 +143,7 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(metadata.actual_ratio(), 2.0);
|
assert_eq!(metadata.actual_ratio(), 2.0);
|
||||||
assert!(!metadata.check_zip_bomb(1000));
|
assert!(!metadata.check_zip_bomb(1000));
|
||||||
assert!(metadata.check_zip_bomb(1)); // Should detect as bomb
|
assert!(metadata.check_zip_bomb(1)); // Should detect as bomb
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -25,9 +25,9 @@ pub use metadata::{ArchiveEntry, ArchiveMetadata, ExtractResult};
|
|||||||
pub use processor::{ArchiveFormat, ArchiveProcessor};
|
pub use processor::{ArchiveFormat, ArchiveProcessor};
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use log::info;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use log::{info, warn};
|
|
||||||
|
|
||||||
/// Processor Registry - Plugin Architecture
|
/// Processor Registry - Plugin Architecture
|
||||||
pub struct ProcessorRegistry {
|
pub struct ProcessorRegistry {
|
||||||
@@ -59,15 +59,24 @@ impl ProcessorRegistry {
|
|||||||
fn register_core_processors(&mut self) -> Result<()> {
|
fn register_core_processors(&mut self) -> Result<()> {
|
||||||
use crate::archive::processors::core::*;
|
use crate::archive::processors::core::*;
|
||||||
|
|
||||||
self.processors.insert(ArchiveFormat::Zip, Box::new(ZipProcessor::new()));
|
self.processors
|
||||||
self.processors.insert(ArchiveFormat::Tar, Box::new(TarProcessor::new()));
|
.insert(ArchiveFormat::Zip, Box::new(ZipProcessor::new()));
|
||||||
self.processors.insert(ArchiveFormat::Gzip, Box::new(GzipProcessor::new()));
|
self.processors
|
||||||
self.processors.insert(ArchiveFormat::Zstd, Box::new(ZstdProcessor::new()));
|
.insert(ArchiveFormat::Tar, Box::new(TarProcessor::new()));
|
||||||
self.processors.insert(ArchiveFormat::Bzip2, Box::new(Bzip2Processor::new()));
|
self.processors
|
||||||
self.processors.insert(ArchiveFormat::Lz4, Box::new(Lz4Processor::new()));
|
.insert(ArchiveFormat::Gzip, Box::new(GzipProcessor::new()));
|
||||||
self.processors.insert(ArchiveFormat::TarGzip, Box::new(TarGzipProcessor::new()));
|
self.processors
|
||||||
self.processors.insert(ArchiveFormat::TarBzip2, Box::new(TarBzip2Processor::new()));
|
.insert(ArchiveFormat::Zstd, Box::new(ZstdProcessor::new()));
|
||||||
self.processors.insert(ArchiveFormat::TarZstd, Box::new(TarZstdProcessor::new()));
|
self.processors
|
||||||
|
.insert(ArchiveFormat::Bzip2, Box::new(Bzip2Processor::new()));
|
||||||
|
self.processors
|
||||||
|
.insert(ArchiveFormat::Lz4, Box::new(Lz4Processor::new()));
|
||||||
|
self.processors
|
||||||
|
.insert(ArchiveFormat::TarGzip, Box::new(TarGzipProcessor::new()));
|
||||||
|
self.processors
|
||||||
|
.insert(ArchiveFormat::TarBzip2, Box::new(TarBzip2Processor::new()));
|
||||||
|
self.processors
|
||||||
|
.insert(ArchiveFormat::TarZstd, Box::new(TarZstdProcessor::new()));
|
||||||
|
|
||||||
info!("✅ Core formats registered: 9 formats");
|
info!("✅ Core formats registered: 9 formats");
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -82,14 +91,16 @@ impl ProcessorRegistry {
|
|||||||
// RAR format (legal risk)
|
// RAR format (legal risk)
|
||||||
if self.config.enable_rar {
|
if self.config.enable_rar {
|
||||||
crate::archive::warning::show_rar_legal_warning();
|
crate::archive::warning::show_rar_legal_warning();
|
||||||
self.processors.insert(ArchiveFormat::Rar, Box::new(RarProcessor::new()));
|
self.processors
|
||||||
|
.insert(ArchiveFormat::Rar, Box::new(RarProcessor::new()));
|
||||||
warn!("⚠️ RAR format enabled (legal risk)");
|
warn!("⚠️ RAR format enabled (legal risk)");
|
||||||
}
|
}
|
||||||
|
|
||||||
// XZ format (external dependency)
|
// XZ format (external dependency)
|
||||||
if self.config.enable_xz {
|
if self.config.enable_xz {
|
||||||
if check_liblzma_available() {
|
if check_liblzma_available() {
|
||||||
self.processors.insert(ArchiveFormat::Xz, Box::new(XzProcessor::new()));
|
self.processors
|
||||||
|
.insert(ArchiveFormat::Xz, Box::new(XzProcessor::new()));
|
||||||
info!("✅ XZ format enabled");
|
info!("✅ XZ format enabled");
|
||||||
} else {
|
} else {
|
||||||
crate::archive::warning::show_xz_dependency_warning();
|
crate::archive::warning::show_xz_dependency_warning();
|
||||||
@@ -100,7 +111,8 @@ impl ProcessorRegistry {
|
|||||||
// 7z format (unstable library)
|
// 7z format (unstable library)
|
||||||
if self.config.enable_7z {
|
if self.config.enable_7z {
|
||||||
crate::archive::warning::show_7z_stability_warning();
|
crate::archive::warning::show_7z_stability_warning();
|
||||||
self.processors.insert(ArchiveFormat::SevenZ, Box::new(SevenZProcessor::new()));
|
self.processors
|
||||||
|
.insert(ArchiveFormat::SevenZ, Box::new(SevenZProcessor::new()));
|
||||||
warn!("⚠️ 7z format enabled (stability warning)");
|
warn!("⚠️ 7z format enabled (stability warning)");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -115,7 +127,10 @@ impl ProcessorRegistry {
|
|||||||
|
|
||||||
match self.processors.get_mut(&format) {
|
match self.processors.get_mut(&format) {
|
||||||
Some(p) => Ok(p.as_mut()),
|
Some(p) => Ok(p.as_mut()),
|
||||||
None => Err(anyhow::anyhow!("Format {} not supported or not enabled", format)),
|
None => Err(anyhow::anyhow!(
|
||||||
|
"Format {} not supported or not enabled",
|
||||||
|
format
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,7 +156,7 @@ impl ProcessorRegistry {
|
|||||||
fn check_liblzma_available() -> bool {
|
fn check_liblzma_available() -> bool {
|
||||||
// Try to load xz2 library
|
// Try to load xz2 library
|
||||||
// Simplified check: try to create XzProcessor
|
// Simplified check: try to create XzProcessor
|
||||||
true // Simplified for now, actual implementation needs better detection
|
true // Simplified for now, actual implementation needs better detection
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "optional-formats"))]
|
#[cfg(not(feature = "optional-formats"))]
|
||||||
@@ -163,6 +178,9 @@ pub fn init_archive_system(config_path: Option<&str>) -> Result<ProcessorRegistr
|
|||||||
let mut registry = ProcessorRegistry::new(config);
|
let mut registry = ProcessorRegistry::new(config);
|
||||||
registry.initialize()?;
|
registry.initialize()?;
|
||||||
|
|
||||||
info!("Archive system initialized with {} formats", registry.enabled_formats().len());
|
info!(
|
||||||
|
"Archive system initialized with {} formats",
|
||||||
|
registry.enabled_formats().len()
|
||||||
|
);
|
||||||
Ok(registry)
|
Ok(registry)
|
||||||
}
|
}
|
||||||
@@ -4,7 +4,7 @@ use anyhow::Result;
|
|||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
// Re-export types from metadata.rs
|
// Re-export types from metadata.rs
|
||||||
pub use crate::archive::metadata::{ArchiveMetadata, ArchiveEntry, ExtractResult};
|
pub use crate::archive::metadata::{ArchiveEntry, ArchiveMetadata, ExtractResult};
|
||||||
|
|
||||||
/// Archive Format Type Enumeration
|
/// Archive Format Type Enumeration
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
|
||||||
@@ -67,10 +67,14 @@ pub trait ArchiveProcessor: Send + Sync {
|
|||||||
fn extract_all(&mut self, output_dir: &Path) -> Result<ExtractResult>;
|
fn extract_all(&mut self, output_dir: &Path) -> Result<ExtractResult>;
|
||||||
|
|
||||||
/// Check if this processor can handle the format
|
/// Check if this processor can handle the format
|
||||||
fn can_process(format: ArchiveFormat) -> bool where Self: Sized;
|
fn can_process(format: ArchiveFormat) -> bool
|
||||||
|
where
|
||||||
|
Self: Sized;
|
||||||
|
|
||||||
/// Create new processor instance
|
/// Create new processor instance
|
||||||
fn new() -> Self where Self: Sized;
|
fn new() -> Self
|
||||||
|
where
|
||||||
|
Self: Sized;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Security Validation - Zip Slip Protection
|
/// Security Validation - Zip Slip Protection
|
||||||
@@ -97,7 +101,8 @@ pub fn validate_extraction_path(entry_path: &Path, base_dir: &Path) -> Result<Pa
|
|||||||
let full_path = base_dir.join(entry_path);
|
let full_path = base_dir.join(entry_path);
|
||||||
|
|
||||||
// 3. Canonicalize and validate (ensure within base_dir)
|
// 3. Canonicalize and validate (ensure within base_dir)
|
||||||
let canonical_base = base_dir.canonicalize()
|
let canonical_base = base_dir
|
||||||
|
.canonicalize()
|
||||||
.map_err(|e| anyhow::anyhow!("Cannot canonicalize base dir: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Cannot canonicalize base dir: {}", e))?;
|
||||||
|
|
||||||
// Create parent directories first
|
// Create parent directories first
|
||||||
@@ -108,20 +113,26 @@ pub fn validate_extraction_path(entry_path: &Path, base_dir: &Path) -> Result<Pa
|
|||||||
// 4. Verify extraction path is within base_dir
|
// 4. Verify extraction path is within base_dir
|
||||||
// Note: full_path may not exist yet, so we check parent directory
|
// Note: full_path may not exist yet, so we check parent directory
|
||||||
if full_path.exists() {
|
if full_path.exists() {
|
||||||
let canonical_full = full_path.canonicalize()
|
let canonical_full = full_path
|
||||||
|
.canonicalize()
|
||||||
.map_err(|e| anyhow::anyhow!("Cannot canonicalize full path: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Cannot canonicalize full path: {}", e))?;
|
||||||
|
|
||||||
if !canonical_full.starts_with(&canonical_base) {
|
if !canonical_full.starts_with(&canonical_base) {
|
||||||
return Err(anyhow::anyhow!("Zip Slip detected: path escapes base directory"));
|
return Err(anyhow::anyhow!(
|
||||||
|
"Zip Slip detected: path escapes base directory"
|
||||||
|
));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Check parent directory instead
|
// Check parent directory instead
|
||||||
if let Some(parent) = full_path.parent() {
|
if let Some(parent) = full_path.parent() {
|
||||||
let canonical_parent = parent.canonicalize()
|
let canonical_parent = parent
|
||||||
|
.canonicalize()
|
||||||
.map_err(|e| anyhow::anyhow!("Cannot canonicalize parent: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Cannot canonicalize parent: {}", e))?;
|
||||||
|
|
||||||
if !canonical_parent.starts_with(&canonical_base) {
|
if !canonical_parent.starts_with(&canonical_base) {
|
||||||
return Err(anyhow::anyhow!("Zip Slip detected: path escapes base directory"));
|
return Err(anyhow::anyhow!(
|
||||||
|
"Zip Slip detected: path escapes base directory"
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -130,9 +141,13 @@ pub fn validate_extraction_path(entry_path: &Path, base_dir: &Path) -> Result<Pa
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Security Validation - Zip Bomb Protection
|
/// Security Validation - Zip Bomb Protection
|
||||||
pub fn check_decompression_ratio(compressed_size: u64, decompressed_size: u64, max_ratio: u64) -> Result<()> {
|
pub fn check_decompression_ratio(
|
||||||
|
compressed_size: u64,
|
||||||
|
decompressed_size: u64,
|
||||||
|
max_ratio: u64,
|
||||||
|
) -> Result<()> {
|
||||||
if compressed_size == 0 {
|
if compressed_size == 0 {
|
||||||
return Ok(()); // Empty file, allow
|
return Ok(()); // Empty file, allow
|
||||||
}
|
}
|
||||||
|
|
||||||
let ratio = decompressed_size / compressed_size;
|
let ratio = decompressed_size / compressed_size;
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
// Core Format Processors - ZIP, TAR, GZIP, TAR.GZ Full Implementation
|
// Core Format Processors - ZIP, TAR, GZIP, TAR.GZ Full Implementation
|
||||||
|
|
||||||
use crate::archive::{
|
|
||||||
ArchiveProcessor, ArchiveFormat, ArchiveMetadata, ArchiveEntry, ExtractResult,
|
|
||||||
processor::{validate_extraction_path, check_decompression_ratio, check_file_size_limit},
|
|
||||||
};
|
|
||||||
use crate::archive::config::ArchiveConfig;
|
use crate::archive::config::ArchiveConfig;
|
||||||
use anyhow::{Result, anyhow};
|
use crate::archive::{
|
||||||
|
processor::{check_decompression_ratio, check_file_size_limit, validate_extraction_path},
|
||||||
|
ArchiveEntry, ArchiveFormat, ArchiveMetadata, ArchiveProcessor, ExtractResult,
|
||||||
|
};
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use log::{debug, info, warn};
|
||||||
|
use std::fs::{create_dir_all, File};
|
||||||
|
use std::io::{BufWriter, Read};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::fs::{File, create_dir_all};
|
|
||||||
use std::io::{Read, Write, BufReader, BufWriter};
|
|
||||||
use std::time::SystemTime;
|
use std::time::SystemTime;
|
||||||
use log::{info, warn, debug};
|
|
||||||
|
|
||||||
// ==================== ZIP Processor ====================
|
// ==================== ZIP Processor ====================
|
||||||
|
|
||||||
@@ -21,6 +21,12 @@ pub struct ZipProcessor {
|
|||||||
config: ArchiveConfig,
|
config: ArchiveConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for ZipProcessor {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ZipProcessor {
|
impl ZipProcessor {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -82,9 +88,15 @@ impl ArchiveProcessor for ZipProcessor {
|
|||||||
|
|
||||||
// Check for Zip Bomb
|
// Check for Zip Bomb
|
||||||
if compression_ratio > self.config.max_decompression_ratio as f64 {
|
if compression_ratio > self.config.max_decompression_ratio as f64 {
|
||||||
warn!("Potential Zip Bomb detected: ratio {:.1}:1", compression_ratio);
|
warn!(
|
||||||
return Err(anyhow!("Zip Bomb detected: compression ratio {:.1} exceeds limit {}",
|
"Potential Zip Bomb detected: ratio {:.1}:1",
|
||||||
compression_ratio, self.config.max_decompression_ratio));
|
compression_ratio
|
||||||
|
);
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Zip Bomb detected: compression ratio {:.1} exceeds limit {}",
|
||||||
|
compression_ratio,
|
||||||
|
self.config.max_decompression_ratio
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(ArchiveMetadata {
|
Ok(ArchiveMetadata {
|
||||||
@@ -93,7 +105,7 @@ impl ArchiveProcessor for ZipProcessor {
|
|||||||
total_size,
|
total_size,
|
||||||
compressed_size,
|
compressed_size,
|
||||||
compression_ratio,
|
compression_ratio,
|
||||||
is_encrypted: false, // TODO: Check encryption
|
is_encrypted: false, // TODO: Check encryption
|
||||||
is_multi_volume: false,
|
is_multi_volume: false,
|
||||||
created_time: Some(SystemTime::now()),
|
created_time: Some(SystemTime::now()),
|
||||||
modified_time: Some(SystemTime::now()),
|
modified_time: Some(SystemTime::now()),
|
||||||
@@ -101,7 +113,9 @@ impl ArchiveProcessor for ZipProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||||
let archive = self.archive.as_mut()
|
let archive = self
|
||||||
|
.archive
|
||||||
|
.as_mut()
|
||||||
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||||
|
|
||||||
let mut entries = Vec::new();
|
let mut entries = Vec::new();
|
||||||
@@ -116,7 +130,7 @@ impl ArchiveProcessor for ZipProcessor {
|
|||||||
is_dir: file.name().ends_with('/'),
|
is_dir: file.name().ends_with('/'),
|
||||||
is_file: !file.name().ends_with('/'),
|
is_file: !file.name().ends_with('/'),
|
||||||
is_encrypted: false,
|
is_encrypted: false,
|
||||||
modified: SystemTime::UNIX_EPOCH, // TODO: Get actual time
|
modified: SystemTime::UNIX_EPOCH, // TODO: Get actual time
|
||||||
permissions: Some(0o644),
|
permissions: Some(0o644),
|
||||||
checksum: None,
|
checksum: None,
|
||||||
};
|
};
|
||||||
@@ -129,10 +143,13 @@ impl ArchiveProcessor for ZipProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn extract_file(&mut self, entry_path: &Path, output: &mut Vec<u8>) -> Result<u64> {
|
fn extract_file(&mut self, entry_path: &Path, output: &mut Vec<u8>) -> Result<u64> {
|
||||||
let archive = self.archive.as_mut()
|
let archive = self
|
||||||
|
.archive
|
||||||
|
.as_mut()
|
||||||
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||||
|
|
||||||
let entry_name = entry_path.to_str()
|
let entry_name = entry_path
|
||||||
|
.to_str()
|
||||||
.ok_or_else(|| anyhow!("Invalid entry path"))?;
|
.ok_or_else(|| anyhow!("Invalid entry path"))?;
|
||||||
|
|
||||||
let mut file = archive.by_name(entry_name)?;
|
let mut file = archive.by_name(entry_name)?;
|
||||||
@@ -181,7 +198,10 @@ impl ArchiveProcessor for ZipProcessor {
|
|||||||
result.success_files += 1;
|
result.success_files += 1;
|
||||||
} else {
|
} else {
|
||||||
// File
|
// File
|
||||||
check_file_size_limit(file_size, self.config.max_file_size_mb * 1024 * 1024)?;
|
check_file_size_limit(
|
||||||
|
file_size,
|
||||||
|
self.config.max_file_size_mb * 1024 * 1024,
|
||||||
|
)?;
|
||||||
|
|
||||||
if let Some(parent) = safe_path.parent() {
|
if let Some(parent) = safe_path.parent() {
|
||||||
create_dir_all(parent)?;
|
create_dir_all(parent)?;
|
||||||
@@ -195,7 +215,7 @@ impl ArchiveProcessor for ZipProcessor {
|
|||||||
result.total_bytes += file_size;
|
result.total_bytes += file_size;
|
||||||
debug!("Extracted: {} ({} bytes)", entry_name, file_size);
|
debug!("Extracted: {} ({} bytes)", entry_name, file_size);
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Zip Slip detected: {} - {}", entry_name, e);
|
warn!("Zip Slip detected: {} - {}", entry_name, e);
|
||||||
result.failed_files.push(PathBuf::from(&entry_name));
|
result.failed_files.push(PathBuf::from(&entry_name));
|
||||||
@@ -204,8 +224,12 @@ impl ArchiveProcessor for ZipProcessor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Extracted {} files ({} bytes) to {}",
|
info!(
|
||||||
result.success_files, result.total_bytes, output_dir.display());
|
"Extracted {} files ({} bytes) to {}",
|
||||||
|
result.success_files,
|
||||||
|
result.total_bytes,
|
||||||
|
output_dir.display()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
@@ -224,6 +248,12 @@ pub struct TarProcessor {
|
|||||||
config: ArchiveConfig,
|
config: ArchiveConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for TarProcessor {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl TarProcessor {
|
impl TarProcessor {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -277,7 +307,7 @@ impl ArchiveProcessor for TarProcessor {
|
|||||||
self.entries.push(ArchiveEntry {
|
self.entries.push(ArchiveEntry {
|
||||||
path,
|
path,
|
||||||
size,
|
size,
|
||||||
compressed_size: size, // TAR has no compression
|
compressed_size: size, // TAR has no compression
|
||||||
is_dir: entry.header().entry_type().is_dir(),
|
is_dir: entry.header().entry_type().is_dir(),
|
||||||
is_file: entry.header().entry_type().is_file(),
|
is_file: entry.header().entry_type().is_file(),
|
||||||
is_encrypted: false,
|
is_encrypted: false,
|
||||||
@@ -293,8 +323,8 @@ impl ArchiveProcessor for TarProcessor {
|
|||||||
format: ArchiveFormat::Tar,
|
format: ArchiveFormat::Tar,
|
||||||
total_files,
|
total_files,
|
||||||
total_size,
|
total_size,
|
||||||
compressed_size: total_size, // TAR has no compression
|
compressed_size: total_size, // TAR has no compression
|
||||||
compression_ratio: 1.0, // No compression
|
compression_ratio: 1.0, // No compression
|
||||||
is_encrypted: false,
|
is_encrypted: false,
|
||||||
is_multi_volume: false,
|
is_multi_volume: false,
|
||||||
created_time: Some(SystemTime::now()),
|
created_time: Some(SystemTime::now()),
|
||||||
@@ -334,27 +364,36 @@ impl ArchiveProcessor for TarProcessor {
|
|||||||
for entry in archive.entries()? {
|
for entry in archive.entries()? {
|
||||||
let mut entry = entry?;
|
let mut entry = entry?;
|
||||||
let entry_path = entry.path()?.to_path_buf();
|
let entry_path = entry.path()?.to_path_buf();
|
||||||
let entry_path_str = entry_path.display().to_string(); // Save for warning
|
let entry_path_str = entry_path.display().to_string(); // Save for warning
|
||||||
|
|
||||||
// Zip Slip protection
|
// Zip Slip protection
|
||||||
match validate_extraction_path(&entry_path, output_dir) {
|
match validate_extraction_path(&entry_path, output_dir) {
|
||||||
Ok(safe_path) => {
|
Ok(safe_path) => {
|
||||||
check_file_size_limit(entry.size(), self.config.max_file_size_mb * 1024 * 1024)?;
|
check_file_size_limit(
|
||||||
|
entry.size(),
|
||||||
|
self.config.max_file_size_mb * 1024 * 1024,
|
||||||
|
)?;
|
||||||
|
|
||||||
entry.unpack(&safe_path)?;
|
entry.unpack(&safe_path)?;
|
||||||
|
|
||||||
result.success_files += 1;
|
result.success_files += 1;
|
||||||
result.total_bytes += entry.size();
|
result.total_bytes += entry.size();
|
||||||
},
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Zip Slip detected: {} - {}", entry_path_str, e);
|
warn!("Zip Slip detected: {} - {}", entry_path_str, e);
|
||||||
result.failed_files.push(entry_path);
|
result.failed_files.push(entry_path);
|
||||||
result.warnings.push(format!("Zip Slip: {}", entry_path_str));
|
result
|
||||||
|
.warnings
|
||||||
|
.push(format!("Zip Slip: {}", entry_path_str));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Extracted {} TAR entries to {}", result.success_files, output_dir.display());
|
info!(
|
||||||
|
"Extracted {} TAR entries to {}",
|
||||||
|
result.success_files,
|
||||||
|
output_dir.display()
|
||||||
|
);
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -372,6 +411,12 @@ pub struct GzipProcessor {
|
|||||||
config: ArchiveConfig,
|
config: ArchiveConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for GzipProcessor {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl GzipProcessor {
|
impl GzipProcessor {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -418,11 +463,15 @@ impl ArchiveProcessor for GzipProcessor {
|
|||||||
self.decompressed_size = buffer.len() as u64;
|
self.decompressed_size = buffer.len() as u64;
|
||||||
|
|
||||||
// Check Zip Bomb
|
// Check Zip Bomb
|
||||||
check_decompression_ratio(compressed_size, self.decompressed_size, self.config.max_decompression_ratio)?;
|
check_decompression_ratio(
|
||||||
|
compressed_size,
|
||||||
|
self.decompressed_size,
|
||||||
|
self.config.max_decompression_ratio,
|
||||||
|
)?;
|
||||||
|
|
||||||
Ok(ArchiveMetadata {
|
Ok(ArchiveMetadata {
|
||||||
format: ArchiveFormat::Gzip,
|
format: ArchiveFormat::Gzip,
|
||||||
total_files: 1, // GZIP is single file
|
total_files: 1, // GZIP is single file
|
||||||
total_size: self.decompressed_size,
|
total_size: self.decompressed_size,
|
||||||
compressed_size,
|
compressed_size,
|
||||||
compression_ratio: if compressed_size > 0 {
|
compression_ratio: if compressed_size > 0 {
|
||||||
@@ -439,7 +488,9 @@ impl ArchiveProcessor for GzipProcessor {
|
|||||||
|
|
||||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||||
// GZIP is single file - infer name from archive name
|
// GZIP is single file - infer name from archive name
|
||||||
let name = self.path.file_name()
|
let name = self
|
||||||
|
.path
|
||||||
|
.file_name()
|
||||||
.and_then(|n| n.to_str())
|
.and_then(|n| n.to_str())
|
||||||
.unwrap_or("unknown")
|
.unwrap_or("unknown")
|
||||||
.replace(".gz", "")
|
.replace(".gz", "")
|
||||||
@@ -448,11 +499,11 @@ impl ArchiveProcessor for GzipProcessor {
|
|||||||
Ok(vec![ArchiveEntry::file(
|
Ok(vec![ArchiveEntry::file(
|
||||||
PathBuf::from(name),
|
PathBuf::from(name),
|
||||||
self.decompressed_size,
|
self.decompressed_size,
|
||||||
0, // GZIP doesn't preserve compressed size per file
|
0, // GZIP doesn't preserve compressed size per file
|
||||||
)])
|
)])
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extract_file(&mut self, entry_path: &Path, output: &mut Vec<u8>) -> Result<u64> {
|
fn extract_file(&mut self, _entry_path: &Path, output: &mut Vec<u8>) -> Result<u64> {
|
||||||
// GZIP is single file - just decompress it
|
// GZIP is single file - just decompress it
|
||||||
let file = File::open(&self.path)?;
|
let file = File::open(&self.path)?;
|
||||||
let mut decoder = flate2::read::GzDecoder::new(file);
|
let mut decoder = flate2::read::GzDecoder::new(file);
|
||||||
@@ -460,7 +511,10 @@ impl ArchiveProcessor for GzipProcessor {
|
|||||||
output.clear();
|
output.clear();
|
||||||
decoder.read_to_end(output)?;
|
decoder.read_to_end(output)?;
|
||||||
|
|
||||||
check_file_size_limit(output.len() as u64, self.config.max_file_size_mb * 1024 * 1024)?;
|
check_file_size_limit(
|
||||||
|
output.len() as u64,
|
||||||
|
self.config.max_file_size_mb * 1024 * 1024,
|
||||||
|
)?;
|
||||||
|
|
||||||
info!("Decompressed GZIP file: {} bytes", output.len());
|
info!("Decompressed GZIP file: {} bytes", output.len());
|
||||||
Ok(output.len() as u64)
|
Ok(output.len() as u64)
|
||||||
@@ -470,7 +524,8 @@ impl ArchiveProcessor for GzipProcessor {
|
|||||||
create_dir_all(output_dir)?;
|
create_dir_all(output_dir)?;
|
||||||
|
|
||||||
let entries = self.list_entries()?;
|
let entries = self.list_entries()?;
|
||||||
let entry = entries.first()
|
let entry = entries
|
||||||
|
.first()
|
||||||
.ok_or_else(|| anyhow!("No entry in GZIP archive"))?;
|
.ok_or_else(|| anyhow!("No entry in GZIP archive"))?;
|
||||||
|
|
||||||
let outpath = output_dir.join(&entry.path);
|
let outpath = output_dir.join(&entry.path);
|
||||||
@@ -514,6 +569,12 @@ pub struct TarGzipProcessor {
|
|||||||
config: ArchiveConfig,
|
config: ArchiveConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for TarGzipProcessor {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl TarGzipProcessor {
|
impl TarGzipProcessor {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -552,7 +613,8 @@ impl ArchiveProcessor for TarGzipProcessor {
|
|||||||
|
|
||||||
// Step 2: Open TAR
|
// Step 2: Open TAR
|
||||||
let tar_entries = self.gzip_processor.list_entries()?;
|
let tar_entries = self.gzip_processor.list_entries()?;
|
||||||
let tar_file = tar_entries.first()
|
let tar_file = tar_entries
|
||||||
|
.first()
|
||||||
.ok_or_else(|| anyhow!("No TAR file in GZIP"))?;
|
.ok_or_else(|| anyhow!("No TAR file in GZIP"))?;
|
||||||
|
|
||||||
let tar_path = temp_dir.path().join(&tar_file.path);
|
let tar_path = temp_dir.path().join(&tar_file.path);
|
||||||
@@ -606,7 +668,8 @@ impl ArchiveProcessor for TarGzipProcessor {
|
|||||||
|
|
||||||
// Step 2: Extract TAR
|
// Step 2: Extract TAR
|
||||||
let tar_entries = self.gzip_processor.list_entries()?;
|
let tar_entries = self.gzip_processor.list_entries()?;
|
||||||
let tar_file = tar_entries.first()
|
let tar_file = tar_entries
|
||||||
|
.first()
|
||||||
.ok_or_else(|| anyhow!("No TAR file found"))?;
|
.ok_or_else(|| anyhow!("No TAR file found"))?;
|
||||||
|
|
||||||
let tar_path = temp_dir.path().join(&tar_file.path);
|
let tar_path = temp_dir.path().join(&tar_file.path);
|
||||||
@@ -627,73 +690,133 @@ impl ArchiveProcessor for TarGzipProcessor {
|
|||||||
pub struct ZstdProcessor;
|
pub struct ZstdProcessor;
|
||||||
|
|
||||||
impl ArchiveProcessor for ZstdProcessor {
|
impl ArchiveProcessor for ZstdProcessor {
|
||||||
fn format(&self) -> ArchiveFormat { ArchiveFormat::Zstd }
|
fn format(&self) -> ArchiveFormat {
|
||||||
|
ArchiveFormat::Zstd
|
||||||
|
}
|
||||||
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
||||||
Err(anyhow!("ZSTD processor not yet implemented"))
|
Err(anyhow!("ZSTD processor not yet implemented"))
|
||||||
}
|
}
|
||||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> { Ok(Vec::new()) }
|
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> { Ok(0) }
|
Ok(Vec::new())
|
||||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> { Ok(ExtractResult::new()) }
|
}
|
||||||
fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::Zstd }
|
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> {
|
||||||
fn new() -> Self { Self }
|
Ok(0)
|
||||||
|
}
|
||||||
|
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> {
|
||||||
|
Ok(ExtractResult::new())
|
||||||
|
}
|
||||||
|
fn can_process(format: ArchiveFormat) -> bool {
|
||||||
|
format == ArchiveFormat::Zstd
|
||||||
|
}
|
||||||
|
fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// BZIP2 Processor Stub (Phase 2/3)
|
/// BZIP2 Processor Stub (Phase 2/3)
|
||||||
pub struct Bzip2Processor;
|
pub struct Bzip2Processor;
|
||||||
|
|
||||||
impl ArchiveProcessor for Bzip2Processor {
|
impl ArchiveProcessor for Bzip2Processor {
|
||||||
fn format(&self) -> ArchiveFormat { ArchiveFormat::Bzip2 }
|
fn format(&self) -> ArchiveFormat {
|
||||||
|
ArchiveFormat::Bzip2
|
||||||
|
}
|
||||||
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
||||||
Err(anyhow!("BZIP2 processor not yet implemented"))
|
Err(anyhow!("BZIP2 processor not yet implemented"))
|
||||||
}
|
}
|
||||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> { Ok(Vec::new()) }
|
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> { Ok(0) }
|
Ok(Vec::new())
|
||||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> { Ok(ExtractResult::new()) }
|
}
|
||||||
fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::Bzip2 }
|
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> {
|
||||||
fn new() -> Self { Self }
|
Ok(0)
|
||||||
|
}
|
||||||
|
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> {
|
||||||
|
Ok(ExtractResult::new())
|
||||||
|
}
|
||||||
|
fn can_process(format: ArchiveFormat) -> bool {
|
||||||
|
format == ArchiveFormat::Bzip2
|
||||||
|
}
|
||||||
|
fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// LZ4 Processor Stub (Phase 2/3)
|
/// LZ4 Processor Stub (Phase 2/3)
|
||||||
pub struct Lz4Processor;
|
pub struct Lz4Processor;
|
||||||
|
|
||||||
impl ArchiveProcessor for Lz4Processor {
|
impl ArchiveProcessor for Lz4Processor {
|
||||||
fn format(&self) -> ArchiveFormat { ArchiveFormat::Lz4 }
|
fn format(&self) -> ArchiveFormat {
|
||||||
|
ArchiveFormat::Lz4
|
||||||
|
}
|
||||||
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
||||||
Err(anyhow!("LZ4 processor not yet implemented"))
|
Err(anyhow!("LZ4 processor not yet implemented"))
|
||||||
}
|
}
|
||||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> { Ok(Vec::new()) }
|
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> { Ok(0) }
|
Ok(Vec::new())
|
||||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> { Ok(ExtractResult::new()) }
|
}
|
||||||
fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::Lz4 }
|
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> {
|
||||||
fn new() -> Self { Self }
|
Ok(0)
|
||||||
|
}
|
||||||
|
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> {
|
||||||
|
Ok(ExtractResult::new())
|
||||||
|
}
|
||||||
|
fn can_process(format: ArchiveFormat) -> bool {
|
||||||
|
format == ArchiveFormat::Lz4
|
||||||
|
}
|
||||||
|
fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// TAR.BZ2 Composite Processor Stub (Phase 2/3)
|
/// TAR.BZ2 Composite Processor Stub (Phase 2/3)
|
||||||
pub struct TarBzip2Processor;
|
pub struct TarBzip2Processor;
|
||||||
|
|
||||||
impl ArchiveProcessor for TarBzip2Processor {
|
impl ArchiveProcessor for TarBzip2Processor {
|
||||||
fn format(&self) -> ArchiveFormat { ArchiveFormat::TarBzip2 }
|
fn format(&self) -> ArchiveFormat {
|
||||||
|
ArchiveFormat::TarBzip2
|
||||||
|
}
|
||||||
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
||||||
Err(anyhow!("TAR.BZ2 processor not yet implemented"))
|
Err(anyhow!("TAR.BZ2 processor not yet implemented"))
|
||||||
}
|
}
|
||||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> { Ok(Vec::new()) }
|
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> { Ok(0) }
|
Ok(Vec::new())
|
||||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> { Ok(ExtractResult::new()) }
|
}
|
||||||
fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::TarBzip2 }
|
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> {
|
||||||
fn new() -> Self { Self }
|
Ok(0)
|
||||||
|
}
|
||||||
|
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> {
|
||||||
|
Ok(ExtractResult::new())
|
||||||
|
}
|
||||||
|
fn can_process(format: ArchiveFormat) -> bool {
|
||||||
|
format == ArchiveFormat::TarBzip2
|
||||||
|
}
|
||||||
|
fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// TAR.ZST Composite Processor Stub (Phase 2/3)
|
/// TAR.ZST Composite Processor Stub (Phase 2/3)
|
||||||
pub struct TarZstdProcessor;
|
pub struct TarZstdProcessor;
|
||||||
|
|
||||||
impl ArchiveProcessor for TarZstdProcessor {
|
impl ArchiveProcessor for TarZstdProcessor {
|
||||||
fn format(&self) -> ArchiveFormat { ArchiveFormat::TarZstd }
|
fn format(&self) -> ArchiveFormat {
|
||||||
|
ArchiveFormat::TarZstd
|
||||||
|
}
|
||||||
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
fn open(&mut self, _path: &Path) -> Result<ArchiveMetadata> {
|
||||||
Err(anyhow!("TAR.ZST processor not yet implemented"))
|
Err(anyhow!("TAR.ZST processor not yet implemented"))
|
||||||
}
|
}
|
||||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> { Ok(Vec::new()) }
|
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||||
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> { Ok(0) }
|
Ok(Vec::new())
|
||||||
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> { Ok(ExtractResult::new()) }
|
}
|
||||||
fn can_process(format: ArchiveFormat) -> bool { format == ArchiveFormat::TarZstd }
|
fn extract_file(&mut self, _entry: &Path, _output: &mut Vec<u8>) -> Result<u64> {
|
||||||
fn new() -> Self { Self }
|
Ok(0)
|
||||||
|
}
|
||||||
|
fn extract_all(&mut self, _dir: &Path) -> Result<ExtractResult> {
|
||||||
|
Ok(ExtractResult::new())
|
||||||
|
}
|
||||||
|
fn can_process(format: ArchiveFormat) -> bool {
|
||||||
|
format == ArchiveFormat::TarZstd
|
||||||
|
}
|
||||||
|
fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -1,13 +1,15 @@
|
|||||||
// Optional Format Processors - RAR, XZ, 7z
|
// Optional Format Processors - RAR, XZ, 7z
|
||||||
// All optional formats have warnings displayed when enabled
|
// All optional formats have warnings displayed when enabled
|
||||||
|
|
||||||
use crate::archive::{ArchiveFormat, ArchiveProcessor, ArchiveMetadata, ArchiveEntry, ExtractResult};
|
use crate::archive::processor::{check_decompression_ratio, validate_extraction_path};
|
||||||
use crate::archive::warning;
|
use crate::archive::warning;
|
||||||
use crate::archive::processor::{validate_extraction_path, check_decompression_ratio};
|
use crate::archive::{
|
||||||
use anyhow::{Result, anyhow};
|
ArchiveEntry, ArchiveFormat, ArchiveMetadata, ArchiveProcessor, ExtractResult,
|
||||||
use std::path::Path;
|
};
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use log::{info, warn};
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use log::{warn, info};
|
use std::path::Path;
|
||||||
|
|
||||||
/// RAR Processor - Only Decompression
|
/// RAR Processor - Only Decompression
|
||||||
/// ⚠️ Legal Warning: RARLAB patent, commercial use requires license
|
/// ⚠️ Legal Warning: RARLAB patent, commercial use requires license
|
||||||
@@ -44,7 +46,8 @@ impl ArchiveProcessor for RarProcessor {
|
|||||||
let entries: Vec<_> = archive.list()?.collect();
|
let entries: Vec<_> = archive.list()?.collect();
|
||||||
let total_files = entries.len() as u64;
|
let total_files = entries.len() as u64;
|
||||||
|
|
||||||
let total_size = entries.iter()
|
let total_size = entries
|
||||||
|
.iter()
|
||||||
.filter_map(|e| e.ok())
|
.filter_map(|e| e.ok())
|
||||||
.map(|e| e.uncompressed_size)
|
.map(|e| e.uncompressed_size)
|
||||||
.sum();
|
.sum();
|
||||||
@@ -56,9 +59,15 @@ impl ArchiveProcessor for RarProcessor {
|
|||||||
total_files,
|
total_files,
|
||||||
total_size,
|
total_size,
|
||||||
compressed_size,
|
compressed_size,
|
||||||
compression_ratio: if compressed_size > 0 { total_size as f64 / compressed_size as f64 } else { 0.0 },
|
compression_ratio: if compressed_size > 0 {
|
||||||
is_encrypted: entries.iter().any(|e| e.ok().map_or(false, |e| e.is_encrypted())),
|
total_size as f64 / compressed_size as f64
|
||||||
is_multi_volume: false, // unrar library limitation
|
} else {
|
||||||
|
0.0
|
||||||
|
},
|
||||||
|
is_encrypted: entries
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.ok().map_or(false, |e| e.is_encrypted())),
|
||||||
|
is_multi_volume: false, // unrar library limitation
|
||||||
created_time: None,
|
created_time: None,
|
||||||
modified_time: None,
|
modified_time: None,
|
||||||
})
|
})
|
||||||
@@ -67,15 +76,19 @@ impl ArchiveProcessor for RarProcessor {
|
|||||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||||
use unrar::Archive;
|
use unrar::Archive;
|
||||||
|
|
||||||
let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?;
|
let path = self
|
||||||
|
.archive_path
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||||
let archive = Archive::new(path)?;
|
let archive = Archive::new(path)?;
|
||||||
|
|
||||||
let entries: Vec<ArchiveEntry> = archive.list()?
|
let entries: Vec<ArchiveEntry> = archive
|
||||||
|
.list()?
|
||||||
.filter_map(|e| e.ok())
|
.filter_map(|e| e.ok())
|
||||||
.map(|e| ArchiveEntry {
|
.map(|e| ArchiveEntry {
|
||||||
path: PathBuf::from(e.filename),
|
path: PathBuf::from(e.filename),
|
||||||
size: e.uncompressed_size,
|
size: e.uncompressed_size,
|
||||||
compressed_size: 0, // unrar doesn't provide this
|
compressed_size: 0, // unrar doesn't provide this
|
||||||
is_dir: e.is_directory(),
|
is_dir: e.is_directory(),
|
||||||
is_file: !e.is_directory(),
|
is_file: !e.is_directory(),
|
||||||
is_encrypted: e.is_encrypted(),
|
is_encrypted: e.is_encrypted(),
|
||||||
@@ -93,7 +106,8 @@ impl ArchiveProcessor for RarProcessor {
|
|||||||
warn!("RAR extract_file requires full extraction (no random access)");
|
warn!("RAR extract_file requires full extraction (no random access)");
|
||||||
|
|
||||||
let entries = self.list_entries()?;
|
let entries = self.list_entries()?;
|
||||||
let entry = entries.iter()
|
let entry = entries
|
||||||
|
.iter()
|
||||||
.find(|e| e.path == entry_path)
|
.find(|e| e.path == entry_path)
|
||||||
.ok_or_else(|| anyhow!("Entry not found: {}", entry_path.display()))?;
|
.ok_or_else(|| anyhow!("Entry not found: {}", entry_path.display()))?;
|
||||||
|
|
||||||
@@ -112,7 +126,10 @@ impl ArchiveProcessor for RarProcessor {
|
|||||||
use unrar::Archive;
|
use unrar::Archive;
|
||||||
use unrar::ExtractOption;
|
use unrar::ExtractOption;
|
||||||
|
|
||||||
let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?;
|
let path = self
|
||||||
|
.archive_path
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||||
|
|
||||||
// Validate output_dir path
|
// Validate output_dir path
|
||||||
validate_extraction_path(output_dir, output_dir)?;
|
validate_extraction_path(output_dir, output_dir)?;
|
||||||
@@ -173,8 +190,8 @@ impl ArchiveProcessor for XzProcessor {
|
|||||||
|
|
||||||
self.archive_path = Some(path.to_path_buf());
|
self.archive_path = Some(path.to_path_buf());
|
||||||
|
|
||||||
use xz2::read::XzDecoder;
|
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
|
use xz2::read::XzDecoder;
|
||||||
|
|
||||||
let file = fs::File::open(path)?;
|
let file = fs::File::open(path)?;
|
||||||
let mut decoder = XzDecoder::new(file);
|
let mut decoder = XzDecoder::new(file);
|
||||||
@@ -191,10 +208,14 @@ impl ArchiveProcessor for XzProcessor {
|
|||||||
|
|
||||||
Ok(ArchiveMetadata {
|
Ok(ArchiveMetadata {
|
||||||
format: ArchiveFormat::Xz,
|
format: ArchiveFormat::Xz,
|
||||||
total_files: 1, // XZ is single-file format
|
total_files: 1, // XZ is single-file format
|
||||||
total_size: decompressed_size,
|
total_size: decompressed_size,
|
||||||
compressed_size,
|
compressed_size,
|
||||||
compression_ratio: if compressed_size > 0 { decompressed_size as f64 / compressed_size as f64 } else { 0.0 },
|
compression_ratio: if compressed_size > 0 {
|
||||||
|
decompressed_size as f64 / compressed_size as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
},
|
||||||
is_encrypted: false,
|
is_encrypted: false,
|
||||||
is_multi_volume: false,
|
is_multi_volume: false,
|
||||||
created_time: None,
|
created_time: None,
|
||||||
@@ -204,16 +225,20 @@ impl ArchiveProcessor for XzProcessor {
|
|||||||
|
|
||||||
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
fn list_entries(&mut self) -> Result<Vec<ArchiveEntry>> {
|
||||||
// XZ is single-file, infer filename from archive name
|
// XZ is single-file, infer filename from archive name
|
||||||
let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?;
|
let path = self
|
||||||
|
.archive_path
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||||
|
|
||||||
let filename = path.file_name()
|
let filename = path
|
||||||
|
.file_name()
|
||||||
.and_then(|n| n.to_str())
|
.and_then(|n| n.to_str())
|
||||||
.map(|s| s.strip_suffix(".xz").unwrap_or(s))
|
.map(|s| s.strip_suffix(".xz").unwrap_or(s))
|
||||||
.unwrap_or("output");
|
.unwrap_or("output");
|
||||||
|
|
||||||
Ok(vec![ArchiveEntry {
|
Ok(vec![ArchiveEntry {
|
||||||
path: PathBuf::from(filename),
|
path: PathBuf::from(filename),
|
||||||
size: 0, // Will be determined during extraction
|
size: 0, // Will be determined during extraction
|
||||||
compressed_size: 0,
|
compressed_size: 0,
|
||||||
is_dir: false,
|
is_dir: false,
|
||||||
is_file: true,
|
is_file: true,
|
||||||
@@ -224,10 +249,13 @@ impl ArchiveProcessor for XzProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn extract_file(&self, _entry_path: &Path, output: &mut Vec<u8>) -> Result<u64> {
|
fn extract_file(&self, _entry_path: &Path, output: &mut Vec<u8>) -> Result<u64> {
|
||||||
use xz2::read::XzDecoder;
|
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
|
use xz2::read::XzDecoder;
|
||||||
|
|
||||||
let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?;
|
let path = self
|
||||||
|
.archive_path
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||||
|
|
||||||
let file = fs::File::open(path)?;
|
let file = fs::File::open(path)?;
|
||||||
let mut decoder = XzDecoder::new(file);
|
let mut decoder = XzDecoder::new(file);
|
||||||
@@ -238,10 +266,13 @@ impl ArchiveProcessor for XzProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn extract_all(&self, output_dir: &Path) -> Result<ExtractResult> {
|
fn extract_all(&self, output_dir: &Path) -> Result<ExtractResult> {
|
||||||
use xz2::read::XzDecoder;
|
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
|
use xz2::read::XzDecoder;
|
||||||
|
|
||||||
let path = self.archive_path.as_ref().ok_or_else(|| anyhow!("Archive not opened"))?;
|
let path = self
|
||||||
|
.archive_path
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| anyhow!("Archive not opened"))?;
|
||||||
|
|
||||||
// Infer output filename
|
// Infer output filename
|
||||||
let entries = self.list_entries()?;
|
let entries = self.list_entries()?;
|
||||||
@@ -298,9 +329,7 @@ impl ArchiveProcessor for SevenZProcessor {
|
|||||||
let entries = reader.entries()?;
|
let entries = reader.entries()?;
|
||||||
let total_files = entries.len() as u64;
|
let total_files = entries.len() as u64;
|
||||||
|
|
||||||
let total_size = entries.iter()
|
let total_size = entries.iter().map(|e| e.uncompressed_size as u64).sum();
|
||||||
.map(|e| e.uncompressed_size as u64)
|
|
||||||
.sum();
|
|
||||||
|
|
||||||
let compressed_size = fs::metadata(path)?.len();
|
let compressed_size = fs::metadata(path)?.len();
|
||||||
|
|
||||||
@@ -309,7 +338,11 @@ impl ArchiveProcessor for SevenZProcessor {
|
|||||||
total_files,
|
total_files,
|
||||||
total_size,
|
total_size,
|
||||||
compressed_size,
|
compressed_size,
|
||||||
compression_ratio: if compressed_size > 0 { total_size as f64 / compressed_size as f64 } else { 0.0 },
|
compression_ratio: if compressed_size > 0 {
|
||||||
|
total_size as f64 / compressed_size as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
},
|
||||||
is_encrypted: entries.iter().any(|e| e.is_encrypted),
|
is_encrypted: entries.iter().any(|e| e.is_encrypted),
|
||||||
is_multi_volume: false,
|
is_multi_volume: false,
|
||||||
created_time: None,
|
created_time: None,
|
||||||
@@ -369,15 +402,21 @@ pub struct SevenZProcessor;
|
|||||||
|
|
||||||
#[cfg(not(feature = "optional-formats"))]
|
#[cfg(not(feature = "optional-formats"))]
|
||||||
impl RarProcessor {
|
impl RarProcessor {
|
||||||
pub fn new() -> Self { Self }
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "optional-formats"))]
|
#[cfg(not(feature = "optional-formats"))]
|
||||||
impl XzProcessor {
|
impl XzProcessor {
|
||||||
pub fn new() -> Self { Self }
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "optional-formats"))]
|
#[cfg(not(feature = "optional-formats"))]
|
||||||
impl SevenZProcessor {
|
impl SevenZProcessor {
|
||||||
pub fn new() -> Self { Self }
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -1,14 +1,14 @@
|
|||||||
use crate::archive::{
|
use crate::archive::{
|
||||||
ArchiveProcessor, ArchiveFormat, ArchiveMetadata, ArchiveEntry, ExtractResult,
|
|
||||||
processors::core::{ZipProcessor, TarProcessor, GzipProcessor, TarGzipProcessor},
|
|
||||||
processor::{validate_extraction_path, check_decompression_ratio},
|
|
||||||
config::ArchiveConfig,
|
config::ArchiveConfig,
|
||||||
|
processor::{check_decompression_ratio, validate_extraction_path},
|
||||||
|
processors::core::{GzipProcessor, TarGzipProcessor, TarProcessor, ZipProcessor},
|
||||||
|
ArchiveEntry, ArchiveFormat, ArchiveMetadata, ArchiveProcessor, ExtractResult,
|
||||||
};
|
};
|
||||||
use tempfile::TempDir;
|
use anyhow::Result;
|
||||||
use std::fs::{File, create_dir_all};
|
use std::fs::{create_dir_all, File};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use anyhow::Result;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod helpers {
|
mod helpers {
|
||||||
@@ -69,8 +69,8 @@ mod helpers {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod core_format_tests {
|
mod core_format_tests {
|
||||||
use super::*;
|
|
||||||
use super::helpers::*;
|
use super::helpers::*;
|
||||||
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_zip_processor_basic() {
|
fn test_zip_processor_basic() {
|
||||||
@@ -145,8 +145,8 @@ mod core_format_tests {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod integration_tests {
|
mod integration_tests {
|
||||||
use super::*;
|
|
||||||
use super::helpers::*;
|
use super::helpers::*;
|
||||||
|
use super::*;
|
||||||
use crate::archive::detector::FormatDetector;
|
use crate::archive::detector::FormatDetector;
|
||||||
use crate::archive::ProcessorRegistry;
|
use crate::archive::ProcessorRegistry;
|
||||||
|
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ use std::fs;
|
|||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
use crate::archive::*;
|
|
||||||
use crate::archive::processor::check_decompression_ratio;
|
use crate::archive::processor::check_decompression_ratio;
|
||||||
use crate::archive::tests::test_helpers::*;
|
use crate::archive::tests::test_helpers::*;
|
||||||
|
use crate::archive::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_zip_processor_full_workflow() {
|
fn test_zip_processor_full_workflow() {
|
||||||
@@ -26,9 +26,7 @@ fn test_zip_processor_full_workflow() {
|
|||||||
assert_eq!(entries.len(), 3);
|
assert_eq!(entries.len(), 3);
|
||||||
|
|
||||||
// Verify entry names
|
// Verify entry names
|
||||||
let names: Vec<&str> = entries.iter()
|
let names: Vec<&str> = entries.iter().map(|e| e.path.to_str().unwrap()).collect();
|
||||||
.map(|e| e.path.to_str().unwrap())
|
|
||||||
.collect();
|
|
||||||
assert!(names.contains(&"file1.txt"));
|
assert!(names.contains(&"file1.txt"));
|
||||||
assert!(names.contains(&"file2.txt"));
|
assert!(names.contains(&"file2.txt"));
|
||||||
assert!(names.contains(&"subdir/file3.txt"));
|
assert!(names.contains(&"subdir/file3.txt"));
|
||||||
@@ -64,7 +62,7 @@ fn test_tar_processor_full_workflow() {
|
|||||||
|
|
||||||
// Test list_entries
|
// Test list_entries
|
||||||
let entries = processor.list_entries().unwrap();
|
let entries = processor.list_entries().unwrap();
|
||||||
assert!(entries.len() >= 3); // TAR may include directory entries
|
assert!(entries.len() >= 3); // TAR may include directory entries
|
||||||
|
|
||||||
// Test extract_all
|
// Test extract_all
|
||||||
let extract_dir = temp_dir.path().join("extracted_tar");
|
let extract_dir = temp_dir.path().join("extracted_tar");
|
||||||
@@ -88,7 +86,7 @@ fn test_gzip_processor_full_workflow() {
|
|||||||
// Test open
|
// Test open
|
||||||
let metadata = processor.open(&gz_path).unwrap();
|
let metadata = processor.open(&gz_path).unwrap();
|
||||||
assert_eq!(metadata.format, ArchiveFormat::Gzip);
|
assert_eq!(metadata.format, ArchiveFormat::Gzip);
|
||||||
assert_eq!(metadata.total_files, 1); // GZIP is single file
|
assert_eq!(metadata.total_files, 1); // GZIP is single file
|
||||||
|
|
||||||
// Test extract_all
|
// Test extract_all
|
||||||
let extract_dir = temp_dir.path().join("extracted_gz");
|
let extract_dir = temp_dir.path().join("extracted_gz");
|
||||||
@@ -159,7 +157,7 @@ fn test_processor_registry_core_formats() {
|
|||||||
let formats = registry.enabled_formats();
|
let formats = registry.enabled_formats();
|
||||||
|
|
||||||
// Should have 9 core formats
|
// Should have 9 core formats
|
||||||
assert!(formats.len() >= 4); // At least the ones we implemented
|
assert!(formats.len() >= 4); // At least the ones we implemented
|
||||||
|
|
||||||
// Verify format support
|
// Verify format support
|
||||||
assert!(formats.contains(&ArchiveFormat::Zip));
|
assert!(formats.contains(&ArchiveFormat::Zip));
|
||||||
@@ -199,11 +197,11 @@ fn test_zip_slip_protection() {
|
|||||||
fn test_zip_bomb_detection() {
|
fn test_zip_bomb_detection() {
|
||||||
// Test decompression ratio check
|
// Test decompression ratio check
|
||||||
let result = check_decompression_ratio(42_000, 5_000_000_000, 1000);
|
let result = check_decompression_ratio(42_000, 5_000_000_000, 1000);
|
||||||
assert!(result.is_err()); // Should detect as Zip Bomb
|
assert!(result.is_err()); // Should detect as Zip Bomb
|
||||||
|
|
||||||
// Test normal ratio
|
// Test normal ratio
|
||||||
let result = check_decompression_ratio(1000, 5000, 1000);
|
let result = check_decompression_ratio(1000, 5000, 1000);
|
||||||
assert!(result.is_ok()); // Normal ratio should pass
|
assert!(result.is_ok()); // Normal ratio should pass
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -220,15 +218,15 @@ fn test_metadata_compression_ratio() {
|
|||||||
modified_time: None,
|
modified_time: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(metadata.actual_ratio(), 5.0); // 5000/1000 = 5.0
|
assert_eq!(metadata.actual_ratio(), 5.0); // 5000/1000 = 5.0
|
||||||
assert!(!metadata.check_zip_bomb(10)); // ratio 5.0 < 10, not a bomb
|
assert!(!metadata.check_zip_bomb(10)); // ratio 5.0 < 10, not a bomb
|
||||||
assert!(metadata.check_zip_bomb(4)); // ratio 5.0 > 4, detected as bomb
|
assert!(metadata.check_zip_bomb(4)); // ratio 5.0 > 4, detected as bomb
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_config_validation() {
|
fn test_config_validation() {
|
||||||
let config = ArchiveConfig {
|
let config = ArchiveConfig {
|
||||||
max_decompression_ratio: 5, // Too low
|
max_decompression_ratio: 5, // Too low
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,17 @@
|
|||||||
|
use flate2::write::GzEncoder;
|
||||||
|
use flate2::Compression;
|
||||||
use std::fs::{self, File};
|
use std::fs::{self, File};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use tempfile::TempDir;
|
|
||||||
use zip::{ZipWriter, write::FileOptions, CompressionMethod};
|
|
||||||
use flate2::write::GzEncoder;
|
|
||||||
use flate2::Compression;
|
|
||||||
use tar::Builder;
|
use tar::Builder;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
use zip::{write::FileOptions, CompressionMethod, ZipWriter};
|
||||||
|
|
||||||
pub fn create_test_zip(temp_dir: &TempDir) -> PathBuf {
|
pub fn create_test_zip(temp_dir: &TempDir) -> PathBuf {
|
||||||
let zip_path = temp_dir.path().join("test.zip");
|
let zip_path = temp_dir.path().join("test.zip");
|
||||||
let file = File::create(&zip_path).unwrap();
|
let file = File::create(&zip_path).unwrap();
|
||||||
let mut zip = ZipWriter::new(file);
|
let mut zip = ZipWriter::new(file);
|
||||||
let options = FileOptions::default()
|
let options = FileOptions::default().compression_method(CompressionMethod::Stored);
|
||||||
.compression_method(CompressionMethod::Stored);
|
|
||||||
|
|
||||||
zip.start_file("file1.txt", options).unwrap();
|
zip.start_file("file1.txt", options).unwrap();
|
||||||
zip.write_all(b"content of file 1").unwrap();
|
zip.write_all(b"content of file 1").unwrap();
|
||||||
@@ -37,21 +36,31 @@ pub fn create_test_tar(temp_dir: &TempDir) -> PathBuf {
|
|||||||
header1.set_size(17);
|
header1.set_size(17);
|
||||||
header1.set_mode(0o644);
|
header1.set_mode(0o644);
|
||||||
header1.set_cksum();
|
header1.set_cksum();
|
||||||
builder.append_data(&mut header1, "file1.txt", &b"content of file 1"[..]).unwrap();
|
builder
|
||||||
|
.append_data(&mut header1, "file1.txt", &b"content of file 1"[..])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let mut header2 = tar::Header::new_gnu();
|
let mut header2 = tar::Header::new_gnu();
|
||||||
header2.set_path("file2.txt").unwrap();
|
header2.set_path("file2.txt").unwrap();
|
||||||
header2.set_size(17);
|
header2.set_size(17);
|
||||||
header2.set_mode(0o644);
|
header2.set_mode(0o644);
|
||||||
header2.set_cksum();
|
header2.set_cksum();
|
||||||
builder.append_data(&mut header2, "file2.txt", &b"content of file 2"[..]).unwrap();
|
builder
|
||||||
|
.append_data(&mut header2, "file2.txt", &b"content of file 2"[..])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let mut header3 = tar::Header::new_gnu();
|
let mut header3 = tar::Header::new_gnu();
|
||||||
header3.set_path("subdir/file3.txt").unwrap();
|
header3.set_path("subdir/file3.txt").unwrap();
|
||||||
header3.set_size(27);
|
header3.set_size(27);
|
||||||
header3.set_mode(0o644);
|
header3.set_mode(0o644);
|
||||||
header3.set_cksum();
|
header3.set_cksum();
|
||||||
builder.append_data(&mut header3, "subdir/file3.txt", &b"content of file 3 in subdir"[..]).unwrap();
|
builder
|
||||||
|
.append_data(
|
||||||
|
&mut header3,
|
||||||
|
"subdir/file3.txt",
|
||||||
|
&b"content of file 3 in subdir"[..],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
builder.finish().unwrap();
|
builder.finish().unwrap();
|
||||||
tar_path
|
tar_path
|
||||||
@@ -61,7 +70,9 @@ pub fn create_test_gzip(temp_dir: &TempDir) -> PathBuf {
|
|||||||
let gz_path = temp_dir.path().join("test.txt.gz");
|
let gz_path = temp_dir.path().join("test.txt.gz");
|
||||||
let file = File::create(&gz_path).unwrap();
|
let file = File::create(&gz_path).unwrap();
|
||||||
let mut encoder = GzEncoder::new(file, Compression::default());
|
let mut encoder = GzEncoder::new(file, Compression::default());
|
||||||
encoder.write_all(b"test gzip content for validation").unwrap();
|
encoder
|
||||||
|
.write_all(b"test gzip content for validation")
|
||||||
|
.unwrap();
|
||||||
encoder.finish().unwrap();
|
encoder.finish().unwrap();
|
||||||
gz_path
|
gz_path
|
||||||
}
|
}
|
||||||
@@ -76,14 +87,18 @@ pub fn create_test_tar_gz(temp_dir: &TempDir) -> PathBuf {
|
|||||||
header1.set_size(10);
|
header1.set_size(10);
|
||||||
header1.set_mode(0o644);
|
header1.set_mode(0o644);
|
||||||
header1.set_cksum();
|
header1.set_cksum();
|
||||||
builder.append_data(&mut header1, "file1.txt", &b"file1 data"[..]).unwrap();
|
builder
|
||||||
|
.append_data(&mut header1, "file1.txt", &b"file1 data"[..])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let mut header2 = tar::Header::new_gnu();
|
let mut header2 = tar::Header::new_gnu();
|
||||||
header2.set_path("file2.txt").unwrap();
|
header2.set_path("file2.txt").unwrap();
|
||||||
header2.set_size(10);
|
header2.set_size(10);
|
||||||
header2.set_mode(0o644);
|
header2.set_mode(0o644);
|
||||||
header2.set_cksum();
|
header2.set_cksum();
|
||||||
builder.append_data(&mut header2, "file2.txt", &b"file2 data"[..]).unwrap();
|
builder
|
||||||
|
.append_data(&mut header2, "file2.txt", &b"file2 data"[..])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
builder.finish().unwrap();
|
builder.finish().unwrap();
|
||||||
|
|
||||||
@@ -106,8 +121,7 @@ pub fn create_zip_bomb_test() -> Vec<u8> {
|
|||||||
let writer = std::io::Cursor::new(&mut buffer);
|
let writer = std::io::Cursor::new(&mut buffer);
|
||||||
let mut zip = ZipWriter::new(writer);
|
let mut zip = ZipWriter::new(writer);
|
||||||
|
|
||||||
let options = FileOptions::default()
|
let options = FileOptions::default().compression_method(CompressionMethod::Stored);
|
||||||
.compression_method(CompressionMethod::Stored);
|
|
||||||
|
|
||||||
zip.start_file("bomb.txt", options).unwrap();
|
zip.start_file("bomb.txt", options).unwrap();
|
||||||
zip.write_all(&[0u8; 100]).unwrap();
|
zip.write_all(&[0u8; 100]).unwrap();
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Warning System - Legal and Technical Warnings for Optional Formats
|
// Warning System - Legal and Technical Warnings for Optional Formats
|
||||||
|
|
||||||
use log::{warn, info};
|
use log::{info, warn};
|
||||||
|
|
||||||
use crate::archive::config::ArchiveConfig;
|
use crate::archive::config::ArchiveConfig;
|
||||||
|
|
||||||
@@ -73,15 +73,17 @@ pub fn show_startup_warnings(config: &ArchiveConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Show summary of enabled formats
|
// Show summary of enabled formats
|
||||||
let enabled_optional = [
|
let enabled_optional = [config.enable_rar, config.enable_xz, config.enable_7z]
|
||||||
config.enable_rar,
|
.iter()
|
||||||
config.enable_xz,
|
.filter(|&x| *x)
|
||||||
config.enable_7z,
|
.count();
|
||||||
].iter().filter(|&x| *x).count();
|
|
||||||
|
|
||||||
if enabled_optional > 0 {
|
if enabled_optional > 0 {
|
||||||
info!("");
|
info!("");
|
||||||
info!("⚠️ {} optional format(s) enabled with warnings shown above", enabled_optional);
|
info!(
|
||||||
|
"⚠️ {} optional format(s) enabled with warnings shown above",
|
||||||
|
enabled_optional
|
||||||
|
);
|
||||||
info!("Core formats (9): ZIP, TAR, GZIP, ZSTD, BZIP2, LZ4, TAR.GZ, TAR.BZ2, TAR.ZST");
|
info!("Core formats (9): ZIP, TAR, GZIP, ZSTD, BZIP2, LZ4, TAR.GZ, TAR.BZ2, TAR.ZST");
|
||||||
info!("");
|
info!("");
|
||||||
}
|
}
|
||||||
@@ -89,8 +91,7 @@ pub fn show_startup_warnings(config: &ArchiveConfig) {
|
|||||||
|
|
||||||
/// Generate user-facing legal disclaimer text
|
/// Generate user-facing legal disclaimer text
|
||||||
pub fn generate_rar_legal_disclaimer() -> String {
|
pub fn generate_rar_legal_disclaimer() -> String {
|
||||||
format!(
|
"RAR FORMAT LEGAL DISCLAIMER
|
||||||
"RAR FORMAT LEGAL DISCLAIMER
|
|
||||||
|
|
||||||
IMPORTANT WARNING:
|
IMPORTANT WARNING:
|
||||||
|
|
||||||
@@ -136,6 +137,5 @@ CONTACT:
|
|||||||
Last Updated: 2026-06-10
|
Last Updated: 2026-06-10
|
||||||
Version: 1.0
|
Version: 1.0
|
||||||
Legal Consultation: [Please consult professional lawyer for commercial use]
|
Legal Consultation: [Please consult professional lawyer for commercial use]
|
||||||
"
|
".to_string()
|
||||||
)
|
|
||||||
}
|
}
|
||||||
@@ -5,7 +5,7 @@ use std::collections::HashMap;
|
|||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::provider::{DataProvider, ProviderError};
|
use crate::provider::DataProvider;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct User {
|
pub struct User {
|
||||||
@@ -71,6 +71,12 @@ pub struct AuthState {
|
|||||||
pub provider: Option<Arc<dyn DataProvider>>,
|
pub provider: Option<Arc<dyn DataProvider>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for AuthState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl AuthState {
|
impl AuthState {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let mut users = HashMap::new();
|
let mut users = HashMap::new();
|
||||||
@@ -284,7 +290,12 @@ impl AuthState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn login_with_provider(&self, provider: &dyn DataProvider, username: &str, password: &str) -> Option<LoginResponse> {
|
fn login_with_provider(
|
||||||
|
&self,
|
||||||
|
provider: &dyn DataProvider,
|
||||||
|
username: &str,
|
||||||
|
password: &str,
|
||||||
|
) -> Option<LoginResponse> {
|
||||||
match provider.get_user(username) {
|
match provider.get_user(username) {
|
||||||
Ok(Some(user)) => {
|
Ok(Some(user)) => {
|
||||||
if user.status != 1 {
|
if user.status != 1 {
|
||||||
|
|||||||
@@ -119,11 +119,17 @@ pub fn get_all_categories() -> Result<CategoriesResponse> {
|
|||||||
let conn = FileTree::open_user_db("accusys")?;
|
let conn = FileTree::open_user_db("accusys")?;
|
||||||
let tree = FileTree::load(&conn, "accusys", "categories")?;
|
let tree = FileTree::load(&conn, "accusys", "categories")?;
|
||||||
|
|
||||||
let categories: Vec<Category> = tree.nodes.iter()
|
let categories: Vec<Category> = tree
|
||||||
|
.nodes
|
||||||
|
.iter()
|
||||||
.filter(|n| n.parent_id.is_none() && n.node_type.as_str() == "folder")
|
.filter(|n| n.parent_id.is_none() && n.node_type.as_str() == "folder")
|
||||||
.map(|n| {
|
.map(|n| {
|
||||||
let file_count = tree.nodes.iter()
|
let file_count = tree
|
||||||
.filter(|f| f.parent_id == Some(n.node_id.clone()) && f.node_type.as_str() == "file")
|
.nodes
|
||||||
|
.iter()
|
||||||
|
.filter(|f| {
|
||||||
|
f.parent_id == Some(n.node_id.clone()) && f.node_type.as_str() == "file"
|
||||||
|
})
|
||||||
.count();
|
.count();
|
||||||
|
|
||||||
Category {
|
Category {
|
||||||
@@ -136,7 +142,9 @@ pub fn get_all_categories() -> Result<CategoriesResponse> {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let total_files = tree.nodes.iter()
|
let total_files = tree
|
||||||
|
.nodes
|
||||||
|
.iter()
|
||||||
.filter(|n| n.node_type.as_str() == "file")
|
.filter(|n| n.node_type.as_str() == "file")
|
||||||
.count();
|
.count();
|
||||||
|
|
||||||
@@ -151,22 +159,41 @@ pub fn get_category_detail(category_name: &str) -> Result<CategoryDetail> {
|
|||||||
let conn = FileTree::open_user_db("accusys")?;
|
let conn = FileTree::open_user_db("accusys")?;
|
||||||
let tree = FileTree::load(&conn, "accusys", "categories")?;
|
let tree = FileTree::load(&conn, "accusys", "categories")?;
|
||||||
|
|
||||||
let category_node = tree.nodes.iter()
|
let category_node = tree
|
||||||
.find(|n| n.label == category_name && n.parent_id.is_none() && n.node_type.as_str() == "folder")
|
.nodes
|
||||||
|
.iter()
|
||||||
|
.find(|n| {
|
||||||
|
n.label == category_name && n.parent_id.is_none() && n.node_type.as_str() == "folder"
|
||||||
|
})
|
||||||
.ok_or_else(|| anyhow::anyhow!("Category not found: {}", category_name))?;
|
.ok_or_else(|| anyhow::anyhow!("Category not found: {}", category_name))?;
|
||||||
|
|
||||||
let series_groups: Vec<SeriesGroup> = tree.nodes.iter()
|
let series_groups: Vec<SeriesGroup> = tree
|
||||||
.filter(|n| n.parent_id == Some(category_node.node_id.clone()) && n.node_type.as_str() == "folder")
|
.nodes
|
||||||
|
.iter()
|
||||||
|
.filter(|n| {
|
||||||
|
n.parent_id == Some(category_node.node_id.clone()) && n.node_type.as_str() == "folder"
|
||||||
|
})
|
||||||
.map(|series_node| {
|
.map(|series_node| {
|
||||||
let files: Vec<CategoryFile> = tree.nodes.iter()
|
let files: Vec<CategoryFile> = tree
|
||||||
.filter(|f| f.parent_id == Some(series_node.node_id.clone()) && f.node_type.as_str() == "file")
|
.nodes
|
||||||
.map(|file_node| {
|
.iter()
|
||||||
CategoryFile {
|
.filter(|f| {
|
||||||
filename: file_node.label.clone(),
|
f.parent_id == Some(series_node.node_id.clone())
|
||||||
size: file_node.aliases.get("file_size_display").cloned().unwrap_or_default(),
|
&& f.node_type.as_str() == "file"
|
||||||
download_url: file_node.aliases.get("download_url").cloned().unwrap_or_default(),
|
})
|
||||||
sha256: file_node.sha256.clone(),
|
.map(|file_node| CategoryFile {
|
||||||
}
|
filename: file_node.label.clone(),
|
||||||
|
size: file_node
|
||||||
|
.aliases
|
||||||
|
.get("file_size_display")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_default(),
|
||||||
|
download_url: file_node
|
||||||
|
.aliases
|
||||||
|
.get("download_url")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_default(),
|
||||||
|
sha256: file_node.sha256.clone(),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
@@ -185,7 +212,11 @@ pub fn get_category_detail(category_name: &str) -> Result<CategoryDetail> {
|
|||||||
display_name: get_category_display_name(category_name),
|
display_name: get_category_display_name(category_name),
|
||||||
file_count,
|
file_count,
|
||||||
last_updated: category_node.updated_at.clone(),
|
last_updated: category_node.updated_at.clone(),
|
||||||
description: category_node.aliases.get("description").cloned().unwrap_or_default(),
|
description: category_node
|
||||||
|
.aliases
|
||||||
|
.get("description")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_default(),
|
||||||
},
|
},
|
||||||
series_groups,
|
series_groups,
|
||||||
})
|
})
|
||||||
@@ -195,19 +226,25 @@ pub fn get_all_series() -> Result<SeriesResponse> {
|
|||||||
let conn = FileTree::open_user_db("accusys")?;
|
let conn = FileTree::open_user_db("accusys")?;
|
||||||
let tree = FileTree::load(&conn, "accusys", "series")?;
|
let tree = FileTree::load(&conn, "accusys", "series")?;
|
||||||
|
|
||||||
let series: Vec<Series> = tree.nodes.iter()
|
let series: Vec<Series> = tree
|
||||||
|
.nodes
|
||||||
|
.iter()
|
||||||
.filter(|n| n.parent_id.is_none() && n.node_type.as_str() == "folder")
|
.filter(|n| n.parent_id.is_none() && n.node_type.as_str() == "folder")
|
||||||
.map(|n| {
|
.map(|n| {
|
||||||
let file_count = tree.nodes.iter()
|
let file_count = tree
|
||||||
|
.nodes
|
||||||
|
.iter()
|
||||||
.filter(|f| {
|
.filter(|f| {
|
||||||
let mut current = f.parent_id.clone();
|
let mut current = f.parent_id.clone();
|
||||||
while let Some(pid) = current {
|
while let Some(pid) = current {
|
||||||
if pid == n.node_id {
|
if pid == n.node_id {
|
||||||
return f.node_type.as_str() == "file";
|
return f.node_type.as_str() == "file";
|
||||||
}
|
}
|
||||||
current = tree.nodes.iter()
|
current = tree
|
||||||
|
.nodes
|
||||||
|
.iter()
|
||||||
.find(|p| p.node_id == pid)
|
.find(|p| p.node_id == pid)
|
||||||
.map(|p| p.parent_id.clone()).flatten();
|
.and_then(|p| p.parent_id.clone());
|
||||||
}
|
}
|
||||||
false
|
false
|
||||||
})
|
})
|
||||||
@@ -224,7 +261,9 @@ pub fn get_all_series() -> Result<SeriesResponse> {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let total_files = tree.nodes.iter()
|
let total_files = tree
|
||||||
|
.nodes
|
||||||
|
.iter()
|
||||||
.filter(|n| n.node_type.as_str() == "file")
|
.filter(|n| n.node_type.as_str() == "file")
|
||||||
.count();
|
.count();
|
||||||
|
|
||||||
@@ -239,32 +278,50 @@ pub fn get_series_detail(series_name: &str) -> Result<SeriesDetail> {
|
|||||||
let conn = FileTree::open_user_db("accusys")?;
|
let conn = FileTree::open_user_db("accusys")?;
|
||||||
let tree = FileTree::load(&conn, "accusys", "series")?;
|
let tree = FileTree::load(&conn, "accusys", "series")?;
|
||||||
|
|
||||||
let series_node = tree.nodes.iter()
|
let series_node = tree
|
||||||
.find(|n| n.label == series_name && n.parent_id.is_none() && n.node_type.as_str() == "folder")
|
.nodes
|
||||||
|
.iter()
|
||||||
|
.find(|n| {
|
||||||
|
n.label == series_name && n.parent_id.is_none() && n.node_type.as_str() == "folder"
|
||||||
|
})
|
||||||
.ok_or_else(|| anyhow::anyhow!("Series not found: {}", series_name))?;
|
.ok_or_else(|| anyhow::anyhow!("Series not found: {}", series_name))?;
|
||||||
|
|
||||||
let categories: Vec<SeriesCategory> = tree.nodes.iter()
|
let categories: Vec<SeriesCategory> = tree
|
||||||
.filter(|n| n.parent_id == Some(series_node.node_id.clone()) && n.node_type.as_str() == "folder")
|
.nodes
|
||||||
|
.iter()
|
||||||
|
.filter(|n| {
|
||||||
|
n.parent_id == Some(series_node.node_id.clone()) && n.node_type.as_str() == "folder"
|
||||||
|
})
|
||||||
.map(|category_node| {
|
.map(|category_node| {
|
||||||
let files: Vec<SeriesFile> = tree.nodes.iter()
|
let files: Vec<SeriesFile> = tree
|
||||||
|
.nodes
|
||||||
|
.iter()
|
||||||
.filter(|f| {
|
.filter(|f| {
|
||||||
let mut current = f.parent_id.clone();
|
let mut current = f.parent_id.clone();
|
||||||
while let Some(pid) = current {
|
while let Some(pid) = current {
|
||||||
if pid == category_node.node_id && f.node_type.as_str() == "file" {
|
if pid == category_node.node_id && f.node_type.as_str() == "file" {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
current = tree.nodes.iter()
|
current = tree
|
||||||
|
.nodes
|
||||||
|
.iter()
|
||||||
.find(|p| p.node_id == pid)
|
.find(|p| p.node_id == pid)
|
||||||
.map(|p| p.parent_id.clone()).flatten();
|
.and_then(|p| p.parent_id.clone());
|
||||||
}
|
}
|
||||||
false
|
false
|
||||||
})
|
})
|
||||||
.map(|file_node| {
|
.map(|file_node| SeriesFile {
|
||||||
SeriesFile {
|
filename: file_node.label.clone(),
|
||||||
filename: file_node.label.clone(),
|
size: file_node
|
||||||
size: file_node.aliases.get("file_size_display").unwrap_or(&"N/A".to_string()).clone(),
|
.aliases
|
||||||
download_url: file_node.aliases.get("download_url").unwrap_or(&"".to_string()).clone(),
|
.get("file_size_display")
|
||||||
}
|
.unwrap_or(&"N/A".to_string())
|
||||||
|
.clone(),
|
||||||
|
download_url: file_node
|
||||||
|
.aliases
|
||||||
|
.get("download_url")
|
||||||
|
.unwrap_or(&"".to_string())
|
||||||
|
.clone(),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
@@ -300,17 +357,27 @@ pub fn search_files(query: &str, view: &str) -> Result<SearchResponse> {
|
|||||||
let conn = FileTree::open_user_db("accusys")?;
|
let conn = FileTree::open_user_db("accusys")?;
|
||||||
let tree = FileTree::load(&conn, "accusys", tree_type)?;
|
let tree = FileTree::load(&conn, "accusys", tree_type)?;
|
||||||
|
|
||||||
let results: Vec<SearchResult> = tree.nodes.iter()
|
let results: Vec<SearchResult> = tree
|
||||||
.filter(|n| n.node_type.as_str() == "file" && n.label.to_lowercase().contains(&query.to_lowercase()))
|
.nodes
|
||||||
|
.iter()
|
||||||
|
.filter(|n| {
|
||||||
|
n.node_type.as_str() == "file" && n.label.to_lowercase().contains(&query.to_lowercase())
|
||||||
|
})
|
||||||
.map(|file_node| {
|
.map(|file_node| {
|
||||||
let parent_node = tree.nodes.iter()
|
let parent_node = tree
|
||||||
|
.nodes
|
||||||
|
.iter()
|
||||||
.find(|n| n.node_id == file_node.parent_id.clone().unwrap_or_default());
|
.find(|n| n.node_id == file_node.parent_id.clone().unwrap_or_default());
|
||||||
|
|
||||||
SearchResult {
|
SearchResult {
|
||||||
category: parent_node.map(|n| n.label.clone()),
|
category: parent_node.map(|n| n.label.clone()),
|
||||||
series: parent_node.map(|n| n.label.clone()),
|
series: parent_node.map(|n| n.label.clone()),
|
||||||
filename: file_node.label.clone(),
|
filename: file_node.label.clone(),
|
||||||
download_url: file_node.aliases.get("download_url").unwrap_or(&"".to_string()).clone(),
|
download_url: file_node
|
||||||
|
.aliases
|
||||||
|
.get("download_url")
|
||||||
|
.unwrap_or(&"".to_string())
|
||||||
|
.clone(),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|||||||
@@ -34,7 +34,8 @@ pub async fn handle_iscsi_command(cmd: IscsiCommand) -> anyhow::Result<()> {
|
|||||||
force,
|
force,
|
||||||
device,
|
device,
|
||||||
} => {
|
} => {
|
||||||
cmd_process.arg("start")
|
cmd_process
|
||||||
|
.arg("start")
|
||||||
.args(["--user", &user])
|
.args(["--user", &user])
|
||||||
.args(["--port", &port.to_string()])
|
.args(["--port", &port.to_string()])
|
||||||
.args(["--lun-size", &lun_size]);
|
.args(["--lun-size", &lun_size]);
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
pub mod web;
|
|
||||||
pub mod ssh;
|
|
||||||
pub mod webdav;
|
|
||||||
pub mod iscsi;
|
pub mod iscsi;
|
||||||
|
pub mod ssh;
|
||||||
pub mod tree;
|
pub mod tree;
|
||||||
|
pub mod web;
|
||||||
|
pub mod webdav;
|
||||||
|
|
||||||
use clap::Subcommand;
|
use clap::Subcommand;
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
use anyhow::Context;
|
||||||
use clap::Subcommand;
|
use clap::Subcommand;
|
||||||
use rusqlite::Connection;
|
use rusqlite::Connection;
|
||||||
use anyhow::Context;
|
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
#[derive(Subcommand)]
|
#[derive(Subcommand)]
|
||||||
@@ -113,7 +113,11 @@ pub enum FolderCommand {
|
|||||||
|
|
||||||
pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
||||||
match cmd {
|
match cmd {
|
||||||
TreeCommand::Create { name, user, tree_type } => {
|
TreeCommand::Create {
|
||||||
|
name,
|
||||||
|
user,
|
||||||
|
tree_type,
|
||||||
|
} => {
|
||||||
let db_path = format!("data/users/{}.sqlite", user);
|
let db_path = format!("data/users/{}.sqlite", user);
|
||||||
let conn = Connection::open(&db_path)
|
let conn = Connection::open(&db_path)
|
||||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||||
@@ -127,7 +131,10 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
|||||||
rusqlite::params![node_id, name, tree_type, created_at]
|
rusqlite::params![node_id, name, tree_type, created_at]
|
||||||
).context("Failed to create tree")?;
|
).context("Failed to create tree")?;
|
||||||
|
|
||||||
println!("✓ Tree created: {} (type: {}) for user: {}", name, tree_type, user);
|
println!(
|
||||||
|
"✓ Tree created: {} (type: {}) for user: {}",
|
||||||
|
name, tree_type, user
|
||||||
|
);
|
||||||
println!("✓ Node ID: {}", node_id);
|
println!("✓ Node ID: {}", node_id);
|
||||||
}
|
}
|
||||||
TreeCommand::List { user } => {
|
TreeCommand::List { user } => {
|
||||||
@@ -135,21 +142,24 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
|||||||
let conn = Connection::open(&db_path)
|
let conn = Connection::open(&db_path)
|
||||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||||
|
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn
|
||||||
"SELECT DISTINCT tree_type FROM file_nodes ORDER BY tree_type"
|
.prepare("SELECT DISTINCT tree_type FROM file_nodes ORDER BY tree_type")
|
||||||
).context("Failed to prepare query")?;
|
.context("Failed to prepare query")?;
|
||||||
|
|
||||||
let tree_types = stmt.query_map([], |row| row.get::<_, String>(0))
|
let tree_types = stmt
|
||||||
|
.query_map([], |row| row.get::<_, String>(0))
|
||||||
.context("Failed to query tree types")?;
|
.context("Failed to query tree types")?;
|
||||||
|
|
||||||
println!("=== Trees for user: {} ===", user);
|
println!("=== Trees for user: {} ===", user);
|
||||||
for tree_type in tree_types {
|
for tree_type in tree_types {
|
||||||
let tt = tree_type?;
|
let tt = tree_type?;
|
||||||
let count: i64 = conn.query_row(
|
let count: i64 = conn
|
||||||
"SELECT COUNT(*) FROM file_nodes WHERE tree_type = ?1",
|
.query_row(
|
||||||
[&tt],
|
"SELECT COUNT(*) FROM file_nodes WHERE tree_type = ?1",
|
||||||
|row| row.get(0)
|
[&tt],
|
||||||
).unwrap_or(0);
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
println!(" {} ({} nodes)", tt, count);
|
println!(" {} ({} nodes)", tt, count);
|
||||||
}
|
}
|
||||||
@@ -168,7 +178,10 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
|||||||
crate::import_markdown::import_series_to_db(&conn, &user, &tree_type)?;
|
crate::import_markdown::import_series_to_db(&conn, &user, &tree_type)?;
|
||||||
println!("✓ Series imported successfully!");
|
println!("✓ Series imported successfully!");
|
||||||
} else {
|
} else {
|
||||||
eprintln!("Invalid tree_type: {}. Use 'categories' or 'series'", tree_type);
|
eprintln!(
|
||||||
|
"Invalid tree_type: {}. Use 'categories' or 'series'",
|
||||||
|
tree_type
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TreeCommand::Delete { user, name } => {
|
TreeCommand::Delete { user, name } => {
|
||||||
@@ -178,8 +191,9 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
|||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"DELETE FROM file_nodes WHERE label = ?1 AND node_type = 'folder'",
|
"DELETE FROM file_nodes WHERE label = ?1 AND node_type = 'folder'",
|
||||||
[&name]
|
[&name],
|
||||||
).context("Failed to delete tree")?;
|
)
|
||||||
|
.context("Failed to delete tree")?;
|
||||||
|
|
||||||
println!("✓ Tree deleted: {} for user: {}", name, user);
|
println!("✓ Tree deleted: {} for user: {}", name, user);
|
||||||
}
|
}
|
||||||
@@ -188,32 +202,41 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
|||||||
handle_folder_command(action)?;
|
handle_folder_command(action)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
TreeCommand::Ls { user, path, tree_type } => {
|
TreeCommand::Ls {
|
||||||
|
user,
|
||||||
|
path,
|
||||||
|
tree_type,
|
||||||
|
} => {
|
||||||
let db_path = format!("data/users/{}.sqlite", user);
|
let db_path = format!("data/users/{}.sqlite", user);
|
||||||
let conn = Connection::open(&db_path)
|
let conn = Connection::open(&db_path)
|
||||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||||
|
|
||||||
let parent_id = find_node_id(&conn, &path, &tree_type)?;
|
let parent_id = find_node_id(&conn, &path, &tree_type)?;
|
||||||
|
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn
|
||||||
"SELECT label, node_type, file_size FROM file_nodes
|
.prepare(
|
||||||
|
"SELECT label, node_type, file_size FROM file_nodes
|
||||||
WHERE parent_id = ?1 AND tree_type = ?2
|
WHERE parent_id = ?1 AND tree_type = ?2
|
||||||
ORDER BY node_type DESC, label ASC"
|
ORDER BY node_type DESC, label ASC",
|
||||||
).context("Failed to prepare ls query")?;
|
)
|
||||||
|
.context("Failed to prepare ls query")?;
|
||||||
|
|
||||||
let entries = stmt.query_map(
|
let entries = stmt
|
||||||
rusqlite::params![parent_id, tree_type],
|
.query_map(rusqlite::params![parent_id, tree_type], |row| {
|
||||||
|row| Ok((
|
Ok((
|
||||||
row.get::<_, String>(0)?,
|
row.get::<_, String>(0)?,
|
||||||
row.get::<_, String>(1)?,
|
row.get::<_, String>(1)?,
|
||||||
row.get::<_, Option<i64>>(2)?
|
row.get::<_, Option<i64>>(2)?,
|
||||||
))
|
))
|
||||||
).context("Failed to query entries")?;
|
})
|
||||||
|
.context("Failed to query entries")?;
|
||||||
|
|
||||||
println!("=== Contents of {} (tree_type: {}) ===", path, tree_type);
|
println!("=== Contents of {} (tree_type: {}) ===", path, tree_type);
|
||||||
for entry in entries {
|
for entry in entries {
|
||||||
let (name, node_type, size) = entry?;
|
let (name, node_type, size) = entry?;
|
||||||
let size_str = size.map(|s| format!("{} bytes", s)).unwrap_or_else(|| "-".to_string());
|
let size_str = size
|
||||||
|
.map(|s| format!("{} bytes", s))
|
||||||
|
.unwrap_or_else(|| "-".to_string());
|
||||||
|
|
||||||
if node_type == "folder" {
|
if node_type == "folder" {
|
||||||
println!(" 📁 {} ({})", name, size_str);
|
println!(" 📁 {} ({})", name, size_str);
|
||||||
@@ -223,7 +246,12 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TreeCommand::Cp { user, source, target, tree_type } => {
|
TreeCommand::Cp {
|
||||||
|
user,
|
||||||
|
source,
|
||||||
|
target,
|
||||||
|
tree_type,
|
||||||
|
} => {
|
||||||
let db_path = format!("data/users/{}.sqlite", user);
|
let db_path = format!("data/users/{}.sqlite", user);
|
||||||
let conn = Connection::open(&db_path)
|
let conn = Connection::open(&db_path)
|
||||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||||
@@ -231,19 +259,23 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
|||||||
let source_id = find_node_id(&conn, &source, &tree_type)?;
|
let source_id = find_node_id(&conn, &source, &tree_type)?;
|
||||||
let target_parent_id = find_node_id(&conn, &target, &tree_type)?;
|
let target_parent_id = find_node_id(&conn, &target, &tree_type)?;
|
||||||
|
|
||||||
let (label, node_type, aliases_json, file_uuid, sha256, file_size) = conn.query_row(
|
let (label, node_type, aliases_json, file_uuid, sha256, file_size) = conn
|
||||||
"SELECT label, node_type, aliases_json, file_uuid, sha256, file_size
|
.query_row(
|
||||||
|
"SELECT label, node_type, aliases_json, file_uuid, sha256, file_size
|
||||||
FROM file_nodes WHERE node_id = ?1",
|
FROM file_nodes WHERE node_id = ?1",
|
||||||
[&source_id],
|
[&source_id],
|
||||||
|row| Ok((
|
|row| {
|
||||||
row.get::<_, String>(0)?,
|
Ok((
|
||||||
row.get::<_, String>(1)?,
|
row.get::<_, String>(0)?,
|
||||||
row.get::<_, String>(2)?,
|
row.get::<_, String>(1)?,
|
||||||
row.get::<_, Option<String>>(3)?,
|
row.get::<_, String>(2)?,
|
||||||
row.get::<_, Option<String>>(4)?,
|
row.get::<_, Option<String>>(3)?,
|
||||||
row.get::<_, Option<i64>>(5)?
|
row.get::<_, Option<String>>(4)?,
|
||||||
))
|
row.get::<_, Option<i64>>(5)?,
|
||||||
).context("Failed to get source node")?;
|
))
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.context("Failed to get source node")?;
|
||||||
|
|
||||||
let new_id = Uuid::new_v4().to_string();
|
let new_id = Uuid::new_v4().to_string();
|
||||||
let created_at = chrono::Utc::now().to_rfc3339();
|
let created_at = chrono::Utc::now().to_rfc3339();
|
||||||
@@ -258,7 +290,12 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
|||||||
println!("✓ Copied {} to {} (new ID: {})", source, target, new_id);
|
println!("✓ Copied {} to {} (new ID: {})", source, target, new_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
TreeCommand::Mv { user, source, target, tree_type } => {
|
TreeCommand::Mv {
|
||||||
|
user,
|
||||||
|
source,
|
||||||
|
target,
|
||||||
|
tree_type,
|
||||||
|
} => {
|
||||||
let db_path = format!("data/users/{}.sqlite", user);
|
let db_path = format!("data/users/{}.sqlite", user);
|
||||||
let conn = Connection::open(&db_path)
|
let conn = Connection::open(&db_path)
|
||||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||||
@@ -270,8 +307,9 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
|||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"UPDATE file_nodes SET parent_id = ?1, updated_at = ?2 WHERE node_id = ?3",
|
"UPDATE file_nodes SET parent_id = ?1, updated_at = ?2 WHERE node_id = ?3",
|
||||||
rusqlite::params![target_parent_id, updated_at, source_id]
|
rusqlite::params![target_parent_id, updated_at, source_id],
|
||||||
).context("Failed to move node")?;
|
)
|
||||||
|
.context("Failed to move node")?;
|
||||||
|
|
||||||
println!("✓ Moved {} to {}", source, target);
|
println!("✓ Moved {} to {}", source, target);
|
||||||
}
|
}
|
||||||
@@ -281,12 +319,17 @@ pub async fn handle_tree_command(cmd: TreeCommand) -> anyhow::Result<()> {
|
|||||||
|
|
||||||
fn handle_folder_command(cmd: FolderCommand) -> anyhow::Result<()> {
|
fn handle_folder_command(cmd: FolderCommand) -> anyhow::Result<()> {
|
||||||
match cmd {
|
match cmd {
|
||||||
FolderCommand::Create { user, path, name, tree_type } => {
|
FolderCommand::Create {
|
||||||
|
user,
|
||||||
|
path,
|
||||||
|
name,
|
||||||
|
tree_type,
|
||||||
|
} => {
|
||||||
let db_path = format!("data/users/{}.sqlite", user);
|
let db_path = format!("data/users/{}.sqlite", user);
|
||||||
let conn = Connection::open(&db_path)
|
let conn = Connection::open(&db_path)
|
||||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||||
|
|
||||||
let parent_id = if path == "/" || path == "" {
|
let parent_id = if path == "/" || path.is_empty() {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(find_node_id(&conn, &path, &tree_type)?)
|
Some(find_node_id(&conn, &path, &tree_type)?)
|
||||||
@@ -299,18 +342,27 @@ fn handle_folder_command(cmd: FolderCommand) -> anyhow::Result<()> {
|
|||||||
"INSERT INTO file_nodes
|
"INSERT INTO file_nodes
|
||||||
(node_id, label, parent_id, node_type, tree_type, created_at, updated_at)
|
(node_id, label, parent_id, node_type, tree_type, created_at, updated_at)
|
||||||
VALUES (?1, ?2, ?3, 'folder', ?4, ?5, ?5)",
|
VALUES (?1, ?2, ?3, 'folder', ?4, ?5, ?5)",
|
||||||
rusqlite::params![node_id, name, parent_id, tree_type, created_at]
|
rusqlite::params![node_id, name, parent_id, tree_type, created_at],
|
||||||
).context("Failed to create folder")?;
|
)
|
||||||
|
.context("Failed to create folder")?;
|
||||||
|
|
||||||
println!("✓ Folder created: {} in {} (tree_type: {})", name, path, tree_type);
|
println!(
|
||||||
|
"✓ Folder created: {} in {} (tree_type: {})",
|
||||||
|
name, path, tree_type
|
||||||
|
);
|
||||||
println!("✓ Node ID: {}", node_id);
|
println!("✓ Node ID: {}", node_id);
|
||||||
}
|
}
|
||||||
FolderCommand::Delete { user, path, name, tree_type } => {
|
FolderCommand::Delete {
|
||||||
|
user,
|
||||||
|
path,
|
||||||
|
name,
|
||||||
|
tree_type,
|
||||||
|
} => {
|
||||||
let db_path = format!("data/users/{}.sqlite", user);
|
let db_path = format!("data/users/{}.sqlite", user);
|
||||||
let conn = Connection::open(&db_path)
|
let conn = Connection::open(&db_path)
|
||||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||||
|
|
||||||
let folder_path = if path == "/" || path == "" {
|
let folder_path = if path == "/" || path.is_empty() {
|
||||||
name.clone()
|
name.clone()
|
||||||
} else {
|
} else {
|
||||||
format!("{}/{}", path, name)
|
format!("{}/{}", path, name)
|
||||||
@@ -320,17 +372,27 @@ fn handle_folder_command(cmd: FolderCommand) -> anyhow::Result<()> {
|
|||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"DELETE FROM file_nodes WHERE node_id = ?1 OR parent_id = ?1",
|
"DELETE FROM file_nodes WHERE node_id = ?1 OR parent_id = ?1",
|
||||||
[&folder_id]
|
[&folder_id],
|
||||||
).context("Failed to delete folder and children")?;
|
)
|
||||||
|
.context("Failed to delete folder and children")?;
|
||||||
|
|
||||||
println!("✓ Folder deleted: {} in {} (tree_type: {})", name, path, tree_type);
|
println!(
|
||||||
|
"✓ Folder deleted: {} in {} (tree_type: {})",
|
||||||
|
name, path, tree_type
|
||||||
|
);
|
||||||
}
|
}
|
||||||
FolderCommand::Rename { user, path, old_name, new_name, tree_type } => {
|
FolderCommand::Rename {
|
||||||
|
user,
|
||||||
|
path,
|
||||||
|
old_name,
|
||||||
|
new_name,
|
||||||
|
tree_type,
|
||||||
|
} => {
|
||||||
let db_path = format!("data/users/{}.sqlite", user);
|
let db_path = format!("data/users/{}.sqlite", user);
|
||||||
let conn = Connection::open(&db_path)
|
let conn = Connection::open(&db_path)
|
||||||
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
.with_context(|| format!("Failed to open database: {}", db_path))?;
|
||||||
|
|
||||||
let folder_path = if path == "/" || path == "" {
|
let folder_path = if path == "/" || path.is_empty() {
|
||||||
old_name.clone()
|
old_name.clone()
|
||||||
} else {
|
} else {
|
||||||
format!("{}/{}", path, old_name)
|
format!("{}/{}", path, old_name)
|
||||||
@@ -342,24 +404,30 @@ fn handle_folder_command(cmd: FolderCommand) -> anyhow::Result<()> {
|
|||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"UPDATE file_nodes SET label = ?1, updated_at = ?2 WHERE node_id = ?3",
|
"UPDATE file_nodes SET label = ?1, updated_at = ?2 WHERE node_id = ?3",
|
||||||
rusqlite::params![new_name, updated_at, folder_id]
|
rusqlite::params![new_name, updated_at, folder_id],
|
||||||
).context("Failed to rename folder")?;
|
)
|
||||||
|
.context("Failed to rename folder")?;
|
||||||
|
|
||||||
println!("✓ Folder renamed: {} → {} in {} (tree_type: {})", old_name, new_name, path, tree_type);
|
println!(
|
||||||
|
"✓ Folder renamed: {} → {} in {} (tree_type: {})",
|
||||||
|
old_name, new_name, path, tree_type
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_node_id(conn: &Connection, path: &str, tree_type: &str) -> anyhow::Result<String> {
|
fn find_node_id(conn: &Connection, path: &str, tree_type: &str) -> anyhow::Result<String> {
|
||||||
if path == "/" || path == "" {
|
if path == "/" || path.is_empty() {
|
||||||
let node_id: String = conn.query_row(
|
let node_id: String = conn
|
||||||
"SELECT node_id FROM file_nodes
|
.query_row(
|
||||||
|
"SELECT node_id FROM file_nodes
|
||||||
WHERE parent_id IS NULL AND node_type = 'folder' AND tree_type = ?1
|
WHERE parent_id IS NULL AND node_type = 'folder' AND tree_type = ?1
|
||||||
LIMIT 1",
|
LIMIT 1",
|
||||||
[tree_type],
|
[tree_type],
|
||||||
|row| row.get(0)
|
|row| row.get(0),
|
||||||
).context("Failed to find root folder")?;
|
)
|
||||||
|
.context("Failed to find root folder")?;
|
||||||
|
|
||||||
return Ok(node_id);
|
return Ok(node_id);
|
||||||
}
|
}
|
||||||
@@ -369,13 +437,15 @@ fn find_node_id(conn: &Connection, path: &str, tree_type: &str) -> anyhow::Resul
|
|||||||
let mut current_parent: Option<String> = None;
|
let mut current_parent: Option<String> = None;
|
||||||
|
|
||||||
for part in parts {
|
for part in parts {
|
||||||
let node_id: String = conn.query_row(
|
let node_id: String = conn
|
||||||
"SELECT node_id FROM file_nodes
|
.query_row(
|
||||||
|
"SELECT node_id FROM file_nodes
|
||||||
WHERE label = ?1 AND tree_type = ?2 AND
|
WHERE label = ?1 AND tree_type = ?2 AND
|
||||||
(parent_id = ?3 OR (?3 IS NULL AND parent_id IS NULL))",
|
(parent_id = ?3 OR (?3 IS NULL AND parent_id IS NULL))",
|
||||||
rusqlite::params![part, tree_type, current_parent],
|
rusqlite::params![part, tree_type, current_parent],
|
||||||
|row| row.get(0)
|
|row| row.get(0),
|
||||||
).context(format!("Failed to find node: {}", part))?;
|
)
|
||||||
|
.context(format!("Failed to find node: {}", part))?;
|
||||||
|
|
||||||
current_parent = Some(node_id);
|
current_parent = Some(node_id);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use clap::Subcommand;
|
|
||||||
use axum::{extract::Request, response::IntoResponse, Extension};
|
use axum::{extract::Request, response::IntoResponse, Extension};
|
||||||
|
use clap::Subcommand;
|
||||||
|
|
||||||
#[derive(Subcommand)]
|
#[derive(Subcommand)]
|
||||||
pub enum WebdavCommand {
|
pub enum WebdavCommand {
|
||||||
@@ -28,7 +28,7 @@ pub async fn handle_webdav_command(cmd: WebdavCommand) -> anyhow::Result<()> {
|
|||||||
println!("User: {}", user);
|
println!("User: {}", user);
|
||||||
println!("Port: {}", port);
|
println!("Port: {}", port);
|
||||||
println!("Database: {}", db_path.display());
|
println!("Database: {}", db_path.display());
|
||||||
println!("");
|
println!();
|
||||||
|
|
||||||
run_webdav_server(port, user, db_path).await?;
|
run_webdav_server(port, user, db_path).await?;
|
||||||
}
|
}
|
||||||
@@ -41,7 +41,7 @@ async fn run_webdav_server(
|
|||||||
user: String,
|
user: String,
|
||||||
db_path: std::path::PathBuf,
|
db_path: std::path::PathBuf,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
use axum::{extract::Request, response::IntoResponse, routing::any, Extension, Router};
|
use axum::{routing::any, Extension, Router};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
let webdav = markbase_webdav::webdav::MarkBaseWebDAV::new(user, db_path);
|
let webdav = markbase_webdav::webdav::MarkBaseWebDAV::new(user, db_path);
|
||||||
@@ -58,7 +58,7 @@ async fn run_webdav_server(
|
|||||||
|
|
||||||
println!("WebDAV server listening on http://{}", addr);
|
println!("WebDAV server listening on http://{}", addr);
|
||||||
println!("Mount point: /webdav");
|
println!("Mount point: /webdav");
|
||||||
println!("");
|
println!();
|
||||||
println!("Press Ctrl+C to stop");
|
println!("Press Ctrl+C to stop");
|
||||||
|
|
||||||
axum::serve(listener, app).await?;
|
axum::serve(listener, app).await?;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
use anyhow::Context;
|
||||||
use clap::Subcommand;
|
use clap::Subcommand;
|
||||||
use rusqlite::Connection;
|
use rusqlite::Connection;
|
||||||
use anyhow::Context;
|
|
||||||
|
|
||||||
#[derive(Subcommand)]
|
#[derive(Subcommand)]
|
||||||
pub enum AuthCommand {
|
pub enum AuthCommand {
|
||||||
@@ -29,17 +29,18 @@ pub fn handle_auth_command(cmd: AuthCommand) -> anyhow::Result<()> {
|
|||||||
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
||||||
}
|
}
|
||||||
|
|
||||||
let conn = Connection::open(db_path)
|
let conn = Connection::open(db_path).context("Failed to open auth database")?;
|
||||||
.context("Failed to open auth database")?;
|
|
||||||
|
|
||||||
let password_hash: String = conn.query_row(
|
let password_hash: String = conn
|
||||||
"SELECT password_hash FROM sftpgo_users WHERE username = ?",
|
.query_row(
|
||||||
[&user],
|
"SELECT password_hash FROM sftpgo_users WHERE username = ?",
|
||||||
|row| row.get(0)
|
[&user],
|
||||||
).context("Failed to query password hash")?;
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.context("Failed to query password hash")?;
|
||||||
|
|
||||||
let valid = bcrypt::verify(&password, &password_hash)
|
let valid =
|
||||||
.context("Failed to verify password")?;
|
bcrypt::verify(&password, &password_hash).context("Failed to verify password")?;
|
||||||
|
|
||||||
if !valid {
|
if !valid {
|
||||||
return Err(anyhow::anyhow!("Invalid password for user: {}", user));
|
return Err(anyhow::anyhow!("Invalid password for user: {}", user));
|
||||||
@@ -86,7 +87,8 @@ fn verify_simple_token(token: &str) -> anyhow::Result<String> {
|
|||||||
let user = parts[0];
|
let user = parts[0];
|
||||||
let timestamp_str = parts[1];
|
let timestamp_str = parts[1];
|
||||||
|
|
||||||
let timestamp: u64 = timestamp_str.parse()
|
let timestamp: u64 = timestamp_str
|
||||||
|
.parse()
|
||||||
.context("Failed to parse token timestamp")?;
|
.context("Failed to parse token timestamp")?;
|
||||||
|
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|||||||
@@ -45,7 +45,9 @@ pub fn handle_config_command(cmd: ConfigCommand) -> anyhow::Result<()> {
|
|||||||
let config_path = Path::new("config/markbase.toml");
|
let config_path = Path::new("config/markbase.toml");
|
||||||
|
|
||||||
if !config_path.exists() {
|
if !config_path.exists() {
|
||||||
println!("Configuration file not found. Run 'markbase metadata config init' first.");
|
println!(
|
||||||
|
"Configuration file not found. Run 'markbase metadata config init' first."
|
||||||
|
);
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,7 +63,9 @@ pub fn handle_config_command(cmd: ConfigCommand) -> anyhow::Result<()> {
|
|||||||
let config_path = Path::new("config/markbase.toml");
|
let config_path = Path::new("config/markbase.toml");
|
||||||
|
|
||||||
if !config_path.exists() {
|
if !config_path.exists() {
|
||||||
println!("Configuration file not found. Run 'markbase metadata config init' first.");
|
println!(
|
||||||
|
"Configuration file not found. Run 'markbase metadata config init' first."
|
||||||
|
);
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,7 +90,9 @@ pub fn handle_config_command(cmd: ConfigCommand) -> anyhow::Result<()> {
|
|||||||
let config_path = Path::new("config/markbase.toml");
|
let config_path = Path::new("config/markbase.toml");
|
||||||
|
|
||||||
if !config_path.exists() {
|
if !config_path.exists() {
|
||||||
println!("Configuration file not found. Run 'markbase metadata config init' first.");
|
println!(
|
||||||
|
"Configuration file not found. Run 'markbase metadata config init' first."
|
||||||
|
);
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
use anyhow::Context;
|
||||||
use clap::Subcommand;
|
use clap::Subcommand;
|
||||||
use rusqlite::Connection;
|
use rusqlite::Connection;
|
||||||
use anyhow::Context;
|
|
||||||
|
|
||||||
#[derive(Subcommand)]
|
#[derive(Subcommand)]
|
||||||
pub enum DbCommand {
|
pub enum DbCommand {
|
||||||
@@ -43,13 +43,16 @@ pub fn handle_db_command(cmd: DbCommand) -> anyhow::Result<()> {
|
|||||||
|
|
||||||
println!("Creating database for user: {}", user);
|
println!("Creating database for user: {}", user);
|
||||||
|
|
||||||
let conn = filetree::FileTree::init_user_db(&user)
|
let conn =
|
||||||
.context("Failed to initialize database")?;
|
filetree::FileTree::init_user_db(&user).context("Failed to initialize database")?;
|
||||||
|
|
||||||
println!("✓ Database created: {}", db_path);
|
println!("✓ Database created: {}", db_path);
|
||||||
println!("✓ Tables initialized: file_nodes, file_registry, file_locations, tree_registry");
|
println!(
|
||||||
|
"✓ Tables initialized: file_nodes, file_registry, file_locations, tree_registry"
|
||||||
|
);
|
||||||
|
|
||||||
conn.close().map_err(|e| anyhow::anyhow!("Failed to close database: {:?}", e))?;
|
conn.close()
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to close database: {:?}", e))?;
|
||||||
}
|
}
|
||||||
DbCommand::Status { user } => {
|
DbCommand::Status { user } => {
|
||||||
let db_path = filetree::FileTree::user_db_path(&user);
|
let db_path = filetree::FileTree::user_db_path(&user);
|
||||||
@@ -58,23 +61,18 @@ pub fn handle_db_command(cmd: DbCommand) -> anyhow::Result<()> {
|
|||||||
return Err(anyhow::anyhow!("Database not found: {}", db_path));
|
return Err(anyhow::anyhow!("Database not found: {}", db_path));
|
||||||
}
|
}
|
||||||
|
|
||||||
let conn = Connection::open(&db_path)
|
let conn = Connection::open(&db_path).context("Failed to open database")?;
|
||||||
.context("Failed to open database")?;
|
|
||||||
|
|
||||||
let file_size = std::fs::metadata(&db_path)?.len();
|
let file_size = std::fs::metadata(&db_path)?.len();
|
||||||
let file_size_mb = file_size as f64 / 1024.0 / 1024.0;
|
let file_size_mb = file_size as f64 / 1024.0 / 1024.0;
|
||||||
|
|
||||||
let node_count: i64 = conn.query_row(
|
let node_count: i64 = conn
|
||||||
"SELECT COUNT(*) FROM file_nodes",
|
.query_row("SELECT COUNT(*) FROM file_nodes", [], |row| row.get(0))
|
||||||
[],
|
.context("Failed to count nodes")?;
|
||||||
|row| row.get(0)
|
|
||||||
).context("Failed to count nodes")?;
|
|
||||||
|
|
||||||
let file_count: i64 = conn.query_row(
|
let file_count: i64 = conn
|
||||||
"SELECT COUNT(*) FROM file_registry",
|
.query_row("SELECT COUNT(*) FROM file_registry", [], |row| row.get(0))
|
||||||
[],
|
.context("Failed to count files")?;
|
||||||
|row| row.get(0)
|
|
||||||
).context("Failed to count files")?;
|
|
||||||
|
|
||||||
let tree_types: Vec<String> = {
|
let tree_types: Vec<String> = {
|
||||||
let mut stmt = conn.prepare("SELECT tree_type FROM tree_registry")?;
|
let mut stmt = conn.prepare("SELECT tree_type FROM tree_registry")?;
|
||||||
@@ -90,7 +88,8 @@ pub fn handle_db_command(cmd: DbCommand) -> anyhow::Result<()> {
|
|||||||
println!("Files: {}", file_count);
|
println!("Files: {}", file_count);
|
||||||
println!("Tree Types: {:?}", tree_types);
|
println!("Tree Types: {:?}", tree_types);
|
||||||
|
|
||||||
conn.close().map_err(|e| anyhow::anyhow!("Failed to close database: {:?}", e))?;
|
conn.close()
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to close database: {:?}", e))?;
|
||||||
}
|
}
|
||||||
DbCommand::Backup { user, output } => {
|
DbCommand::Backup { user, output } => {
|
||||||
let db_path = filetree::FileTree::user_db_path(&user);
|
let db_path = filetree::FileTree::user_db_path(&user);
|
||||||
@@ -101,8 +100,7 @@ pub fn handle_db_command(cmd: DbCommand) -> anyhow::Result<()> {
|
|||||||
|
|
||||||
println!("Backing up database for user: {} to {}", user, output);
|
println!("Backing up database for user: {} to {}", user, output);
|
||||||
|
|
||||||
std::fs::copy(&db_path, &output)
|
std::fs::copy(&db_path, &output).context("Failed to backup database")?;
|
||||||
.context("Failed to backup database")?;
|
|
||||||
|
|
||||||
println!("✓ Database backed up to: {}", output);
|
println!("✓ Database backed up to: {}", output);
|
||||||
println!("✓ Backup size: {} bytes", std::fs::metadata(&output)?.len());
|
println!("✓ Backup size: {} bytes", std::fs::metadata(&output)?.len());
|
||||||
@@ -123,11 +121,13 @@ pub fn handle_db_command(cmd: DbCommand) -> anyhow::Result<()> {
|
|||||||
|
|
||||||
println!("Restoring database for user: {} from {}", user, input);
|
println!("Restoring database for user: {} from {}", user, input);
|
||||||
|
|
||||||
std::fs::copy(&input, &db_path)
|
std::fs::copy(&input, &db_path).context("Failed to restore database")?;
|
||||||
.context("Failed to restore database")?;
|
|
||||||
|
|
||||||
println!("✓ Database restored from: {}", input);
|
println!("✓ Database restored from: {}", input);
|
||||||
println!("✓ Database size: {} bytes", std::fs::metadata(&db_path)?.len());
|
println!(
|
||||||
|
"✓ Database size: {} bytes",
|
||||||
|
std::fs::metadata(&db_path)?.len()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
pub mod config;
|
|
||||||
pub mod user;
|
|
||||||
pub mod db;
|
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
|
pub mod config;
|
||||||
|
pub mod db;
|
||||||
|
pub mod user;
|
||||||
|
|
||||||
use clap::Subcommand;
|
use clap::Subcommand;
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
use anyhow::Context;
|
||||||
use clap::Subcommand;
|
use clap::Subcommand;
|
||||||
use rusqlite::Connection;
|
use rusqlite::Connection;
|
||||||
use anyhow::Context;
|
|
||||||
|
|
||||||
#[derive(Subcommand)]
|
#[derive(Subcommand)]
|
||||||
pub enum UserCommand {
|
pub enum UserCommand {
|
||||||
@@ -18,7 +18,7 @@ pub enum UserCommand {
|
|||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
name: String,
|
name: String,
|
||||||
},
|
},
|
||||||
#[command(name = "user-delete")]
|
#[command(name = "user-delete")]
|
||||||
Delete {
|
Delete {
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
name: String,
|
name: String,
|
||||||
@@ -34,21 +34,22 @@ pub fn handle_user_command(cmd: UserCommand) -> anyhow::Result<()> {
|
|||||||
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
||||||
}
|
}
|
||||||
|
|
||||||
let conn = Connection::open(db_path)
|
let conn = Connection::open(db_path).context("Failed to open auth database")?;
|
||||||
.context("Failed to open auth database")?;
|
|
||||||
|
|
||||||
let exists: i64 = conn.query_row(
|
let exists: i64 = conn
|
||||||
"SELECT COUNT(*) FROM sftpgo_users WHERE username = ?",
|
.query_row(
|
||||||
[&name],
|
"SELECT COUNT(*) FROM sftpgo_users WHERE username = ?",
|
||||||
|row| row.get(0)
|
[&name],
|
||||||
).context("Failed to check user existence")?;
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.context("Failed to check user existence")?;
|
||||||
|
|
||||||
if exists > 0 {
|
if exists > 0 {
|
||||||
return Err(anyhow::anyhow!("User already exists: {}", name));
|
return Err(anyhow::anyhow!("User already exists: {}", name));
|
||||||
}
|
}
|
||||||
|
|
||||||
let password_hash = bcrypt::hash(&password, bcrypt::DEFAULT_COST)
|
let password_hash =
|
||||||
.context("Failed to hash password")?;
|
bcrypt::hash(&password, bcrypt::DEFAULT_COST).context("Failed to hash password")?;
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO sftpgo_users (username, password_hash, role, created_at) VALUES (?, ?, 'user', datetime('now'))",
|
"INSERT INTO sftpgo_users (username, password_hash, role, created_at) VALUES (?, ?, 'user', datetime('now'))",
|
||||||
@@ -66,20 +67,21 @@ pub fn handle_user_command(cmd: UserCommand) -> anyhow::Result<()> {
|
|||||||
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
||||||
}
|
}
|
||||||
|
|
||||||
let conn = Connection::open(db_path)
|
let conn = Connection::open(db_path).context("Failed to open auth database")?;
|
||||||
.context("Failed to open auth database")?;
|
|
||||||
|
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn
|
||||||
"SELECT username, role, created_at FROM sftpgo_users ORDER BY username"
|
.prepare("SELECT username, role, created_at FROM sftpgo_users ORDER BY username")
|
||||||
).context("Failed to prepare query")?;
|
.context("Failed to prepare query")?;
|
||||||
|
|
||||||
let users = stmt.query_map([], |row| {
|
let users = stmt
|
||||||
Ok((
|
.query_map([], |row| {
|
||||||
row.get::<_, String>(0)?,
|
Ok((
|
||||||
row.get::<_, String>(1)?,
|
row.get::<_, String>(0)?,
|
||||||
row.get::<_, String>(2)?,
|
row.get::<_, String>(1)?,
|
||||||
))
|
row.get::<_, String>(2)?,
|
||||||
}).context("Failed to query users")?;
|
))
|
||||||
|
})
|
||||||
|
.context("Failed to query users")?;
|
||||||
|
|
||||||
println!("=== Users List ===");
|
println!("=== Users List ===");
|
||||||
let mut count = 0;
|
let mut count = 0;
|
||||||
@@ -102,18 +104,21 @@ pub fn handle_user_command(cmd: UserCommand) -> anyhow::Result<()> {
|
|||||||
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
||||||
}
|
}
|
||||||
|
|
||||||
let conn = Connection::open(db_path)
|
let conn = Connection::open(db_path).context("Failed to open auth database")?;
|
||||||
.context("Failed to open auth database")?;
|
|
||||||
|
|
||||||
let user = conn.query_row(
|
let user = conn
|
||||||
"SELECT username, role, created_at FROM sftpgo_users WHERE username = ?",
|
.query_row(
|
||||||
[&name],
|
"SELECT username, role, created_at FROM sftpgo_users WHERE username = ?",
|
||||||
|row| Ok((
|
[&name],
|
||||||
row.get::<_, String>(0)?,
|
|row| {
|
||||||
row.get::<_, String>(1)?,
|
Ok((
|
||||||
row.get::<_, String>(2)?,
|
row.get::<_, String>(0)?,
|
||||||
))
|
row.get::<_, String>(1)?,
|
||||||
).context("Failed to query user")?;
|
row.get::<_, String>(2)?,
|
||||||
|
))
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.context("Failed to query user")?;
|
||||||
|
|
||||||
let (username, role, created_at) = user;
|
let (username, role, created_at) = user;
|
||||||
println!("=== User Details ===");
|
println!("=== User Details ===");
|
||||||
@@ -128,23 +133,22 @@ pub fn handle_user_command(cmd: UserCommand) -> anyhow::Result<()> {
|
|||||||
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
return Err(anyhow::anyhow!("Auth database not found: {}", db_path));
|
||||||
}
|
}
|
||||||
|
|
||||||
let conn = Connection::open(db_path)
|
let conn = Connection::open(db_path).context("Failed to open auth database")?;
|
||||||
.context("Failed to open auth database")?;
|
|
||||||
|
|
||||||
let exists: i64 = conn.query_row(
|
let exists: i64 = conn
|
||||||
"SELECT COUNT(*) FROM sftpgo_users WHERE username = ?",
|
.query_row(
|
||||||
[&name],
|
"SELECT COUNT(*) FROM sftpgo_users WHERE username = ?",
|
||||||
|row| row.get(0)
|
[&name],
|
||||||
).context("Failed to check user existence")?;
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.context("Failed to check user existence")?;
|
||||||
|
|
||||||
if exists == 0 {
|
if exists == 0 {
|
||||||
return Err(anyhow::anyhow!("User not found: {}", name));
|
return Err(anyhow::anyhow!("User not found: {}", name));
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.execute(
|
conn.execute("DELETE FROM sftpgo_users WHERE username = ?", [&name])
|
||||||
"DELETE FROM sftpgo_users WHERE username = ?",
|
.context("Failed to delete user")?;
|
||||||
[&name]
|
|
||||||
).context("Failed to delete user")?;
|
|
||||||
|
|
||||||
println!("✓ User deleted: {}", name);
|
println!("✓ User deleted: {}", name);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ pub fn handle_archive_command(cmd: ArchiveCommand) -> anyhow::Result<()> {
|
|||||||
println!("Format: {}", metadata.format);
|
println!("Format: {}", metadata.format);
|
||||||
println!("Total files: {}", metadata.total_files);
|
println!("Total files: {}", metadata.total_files);
|
||||||
println!("Total size: {} bytes", metadata.total_size);
|
println!("Total size: {} bytes", metadata.total_size);
|
||||||
println!("");
|
println!();
|
||||||
|
|
||||||
for entry in entries {
|
for entry in entries {
|
||||||
println!(" {} ({} bytes)", entry.path.display(), entry.size);
|
println!(" {} ({} bytes)", entry.path.display(), entry.size);
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
pub mod scan;
|
|
||||||
pub mod hash;
|
|
||||||
pub mod archive;
|
pub mod archive;
|
||||||
pub mod sync;
|
pub mod hash;
|
||||||
pub mod mount;
|
pub mod mount;
|
||||||
|
pub mod scan;
|
||||||
|
pub mod sync;
|
||||||
|
|
||||||
use clap::Subcommand;
|
use clap::Subcommand;
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,11 @@ pub enum MountCommand {
|
|||||||
|
|
||||||
pub fn handle_mount_command(cmd: MountCommand) -> anyhow::Result<()> {
|
pub fn handle_mount_command(cmd: MountCommand) -> anyhow::Result<()> {
|
||||||
match cmd {
|
match cmd {
|
||||||
MountCommand::Attach { type_, server, path } => {
|
MountCommand::Attach {
|
||||||
|
type_,
|
||||||
|
server,
|
||||||
|
path,
|
||||||
|
} => {
|
||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
|
|
||||||
println!("Mounting {} from {} to {}", type_, server, path);
|
println!("Mounting {} from {} to {}", type_, server, path);
|
||||||
@@ -58,7 +62,10 @@ pub fn handle_mount_command(cmd: MountCommand) -> anyhow::Result<()> {
|
|||||||
return Err(anyhow::anyhow!("SMB mount failed"));
|
return Err(anyhow::anyhow!("SMB mount failed"));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return Err(anyhow::anyhow!("Unknown mount type: {}. Use 'nfs' or 'smb'", type_));
|
return Err(anyhow::anyhow!(
|
||||||
|
"Unknown mount type: {}. Use 'nfs' or 'smb'",
|
||||||
|
type_
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MountCommand::Detach { path } => {
|
MountCommand::Detach { path } => {
|
||||||
@@ -66,9 +73,7 @@ pub fn handle_mount_command(cmd: MountCommand) -> anyhow::Result<()> {
|
|||||||
|
|
||||||
println!("Unmounting {}", path);
|
println!("Unmounting {}", path);
|
||||||
|
|
||||||
let status = Command::new("umount")
|
let status = Command::new("umount").arg(&path).status()?;
|
||||||
.arg(&path)
|
|
||||||
.status()?;
|
|
||||||
|
|
||||||
if status.success() {
|
if status.success() {
|
||||||
println!("✓ Unmounted: {}", path);
|
println!("✓ Unmounted: {}", path);
|
||||||
@@ -81,8 +86,7 @@ pub fn handle_mount_command(cmd: MountCommand) -> anyhow::Result<()> {
|
|||||||
|
|
||||||
println!("Listing mounted storage");
|
println!("Listing mounted storage");
|
||||||
|
|
||||||
let output = Command::new("mount")
|
let output = Command::new("mount").output()?;
|
||||||
.output()?;
|
|
||||||
|
|
||||||
let mounts = String::from_utf8_lossy(&output.stdout);
|
let mounts = String::from_utf8_lossy(&output.stdout);
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pub enum SyncCommand {
|
|||||||
|
|
||||||
pub fn handle_sync_command(cmd: SyncCommand) -> anyhow::Result<()> {
|
pub fn handle_sync_command(cmd: SyncCommand) -> anyhow::Result<()> {
|
||||||
match cmd {
|
match cmd {
|
||||||
SyncCommand::Start { source, target, mode } => {
|
SyncCommand::Start {
|
||||||
|
source,
|
||||||
|
target,
|
||||||
|
mode,
|
||||||
|
} => {
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
println!("Syncing {} to {} (mode: {})", source, target, mode);
|
println!("Syncing {} to {} (mode: {})", source, target, mode);
|
||||||
|
|||||||
@@ -16,36 +16,39 @@ pub enum TestCommand {
|
|||||||
|
|
||||||
pub fn handle_test_command(cmd: TestCommand) -> anyhow::Result<()> {
|
pub fn handle_test_command(cmd: TestCommand) -> anyhow::Result<()> {
|
||||||
match cmd {
|
match cmd {
|
||||||
TestCommand::Bcrypt { password, verify_hash } => {
|
TestCommand::Bcrypt {
|
||||||
|
password,
|
||||||
|
verify_hash,
|
||||||
|
} => {
|
||||||
use bcrypt::{hash, verify, DEFAULT_COST};
|
use bcrypt::{hash, verify, DEFAULT_COST};
|
||||||
|
|
||||||
println!("=== bcrypt Hash Test ===");
|
println!("=== bcrypt Hash Test ===");
|
||||||
println!("Password: {}", password);
|
println!("Password: {}", password);
|
||||||
println!("");
|
println!();
|
||||||
|
|
||||||
let new_hash = hash(&password, DEFAULT_COST)?;
|
let new_hash = hash(&password, DEFAULT_COST)?;
|
||||||
println!("Generated hash:");
|
println!("Generated hash:");
|
||||||
println!("{}", new_hash);
|
println!("{}", new_hash);
|
||||||
println!("");
|
println!();
|
||||||
|
|
||||||
if let Some(hash_to_verify) = verify_hash {
|
if let Some(hash_to_verify) = verify_hash {
|
||||||
println!("Verifying hash: {}", hash_to_verify);
|
println!("Verifying hash: {}", hash_to_verify);
|
||||||
let valid = verify(&password, &hash_to_verify)?;
|
let valid = verify(&password, &hash_to_verify)?;
|
||||||
println!("Valid: {}", valid);
|
println!("Valid: {}", valid);
|
||||||
println!("");
|
println!();
|
||||||
}
|
}
|
||||||
|
|
||||||
let db_hash = "$2b$10$ha5wU.mOi8fHLJCfun860u2cfVopa04jwe/q82IKOwqp5uG70qsH6";
|
let db_hash = "$2b$10$ha5wU.mOi8fHLJCfun860u2cfVopa04jwe/q82IKOwqp5uG70qsH6";
|
||||||
println!("Database hash: {}", db_hash);
|
println!("Database hash: {}", db_hash);
|
||||||
let valid = verify(&password, db_hash)?;
|
let valid = verify(&password, db_hash)?;
|
||||||
println!("Database hash valid for '{}': {}", password, valid);
|
println!("Database hash valid for '{}': {}", password, valid);
|
||||||
println!("");
|
println!();
|
||||||
|
|
||||||
if !valid {
|
if !valid {
|
||||||
println!("❌ Database hash is incorrect!");
|
println!("❌ Database hash is incorrect!");
|
||||||
println!("Update SQL:");
|
println!("Update SQL:");
|
||||||
println!("UPDATE sftpgo_users SET password_hash = '{}' WHERE username IN ('testuser', 'demo', 'warren', 'momentry');", new_hash);
|
println!("UPDATE sftpgo_users SET password_hash = '{}' WHERE username IN ('testuser', 'demo', 'warren', 'momentry');", new_hash);
|
||||||
println!("");
|
println!();
|
||||||
println!("Execute:");
|
println!("Execute:");
|
||||||
println!("sqlite3 data/auth.sqlite \"UPDATE sftpgo_users SET password_hash = '{}' WHERE username IN ('testuser', 'demo', 'warren', 'momentry');\"", new_hash);
|
println!("sqlite3 data/auth.sqlite \"UPDATE sftpgo_users SET password_hash = '{}' WHERE username IN ('testuser', 'demo', 'warren', 'momentry');\"", new_hash);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ pub use web::*;
|
|||||||
|
|
||||||
/// Unified application configuration
|
/// Unified application configuration
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[derive(Default)]
|
||||||
pub struct AppConfig {
|
pub struct AppConfig {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub web: WebSection,
|
pub web: WebSection,
|
||||||
@@ -154,13 +155,19 @@ impl AppConfig {
|
|||||||
self.web.host = v;
|
self.web.host = v;
|
||||||
}
|
}
|
||||||
if let Ok(v) = std::env::var("MB_WEB_PORT") {
|
if let Ok(v) = std::env::var("MB_WEB_PORT") {
|
||||||
if let Ok(p) = v.parse() { self.web.port = p; }
|
if let Ok(p) = v.parse() {
|
||||||
|
self.web.port = p;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if let Ok(v) = std::env::var("MB_SSH_PORT") {
|
if let Ok(v) = std::env::var("MB_SSH_PORT") {
|
||||||
if let Ok(p) = v.parse() { self.ssh.port = p; }
|
if let Ok(p) = v.parse() {
|
||||||
|
self.ssh.port = p;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if let Ok(v) = std::env::var("MB_SFTP_PORT") {
|
if let Ok(v) = std::env::var("MB_SFTP_PORT") {
|
||||||
if let Ok(p) = v.parse() { self.sftp.port = p; }
|
if let Ok(p) = v.parse() {
|
||||||
|
self.sftp.port = p;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if let Ok(v) = std::env::var("MB_S3_ENABLED") {
|
if let Ok(v) = std::env::var("MB_S3_ENABLED") {
|
||||||
self.s3.enabled = v == "true" || v == "1";
|
self.s3.enabled = v == "true" || v == "1";
|
||||||
@@ -172,16 +179,6 @@ impl AppConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for AppConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
web: WebSection::default(),
|
|
||||||
s3: S3Section::default(),
|
|
||||||
sftp: SftpSection::default(),
|
|
||||||
ssh: SshSection::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|||||||
@@ -323,11 +323,15 @@ impl MarkBaseConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if self.authentication.default_user.is_empty() {
|
if self.authentication.default_user.is_empty() {
|
||||||
return Err(anyhow::anyhow!("authentication.default_user cannot be empty"));
|
return Err(anyhow::anyhow!(
|
||||||
|
"authentication.default_user cannot be empty"
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.authentication.default_password.is_empty() {
|
if self.authentication.default_password.is_empty() {
|
||||||
return Err(anyhow::anyhow!("authentication.default_password cannot be empty"));
|
return Err(anyhow::anyhow!(
|
||||||
|
"authentication.default_password cannot be empty"
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.authentication.max_sessions_per_user == 0 {
|
if self.authentication.max_sessions_per_user == 0 {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use rusqlite::{Connection, params};
|
use rusqlite::{params, Connection};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
@@ -74,13 +74,18 @@ impl DownloadDb {
|
|||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_product_files_product_id ON product_files(product_id);
|
CREATE INDEX IF NOT EXISTS idx_product_files_product_id ON product_files(product_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_products_series ON products(series);
|
CREATE INDEX IF NOT EXISTS idx_products_series ON products(series);
|
||||||
"
|
",
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_product(&mut self, product_name: &str, series: &str, description: Option<&str>) -> Result<i64> {
|
pub fn create_product(
|
||||||
|
&mut self,
|
||||||
|
product_name: &str,
|
||||||
|
series: &str,
|
||||||
|
description: Option<&str>,
|
||||||
|
) -> Result<i64> {
|
||||||
let now = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
|
let now = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
|
||||||
|
|
||||||
self.conn.execute(
|
self.conn.execute(
|
||||||
@@ -97,16 +102,17 @@ impl DownloadDb {
|
|||||||
"SELECT id, product_name, series, description, created_at FROM products ORDER BY series, product_name"
|
"SELECT id, product_name, series, description, created_at FROM products ORDER BY series, product_name"
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let products = stmt.query_map([], |row| {
|
let products = stmt
|
||||||
Ok(Product {
|
.query_map([], |row| {
|
||||||
id: row.get(0)?,
|
Ok(Product {
|
||||||
product_name: row.get(1)?,
|
id: row.get(0)?,
|
||||||
series: row.get(2)?,
|
product_name: row.get(1)?,
|
||||||
description: row.get(3)?,
|
series: row.get(2)?,
|
||||||
created_at: row.get(4)?,
|
description: row.get(3)?,
|
||||||
})
|
created_at: row.get(4)?,
|
||||||
})?
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
})?
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
Ok(products)
|
Ok(products)
|
||||||
}
|
}
|
||||||
@@ -114,19 +120,20 @@ impl DownloadDb {
|
|||||||
pub fn get_products_by_series(&self, series: &str) -> Result<Vec<Product>> {
|
pub fn get_products_by_series(&self, series: &str) -> Result<Vec<Product>> {
|
||||||
let mut stmt = self.conn.prepare(
|
let mut stmt = self.conn.prepare(
|
||||||
"SELECT id, product_name, series, description, created_at FROM products
|
"SELECT id, product_name, series, description, created_at FROM products
|
||||||
WHERE series = ?1 ORDER BY product_name"
|
WHERE series = ?1 ORDER BY product_name",
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let products = stmt.query_map([series], |row| {
|
let products = stmt
|
||||||
Ok(Product {
|
.query_map([series], |row| {
|
||||||
id: row.get(0)?,
|
Ok(Product {
|
||||||
product_name: row.get(1)?,
|
id: row.get(0)?,
|
||||||
series: row.get(2)?,
|
product_name: row.get(1)?,
|
||||||
description: row.get(3)?,
|
series: row.get(2)?,
|
||||||
created_at: row.get(4)?,
|
description: row.get(3)?,
|
||||||
})
|
created_at: row.get(4)?,
|
||||||
})?
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
})?
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
Ok(products)
|
Ok(products)
|
||||||
}
|
}
|
||||||
@@ -141,23 +148,31 @@ impl DownloadDb {
|
|||||||
FROM products p
|
FROM products p
|
||||||
LEFT JOIN product_files pf ON p.id = pf.product_id
|
LEFT JOIN product_files pf ON p.id = pf.product_id
|
||||||
GROUP BY p.series
|
GROUP BY p.series
|
||||||
ORDER BY p.series"
|
ORDER BY p.series",
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let stats = stmt.query_map([], |row| {
|
let stats = stmt
|
||||||
Ok(SeriesStats {
|
.query_map([], |row| {
|
||||||
series: row.get(0)?,
|
Ok(SeriesStats {
|
||||||
product_count: row.get(1)?,
|
series: row.get(0)?,
|
||||||
file_count: row.get(2)?,
|
product_count: row.get(1)?,
|
||||||
total_size: row.get::<_, i64>(3)? as u64,
|
file_count: row.get(2)?,
|
||||||
})
|
total_size: row.get::<_, i64>(3)? as u64,
|
||||||
})?
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
})?
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
Ok(stats)
|
Ok(stats)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_file_to_product(&mut self, product_id: i64, file_path: &str, file_name: &str, file_size: u64, file_hash: Option<&str>) -> Result<i64> {
|
pub fn add_file_to_product(
|
||||||
|
&mut self,
|
||||||
|
product_id: i64,
|
||||||
|
file_path: &str,
|
||||||
|
file_name: &str,
|
||||||
|
file_size: u64,
|
||||||
|
file_hash: Option<&str>,
|
||||||
|
) -> Result<i64> {
|
||||||
let now = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
|
let now = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
|
||||||
|
|
||||||
self.conn.execute(
|
self.conn.execute(
|
||||||
@@ -175,19 +190,20 @@ impl DownloadDb {
|
|||||||
FROM product_files WHERE product_id = ?1 ORDER BY file_name"
|
FROM product_files WHERE product_id = ?1 ORDER BY file_name"
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let files = stmt.query_map([product_id], |row| {
|
let files = stmt
|
||||||
Ok(ProductFile {
|
.query_map([product_id], |row| {
|
||||||
id: row.get(0)?,
|
Ok(ProductFile {
|
||||||
product_id: row.get(1)?,
|
id: row.get(0)?,
|
||||||
file_path: row.get(2)?,
|
product_id: row.get(1)?,
|
||||||
file_name: row.get(3)?,
|
file_path: row.get(2)?,
|
||||||
file_size: row.get::<_, i64>(4)? as u64,
|
file_name: row.get(3)?,
|
||||||
file_hash: row.get(5)?,
|
file_size: row.get::<_, i64>(4)? as u64,
|
||||||
download_count: row.get(6)?,
|
file_hash: row.get(5)?,
|
||||||
uploaded_at: row.get(7)?,
|
download_count: row.get(6)?,
|
||||||
})
|
uploaded_at: row.get(7)?,
|
||||||
})?
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
})?
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
Ok(files)
|
Ok(files)
|
||||||
}
|
}
|
||||||
@@ -207,19 +223,20 @@ impl DownloadDb {
|
|||||||
FROM product_files ORDER BY uploaded_at DESC"
|
FROM product_files ORDER BY uploaded_at DESC"
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let files = stmt.query_map([], |row| {
|
let files = stmt
|
||||||
Ok(ProductFile {
|
.query_map([], |row| {
|
||||||
id: row.get(0)?,
|
Ok(ProductFile {
|
||||||
product_id: row.get(1)?,
|
id: row.get(0)?,
|
||||||
file_path: row.get(2)?,
|
product_id: row.get(1)?,
|
||||||
file_name: row.get(3)?,
|
file_path: row.get(2)?,
|
||||||
file_size: row.get::<_, i64>(4)? as u64,
|
file_name: row.get(3)?,
|
||||||
file_hash: row.get(5)?,
|
file_size: row.get::<_, i64>(4)? as u64,
|
||||||
download_count: row.get(6)?,
|
file_hash: row.get(5)?,
|
||||||
uploaded_at: row.get(7)?,
|
download_count: row.get(6)?,
|
||||||
})
|
uploaded_at: row.get(7)?,
|
||||||
})?
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
})?
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
Ok(files)
|
Ok(files)
|
||||||
}
|
}
|
||||||
@@ -234,12 +251,14 @@ impl DownloadDb {
|
|||||||
let deleted_files = self.conn.last_insert_rowid();
|
let deleted_files = self.conn.last_insert_rowid();
|
||||||
|
|
||||||
// 再删除产品记录
|
// 再删除产品记录
|
||||||
self.conn.execute(
|
self.conn
|
||||||
"DELETE FROM products WHERE id = ?1",
|
.execute("DELETE FROM products WHERE id = ?1", params![product_id])?;
|
||||||
params![product_id],
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let deleted_product = if self.conn.last_insert_rowid() > 0 { 1 } else { 0 };
|
let deleted_product = if self.conn.last_insert_rowid() > 0 {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
Ok((deleted_files, deleted_product))
|
Ok((deleted_files, deleted_product))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,14 +7,18 @@ use axum::{
|
|||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
|
|
||||||
use crate::server::AppState;
|
|
||||||
use crate::download::db::DownloadDb;
|
use crate::download::db::DownloadDb;
|
||||||
|
use crate::server::AppState;
|
||||||
|
|
||||||
pub async fn download_file(
|
pub async fn download_file(
|
||||||
Path(file_id): Path<i64>,
|
Path(file_id): Path<i64>,
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
let db_path = format!(
|
||||||
|
"{}{}",
|
||||||
|
state.db_dir.replace("users", "downloads"),
|
||||||
|
"/products.sqlite"
|
||||||
|
);
|
||||||
|
|
||||||
match DownloadDb::new(&db_path) {
|
match DownloadDb::new(&db_path) {
|
||||||
Ok(mut db) => {
|
Ok(mut db) => {
|
||||||
@@ -43,29 +47,46 @@ pub async fn download_file(
|
|||||||
Ok(mut file) => {
|
Ok(mut file) => {
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
match file.read_to_end(&mut buffer) {
|
match file.read_to_end(&mut buffer) {
|
||||||
Ok(_) => {
|
Ok(_) => Response::builder()
|
||||||
Response::builder()
|
.status(StatusCode::OK)
|
||||||
.status(StatusCode::OK)
|
.header(header::CONTENT_TYPE, "application/octet-stream")
|
||||||
.header(header::CONTENT_TYPE, "application/octet-stream")
|
.header(
|
||||||
.header(
|
header::CONTENT_DISPOSITION,
|
||||||
header::CONTENT_DISPOSITION,
|
format!("attachment; filename=\"{}\"", file_info.file_name),
|
||||||
format!("attachment; filename=\"{}\"", file_info.file_name)
|
)
|
||||||
)
|
.header(
|
||||||
.header("X-File-Hash", file_info.file_hash.clone().unwrap_or_default())
|
"X-File-Hash",
|
||||||
.header("X-File-Size", file_info.file_size)
|
file_info.file_hash.clone().unwrap_or_default(),
|
||||||
.body(buffer.into())
|
)
|
||||||
.unwrap()
|
.header("X-File-Size", file_info.file_size)
|
||||||
}
|
.body(buffer.into())
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error reading file: {}", e)).into_response()
|
.unwrap(),
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Error reading file: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error opening file: {}", e)).into_response()
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Error opening file: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {}", e)).into_response()
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Database error: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {}", e)).into_response()
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Database error: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,62 +108,77 @@ pub async fn download_file_by_path(
|
|||||||
return (StatusCode::NOT_FOUND, "File not found").into_response();
|
return (StatusCode::NOT_FOUND, "File not found").into_response();
|
||||||
}
|
}
|
||||||
|
|
||||||
let filename = file_path.split('/').last().unwrap_or("unknown");
|
let filename = file_path.split('/').next_back().unwrap_or("unknown");
|
||||||
|
|
||||||
match File::open(&full_path) {
|
match File::open(&full_path) {
|
||||||
Ok(mut file) => {
|
Ok(mut file) => {
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
match file.read_to_end(&mut buffer) {
|
match file.read_to_end(&mut buffer) {
|
||||||
Ok(_) => {
|
Ok(_) => Response::builder()
|
||||||
Response::builder()
|
.status(StatusCode::OK)
|
||||||
.status(StatusCode::OK)
|
.header(header::CONTENT_TYPE, "application/octet-stream")
|
||||||
.header(header::CONTENT_TYPE, "application/octet-stream")
|
.header(
|
||||||
.header(
|
header::CONTENT_DISPOSITION,
|
||||||
header::CONTENT_DISPOSITION,
|
format!("attachment; filename=\"{}\"", filename),
|
||||||
format!("attachment; filename=\"{}\"", filename)
|
)
|
||||||
)
|
.body(buffer.into())
|
||||||
.body(buffer.into())
|
.unwrap(),
|
||||||
.unwrap()
|
Err(e) => (
|
||||||
}
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error reading file: {}", e)).into_response()
|
format!("Error reading file: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error opening file: {}", e)).into_response()
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Error opening file: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_download_stats(
|
pub async fn get_download_stats(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
State(state): State<AppState>,
|
let db_path = format!(
|
||||||
) -> impl IntoResponse {
|
"{}{}",
|
||||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
state.db_dir.replace("users", "downloads"),
|
||||||
|
"/products.sqlite"
|
||||||
|
);
|
||||||
|
|
||||||
match DownloadDb::new(&db_path) {
|
match DownloadDb::new(&db_path) {
|
||||||
Ok(db) => {
|
Ok(db) => match db.get_all_files() {
|
||||||
match db.get_all_files() {
|
Ok(files) => {
|
||||||
Ok(files) => {
|
let total_downloads: i64 = files.iter().map(|f| f.download_count).sum();
|
||||||
let total_downloads: i64 = files.iter().map(|f| f.download_count).sum();
|
let top_files: Vec<_> = files
|
||||||
let top_files: Vec<_> = files.iter()
|
.iter()
|
||||||
.filter(|f| f.download_count > 0)
|
.filter(|f| f.download_count > 0)
|
||||||
.take(10)
|
.take(10)
|
||||||
.map(|f| serde_json::json!({
|
.map(|f| {
|
||||||
|
serde_json::json!({
|
||||||
"file_name": f.file_name,
|
"file_name": f.file_name,
|
||||||
"download_count": f.download_count
|
"download_count": f.download_count
|
||||||
}))
|
})
|
||||||
.collect();
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
(
|
(
|
||||||
StatusCode::OK,
|
StatusCode::OK,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"total_files": files.len(),
|
"total_files": files.len(),
|
||||||
"total_downloads": total_downloads,
|
"total_downloads": total_downloads,
|
||||||
"top_files": top_files
|
"top_files": top_files
|
||||||
}))
|
})),
|
||||||
)
|
)
|
||||||
}
|
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
|
||||||
}
|
}
|
||||||
}
|
Err(e) => (
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(serde_json::json!({"error": e.to_string()})),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(serde_json::json!({"error": e.to_string()})),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -160,26 +196,32 @@ pub async fn download_product_file(
|
|||||||
return (StatusCode::BAD_REQUEST, "Path is a directory, not a file").into_response();
|
return (StatusCode::BAD_REQUEST, "Path is a directory, not a file").into_response();
|
||||||
}
|
}
|
||||||
|
|
||||||
let filename = file_path.split('/').last().unwrap_or("unknown");
|
let filename = file_path.split('/').next_back().unwrap_or("unknown");
|
||||||
|
|
||||||
match File::open(&full_path) {
|
match File::open(&full_path) {
|
||||||
Ok(mut file) => {
|
Ok(mut file) => {
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
match file.read_to_end(&mut buffer) {
|
match file.read_to_end(&mut buffer) {
|
||||||
Ok(_) => {
|
Ok(_) => Response::builder()
|
||||||
Response::builder()
|
.status(StatusCode::OK)
|
||||||
.status(StatusCode::OK)
|
.header(header::CONTENT_TYPE, "application/octet-stream")
|
||||||
.header(header::CONTENT_TYPE, "application/octet-stream")
|
.header(
|
||||||
.header(
|
header::CONTENT_DISPOSITION,
|
||||||
header::CONTENT_DISPOSITION,
|
format!("attachment; filename=\"{}\"", filename),
|
||||||
format!("attachment; filename=\"{}\"", filename)
|
)
|
||||||
)
|
.body(buffer.into())
|
||||||
.body(buffer.into())
|
.unwrap(),
|
||||||
.unwrap()
|
Err(e) => (
|
||||||
}
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error reading file: {}", e)).into_response()
|
format!("Error reading file: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error opening file: {}", e)).into_response()
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Error opening file: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,24 +1,19 @@
|
|||||||
use axum::{
|
use axum::{
|
||||||
extract::{Path, State},
|
extract::Path,
|
||||||
http::{HeaderMap, StatusCode},
|
http::StatusCode,
|
||||||
response::{Html, IntoResponse, Json},
|
response::{IntoResponse, Json},
|
||||||
};
|
};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use crate::server::AppState;
|
|
||||||
use crate::download::storage;
|
use crate::download::storage;
|
||||||
|
|
||||||
pub async fn list_uploaded_files(
|
pub async fn list_uploaded_files(Path(user_id): Path<String>) -> impl IntoResponse {
|
||||||
Path(user_id): Path<String>,
|
|
||||||
) -> impl IntoResponse {
|
|
||||||
let file_list = storage::scan_uploaded_files(&user_id);
|
let file_list = storage::scan_uploaded_files(&user_id);
|
||||||
(StatusCode::OK, Json(file_list))
|
(StatusCode::OK, Json(file_list))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_file_info(
|
pub async fn get_file_info(Path((user_id, filename)): Path<(String, String)>) -> impl IntoResponse {
|
||||||
Path((user_id, filename)): Path<(String, String)>,
|
|
||||||
) -> impl IntoResponse {
|
|
||||||
let base_path = format!("/Users/accusys/Downloads/{}", user_id);
|
let base_path = format!("/Users/accusys/Downloads/{}", user_id);
|
||||||
let file_path = PathBuf::from(&base_path).join(&filename);
|
let file_path = PathBuf::from(&base_path).join(&filename);
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
pub mod models;
|
|
||||||
pub mod db;
|
pub mod db;
|
||||||
pub mod handlers;
|
|
||||||
pub mod storage;
|
|
||||||
pub mod product_handlers;
|
|
||||||
pub mod download_handler;
|
pub mod download_handler;
|
||||||
|
pub mod handlers;
|
||||||
|
pub mod models;
|
||||||
|
pub mod product_handlers;
|
||||||
|
pub mod storage;
|
||||||
|
|
||||||
pub use models::*;
|
|
||||||
pub use db::{DownloadDb, Product, ProductFile, SeriesStats};
|
pub use db::{DownloadDb, Product, ProductFile, SeriesStats};
|
||||||
pub use handlers::*;
|
|
||||||
pub use product_handlers::*;
|
|
||||||
pub use download_handler::*;
|
pub use download_handler::*;
|
||||||
|
pub use handlers::*;
|
||||||
|
pub use models::*;
|
||||||
|
pub use product_handlers::*;
|
||||||
|
|||||||
@@ -1,33 +1,42 @@
|
|||||||
use axum::{
|
use axum::{
|
||||||
extract::{Path, State},
|
extract::{Path, State},
|
||||||
http::StatusCode,
|
http::StatusCode,
|
||||||
response::{Json, IntoResponse},
|
response::{IntoResponse, Json},
|
||||||
};
|
};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::download::db::DownloadDb;
|
||||||
use crate::server::AppState;
|
use crate::server::AppState;
|
||||||
use crate::download::db::{DownloadDb, Product, ProductFile, SeriesStats};
|
|
||||||
|
|
||||||
pub async fn list_all_products(
|
pub async fn list_all_products(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
State(state): State<AppState>,
|
let db_path = format!(
|
||||||
) -> impl IntoResponse {
|
"{}{}",
|
||||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
state.db_dir.replace("users", "downloads"),
|
||||||
|
"/products.sqlite"
|
||||||
|
);
|
||||||
|
|
||||||
match DownloadDb::new(&db_path) {
|
match DownloadDb::new(&db_path) {
|
||||||
Ok(db) => {
|
Ok(db) => match db.get_all_products() {
|
||||||
match db.get_all_products() {
|
Ok(products) => (
|
||||||
Ok(products) => (StatusCode::OK, Json(json!({
|
StatusCode::OK,
|
||||||
|
Json(json!({
|
||||||
"products": products,
|
"products": products,
|
||||||
"total": products.len()
|
"total": products.len()
|
||||||
}))),
|
})),
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
),
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({
|
||||||
"error": e.to_string()
|
"error": e.to_string()
|
||||||
}))),
|
})),
|
||||||
}
|
),
|
||||||
}
|
},
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
Err(e) => (
|
||||||
"error": e.to_string()
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
}))),
|
Json(json!({
|
||||||
|
"error": e.to_string()
|
||||||
|
})),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,47 +44,67 @@ pub async fn list_products_by_series(
|
|||||||
Path(series): Path<String>,
|
Path(series): Path<String>,
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
let db_path = format!(
|
||||||
|
"{}{}",
|
||||||
|
state.db_dir.replace("users", "downloads"),
|
||||||
|
"/products.sqlite"
|
||||||
|
);
|
||||||
|
|
||||||
match DownloadDb::new(&db_path) {
|
match DownloadDb::new(&db_path) {
|
||||||
Ok(db) => {
|
Ok(db) => match db.get_products_by_series(&series) {
|
||||||
match db.get_products_by_series(&series) {
|
Ok(products) => (
|
||||||
Ok(products) => (StatusCode::OK, Json(json!({
|
StatusCode::OK,
|
||||||
|
Json(json!({
|
||||||
"series": series,
|
"series": series,
|
||||||
"products": products,
|
"products": products,
|
||||||
"total": products.len()
|
"total": products.len()
|
||||||
}))),
|
})),
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
),
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({
|
||||||
"error": e.to_string()
|
"error": e.to_string()
|
||||||
}))),
|
})),
|
||||||
}
|
),
|
||||||
}
|
},
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
Err(e) => (
|
||||||
"error": e.to_string()
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
}))),
|
Json(json!({
|
||||||
|
"error": e.to_string()
|
||||||
|
})),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_series_stats(
|
pub async fn get_series_stats(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
State(state): State<AppState>,
|
let db_path = format!(
|
||||||
) -> impl IntoResponse {
|
"{}{}",
|
||||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
state.db_dir.replace("users", "downloads"),
|
||||||
|
"/products.sqlite"
|
||||||
|
);
|
||||||
|
|
||||||
match DownloadDb::new(&db_path) {
|
match DownloadDb::new(&db_path) {
|
||||||
Ok(db) => {
|
Ok(db) => match db.get_series_stats() {
|
||||||
match db.get_series_stats() {
|
Ok(stats) => (
|
||||||
Ok(stats) => (StatusCode::OK, Json(json!({
|
StatusCode::OK,
|
||||||
|
Json(json!({
|
||||||
"series_stats": stats,
|
"series_stats": stats,
|
||||||
"total_series": stats.len()
|
"total_series": stats.len()
|
||||||
}))),
|
})),
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
),
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({
|
||||||
"error": e.to_string()
|
"error": e.to_string()
|
||||||
}))),
|
})),
|
||||||
}
|
),
|
||||||
}
|
},
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
Err(e) => (
|
||||||
"error": e.to_string()
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
}))),
|
Json(json!({
|
||||||
|
"error": e.to_string()
|
||||||
|
})),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,25 +112,36 @@ pub async fn get_product_files(
|
|||||||
Path(product_id): Path<i64>,
|
Path(product_id): Path<i64>,
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
let db_path = format!(
|
||||||
|
"{}{}",
|
||||||
|
state.db_dir.replace("users", "downloads"),
|
||||||
|
"/products.sqlite"
|
||||||
|
);
|
||||||
|
|
||||||
match DownloadDb::new(&db_path) {
|
match DownloadDb::new(&db_path) {
|
||||||
Ok(db) => {
|
Ok(db) => match db.get_files_by_product(product_id) {
|
||||||
match db.get_files_by_product(product_id) {
|
Ok(files) => (
|
||||||
Ok(files) => (StatusCode::OK, Json(json!({
|
StatusCode::OK,
|
||||||
|
Json(json!({
|
||||||
"product_id": product_id,
|
"product_id": product_id,
|
||||||
"files": files,
|
"files": files,
|
||||||
"total_files": files.len(),
|
"total_files": files.len(),
|
||||||
"total_size": files.iter().map(|f| f.file_size).sum::<u64>()
|
"total_size": files.iter().map(|f| f.file_size).sum::<u64>()
|
||||||
}))),
|
})),
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
),
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({
|
||||||
"error": e.to_string()
|
"error": e.to_string()
|
||||||
}))),
|
})),
|
||||||
}
|
),
|
||||||
}
|
},
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
Err(e) => (
|
||||||
"error": e.to_string()
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
}))),
|
Json(json!({
|
||||||
|
"error": e.to_string()
|
||||||
|
})),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,29 +149,40 @@ pub async fn create_product_handler(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Json(payload): Json<serde_json::Value>,
|
Json(payload): Json<serde_json::Value>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
let db_path = format!(
|
||||||
|
"{}{}",
|
||||||
|
state.db_dir.replace("users", "downloads"),
|
||||||
|
"/products.sqlite"
|
||||||
|
);
|
||||||
|
|
||||||
let product_name = payload["product_name"].as_str().unwrap_or("");
|
let product_name = payload["product_name"].as_str().unwrap_or("");
|
||||||
let series = payload["series"].as_str().unwrap_or("");
|
let series = payload["series"].as_str().unwrap_or("");
|
||||||
let description = payload["description"].as_str();
|
let description = payload["description"].as_str();
|
||||||
|
|
||||||
match DownloadDb::new(&db_path) {
|
match DownloadDb::new(&db_path) {
|
||||||
Ok(mut db) => {
|
Ok(mut db) => match db.create_product(product_name, series, description) {
|
||||||
match db.create_product(product_name, series, description) {
|
Ok(product_id) => (
|
||||||
Ok(product_id) => (StatusCode::OK, Json(json!({
|
StatusCode::OK,
|
||||||
|
Json(json!({
|
||||||
"ok": true,
|
"ok": true,
|
||||||
"product_id": product_id,
|
"product_id": product_id,
|
||||||
"product_name": product_name,
|
"product_name": product_name,
|
||||||
"series": series
|
"series": series
|
||||||
}))),
|
})),
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
),
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({
|
||||||
"error": e.to_string()
|
"error": e.to_string()
|
||||||
}))),
|
})),
|
||||||
}
|
),
|
||||||
}
|
},
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
Err(e) => (
|
||||||
"error": e.to_string()
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
}))),
|
Json(json!({
|
||||||
|
"error": e.to_string()
|
||||||
|
})),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,7 +191,11 @@ pub async fn assign_files_to_product(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Json(payload): Json<serde_json::Value>,
|
Json(payload): Json<serde_json::Value>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
let db_path = format!(
|
||||||
|
"{}{}",
|
||||||
|
state.db_dir.replace("users", "downloads"),
|
||||||
|
"/products.sqlite"
|
||||||
|
);
|
||||||
|
|
||||||
let files_vec = payload["files"].as_array().cloned().unwrap_or_default();
|
let files_vec = payload["files"].as_array().cloned().unwrap_or_default();
|
||||||
let files = files_vec.as_slice();
|
let files = files_vec.as_slice();
|
||||||
@@ -156,7 +211,8 @@ pub async fn assign_files_to_product(
|
|||||||
let file_size = file["file_size"].as_u64().unwrap_or(0);
|
let file_size = file["file_size"].as_u64().unwrap_or(0);
|
||||||
let file_hash = file["file_hash"].as_str();
|
let file_hash = file["file_hash"].as_str();
|
||||||
|
|
||||||
match db.add_file_to_product(product_id, file_path, file_name, file_size, file_hash) {
|
match db.add_file_to_product(product_id, file_path, file_name, file_size, file_hash)
|
||||||
|
{
|
||||||
Ok(_) => assigned_count += 1,
|
Ok(_) => assigned_count += 1,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
errors.push(format!("Failed to assign {}: {}", file_path, e));
|
errors.push(format!("Failed to assign {}: {}", file_path, e));
|
||||||
@@ -165,23 +221,32 @@ pub async fn assign_files_to_product(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if errors.is_empty() {
|
if errors.is_empty() {
|
||||||
(StatusCode::OK, Json(json!({
|
(
|
||||||
"ok": true,
|
StatusCode::OK,
|
||||||
"product_id": product_id,
|
Json(json!({
|
||||||
"assigned_count": assigned_count
|
"ok": true,
|
||||||
})))
|
"product_id": product_id,
|
||||||
|
"assigned_count": assigned_count
|
||||||
|
})),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
(StatusCode::PARTIAL_CONTENT, Json(json!({
|
(
|
||||||
"ok": true,
|
StatusCode::PARTIAL_CONTENT,
|
||||||
"product_id": product_id,
|
Json(json!({
|
||||||
"assigned_count": assigned_count,
|
"ok": true,
|
||||||
"errors": errors
|
"product_id": product_id,
|
||||||
})))
|
"assigned_count": assigned_count,
|
||||||
|
"errors": errors
|
||||||
|
})),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
Err(e) => (
|
||||||
"error": e.to_string()
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
}))),
|
Json(json!({
|
||||||
|
"error": e.to_string()
|
||||||
|
})),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,24 +254,35 @@ pub async fn delete_product(
|
|||||||
Path(product_id): Path<i64>,
|
Path(product_id): Path<i64>,
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let db_path = format!("{}{}", state.db_dir.replace("users", "downloads"), "/products.sqlite");
|
let db_path = format!(
|
||||||
|
"{}{}",
|
||||||
|
state.db_dir.replace("users", "downloads"),
|
||||||
|
"/products.sqlite"
|
||||||
|
);
|
||||||
|
|
||||||
match DownloadDb::new(&db_path) {
|
match DownloadDb::new(&db_path) {
|
||||||
Ok(mut db) => {
|
Ok(mut db) => match db.delete_product_with_files(product_id) {
|
||||||
match db.delete_product_with_files(product_id) {
|
Ok((deleted_files, deleted_product)) => (
|
||||||
Ok((deleted_files, deleted_product)) => (StatusCode::OK, Json(json!({
|
StatusCode::OK,
|
||||||
|
Json(json!({
|
||||||
"ok": true,
|
"ok": true,
|
||||||
"product_id": product_id,
|
"product_id": product_id,
|
||||||
"deleted_files": deleted_files,
|
"deleted_files": deleted_files,
|
||||||
"deleted_product": deleted_product
|
"deleted_product": deleted_product
|
||||||
}))),
|
})),
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
),
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({
|
||||||
"error": e.to_string()
|
"error": e.to_string()
|
||||||
}))),
|
})),
|
||||||
}
|
),
|
||||||
}
|
},
|
||||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({
|
Err(e) => (
|
||||||
"error": e.to_string()
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
}))),
|
Json(json!({
|
||||||
|
"error": e.to_string()
|
||||||
|
})),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::Path;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct FileInfo {
|
pub struct FileInfo {
|
||||||
@@ -33,7 +33,7 @@ pub fn scan_uploaded_files(user_id: &str) -> FileListResponse {
|
|||||||
|
|
||||||
FileListResponse {
|
FileListResponse {
|
||||||
user_id: user_id.to_string(),
|
user_id: user_id.to_string(),
|
||||||
base_path: base_path,
|
base_path,
|
||||||
total_files: files.len(),
|
total_files: files.len(),
|
||||||
total_size,
|
total_size,
|
||||||
files,
|
files,
|
||||||
@@ -51,22 +51,23 @@ fn scan_directory_recursive(
|
|||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
|
|
||||||
if path.is_file() {
|
if path.is_file() {
|
||||||
let filename = path.file_name()
|
let filename = path
|
||||||
|
.file_name()
|
||||||
.and_then(|n| n.to_str())
|
.and_then(|n| n.to_str())
|
||||||
.unwrap_or("unknown")
|
.unwrap_or("unknown")
|
||||||
.to_string();
|
.to_string();
|
||||||
|
|
||||||
let file_size = entry.metadata()
|
let file_size = entry.metadata().map(|m| m.len()).unwrap_or(0);
|
||||||
.map(|m| m.len())
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
let relative_path = path.strip_prefix(base)
|
let relative_path = path
|
||||||
|
.strip_prefix(base)
|
||||||
.ok()
|
.ok()
|
||||||
.and_then(|p| p.to_str())
|
.and_then(|p| p.to_str())
|
||||||
.map(|s| s.to_string())
|
.map(|s| s.to_string())
|
||||||
.unwrap_or_else(|| filename.clone());
|
.unwrap_or_else(|| filename.clone());
|
||||||
|
|
||||||
let upload_time = entry.metadata()
|
let upload_time = entry
|
||||||
|
.metadata()
|
||||||
.ok()
|
.ok()
|
||||||
.and_then(|m| m.modified().ok())
|
.and_then(|m| m.modified().ok())
|
||||||
.and_then(|t| {
|
.and_then(|t| {
|
||||||
@@ -79,7 +80,10 @@ fn scan_directory_recursive(
|
|||||||
let file_hash = if file_size > 0 {
|
let file_hash = if file_size > 0 {
|
||||||
compute_file_hash(&path).ok()
|
compute_file_hash(&path).ok()
|
||||||
} else {
|
} else {
|
||||||
Some("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string())
|
Some(
|
||||||
|
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
|
||||||
|
.to_string(),
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
files.push(FileInfo {
|
files.push(FileInfo {
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use pulldown_cmark::{Parser, Event, Tag, HeadingLevel, TagEnd};
|
|
||||||
use regex::Regex;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct MarkdownFile {
|
pub struct MarkdownFile {
|
||||||
@@ -49,7 +47,11 @@ pub fn parse_category_markdown(content: &str) -> Result<CategoryMarkdown> {
|
|||||||
let line = lines[i].trim();
|
let line = lines[i].trim();
|
||||||
|
|
||||||
if line.contains("**Category**:") {
|
if line.contains("**Category**:") {
|
||||||
category = line.replace("**Category**:", "").replace("**", "").trim().to_string();
|
category = line
|
||||||
|
.replace("**Category**:", "")
|
||||||
|
.replace("**", "")
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
} else if line.starts_with("## ") {
|
} else if line.starts_with("## ") {
|
||||||
if !current_product.is_empty() && !current_files.is_empty() {
|
if !current_product.is_empty() && !current_files.is_empty() {
|
||||||
sections.push(CategorySection {
|
sections.push(CategorySection {
|
||||||
@@ -72,7 +74,11 @@ pub fn parse_category_markdown(content: &str) -> Result<CategoryMarkdown> {
|
|||||||
current_files.push(MarkdownFile {
|
current_files.push(MarkdownFile {
|
||||||
filename,
|
filename,
|
||||||
size: Some(size),
|
size: Some(size),
|
||||||
download_url: line.trim_start_matches('`').trim_end_matches('`').trim().to_string(),
|
download_url: line
|
||||||
|
.trim_start_matches('`')
|
||||||
|
.trim_end_matches('`')
|
||||||
|
.trim()
|
||||||
|
.to_string(),
|
||||||
});
|
});
|
||||||
pending_file = None;
|
pending_file = None;
|
||||||
}
|
}
|
||||||
@@ -102,7 +108,11 @@ pub fn parse_series_markdown(content: &str) -> Result<SeriesMarkdown> {
|
|||||||
let line = lines[i].trim();
|
let line = lines[i].trim();
|
||||||
|
|
||||||
if line.starts_with("# ") && line.contains("Download Links") {
|
if line.starts_with("# ") && line.contains("Download Links") {
|
||||||
series = line.replace("# ", "").replace(" Download Links", "").trim().to_string();
|
series = line
|
||||||
|
.replace("# ", "")
|
||||||
|
.replace(" Download Links", "")
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
} else if line.starts_with("## ") {
|
} else if line.starts_with("## ") {
|
||||||
if !current_category.is_empty() && !current_files.is_empty() {
|
if !current_category.is_empty() && !current_files.is_empty() {
|
||||||
sections.push(SeriesSection {
|
sections.push(SeriesSection {
|
||||||
@@ -125,7 +135,11 @@ pub fn parse_series_markdown(content: &str) -> Result<SeriesMarkdown> {
|
|||||||
current_files.push(MarkdownFile {
|
current_files.push(MarkdownFile {
|
||||||
filename,
|
filename,
|
||||||
size: Some(size),
|
size: Some(size),
|
||||||
download_url: line.trim_start_matches('`').trim_end_matches('`').trim().to_string(),
|
download_url: line
|
||||||
|
.trim_start_matches('`')
|
||||||
|
.trim_end_matches('`')
|
||||||
|
.trim()
|
||||||
|
.to_string(),
|
||||||
});
|
});
|
||||||
pending_file = None;
|
pending_file = None;
|
||||||
}
|
}
|
||||||
@@ -149,7 +163,9 @@ pub fn read_category_files(dir: &Path) -> Result<Vec<(String, String)>> {
|
|||||||
let entry = entry?;
|
let entry = entry?;
|
||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
|
|
||||||
if path.extension().map_or(false, |ext| ext == "md") && path.file_name() != Some(std::ffi::OsStr::new("README.md")) {
|
if path.extension().is_some_and(|ext| ext == "md")
|
||||||
|
&& path.file_name() != Some(std::ffi::OsStr::new("README.md"))
|
||||||
|
{
|
||||||
let filename = path.file_name().unwrap().to_string_lossy().to_string();
|
let filename = path.file_name().unwrap().to_string_lossy().to_string();
|
||||||
let content = fs::read_to_string(&path)?;
|
let content = fs::read_to_string(&path)?;
|
||||||
files.push((filename, content));
|
files.push((filename, content));
|
||||||
@@ -166,7 +182,9 @@ pub fn read_series_files(dir: &Path) -> Result<Vec<(String, String)>> {
|
|||||||
let entry = entry?;
|
let entry = entry?;
|
||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
|
|
||||||
if path.extension().map_or(false, |ext| ext == "md") && path.file_name() != Some(std::ffi::OsStr::new("README.md")) {
|
if path.extension().is_some_and(|ext| ext == "md")
|
||||||
|
&& path.file_name() != Some(std::ffi::OsStr::new("README.md"))
|
||||||
|
{
|
||||||
let filename = path.file_name().unwrap().to_string_lossy().to_string();
|
let filename = path.file_name().unwrap().to_string_lossy().to_string();
|
||||||
let content = fs::read_to_string(&path)?;
|
let content = fs::read_to_string(&path)?;
|
||||||
files.push((filename, content));
|
files.push((filename, content));
|
||||||
@@ -176,11 +194,15 @@ pub fn read_series_files(dir: &Path) -> Result<Vec<(String, String)>> {
|
|||||||
Ok(files)
|
Ok(files)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn import_categories_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type: &str) -> Result<()> {
|
pub fn import_categories_to_db(
|
||||||
|
conn: &rusqlite::Connection,
|
||||||
|
user_id: &str,
|
||||||
|
tree_type: &str,
|
||||||
|
) -> Result<()> {
|
||||||
use crate::FileTree;
|
use crate::FileTree;
|
||||||
use filetree::node::{FileNode, Aliases, NodeType};
|
use filetree::node::{Aliases, FileNode, NodeType};
|
||||||
use uuid::Uuid;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
let category_dir = Path::new("/Users/accusys/markbase/data/downloads/by_category");
|
let category_dir = Path::new("/Users/accusys/markbase/data/downloads/by_category");
|
||||||
let files = read_category_files(category_dir)?;
|
let files = read_category_files(category_dir)?;
|
||||||
@@ -192,7 +214,11 @@ pub fn import_categories_to_db(conn: &rusqlite::Connection, user_id: &str, tree_
|
|||||||
for (_filename, content) in files {
|
for (_filename, content) in files {
|
||||||
let parsed = parse_category_markdown(&content)?;
|
let parsed = parse_category_markdown(&content)?;
|
||||||
|
|
||||||
println!("Parsed category: '{}', sections: {}", parsed.category, parsed.sections.len());
|
println!(
|
||||||
|
"Parsed category: '{}', sections: {}",
|
||||||
|
parsed.category,
|
||||||
|
parsed.sections.len()
|
||||||
|
);
|
||||||
|
|
||||||
if parsed.category.is_empty() {
|
if parsed.category.is_empty() {
|
||||||
println!("Warning: category is empty, skipping");
|
println!("Warning: category is empty, skipping");
|
||||||
@@ -222,14 +248,21 @@ pub fn import_categories_to_db(conn: &rusqlite::Connection, user_id: &str, tree_
|
|||||||
sort_order: 0,
|
sort_order: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
println!("Inserting category node: {} (id: {})", category_node.label, category_node_id);
|
println!(
|
||||||
|
"Inserting category node: {} (id: {})",
|
||||||
|
category_node.label, category_node_id
|
||||||
|
);
|
||||||
|
|
||||||
tree.insert_node(conn, &category_node)?;
|
tree.insert_node(conn, &category_node)?;
|
||||||
|
|
||||||
println!("Category node inserted successfully");
|
println!("Category node inserted successfully");
|
||||||
|
|
||||||
for section in parsed.sections {
|
for section in parsed.sections {
|
||||||
println!("Processing section: {} with {} files", section.product, section.files.len());
|
println!(
|
||||||
|
"Processing section: {} with {} files",
|
||||||
|
section.product,
|
||||||
|
section.files.len()
|
||||||
|
);
|
||||||
|
|
||||||
let product_node_id = Uuid::new_v4().to_string();
|
let product_node_id = Uuid::new_v4().to_string();
|
||||||
let mut aliases_map = HashMap::new();
|
let mut aliases_map = HashMap::new();
|
||||||
@@ -260,7 +293,10 @@ pub fn import_categories_to_db(conn: &rusqlite::Connection, user_id: &str, tree_
|
|||||||
let file_node_id = Uuid::new_v4().to_string();
|
let file_node_id = Uuid::new_v4().to_string();
|
||||||
let mut aliases_map = HashMap::new();
|
let mut aliases_map = HashMap::new();
|
||||||
aliases_map.insert("download_url".to_string(), file.download_url.clone());
|
aliases_map.insert("download_url".to_string(), file.download_url.clone());
|
||||||
aliases_map.insert("file_size_display".to_string(), file.size.clone().unwrap_or_else(|| "Unknown".to_string()));
|
aliases_map.insert(
|
||||||
|
"file_size_display".to_string(),
|
||||||
|
file.size.clone().unwrap_or_else(|| "Unknown".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
let file_node = FileNode {
|
let file_node = FileNode {
|
||||||
node_id: file_node_id.clone(),
|
node_id: file_node_id.clone(),
|
||||||
@@ -289,11 +325,15 @@ pub fn import_categories_to_db(conn: &rusqlite::Connection, user_id: &str, tree_
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn import_series_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type: &str) -> Result<()> {
|
pub fn import_series_to_db(
|
||||||
|
conn: &rusqlite::Connection,
|
||||||
|
user_id: &str,
|
||||||
|
tree_type: &str,
|
||||||
|
) -> Result<()> {
|
||||||
use crate::FileTree;
|
use crate::FileTree;
|
||||||
use filetree::node::{FileNode, Aliases, NodeType};
|
use filetree::node::{Aliases, FileNode, NodeType};
|
||||||
use uuid::Uuid;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
let series_dir = Path::new("/Users/accusys/markbase/data/downloads/by_series");
|
let series_dir = Path::new("/Users/accusys/markbase/data/downloads/by_series");
|
||||||
let files = read_series_files(series_dir)?;
|
let files = read_series_files(series_dir)?;
|
||||||
@@ -305,7 +345,11 @@ pub fn import_series_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type
|
|||||||
for (_filename, content) in files {
|
for (_filename, content) in files {
|
||||||
let parsed = parse_series_markdown(&content)?;
|
let parsed = parse_series_markdown(&content)?;
|
||||||
|
|
||||||
println!("Parsed series: '{}', sections: {}", parsed.series, parsed.sections.len());
|
println!(
|
||||||
|
"Parsed series: '{}', sections: {}",
|
||||||
|
parsed.series,
|
||||||
|
parsed.sections.len()
|
||||||
|
);
|
||||||
|
|
||||||
if parsed.series.is_empty() {
|
if parsed.series.is_empty() {
|
||||||
println!("Warning: series is empty, skipping");
|
println!("Warning: series is empty, skipping");
|
||||||
@@ -340,7 +384,11 @@ pub fn import_series_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type
|
|||||||
println!("Series node inserted successfully");
|
println!("Series node inserted successfully");
|
||||||
|
|
||||||
for section in parsed.sections {
|
for section in parsed.sections {
|
||||||
println!("Processing section: {} with {} files", section.category, section.files.len());
|
println!(
|
||||||
|
"Processing section: {} with {} files",
|
||||||
|
section.category,
|
||||||
|
section.files.len()
|
||||||
|
);
|
||||||
|
|
||||||
let category_node_id = Uuid::new_v4().to_string();
|
let category_node_id = Uuid::new_v4().to_string();
|
||||||
let mut aliases_map = HashMap::new();
|
let mut aliases_map = HashMap::new();
|
||||||
@@ -371,7 +419,10 @@ pub fn import_series_to_db(conn: &rusqlite::Connection, user_id: &str, tree_type
|
|||||||
let file_node_id = Uuid::new_v4().to_string();
|
let file_node_id = Uuid::new_v4().to_string();
|
||||||
let mut aliases_map = HashMap::new();
|
let mut aliases_map = HashMap::new();
|
||||||
aliases_map.insert("download_url".to_string(), file.download_url.clone());
|
aliases_map.insert("download_url".to_string(), file.download_url.clone());
|
||||||
aliases_map.insert("file_size_display".to_string(), file.size.clone().unwrap_or_else(|| "Unknown".to_string()));
|
aliases_map.insert(
|
||||||
|
"file_size_display".to_string(),
|
||||||
|
file.size.clone().unwrap_or_else(|| "Unknown".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
let file_node = FileNode {
|
let file_node = FileNode {
|
||||||
node_id: file_node_id.clone(),
|
node_id: file_node_id.clone(),
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
pub mod audio;
|
|
||||||
pub mod auth;
|
|
||||||
pub mod audit;
|
|
||||||
pub mod cli;
|
|
||||||
pub mod api;
|
pub mod api;
|
||||||
|
pub mod archive; // Archive Module - Universal Compression Format Support (Phase 1-3完成)
|
||||||
|
pub mod audio;
|
||||||
|
pub mod audit;
|
||||||
|
pub mod auth;
|
||||||
|
pub mod category_view;
|
||||||
|
pub mod cli;
|
||||||
pub mod command;
|
pub mod command;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod download;
|
pub mod download;
|
||||||
|
pub mod import_markdown;
|
||||||
pub mod pg_client;
|
pub mod pg_client;
|
||||||
pub mod render;
|
pub mod render;
|
||||||
pub mod rsync;
|
pub mod rsync;
|
||||||
@@ -14,20 +17,17 @@ pub mod s3_auth;
|
|||||||
pub mod s3_config;
|
pub mod s3_config;
|
||||||
pub mod s3_xml;
|
pub mod s3_xml;
|
||||||
pub mod scan;
|
pub mod scan;
|
||||||
pub mod server;
|
pub mod server; // Category View Module - 双视图管理(Phase 1)
|
||||||
pub mod archive; // Archive Module - Universal Compression Format Support (Phase 1-3完成)
|
// pub mod sftp; // ⚠️ russh版本(已禁用)
|
||||||
pub mod category_view;
|
// pub mod ssh2_server; // ssh2服务器(已禁用)
|
||||||
pub mod import_markdown; // Category View Module - 双视图管理(Phase 1)
|
// pub mod ssh2_mod; // ssh2辅助模块(已禁用)
|
||||||
// pub mod sftp; // ⚠️ russh版本(已禁用)
|
pub mod provider; // DataProvider抽象层(Phase 5)
|
||||||
// pub mod ssh2_server; // ssh2服务器(已禁用)
|
pub mod ssh_server; // SSH服务器(Phase 1-9完成,正在修复编译错误)⭐⭐⭐⭐⭐
|
||||||
// pub mod ssh2_mod; // ssh2辅助模块(已禁用)
|
|
||||||
pub mod ssh_server; // SSH服务器(Phase 1-9完成,正在修复编译错误)⭐⭐⭐⭐⭐
|
|
||||||
pub mod sync;
|
pub mod sync;
|
||||||
pub mod provider; // DataProvider抽象层(Phase 5)
|
pub mod vfs; // VFS抽象层(Phase 1-6重构计划)
|
||||||
pub mod vfs; // VFS抽象层(Phase 1-6重构计划)
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod security_audit; // Security Audit Module - Phase 9
|
mod security_audit; // Security Audit Module - Phase 9
|
||||||
|
|
||||||
// Re-export from external filetree crate
|
// Re-export from external filetree crate
|
||||||
pub use filetree::node::FileNode;
|
pub use filetree::node::FileNode;
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use markbase_core::cli::Cli;
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use markbase_core::cli::Cli;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
|||||||
@@ -10,6 +10,12 @@ pub struct PgClient {
|
|||||||
database: String,
|
database: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for PgClient {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl PgClient {
|
impl PgClient {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
pub mod sqlite;
|
|
||||||
pub mod pg;
|
pub mod pg;
|
||||||
|
pub mod sqlite;
|
||||||
|
|
||||||
pub use sqlite::SqliteProvider;
|
|
||||||
pub use pg::PgProvider;
|
pub use pg::PgProvider;
|
||||||
|
pub use sqlite::SqliteProvider;
|
||||||
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
@@ -57,7 +57,10 @@ pub trait DataProvider: Send + Sync {
|
|||||||
|
|
||||||
/// 检查用户是否存在且启用
|
/// 检查用户是否存在且启用
|
||||||
fn user_exists(&self, username: &str) -> Result<bool, ProviderError> {
|
fn user_exists(&self, username: &str) -> Result<bool, ProviderError> {
|
||||||
Ok(self.get_user(username)?.map(|u| u.status == 1).unwrap_or(false))
|
Ok(self
|
||||||
|
.get_user(username)?
|
||||||
|
.map(|u| u.status == 1)
|
||||||
|
.unwrap_or(false))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 获取用户的公开密钥列表(OpenSSH authorized_keys格式)
|
/// 获取用户的公开密钥列表(OpenSSH authorized_keys格式)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use std::path::PathBuf;
|
|
||||||
use postgres::{Client, NoTls};
|
|
||||||
use bcrypt::verify;
|
|
||||||
use super::{DataProvider, ProviderError, User};
|
use super::{DataProvider, ProviderError, User};
|
||||||
|
use bcrypt::verify;
|
||||||
|
use postgres::{Client, NoTls};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
/// PostgreSQL 数据提供者(兼容 SFTPGo 的 users 表)
|
/// PostgreSQL 数据提供者(兼容 SFTPGo 的 users 表)
|
||||||
pub struct PgProvider {
|
pub struct PgProvider {
|
||||||
@@ -13,7 +13,9 @@ impl PgProvider {
|
|||||||
///
|
///
|
||||||
/// 连接字符串格式:host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026
|
/// 连接字符串格式:host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026
|
||||||
pub fn new(conn_str: &str) -> Result<Self, ProviderError> {
|
pub fn new(conn_str: &str) -> Result<Self, ProviderError> {
|
||||||
Ok(Self { conn_str: conn_str.to_string() })
|
Ok(Self {
|
||||||
|
conn_str: conn_str.to_string(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_params(
|
pub fn from_params(
|
||||||
@@ -40,18 +42,22 @@ impl DataProvider for PgProvider {
|
|||||||
fn get_user(&self, username: &str) -> Result<Option<User>, ProviderError> {
|
fn get_user(&self, username: &str) -> Result<Option<User>, ProviderError> {
|
||||||
let mut conn = self.open_conn()?;
|
let mut conn = self.open_conn()?;
|
||||||
|
|
||||||
let result = conn.query_opt(
|
let result = conn
|
||||||
"SELECT username, password, home_dir, permissions, uid, gid, status
|
.query_opt(
|
||||||
|
"SELECT username, password, home_dir, permissions, uid, gid, status
|
||||||
FROM users WHERE username = $1 AND status = 1",
|
FROM users WHERE username = $1 AND status = 1",
|
||||||
&[&username],
|
&[&username],
|
||||||
).map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
|
)
|
||||||
|
.map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Some(row) => Ok(Some(User {
|
Some(row) => Ok(Some(User {
|
||||||
username: row.get(0),
|
username: row.get(0),
|
||||||
password_hash: row.get::<_, Option<String>>(1).unwrap_or_default(),
|
password_hash: row.get::<_, Option<String>>(1).unwrap_or_default(),
|
||||||
home_dir: PathBuf::from(row.get::<_, String>(2)),
|
home_dir: PathBuf::from(row.get::<_, String>(2)),
|
||||||
permissions: row.get::<_, Option<String>>(3).unwrap_or_else(|| "*".to_string()),
|
permissions: row
|
||||||
|
.get::<_, Option<String>>(3)
|
||||||
|
.unwrap_or_else(|| "*".to_string()),
|
||||||
uid: row.get::<_, i64>(4) as u32,
|
uid: row.get::<_, i64>(4) as u32,
|
||||||
gid: row.get::<_, i64>(5) as u32,
|
gid: row.get::<_, i64>(5) as u32,
|
||||||
status: row.get(6),
|
status: row.get(6),
|
||||||
@@ -75,24 +81,31 @@ impl DataProvider for PgProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
|
fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
|
||||||
Ok(self.get_user(username)?.map(|u| u.home_dir.to_string_lossy().to_string()))
|
Ok(self
|
||||||
|
.get_user(username)?
|
||||||
|
.map(|u| u.home_dir.to_string_lossy().to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_public_keys(&self, username: &str) -> Result<Vec<String>, ProviderError> {
|
fn get_public_keys(&self, username: &str) -> Result<Vec<String>, ProviderError> {
|
||||||
let mut conn = self.open_conn()?;
|
let mut conn = self.open_conn()?;
|
||||||
let result = conn.query_opt(
|
let result = conn
|
||||||
"SELECT public_keys FROM users WHERE username = $1 AND status = 1",
|
.query_opt(
|
||||||
&[&username],
|
"SELECT public_keys FROM users WHERE username = $1 AND status = 1",
|
||||||
).map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
|
&[&username],
|
||||||
|
)
|
||||||
|
.map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Some(row) => {
|
Some(row) => {
|
||||||
let json_str: Option<String> = row.get(0);
|
let json_str: Option<String> = row.get(0);
|
||||||
match json_str {
|
match json_str {
|
||||||
Some(s) if !s.is_empty() => {
|
Some(s) if !s.is_empty() => {
|
||||||
let keys: Vec<serde_json::Value> = serde_json::from_str(&s)
|
let keys: Vec<serde_json::Value> =
|
||||||
.map_err(|e| ProviderError::Internal(format!("JSON parse error: {}", e)))?;
|
serde_json::from_str(&s).map_err(|e| {
|
||||||
Ok(keys.iter()
|
ProviderError::Internal(format!("JSON parse error: {}", e))
|
||||||
|
})?;
|
||||||
|
Ok(keys
|
||||||
|
.iter()
|
||||||
.filter_map(|v| v.get("public_key")?.as_str().map(|s| s.to_string()))
|
.filter_map(|v| v.get("public_key")?.as_str().map(|s| s.to_string()))
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
@@ -112,7 +125,7 @@ mod tests {
|
|||||||
fn test_pg_provider_connection() {
|
fn test_pg_provider_connection() {
|
||||||
// 仅当 SFTPGo PostgreSQL 可用时运行
|
// 仅当 SFTPGo PostgreSQL 可用时运行
|
||||||
let provider = PgProvider::new(
|
let provider = PgProvider::new(
|
||||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||||
);
|
);
|
||||||
assert!(provider.is_ok(), "Should connect to SFTPGo PostgreSQL");
|
assert!(provider.is_ok(), "Should connect to SFTPGo PostgreSQL");
|
||||||
}
|
}
|
||||||
@@ -120,8 +133,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_pg_get_user_demo() {
|
fn test_pg_get_user_demo() {
|
||||||
let provider = PgProvider::new(
|
let provider = PgProvider::new(
|
||||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
let user = provider.get_user("demo").unwrap();
|
let user = provider.get_user("demo").unwrap();
|
||||||
assert!(user.is_some(), "Demo user should exist");
|
assert!(user.is_some(), "Demo user should exist");
|
||||||
assert_eq!(user.unwrap().username, "demo");
|
assert_eq!(user.unwrap().username, "demo");
|
||||||
@@ -130,8 +144,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_pg_get_user_momentry() {
|
fn test_pg_get_user_momentry() {
|
||||||
let provider = PgProvider::new(
|
let provider = PgProvider::new(
|
||||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
let user = provider.get_user("momentry").unwrap();
|
let user = provider.get_user("momentry").unwrap();
|
||||||
assert!(user.is_some(), "Momentry user should exist");
|
assert!(user.is_some(), "Momentry user should exist");
|
||||||
}
|
}
|
||||||
@@ -139,8 +154,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_pg_get_user_warren() {
|
fn test_pg_get_user_warren() {
|
||||||
let provider = PgProvider::new(
|
let provider = PgProvider::new(
|
||||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
let user = provider.get_user("warren").unwrap();
|
let user = provider.get_user("warren").unwrap();
|
||||||
assert!(user.is_some(), "Warren user should exist");
|
assert!(user.is_some(), "Warren user should exist");
|
||||||
}
|
}
|
||||||
@@ -148,8 +164,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_pg_check_password_demo() {
|
fn test_pg_check_password_demo() {
|
||||||
let provider = PgProvider::new(
|
let provider = PgProvider::new(
|
||||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
let valid = provider.check_password("demo", "demo123").unwrap();
|
let valid = provider.check_password("demo", "demo123").unwrap();
|
||||||
assert!(valid, "Password should be valid");
|
assert!(valid, "Password should be valid");
|
||||||
}
|
}
|
||||||
@@ -157,8 +174,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_pg_check_password_invalid() {
|
fn test_pg_check_password_invalid() {
|
||||||
let provider = PgProvider::new(
|
let provider = PgProvider::new(
|
||||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
let valid = provider.check_password("demo", "wrong").unwrap();
|
let valid = provider.check_password("demo", "wrong").unwrap();
|
||||||
assert!(!valid, "Wrong password should fail");
|
assert!(!valid, "Wrong password should fail");
|
||||||
}
|
}
|
||||||
@@ -166,8 +184,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_pg_get_home_dir() {
|
fn test_pg_get_home_dir() {
|
||||||
let provider = PgProvider::new(
|
let provider = PgProvider::new(
|
||||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
let dir = provider.get_home_dir("demo").unwrap();
|
let dir = provider.get_home_dir("demo").unwrap();
|
||||||
assert!(dir.is_some());
|
assert!(dir.is_some());
|
||||||
assert!(dir.unwrap().contains("momentry"));
|
assert!(dir.unwrap().contains("momentry"));
|
||||||
@@ -176,8 +195,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_pg_nonexistent_user() {
|
fn test_pg_nonexistent_user() {
|
||||||
let provider = PgProvider::new(
|
let provider = PgProvider::new(
|
||||||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026"
|
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
let user = provider.get_user("__nonexistent__").unwrap();
|
let user = provider.get_user("__nonexistent__").unwrap();
|
||||||
assert!(user.is_none());
|
assert!(user.is_none());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use std::path::PathBuf;
|
|
||||||
use rusqlite::{Connection, params};
|
|
||||||
use bcrypt::verify;
|
|
||||||
use super::{DataProvider, ProviderError, User};
|
use super::{DataProvider, ProviderError, User};
|
||||||
|
use bcrypt::verify;
|
||||||
|
use rusqlite::{params, Connection};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
/// SQLite 数据提供者
|
/// SQLite 数据提供者
|
||||||
pub struct SqliteProvider {
|
pub struct SqliteProvider {
|
||||||
@@ -13,7 +13,8 @@ impl SqliteProvider {
|
|||||||
let path = PathBuf::from(db_path);
|
let path = PathBuf::from(db_path);
|
||||||
if !path.exists() {
|
if !path.exists() {
|
||||||
return Err(ProviderError::NotFound(format!(
|
return Err(ProviderError::NotFound(format!(
|
||||||
"Database not found: {}", db_path
|
"Database not found: {}",
|
||||||
|
db_path
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
Ok(Self { db_path: path })
|
Ok(Self { db_path: path })
|
||||||
@@ -50,7 +51,8 @@ impl DataProvider for SqliteProvider {
|
|||||||
Ok(user) => Ok(Some(user)),
|
Ok(user) => Ok(Some(user)),
|
||||||
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
|
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
|
||||||
Err(e) => Err(ProviderError::Internal(format!(
|
Err(e) => Err(ProviderError::Internal(format!(
|
||||||
"Database query error: {}", e
|
"Database query error: {}",
|
||||||
|
e
|
||||||
))),
|
))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -66,7 +68,9 @@ impl DataProvider for SqliteProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
|
fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
|
||||||
Ok(self.get_user(username)?.map(|u| u.home_dir.to_string_lossy().to_string()))
|
Ok(self
|
||||||
|
.get_user(username)?
|
||||||
|
.map(|u| u.home_dir.to_string_lossy().to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_public_keys(&self, username: &str) -> Result<Vec<String>, ProviderError> {
|
fn get_public_keys(&self, username: &str) -> Result<Vec<String>, ProviderError> {
|
||||||
@@ -98,7 +102,10 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn get_test_db_path() -> String {
|
fn get_test_db_path() -> String {
|
||||||
format!("{}/../data/auth.sqlite", std::env::var("CARGO_MANIFEST_DIR").unwrap())
|
format!(
|
||||||
|
"{}/../data/auth.sqlite",
|
||||||
|
std::env::var("CARGO_MANIFEST_DIR").unwrap()
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use md5::compute;
|
use md5::compute;
|
||||||
|
|
||||||
pub struct RollingChecksum {
|
pub struct RollingChecksum {
|
||||||
|
|||||||
@@ -50,6 +50,12 @@ pub struct DecompressionStream {
|
|||||||
decompressor: Decompress,
|
decompressor: Decompress,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for DecompressionStream {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl DecompressionStream {
|
impl DecompressionStream {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
use crate::rsync::checksum::{compute_block_checksums, BlockChecksum};
|
|
||||||
use crate::rsync::compress::{CompressionStream, DecompressionStream};
|
|
||||||
use crate::rsync::delta::{DeltaAlgorithm, DeltaInstruction};
|
use crate::rsync::delta::{DeltaAlgorithm, DeltaInstruction};
|
||||||
use crate::rsync::protocol::{RsyncCommand, RsyncProtocol};
|
use crate::rsync::protocol::RsyncCommand;
|
||||||
use crate::rsync::RsyncConfig;
|
use crate::rsync::RsyncConfig;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|||||||
@@ -162,6 +162,12 @@ pub struct RsyncHandshake {
|
|||||||
negotiated_version: u32,
|
negotiated_version: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for RsyncHandshake {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl RsyncHandshake {
|
impl RsyncHandshake {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
use filetree::{FileTree, node::{FileNode, Aliases}};
|
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
extract::{Path, State},
|
extract::{Path, State},
|
||||||
http::{HeaderMap, StatusCode},
|
http::{HeaderMap, StatusCode},
|
||||||
response::{IntoResponse, Json},
|
response::{IntoResponse, Json},
|
||||||
};
|
};
|
||||||
|
use filetree::{
|
||||||
|
node::FileNode,
|
||||||
|
FileTree,
|
||||||
|
};
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
use tokio_util::io::ReaderStream;
|
use tokio_util::io::ReaderStream;
|
||||||
|
|
||||||
@@ -41,7 +43,7 @@ pub async fn list_buckets(State(state): State<crate::server::AppState>) -> impl
|
|||||||
|
|
||||||
pub async fn list_objects(
|
pub async fn list_objects(
|
||||||
Path(bucket): Path<String>,
|
Path(bucket): Path<String>,
|
||||||
State(state): State<crate::server::AppState>,
|
State(_state): State<crate::server::AppState>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
println!("S3 List Objects: bucket={}", bucket);
|
println!("S3 List Objects: bucket={}", bucket);
|
||||||
|
|
||||||
@@ -70,7 +72,7 @@ pub async fn list_objects(
|
|||||||
"Key": build_s3_key(&tree, n),
|
"Key": build_s3_key(&tree, n),
|
||||||
"LastModified": n.registered_at.clone().unwrap_or_default(),
|
"LastModified": n.registered_at.clone().unwrap_or_default(),
|
||||||
"ETag": n.sha256.clone().unwrap_or_default(),
|
"ETag": n.sha256.clone().unwrap_or_default(),
|
||||||
"Size": n.file_size.clone().unwrap_or(0),
|
"Size": n.file_size.unwrap_or(0),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
@@ -83,7 +85,7 @@ pub async fn list_objects(
|
|||||||
|
|
||||||
pub async fn get_object(
|
pub async fn get_object(
|
||||||
Path((bucket, key)): Path<(String, String)>,
|
Path((bucket, key)): Path<(String, String)>,
|
||||||
State(state): State<crate::server::AppState>,
|
State(_state): State<crate::server::AppState>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
println!("S3 GET Object: bucket={}, key={}", bucket, key);
|
println!("S3 GET Object: bucket={}, key={}", bucket, key);
|
||||||
@@ -119,7 +121,7 @@ pub async fn get_object(
|
|||||||
);
|
);
|
||||||
|
|
||||||
let file_uuid = node.file_uuid.clone().unwrap_or_default();
|
let file_uuid = node.file_uuid.clone().unwrap_or_default();
|
||||||
let file_size = node.file_size.clone().unwrap_or(0);
|
let file_size = node.file_size.unwrap_or(0);
|
||||||
let sha256 = node.sha256.clone().unwrap_or_default();
|
let sha256 = node.sha256.clone().unwrap_or_default();
|
||||||
|
|
||||||
let real_path = get_real_file_path(&conn, &file_uuid);
|
let real_path = get_real_file_path(&conn, &file_uuid);
|
||||||
@@ -233,7 +235,7 @@ pub async fn put_object(
|
|||||||
|
|
||||||
let sha256_hash_clone = sha256_hash.clone();
|
let sha256_hash_clone = sha256_hash.clone();
|
||||||
let file_path_clone = file_path.clone();
|
let file_path_clone = file_path.clone();
|
||||||
let label = key.split('/').last().unwrap_or(&key).to_string();
|
let label = key.split('/').next_back().unwrap_or(&key).to_string();
|
||||||
|
|
||||||
let result = tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
let result = tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||||
let conn = match FileTree::open_user_db(&bucket) {
|
let conn = match FileTree::open_user_db(&bucket) {
|
||||||
@@ -298,7 +300,7 @@ pub async fn put_object(
|
|||||||
|
|
||||||
pub async fn head_object(
|
pub async fn head_object(
|
||||||
Path((bucket, key)): Path<(String, String)>,
|
Path((bucket, key)): Path<(String, String)>,
|
||||||
State(state): State<crate::server::AppState>,
|
State(_state): State<crate::server::AppState>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let conn = match FileTree::open_user_db(&bucket) {
|
let conn = match FileTree::open_user_db(&bucket) {
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
@@ -323,7 +325,7 @@ pub async fn head_object(
|
|||||||
"ETag",
|
"ETag",
|
||||||
node.sha256.clone().unwrap_or_default().parse().unwrap(),
|
node.sha256.clone().unwrap_or_default().parse().unwrap(),
|
||||||
);
|
);
|
||||||
headers.insert("Content-Length", node.file_size.clone().unwrap_or(0).into());
|
headers.insert("Content-Length", node.file_size.unwrap_or(0).into());
|
||||||
|
|
||||||
(StatusCode::OK, headers)
|
(StatusCode::OK, headers)
|
||||||
}
|
}
|
||||||
@@ -438,7 +440,7 @@ fn find_node_by_s3_key(tree: &FileTree, key: &str) -> Option<FileNode> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 方法2:通过filename直接匹配(fallback)
|
// 方法2:通过filename直接匹配(fallback)
|
||||||
let filename = key.split('/').last().unwrap_or(key);
|
let filename = key.split('/').next_back().unwrap_or(key);
|
||||||
tree.nodes
|
tree.nodes
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|n| n.node_type == filetree::node::NodeType::File)
|
.filter(|n| n.node_type == filetree::node::NodeType::File)
|
||||||
@@ -501,7 +503,7 @@ async fn handle_range_request(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 使用take限制读取长度
|
// 使用take限制读取长度
|
||||||
let limited_file = file.take(content_length as u64);
|
let limited_file = file.take(content_length);
|
||||||
let stream = ReaderStream::new(limited_file);
|
let stream = ReaderStream::new(limited_file);
|
||||||
let body = Body::from_stream(stream);
|
let body = Body::from_stream(stream);
|
||||||
|
|
||||||
@@ -535,11 +537,7 @@ fn parse_range_header(range: &str, file_size: i64) -> Option<(u64, u64)> {
|
|||||||
let (start, end) = if parts[0].is_empty() {
|
let (start, end) = if parts[0].is_empty() {
|
||||||
// "bytes=-N"格式:最后N字节
|
// "bytes=-N"格式:最后N字节
|
||||||
let suffix_length = parts[1].parse::<u64>().ok()?;
|
let suffix_length = parts[1].parse::<u64>().ok()?;
|
||||||
let start = if suffix_length > file_size as u64 {
|
let start = (file_size as u64).saturating_sub(suffix_length);
|
||||||
0
|
|
||||||
} else {
|
|
||||||
file_size as u64 - suffix_length
|
|
||||||
};
|
|
||||||
(start, file_size as u64 - 1)
|
(start, file_size as u64 - 1)
|
||||||
} else if parts[1].is_empty() {
|
} else if parts[1].is_empty() {
|
||||||
// "bytes=N-"格式:从N到结尾
|
// "bytes=N-"格式:从N到结尾
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ fn calculate_signature(
|
|||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
method: &str,
|
method: &str,
|
||||||
path: &str,
|
path: &str,
|
||||||
access_key: &str,
|
_access_key: &str,
|
||||||
secret_key: &str,
|
secret_key: &str,
|
||||||
region: &str,
|
region: &str,
|
||||||
service: &str,
|
service: &str,
|
||||||
@@ -143,9 +143,9 @@ fn calculate_signature(
|
|||||||
let signing_key = calculate_signing_key(secret_key, date, region, service);
|
let signing_key = calculate_signing_key(secret_key, date, region, service);
|
||||||
|
|
||||||
// 4. Calculate Signature
|
// 4. Calculate Signature
|
||||||
let signature = hmac_sha256_hex(&signing_key, &string_to_sign);
|
|
||||||
|
|
||||||
signature
|
|
||||||
|
hmac_sha256_hex(&signing_key, &string_to_sign)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_canonical_request(headers: HeaderMap, method: &str, path: &str) -> String {
|
fn create_canonical_request(headers: HeaderMap, method: &str, path: &str) -> String {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ use std::fs;
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[derive(Default)]
|
||||||
pub struct S3Config {
|
pub struct S3Config {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub s3: S3Section,
|
pub s3: S3Section,
|
||||||
@@ -40,6 +41,7 @@ pub struct KeysSection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[derive(Default)]
|
||||||
pub struct BucketsSection {
|
pub struct BucketsSection {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub mappings: std::collections::HashMap<String, String>,
|
pub mappings: std::collections::HashMap<String, String>,
|
||||||
@@ -96,16 +98,6 @@ fn admin_permissions() -> Vec<String> {
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for S3Config {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
s3: S3Section::default(),
|
|
||||||
keys: KeysSection::default(),
|
|
||||||
buckets: BucketsSection::default(),
|
|
||||||
permissions: PermissionsSection::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for S3Section {
|
impl Default for S3Section {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
@@ -129,13 +121,6 @@ impl Default for KeysSection {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for BucketsSection {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
mappings: std::collections::HashMap::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for PermissionsSection {
|
impl Default for PermissionsSection {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
@@ -169,7 +154,7 @@ impl S3Config {
|
|||||||
Self::load("config/s3.toml")
|
Self::load("config/s3.toml")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn save(&self, path: &str) -> Result<()> {
|
pub fn save(&self, path: &str) -> Result<()> {
|
||||||
let config_path = PathBuf::from(path);
|
let config_path = PathBuf::from(path);
|
||||||
|
|
||||||
// Create backup before saving
|
// Create backup before saving
|
||||||
@@ -180,8 +165,8 @@ pub fn save(&self, path: &str) -> Result<()> {
|
|||||||
log::info!("S3 config backup created: {}", backup_path.display());
|
log::info!("S3 config backup created: {}", backup_path.display());
|
||||||
}
|
}
|
||||||
|
|
||||||
let content = toml::to_string_pretty(self)
|
let content =
|
||||||
.with_context(|| "Failed to serialize S3 config")?;
|
toml::to_string_pretty(self).with_context(|| "Failed to serialize S3 config")?;
|
||||||
|
|
||||||
std::fs::write(&config_path, content)
|
std::fs::write(&config_path, content)
|
||||||
.with_context(|| format!("Failed to write S3 config: {}", path))?;
|
.with_context(|| format!("Failed to write S3 config: {}", path))?;
|
||||||
@@ -255,8 +240,14 @@ pub fn save(&self, path: &str) -> Result<()> {
|
|||||||
|
|
||||||
// Validate permission format
|
// Validate permission format
|
||||||
let valid_permissions = [
|
let valid_permissions = [
|
||||||
"GetObject", "PutObject", "DeleteObject", "ListBucket",
|
"GetObject",
|
||||||
"HeadObject", "ListAllMyBuckets", "CreateBucket", "DeleteBucket"
|
"PutObject",
|
||||||
|
"DeleteObject",
|
||||||
|
"ListBucket",
|
||||||
|
"HeadObject",
|
||||||
|
"ListAllMyBuckets",
|
||||||
|
"CreateBucket",
|
||||||
|
"DeleteBucket",
|
||||||
];
|
];
|
||||||
|
|
||||||
for perm in &self.permissions.default_permissions {
|
for perm in &self.permissions.default_permissions {
|
||||||
@@ -294,9 +285,9 @@ pub fn save(&self, path: &str) -> Result<()> {
|
|||||||
"keys.default_secret_key" => Some(self.keys.default_secret_key.clone()),
|
"keys.default_secret_key" => Some(self.keys.default_secret_key.clone()),
|
||||||
"keys.keys_db_path" => Some(self.keys.keys_db_path.clone()),
|
"keys.keys_db_path" => Some(self.keys.keys_db_path.clone()),
|
||||||
|
|
||||||
"permissions.default_permissions" => {
|
"permissions.default_permissions" => Some(
|
||||||
Some(serde_json::to_string(&self.permissions.default_permissions).unwrap_or_default())
|
serde_json::to_string(&self.permissions.default_permissions).unwrap_or_default(),
|
||||||
}
|
),
|
||||||
"permissions.admin_permissions" => {
|
"permissions.admin_permissions" => {
|
||||||
Some(serde_json::to_string(&self.permissions.admin_permissions).unwrap_or_default())
|
Some(serde_json::to_string(&self.permissions.admin_permissions).unwrap_or_default())
|
||||||
}
|
}
|
||||||
@@ -394,7 +385,10 @@ mod tests {
|
|||||||
let mut config = S3Config::default();
|
let mut config = S3Config::default();
|
||||||
|
|
||||||
assert_eq!(config.get("s3.enabled"), Some("true".to_string()));
|
assert_eq!(config.get("s3.enabled"), Some("true".to_string()));
|
||||||
assert_eq!(config.get("s3.endpoint"), Some("http://localhost:11438/s3".to_string()));
|
assert_eq!(
|
||||||
|
config.get("s3.endpoint"),
|
||||||
|
Some("http://localhost:11438/s3".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
config.set("s3.require_auth", "true").unwrap();
|
config.set("s3.require_auth", "true").unwrap();
|
||||||
assert_eq!(config.s3.require_auth, true);
|
assert_eq!(config.s3.require_auth, true);
|
||||||
|
|||||||
@@ -7,10 +7,12 @@ pub fn list_buckets_xml(buckets: &[String]) -> (HeaderMap, String) {
|
|||||||
|
|
||||||
let bucket_entries = buckets
|
let bucket_entries = buckets
|
||||||
.iter()
|
.iter()
|
||||||
.map(|b| format!(
|
.map(|b| {
|
||||||
"<Bucket><Name>{}</Name><CreationDate>2026-05-27T00:00:00Z</CreationDate></Bucket>",
|
format!(
|
||||||
b
|
"<Bucket><Name>{}</Name><CreationDate>2026-05-27T00:00:00Z</CreationDate></Bucket>",
|
||||||
))
|
b
|
||||||
|
)
|
||||||
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join("\n ");
|
.join("\n ");
|
||||||
|
|
||||||
@@ -39,7 +41,10 @@ pub fn list_objects_xml(bucket_name: &str, objects: &[Value]) -> (HeaderMap, Str
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|obj| {
|
.map(|obj| {
|
||||||
let key = obj.get("Key").and_then(|k| k.as_str()).unwrap_or("");
|
let key = obj.get("Key").and_then(|k| k.as_str()).unwrap_or("");
|
||||||
let last_modified = obj.get("LastModified").and_then(|l| l.as_str()).unwrap_or("");
|
let last_modified = obj
|
||||||
|
.get("LastModified")
|
||||||
|
.and_then(|l| l.as_str())
|
||||||
|
.unwrap_or("");
|
||||||
let etag = obj.get("ETag").and_then(|e| e.as_str()).unwrap_or("");
|
let etag = obj.get("ETag").and_then(|e| e.as_str()).unwrap_or("");
|
||||||
let size = obj.get("Size").and_then(|s| s.as_i64()).unwrap_or(0);
|
let size = obj.get("Size").and_then(|s| s.as_i64()).unwrap_or(0);
|
||||||
|
|
||||||
|
|||||||
@@ -439,7 +439,7 @@ fn compute_hashes_parallel(
|
|||||||
|
|
||||||
let mut p = processed.lock().unwrap();
|
let mut p = processed.lock().unwrap();
|
||||||
*p += 1;
|
*p += 1;
|
||||||
if *p % 100 == 0 {
|
if (*p).is_multiple_of(100) {
|
||||||
print!("\r Hashed {}/{} files...", *p, total);
|
print!("\r Hashed {}/{} files...", *p, total);
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
std::io::stdout().flush().ok();
|
std::io::stdout().flush().ok();
|
||||||
|
|||||||
@@ -16,7 +16,9 @@ fn test_password_authentication_brute_force_prevention() {
|
|||||||
assert!(provider.check_password("demo", "demo123").unwrap());
|
assert!(provider.check_password("demo", "demo123").unwrap());
|
||||||
assert!(!provider.check_password("demo", "wrongpassword").unwrap());
|
assert!(!provider.check_password("demo", "wrongpassword").unwrap());
|
||||||
assert!(!provider.check_password("demo", "").unwrap());
|
assert!(!provider.check_password("demo", "").unwrap());
|
||||||
assert!(!provider.check_password("__nonexistent__", "anypassword").unwrap());
|
assert!(!provider
|
||||||
|
.check_password("__nonexistent__", "anypassword")
|
||||||
|
.unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::ssh_server::cipher::EncryptionContext;
|
use crate::ssh_server::cipher::EncryptionContext;
|
||||||
use crate::ssh_server::crypto::{SessionKeys, Curve25519Kex, Ed25519HostKey};
|
use crate::ssh_server::crypto::{Curve25519Kex, Ed25519HostKey, SessionKeys};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_ctr_encryption_decryption_consistency() {
|
fn test_aes_ctr_encryption_decryption_consistency() {
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
mod auth_security;
|
mod auth_security;
|
||||||
|
mod channel_security;
|
||||||
mod crypto_security;
|
mod crypto_security;
|
||||||
mod file_access_security;
|
mod file_access_security;
|
||||||
mod channel_security;
|
|
||||||
|
|
||||||
pub use auth_security::*;
|
pub use auth_security::*;
|
||||||
|
pub use channel_security::*;
|
||||||
pub use crypto_security::*;
|
pub use crypto_security::*;
|
||||||
pub use file_access_security::*;
|
pub use file_access_security::*;
|
||||||
pub use channel_security::*;
|
|
||||||
@@ -1,22 +1,23 @@
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use axum::{
|
use axum::{
|
||||||
|
extract::DefaultBodyLimit,
|
||||||
extract::{Path, Query, State},
|
extract::{Path, Query, State},
|
||||||
http::{HeaderMap, StatusCode},
|
http::{HeaderMap, StatusCode},
|
||||||
response::{Html, IntoResponse, Json},
|
response::{Html, IntoResponse, Json},
|
||||||
routing::{delete, get, patch, post, put},
|
routing::{delete, get, patch, post, put},
|
||||||
Router,
|
Router,
|
||||||
extract::DefaultBodyLimit,
|
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use crate::archive::{
|
||||||
|
ArchiveConfig, ArchiveFormat, ArchiveProcessor, FormatDetector, ProcessorRegistry,
|
||||||
|
};
|
||||||
use crate::audio;
|
use crate::audio;
|
||||||
use crate::auth::{AuthState, LoginRequest};
|
use crate::auth::{AuthState, LoginRequest};
|
||||||
use crate::provider::sqlite::SqliteProvider;
|
use crate::provider::sqlite::SqliteProvider;
|
||||||
use crate::render;
|
use crate::render;
|
||||||
use crate::download;
|
|
||||||
use crate::archive::{self, ArchiveFormat, ArchiveProcessor, FormatDetector, ArchiveConfig, ProcessorRegistry};
|
|
||||||
use filetree::{self, FileTree};
|
use filetree::{self, FileTree};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@@ -60,7 +61,7 @@ pub async fn run(port: u16, file: Option<String>) -> anyhow::Result<()> {
|
|||||||
db_dir: "data/users".to_string(),
|
db_dir: "data/users".to_string(),
|
||||||
auth: AuthState::with_provider(Box::new(
|
auth: AuthState::with_provider(Box::new(
|
||||||
SqliteProvider::new("data/auth.sqlite")
|
SqliteProvider::new("data/auth.sqlite")
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to init SqliteProvider: {}", e))?
|
.map_err(|e| anyhow::anyhow!("Failed to init SqliteProvider: {}", e))?,
|
||||||
)),
|
)),
|
||||||
auth_db_path: "data/auth.sqlite".to_string(),
|
auth_db_path: "data/auth.sqlite".to_string(),
|
||||||
s3_keys: Arc::new(Mutex::new(load_s3_keys())),
|
s3_keys: Arc::new(Mutex::new(load_s3_keys())),
|
||||||
@@ -578,7 +579,7 @@ async fn search_tree(
|
|||||||
ORDER BY sort_order ASC, created_at ASC",
|
ORDER BY sort_order ASC, created_at ASC",
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let nodes: Vec<filetree::node::FileNode> = stmt
|
let _nodes: Vec<filetree::node::FileNode> = stmt
|
||||||
.query_map([&search_pattern], |row| {
|
.query_map([&search_pattern], |row| {
|
||||||
let children_json: String = row.get(6)?;
|
let children_json: String = row.get(6)?;
|
||||||
let children: Vec<String> =
|
let children: Vec<String> =
|
||||||
@@ -607,7 +608,7 @@ async fn search_tree(
|
|||||||
.filter_map(|r| r.ok())
|
.filter_map(|r| r.ok())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let tree = filetree::FileTree {
|
let tree = filetree::FileTree {
|
||||||
user_id: user_id.clone(),
|
user_id: user_id.clone(),
|
||||||
tree_type: "untitled folder".to_string(),
|
tree_type: "untitled folder".to_string(),
|
||||||
nodes: vec![],
|
nodes: vec![],
|
||||||
@@ -914,8 +915,8 @@ fn extract_and_register_archive(
|
|||||||
user_id: &str,
|
user_id: &str,
|
||||||
original_filename: &str,
|
original_filename: &str,
|
||||||
) -> anyhow::Result<(u64, u64, String)> {
|
) -> anyhow::Result<(u64, u64, String)> {
|
||||||
use std::path::PathBuf;
|
use sha2::{Digest, Sha256};
|
||||||
use sha2::{Sha256, Digest};
|
|
||||||
|
|
||||||
// Initialize archive system
|
// Initialize archive system
|
||||||
let config = ArchiveConfig::default();
|
let config = ArchiveConfig::default();
|
||||||
@@ -926,7 +927,11 @@ fn extract_and_register_archive(
|
|||||||
let detector = FormatDetector::new();
|
let detector = FormatDetector::new();
|
||||||
let format = detector.detect(archive_path)?;
|
let format = detector.detect(archive_path)?;
|
||||||
|
|
||||||
eprintln!("[archive] Detected format: {} for file: {}", format, archive_path.display());
|
eprintln!(
|
||||||
|
"[archive] Detected format: {} for file: {}",
|
||||||
|
format,
|
||||||
|
archive_path.display()
|
||||||
|
);
|
||||||
|
|
||||||
// Get processor
|
// Get processor
|
||||||
let processor = registry.get_processor_mut(archive_path)?;
|
let processor = registry.get_processor_mut(archive_path)?;
|
||||||
@@ -937,7 +942,8 @@ fn extract_and_register_archive(
|
|||||||
.map(|(name, _)| name)
|
.map(|(name, _)| name)
|
||||||
.unwrap_or(original_filename);
|
.unwrap_or(original_filename);
|
||||||
|
|
||||||
let extraction_dir = archive_path.parent()
|
let extraction_dir = archive_path
|
||||||
|
.parent()
|
||||||
.unwrap_or(std::path::Path::new("."))
|
.unwrap_or(std::path::Path::new("."))
|
||||||
.join(format!("{}_extracted", base_name));
|
.join(format!("{}_extracted", base_name));
|
||||||
|
|
||||||
@@ -946,13 +952,17 @@ fn extract_and_register_archive(
|
|||||||
// Open and extract
|
// Open and extract
|
||||||
let metadata = processor.open(archive_path)?;
|
let metadata = processor.open(archive_path)?;
|
||||||
|
|
||||||
eprintln!("[archive] Archive metadata: {} files, {} bytes",
|
eprintln!(
|
||||||
metadata.total_files, metadata.total_size);
|
"[archive] Archive metadata: {} files, {} bytes",
|
||||||
|
metadata.total_files, metadata.total_size
|
||||||
|
);
|
||||||
|
|
||||||
let result = processor.extract_all(&extraction_dir)?;
|
let result = processor.extract_all(&extraction_dir)?;
|
||||||
|
|
||||||
eprintln!("[archive] Extracted {} files ({} bytes)",
|
eprintln!(
|
||||||
result.success_files, result.total_bytes);
|
"[archive] Extracted {} files ({} bytes)",
|
||||||
|
result.success_files, result.total_bytes
|
||||||
|
);
|
||||||
|
|
||||||
// Register extracted files to database
|
// Register extracted files to database
|
||||||
let conn = FileTree::init_user_db(user_id)?;
|
let conn = FileTree::init_user_db(user_id)?;
|
||||||
@@ -999,14 +1009,13 @@ fn extract_and_register_archive(
|
|||||||
let file_hash = format!("{:x}", Sha256::digest(&file_data));
|
let file_hash = format!("{:x}", Sha256::digest(&file_data));
|
||||||
let file_size = file_data.len() as i64;
|
let file_size = file_data.len() as i64;
|
||||||
|
|
||||||
let filename = path.file_name()
|
let filename = path
|
||||||
|
.file_name()
|
||||||
.and_then(|n| n.to_str())
|
.and_then(|n| n.to_str())
|
||||||
.unwrap_or("unknown")
|
.unwrap_or("unknown")
|
||||||
.to_string();
|
.to_string();
|
||||||
|
|
||||||
let file_path_str = path.to_str()
|
let file_path_str = path.to_str().unwrap_or("unknown").to_string();
|
||||||
.unwrap_or("unknown")
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
// Generate file UUID
|
// Generate file UUID
|
||||||
let mtime = std::fs::metadata(&path)
|
let mtime = std::fs::metadata(&path)
|
||||||
@@ -1054,9 +1063,16 @@ fn extract_and_register_archive(
|
|||||||
|
|
||||||
registered_count = scan_directory(&extraction_dir, &conn, user_id, mac, now)?;
|
registered_count = scan_directory(&extraction_dir, &conn, user_id, mac, now)?;
|
||||||
|
|
||||||
eprintln!("[archive] Registered {} extracted files to database", registered_count);
|
eprintln!(
|
||||||
|
"[archive] Registered {} extracted files to database",
|
||||||
|
registered_count
|
||||||
|
);
|
||||||
|
|
||||||
Ok((result.success_files, result.total_bytes, extraction_dir.to_str().unwrap_or("unknown").to_string()))
|
Ok((
|
||||||
|
result.success_files,
|
||||||
|
result.total_bytes,
|
||||||
|
extraction_dir.to_str().unwrap_or("unknown").to_string(),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn upload_file(
|
async fn upload_file(
|
||||||
@@ -1150,19 +1166,19 @@ async fn upload_file(
|
|||||||
|
|
||||||
if let Ok(format) = detector.detect(&file_path_buf) {
|
if let Ok(format) = detector.detect(&file_path_buf) {
|
||||||
if format != ArchiveFormat::Unknown {
|
if format != ArchiveFormat::Unknown {
|
||||||
eprintln!("[upload] Detected archive format: {}, extracting...", format);
|
eprintln!(
|
||||||
|
"[upload] Detected archive format: {}, extracting...",
|
||||||
|
format
|
||||||
|
);
|
||||||
|
|
||||||
let user_id_clone = user_id.clone();
|
let user_id_clone = user_id.clone();
|
||||||
let filename_clone = filename.clone();
|
let filename_clone = filename.clone();
|
||||||
|
|
||||||
// Extract in blocking thread
|
// Extract in blocking thread
|
||||||
let extraction_result = tokio::task::spawn_blocking(move || {
|
let extraction_result = tokio::task::spawn_blocking(move || {
|
||||||
extract_and_register_archive(
|
extract_and_register_archive(&file_path_buf, &user_id_clone, &filename_clone)
|
||||||
&file_path_buf,
|
})
|
||||||
&user_id_clone,
|
.await;
|
||||||
&filename_clone,
|
|
||||||
)
|
|
||||||
}).await;
|
|
||||||
|
|
||||||
match extraction_result {
|
match extraction_result {
|
||||||
Ok(Ok((count, bytes, extract_dir))) => {
|
Ok(Ok((count, bytes, extract_dir))) => {
|
||||||
@@ -1208,7 +1224,7 @@ async fn upload_file(
|
|||||||
let hex = format!("{:x}", hash);
|
let hex = format!("{:x}", hash);
|
||||||
let file_uuid = hex[0..32].to_string();
|
let file_uuid = hex[0..32].to_string();
|
||||||
|
|
||||||
// Save to database (user-specific SQLite)
|
// Save to database (user-specific SQLite)
|
||||||
let file_uuid_clone = file_uuid.clone();
|
let file_uuid_clone = file_uuid.clone();
|
||||||
let file_hash_clone = file_hash.clone();
|
let file_hash_clone = file_hash.clone();
|
||||||
let filename_clone = filename.clone();
|
let filename_clone = filename.clone();
|
||||||
@@ -1290,11 +1306,7 @@ async fn upload_file(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
(
|
(StatusCode::CREATED, Json(response)).into_response()
|
||||||
StatusCode::CREATED,
|
|
||||||
Json(response),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn upload_unlimited(
|
async fn upload_unlimited(
|
||||||
@@ -1798,7 +1810,7 @@ async fn logout_handler(State(state): State<AppState>, headers: HeaderMap) -> im
|
|||||||
let auth_header = headers
|
let auth_header = headers
|
||||||
.get("Authorization")
|
.get("Authorization")
|
||||||
.and_then(|h| h.to_str().ok())
|
.and_then(|h| h.to_str().ok())
|
||||||
.and_then(|h| crate::auth::parse_auth_header(h));
|
.and_then(crate::auth::parse_auth_header);
|
||||||
|
|
||||||
match auth_header {
|
match auth_header {
|
||||||
Some(token) => {
|
Some(token) => {
|
||||||
@@ -1824,7 +1836,7 @@ async fn verify_handler(State(state): State<AppState>, headers: HeaderMap) -> im
|
|||||||
let auth_header = headers
|
let auth_header = headers
|
||||||
.get("Authorization")
|
.get("Authorization")
|
||||||
.and_then(|h| h.to_str().ok())
|
.and_then(|h| h.to_str().ok())
|
||||||
.and_then(|h| crate::auth::parse_auth_header(h));
|
.and_then(crate::auth::parse_auth_header);
|
||||||
|
|
||||||
match auth_header {
|
match auth_header {
|
||||||
Some(token) => match state.auth.verify_token(&token) {
|
Some(token) => match state.auth.verify_token(&token) {
|
||||||
@@ -1857,7 +1869,7 @@ fn verify_auth(state: &AppState, headers: &HeaderMap) -> Result<String, StatusCo
|
|||||||
let auth_header = headers
|
let auth_header = headers
|
||||||
.get("Authorization")
|
.get("Authorization")
|
||||||
.and_then(|h| h.to_str().ok())
|
.and_then(|h| h.to_str().ok())
|
||||||
.and_then(|h| crate::auth::parse_auth_header(h));
|
.and_then(crate::auth::parse_auth_header);
|
||||||
|
|
||||||
match auth_header {
|
match auth_header {
|
||||||
Some(token) => match state.auth.verify_token(&token) {
|
Some(token) => match state.auth.verify_token(&token) {
|
||||||
@@ -2343,7 +2355,7 @@ async fn audit_handler() -> Json<serde_json::Value> {
|
|||||||
// Category View API handlers (Phase 1: 双视图管理)
|
// Category View API handlers (Phase 1: 双视图管理)
|
||||||
|
|
||||||
async fn get_all_categories_handler() -> impl IntoResponse {
|
async fn get_all_categories_handler() -> impl IntoResponse {
|
||||||
let base_path = std::path::Path::new("/Users/accusys/markbase");
|
let _base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||||
match crate::category_view::get_all_categories() {
|
match crate::category_view::get_all_categories() {
|
||||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||||
Err(e) => (
|
Err(e) => (
|
||||||
@@ -2354,10 +2366,8 @@ async fn get_all_categories_handler() -> impl IntoResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_category_detail_handler(
|
async fn get_category_detail_handler(Path(category_name): Path<String>) -> impl IntoResponse {
|
||||||
Path(category_name): Path<String>,
|
let _base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||||
) -> impl IntoResponse {
|
|
||||||
let base_path = std::path::Path::new("/Users/accusys/markbase");
|
|
||||||
match crate::category_view::get_category_detail(&category_name) {
|
match crate::category_view::get_category_detail(&category_name) {
|
||||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||||
Err(e) => (
|
Err(e) => (
|
||||||
@@ -2369,7 +2379,7 @@ async fn get_category_detail_handler(
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn get_all_series_handler() -> impl IntoResponse {
|
async fn get_all_series_handler() -> impl IntoResponse {
|
||||||
let base_path = std::path::Path::new("/Users/accusys/markbase");
|
let _base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||||
match crate::category_view::get_all_series() {
|
match crate::category_view::get_all_series() {
|
||||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||||
Err(e) => (
|
Err(e) => (
|
||||||
@@ -2380,10 +2390,8 @@ async fn get_all_series_handler() -> impl IntoResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_series_detail_handler(
|
async fn get_series_detail_handler(Path(series_name): Path<String>) -> impl IntoResponse {
|
||||||
Path(series_name): Path<String>,
|
let _base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||||
) -> impl IntoResponse {
|
|
||||||
let base_path = std::path::Path::new("/Users/accusys/markbase");
|
|
||||||
match crate::category_view::get_series_detail(&series_name) {
|
match crate::category_view::get_series_detail(&series_name) {
|
||||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||||
Err(e) => (
|
Err(e) => (
|
||||||
@@ -2400,10 +2408,8 @@ struct SearchQuery {
|
|||||||
view: String,
|
view: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn search_files_handler(
|
async fn search_files_handler(Query(query): Query<SearchQuery>) -> impl IntoResponse {
|
||||||
Query(query): Query<SearchQuery>,
|
let _base_path = std::path::Path::new("/Users/accusys/markbase");
|
||||||
) -> impl IntoResponse {
|
|
||||||
let base_path = std::path::Path::new("/Users/accusys/markbase");
|
|
||||||
match crate::category_view::search_files(&query.q, &query.view) {
|
match crate::category_view::search_files(&query.q, &query.view) {
|
||||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||||
Err(e) => (
|
Err(e) => (
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||||
use std::io::Write;
|
use anyhow::{anyhow, Result};
|
||||||
use anyhow::{Result, anyhow};
|
use base64::{engine::general_purpose, Engine as _};
|
||||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||||
use log::{info, warn, debug};
|
use log::{debug, info, warn};
|
||||||
use base64::{Engine as _, engine::general_purpose};
|
use std::io::Write;
|
||||||
|
|
||||||
use ed25519_dalek::{VerifyingKey, Signature};
|
use ed25519_dalek::{Signature, VerifyingKey};
|
||||||
|
|
||||||
use crate::provider::{DataProvider, ProviderError};
|
use crate::provider::{DataProvider, ProviderError};
|
||||||
|
|
||||||
@@ -27,7 +27,11 @@ impl AuthHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// 处理SSH_MSG_USERAUTH_REQUEST(参考OpenSSH auth2.c: userauth_request())
|
/// 处理SSH_MSG_USERAUTH_REQUEST(参考OpenSSH auth2.c: userauth_request())
|
||||||
pub fn handle_userauth_request(&mut self, packet: &SshPacket, session_id: &[u8]) -> Result<AuthResult> {
|
pub fn handle_userauth_request(
|
||||||
|
&mut self,
|
||||||
|
packet: &SshPacket,
|
||||||
|
session_id: &[u8],
|
||||||
|
) -> Result<AuthResult> {
|
||||||
info!("Processing SSH_MSG_USERAUTH_REQUEST");
|
info!("Processing SSH_MSG_USERAUTH_REQUEST");
|
||||||
|
|
||||||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
|
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
|
||||||
@@ -41,7 +45,10 @@ impl AuthHandler {
|
|||||||
let service = read_ssh_string(&mut cursor)?;
|
let service = read_ssh_string(&mut cursor)?;
|
||||||
let method = read_ssh_string(&mut cursor)?;
|
let method = read_ssh_string(&mut cursor)?;
|
||||||
|
|
||||||
info!("Auth request: user={}, service={}, method={}", user, service, method);
|
info!(
|
||||||
|
"Auth request: user={}, service={}, method={}",
|
||||||
|
user, service, method
|
||||||
|
);
|
||||||
|
|
||||||
if service != "ssh-connection" {
|
if service != "ssh-connection" {
|
||||||
warn!("Unsupported service: {}", service);
|
warn!("Unsupported service: {}", service);
|
||||||
@@ -62,18 +69,28 @@ impl AuthHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// 处理password认证(参考OpenSSH auth-passwd.c)
|
/// 处理password认证(参考OpenSSH auth-passwd.c)
|
||||||
fn handle_password_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result<AuthResult> {
|
fn handle_password_auth(
|
||||||
|
&mut self,
|
||||||
|
cursor: &mut std::io::Cursor<&[u8]>,
|
||||||
|
user: &str,
|
||||||
|
) -> Result<AuthResult> {
|
||||||
info!("Handling password auth for user: {}", user);
|
info!("Handling password auth for user: {}", user);
|
||||||
|
|
||||||
let change_password = cursor.read_u8()? != 0;
|
let change_password = cursor.read_u8()? != 0;
|
||||||
if change_password {
|
if change_password {
|
||||||
warn!("Password change not supported");
|
warn!("Password change not supported");
|
||||||
return Ok(AuthResult::Failure("Password change not supported".to_string()));
|
return Ok(AuthResult::Failure(
|
||||||
|
"Password change not supported".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let password = read_ssh_string(cursor)?;
|
let password = read_ssh_string(cursor)?;
|
||||||
|
|
||||||
debug!("Password auth attempt: user={}, password length={}", user, password.len());
|
debug!(
|
||||||
|
"Password auth attempt: user={}, password length={}",
|
||||||
|
user,
|
||||||
|
password.len()
|
||||||
|
);
|
||||||
|
|
||||||
match self.provider.check_password(user, &password) {
|
match self.provider.check_password(user, &password) {
|
||||||
Ok(true) => {
|
Ok(true) => {
|
||||||
@@ -88,9 +105,7 @@ impl AuthHandler {
|
|||||||
warn!("User not found: {}", msg);
|
warn!("User not found: {}", msg);
|
||||||
Ok(AuthResult::Failure("password,publickey".to_string()))
|
Ok(AuthResult::Failure("password,publickey".to_string()))
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => Err(anyhow!("Password auth error: {}", e)),
|
||||||
Err(anyhow!("Password auth error: {}", e))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -145,7 +160,12 @@ impl AuthHandler {
|
|||||||
let algorithm = read_ssh_string(cursor)?;
|
let algorithm = read_ssh_string(cursor)?;
|
||||||
let public_key_blob = read_ssh_string_bytes(cursor)?;
|
let public_key_blob = read_ssh_string_bytes(cursor)?;
|
||||||
|
|
||||||
info!("Publickey auth: algorithm={}, blob_len={}, is_signed={}", algorithm, public_key_blob.len(), is_signed);
|
info!(
|
||||||
|
"Publickey auth: algorithm={}, blob_len={}, is_signed={}",
|
||||||
|
algorithm,
|
||||||
|
public_key_blob.len(),
|
||||||
|
is_signed
|
||||||
|
);
|
||||||
|
|
||||||
if !self.is_key_authorized(user, &algorithm, &public_key_blob)? {
|
if !self.is_key_authorized(user, &algorithm, &public_key_blob)? {
|
||||||
warn!("Public key not authorized for user: {}", user);
|
warn!("Public key not authorized for user: {}", user);
|
||||||
@@ -160,14 +180,26 @@ impl AuthHandler {
|
|||||||
|
|
||||||
let signature_blob = read_ssh_string_bytes(cursor)?;
|
let signature_blob = read_ssh_string_bytes(cursor)?;
|
||||||
|
|
||||||
self.verify_signature(&algorithm, &public_key_blob, &signature_blob, user, service, session_id)?;
|
self.verify_signature(
|
||||||
|
&algorithm,
|
||||||
|
&public_key_blob,
|
||||||
|
&signature_blob,
|
||||||
|
user,
|
||||||
|
service,
|
||||||
|
session_id,
|
||||||
|
)?;
|
||||||
|
|
||||||
info!("Publickey auth successful for user: {}", user);
|
info!("Publickey auth successful for user: {}", user);
|
||||||
Ok(AuthResult::Success)
|
Ok(AuthResult::Success)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 检查public key是否在授权列表中(数据库优先,fallback到filesystem)
|
/// 检查public key是否在授权列表中(数据库优先,fallback到filesystem)
|
||||||
fn is_key_authorized(&self, user: &str, algorithm: &str, public_key_blob: &[u8]) -> Result<bool> {
|
fn is_key_authorized(
|
||||||
|
&self,
|
||||||
|
user: &str,
|
||||||
|
algorithm: &str,
|
||||||
|
public_key_blob: &[u8],
|
||||||
|
) -> Result<bool> {
|
||||||
// 1. 先检查数据库
|
// 1. 先检查数据库
|
||||||
match self.provider.get_public_keys(user) {
|
match self.provider.get_public_keys(user) {
|
||||||
Ok(keys) => {
|
Ok(keys) => {
|
||||||
@@ -187,10 +219,12 @@ impl AuthHandler {
|
|||||||
Err(_) => match std::fs::read_to_string("data/authorized_keys") {
|
Err(_) => match std::fs::read_to_string("data/authorized_keys") {
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(_) => return Ok(false),
|
Err(_) => return Ok(false),
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(content.lines().any(|line| public_key_matches_line(line, algorithm, public_key_blob)))
|
Ok(content
|
||||||
|
.lines()
|
||||||
|
.any(|line| public_key_matches_line(line, algorithm, public_key_blob)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 验证Ed25519签名(RFC 4252 §7)
|
/// 验证Ed25519签名(RFC 4252 §7)
|
||||||
@@ -246,7 +280,8 @@ impl AuthHandler {
|
|||||||
signed_data.write_all(public_key_blob)?;
|
signed_data.write_all(public_key_blob)?;
|
||||||
|
|
||||||
// 验证签名
|
// 验证签名
|
||||||
verifying_key.verify_strict(&signed_data, &signature)
|
verifying_key
|
||||||
|
.verify_strict(&signed_data, &signature)
|
||||||
.map_err(|e| anyhow!("Ed25519 signature verification failed: {}", e))
|
.map_err(|e| anyhow!("Ed25519 signature verification failed: {}", e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -270,10 +305,10 @@ fn parse_ed25519_verifying_key(public_key_blob: &[u8]) -> Result<VerifyingKey> {
|
|||||||
if key_bytes.len() != 32 {
|
if key_bytes.len() != 32 {
|
||||||
return Err(anyhow!("Invalid Ed25519 key length: {}", key_bytes.len()));
|
return Err(anyhow!("Invalid Ed25519 key length: {}", key_bytes.len()));
|
||||||
}
|
}
|
||||||
let key_array: [u8; 32] = key_bytes.try_into()
|
let key_array: [u8; 32] = key_bytes
|
||||||
|
.try_into()
|
||||||
.map_err(|_| anyhow!("Invalid Ed25519 key data"))?;
|
.map_err(|_| anyhow!("Invalid Ed25519 key data"))?;
|
||||||
VerifyingKey::from_bytes(&key_array)
|
VerifyingKey::from_bytes(&key_array).map_err(|e| anyhow!("Invalid Ed25519 key: {}", e))
|
||||||
.map_err(|e| anyhow!("Invalid Ed25519 key: {}", e))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 解析Ed25519签名blob(SSH格式 -> Signature)
|
/// 解析Ed25519签名blob(SSH格式 -> Signature)
|
||||||
@@ -285,9 +320,13 @@ fn parse_ed25519_signature(signature_blob: &[u8]) -> Result<Signature> {
|
|||||||
}
|
}
|
||||||
let sig_bytes = read_ssh_string_bytes(&mut cursor)?;
|
let sig_bytes = read_ssh_string_bytes(&mut cursor)?;
|
||||||
if sig_bytes.len() != 64 {
|
if sig_bytes.len() != 64 {
|
||||||
return Err(anyhow!("Invalid Ed25519 signature length: {}", sig_bytes.len()));
|
return Err(anyhow!(
|
||||||
|
"Invalid Ed25519 signature length: {}",
|
||||||
|
sig_bytes.len()
|
||||||
|
));
|
||||||
}
|
}
|
||||||
let sig_array: [u8; 64] = sig_bytes.try_into()
|
let sig_array: [u8; 64] = sig_bytes
|
||||||
|
.try_into()
|
||||||
.map_err(|_| anyhow!("Invalid Ed25519 signature data"))?;
|
.map_err(|_| anyhow!("Invalid Ed25519 signature data"))?;
|
||||||
Ok(Signature::from_bytes(&sig_array))
|
Ok(Signature::from_bytes(&sig_array))
|
||||||
}
|
}
|
||||||
@@ -305,7 +344,9 @@ fn public_key_matches_line(line: &str, algorithm: &str, public_key_blob: &[u8])
|
|||||||
if parts[0] != algorithm {
|
if parts[0] != algorithm {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
base64_decode(parts[1]).map(|decoded| decoded == public_key_blob).unwrap_or(false)
|
base64_decode(parts[1])
|
||||||
|
.map(|decoded| decoded == public_key_blob)
|
||||||
|
.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
|
fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
|
||||||
@@ -323,7 +364,8 @@ fn read_ssh_string_bytes<R: std::io::Read>(reader: &mut R) -> Result<Vec<u8>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn base64_decode(input: &str) -> Result<Vec<u8>> {
|
fn base64_decode(input: &str) -> Result<Vec<u8>> {
|
||||||
general_purpose::STANDARD.decode(input)
|
general_purpose::STANDARD
|
||||||
|
.decode(input)
|
||||||
.map_err(|e| anyhow!("Base64 decode error: {}", e))
|
.map_err(|e| anyhow!("Base64 decode error: {}", e))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,7 +377,10 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_userauth_success_packet() {
|
fn test_userauth_success_packet() {
|
||||||
let packet = AuthHandler::build_userauth_success().unwrap();
|
let packet = AuthHandler::build_userauth_success().unwrap();
|
||||||
assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_SUCCESS as u8);
|
assert_eq!(
|
||||||
|
packet.payload[0],
|
||||||
|
PacketType::SSH_MSG_USERAUTH_SUCCESS as u8
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -343,6 +388,9 @@ mod tests {
|
|||||||
let methods = vec!["password".to_string(), "publickey".to_string()];
|
let methods = vec!["password".to_string(), "publickey".to_string()];
|
||||||
let packet = AuthHandler::build_userauth_failure(&methods, false).unwrap();
|
let packet = AuthHandler::build_userauth_failure(&methods, false).unwrap();
|
||||||
|
|
||||||
assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_FAILURE as u8);
|
assert_eq!(
|
||||||
|
packet.payload[0],
|
||||||
|
PacketType::SSH_MSG_USERAUTH_FAILURE as u8
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,33 +1,33 @@
|
|||||||
// SSH加密通道实现(Phase 4)
|
// SSH加密通道实现(Phase 4)
|
||||||
// 参考OpenSSH cipher.c, mac.c
|
// 参考OpenSSH cipher.c, mac.c
|
||||||
|
|
||||||
use aes::Aes128; // 改为AES-128(协商算法是aes128-ctr)
|
use super::crypto::SessionKeys;
|
||||||
|
use aes::Aes128; // 改为AES-128(协商算法是aes128-ctr)
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use byteorder::{BigEndian, WriteBytesExt};
|
||||||
|
use cipher::{KeyIvInit, StreamCipher};
|
||||||
use ctr::Ctr128BE;
|
use ctr::Ctr128BE;
|
||||||
use hmac::{Hmac, Mac};
|
use hmac::{Hmac, Mac};
|
||||||
|
use log::info;
|
||||||
use sha2::Sha256;
|
use sha2::Sha256;
|
||||||
use cipher::{KeyIvInit, StreamCipher};
|
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use anyhow::{Result, anyhow};
|
|
||||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
|
||||||
use log::{info, debug, warn};
|
|
||||||
use super::crypto::SessionKeys;
|
|
||||||
|
|
||||||
type Aes128Ctr = Ctr128BE<Aes128>; // AES-128-CTR(16字节密钥)
|
type Aes128Ctr = Ctr128BE<Aes128>; // AES-128-CTR(16字节密钥)
|
||||||
type HmacSha256 = Hmac<Sha256>;
|
type HmacSha256 = Hmac<Sha256>;
|
||||||
|
|
||||||
/// SSH加密通道管理器(参考OpenSSH struct sshcipher_ctx)
|
/// SSH加密通道管理器(参考OpenSSH struct sshcipher_ctx)
|
||||||
pub struct EncryptionContext {
|
pub struct EncryptionContext {
|
||||||
pub session_id: Vec<u8>, // session identifier (exchange hash)
|
pub session_id: Vec<u8>, // session identifier (exchange hash)
|
||||||
pub encryption_key_ctos: Vec<u8>, // 客户端→服务器加密密钥
|
pub encryption_key_ctos: Vec<u8>, // 客户端→服务器加密密钥
|
||||||
pub encryption_key_stoc: Vec<u8>, // 服务器→客户端加密密钥
|
pub encryption_key_stoc: Vec<u8>, // 服务器→客户端加密密钥
|
||||||
pub mac_key_ctos: Vec<u8>, // 客户端→服务器MAC密钥
|
pub mac_key_ctos: Vec<u8>, // 客户端→服务器MAC密钥
|
||||||
pub mac_key_stoc: Vec<u8>, // 服务器→客户端MAC密钥
|
pub mac_key_stoc: Vec<u8>, // 服务器→客户端MAC密钥
|
||||||
pub iv_ctos: Vec<u8>, // 客户端→服务器IV
|
pub iv_ctos: Vec<u8>, // 客户端→服务器IV
|
||||||
pub iv_stoc: Vec<u8>, // 服务器→客户端IV
|
pub iv_stoc: Vec<u8>, // 服务器→客户端IV
|
||||||
pub sequence_number_ctos: u32, // 客户端→服务器序列号
|
pub sequence_number_ctos: u32, // 客户端→服务器序列号
|
||||||
pub sequence_number_stoc: u32, // 服务器→客户端序列号
|
pub sequence_number_stoc: u32, // 服务器→客户端序列号
|
||||||
pub cipher_ctos: Option<Aes128Ctr>, // 客户端→服务器cipher实例(持久化)
|
pub cipher_ctos: Option<Aes128Ctr>, // 客户端→服务器cipher实例(持久化)
|
||||||
pub cipher_stoc: Option<Aes128Ctr>, // 服务器→客户端cipher实例(持久化)
|
pub cipher_stoc: Option<Aes128Ctr>, // 服务器→客户端cipher实例(持久化)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for EncryptionContext {
|
impl Default for EncryptionContext {
|
||||||
@@ -53,23 +53,29 @@ impl EncryptionContext {
|
|||||||
/// OpenSSH cipher.c: cipher初始化后状态持久化,counter跨packet递增
|
/// OpenSSH cipher.c: cipher初始化后状态持久化,counter跨packet递增
|
||||||
pub fn from_session_keys(keys: &SessionKeys) -> Self {
|
pub fn from_session_keys(keys: &SessionKeys) -> Self {
|
||||||
info!("Initializing ciphers with session keys:");
|
info!("Initializing ciphers with session keys:");
|
||||||
info!(" encryption_key_ctos (16 bytes): {:?}", &keys.encryption_key_ctos[..16]);
|
info!(
|
||||||
|
" encryption_key_ctos (16 bytes): {:?}",
|
||||||
|
&keys.encryption_key_ctos[..16]
|
||||||
|
);
|
||||||
info!(" iv_ctos (16 bytes): {:?}", &keys.iv_ctos[..16]);
|
info!(" iv_ctos (16 bytes): {:?}", &keys.iv_ctos[..16]);
|
||||||
info!(" encryption_key_stoc (16 bytes): {:?}", &keys.encryption_key_stoc[..16]);
|
info!(
|
||||||
|
" encryption_key_stoc (16 bytes): {:?}",
|
||||||
|
&keys.encryption_key_stoc[..16]
|
||||||
|
);
|
||||||
info!(" iv_stoc (16 bytes): {:?}", &keys.iv_stoc[..16]);
|
info!(" iv_stoc (16 bytes): {:?}", &keys.iv_stoc[..16]);
|
||||||
|
|
||||||
// 初始化客户端→服务器cipher(用于解密client packets)
|
// 初始化客户端→服务器cipher(用于解密client packets)
|
||||||
let key_ctos_array = <[u8; 16]>::try_from(&keys.encryption_key_ctos[..16])
|
let key_ctos_array = <[u8; 16]>::try_from(&keys.encryption_key_ctos[..16])
|
||||||
.expect("encryption_key_ctos must be 16 bytes");
|
.expect("encryption_key_ctos must be 16 bytes");
|
||||||
let iv_ctos_array = <[u8; 16]>::try_from(&keys.iv_ctos[..16])
|
let iv_ctos_array =
|
||||||
.expect("iv_ctos must be 16 bytes");
|
<[u8; 16]>::try_from(&keys.iv_ctos[..16]).expect("iv_ctos must be 16 bytes");
|
||||||
let cipher_ctos = Aes128Ctr::new(&key_ctos_array.into(), &iv_ctos_array.into());
|
let cipher_ctos = Aes128Ctr::new(&key_ctos_array.into(), &iv_ctos_array.into());
|
||||||
|
|
||||||
// 初始化服务器→客户端cipher(用于加密server packets)
|
// 初始化服务器→客户端cipher(用于加密server packets)
|
||||||
let key_stoc_array = <[u8; 16]>::try_from(&keys.encryption_key_stoc[..16])
|
let key_stoc_array = <[u8; 16]>::try_from(&keys.encryption_key_stoc[..16])
|
||||||
.expect("encryption_key_stoc must be 16 bytes");
|
.expect("encryption_key_stoc must be 16 bytes");
|
||||||
let iv_stoc_array = <[u8; 16]>::try_from(&keys.iv_stoc[..16])
|
let iv_stoc_array =
|
||||||
.expect("iv_stoc must be 16 bytes");
|
<[u8; 16]>::try_from(&keys.iv_stoc[..16]).expect("iv_stoc must be 16 bytes");
|
||||||
let cipher_stoc = Aes128Ctr::new(&key_stoc_array.into(), &iv_stoc_array.into());
|
let cipher_stoc = Aes128Ctr::new(&key_stoc_array.into(), &iv_stoc_array.into());
|
||||||
|
|
||||||
info!("Ciphers initialized successfully");
|
info!("Ciphers initialized successfully");
|
||||||
@@ -84,8 +90,8 @@ impl EncryptionContext {
|
|||||||
iv_stoc: keys.iv_stoc.clone(),
|
iv_stoc: keys.iv_stoc.clone(),
|
||||||
sequence_number_ctos: 0,
|
sequence_number_ctos: 0,
|
||||||
sequence_number_stoc: 0,
|
sequence_number_stoc: 0,
|
||||||
cipher_ctos: Some(cipher_ctos), // 持久化cipher实例
|
cipher_ctos: Some(cipher_ctos), // 持久化cipher实例
|
||||||
cipher_stoc: Some(cipher_stoc), // 持久化cipher实例
|
cipher_stoc: Some(cipher_stoc), // 持久化cipher实例
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,11 +193,11 @@ impl EncryptionContext {
|
|||||||
|
|
||||||
/// SSH加密packet封装(参考OpenSSH packet.c: ssh_packet_write_poll())
|
/// SSH加密packet封装(参考OpenSSH packet.c: ssh_packet_write_poll())
|
||||||
pub struct EncryptedPacket {
|
pub struct EncryptedPacket {
|
||||||
pub packet_length: u32, // 加密后packet长度
|
pub packet_length: u32, // 加密后packet长度
|
||||||
pub padding_length: u8, // padding长度(加密后)
|
pub padding_length: u8, // padding长度(加密后)
|
||||||
pub payload: Vec<u8>, // payload(加密后)
|
pub payload: Vec<u8>, // payload(加密后)
|
||||||
pub padding: Vec<u8>, // padding(加密后)
|
pub padding: Vec<u8>, // padding(加密后)
|
||||||
pub mac: Vec<u8>, // MAC(32字节,HMAC-SHA256)
|
pub mac: Vec<u8>, // MAC(32字节,HMAC-SHA256)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EncryptedPacket {
|
impl EncryptedPacket {
|
||||||
@@ -211,12 +217,12 @@ impl EncryptedPacket {
|
|||||||
// plaintext_packet = packet_length_field(4) + padding_length(1) + payload + padding
|
// plaintext_packet = packet_length_field(4) + padding_length(1) + payload + padding
|
||||||
// So: (4 + 1 + payload_length + padding_length) % 16 == 0
|
// So: (4 + 1 + payload_length + padding_length) % 16 == 0
|
||||||
|
|
||||||
let base_size = 4 + 1 + payload_length; // without padding
|
let base_size = 4 + 1 + payload_length; // without padding
|
||||||
let padding_needed = (block_size - (base_size % block_size)) % block_size;
|
let padding_needed = (block_size - (base_size % block_size)) % block_size;
|
||||||
|
|
||||||
// Ensure padding >= min_padding (RFC 4253 requirement)
|
// Ensure padding >= min_padding (RFC 4253 requirement)
|
||||||
let padding_length: u8 = if padding_needed < min_padding {
|
let padding_length: u8 = if padding_needed < min_padding {
|
||||||
(padding_needed + block_size) as u8 // Add one more block to meet minimum
|
(padding_needed + block_size) as u8 // Add one more block to meet minimum
|
||||||
} else {
|
} else {
|
||||||
padding_needed as u8
|
padding_needed as u8
|
||||||
};
|
};
|
||||||
@@ -224,19 +230,21 @@ impl EncryptedPacket {
|
|||||||
// packet_length = padding_length(1) + payload + padding
|
// packet_length = padding_length(1) + payload + padding
|
||||||
let packet_length = 1 + payload_length + padding_length as usize;
|
let packet_length = 1 + payload_length + padding_length as usize;
|
||||||
|
|
||||||
info!("Creating AES-CTR encrypted packet: payload_len={}, padding_len={}, packet_len={}",
|
info!(
|
||||||
payload_length, padding_length, packet_length);
|
"Creating AES-CTR encrypted packet: payload_len={}, padding_len={}, packet_len={}",
|
||||||
|
payload_length, padding_length, packet_length
|
||||||
|
);
|
||||||
|
|
||||||
// 构建plaintext packet(packet_length + padding_length + payload + padding)
|
// 构建plaintext packet(packet_length + padding_length + payload + padding)
|
||||||
let mut plaintext_packet = Vec::new();
|
let mut plaintext_packet = Vec::new();
|
||||||
plaintext_packet.write_u32::<BigEndian>(packet_length as u32)?; // plaintext packet_length
|
plaintext_packet.write_u32::<BigEndian>(packet_length as u32)?; // plaintext packet_length
|
||||||
plaintext_packet.write_u8(padding_length)?; // plaintext padding_length
|
plaintext_packet.write_u8(padding_length)?; // plaintext padding_length
|
||||||
plaintext_packet.write_all(plaintext_payload)?; // plaintext payload
|
plaintext_packet.write_all(plaintext_payload)?; // plaintext payload
|
||||||
|
|
||||||
let mut random_padding = vec![0u8; padding_length as usize];
|
let mut random_padding = vec![0u8; padding_length as usize];
|
||||||
use rand::RngCore;
|
use rand::RngCore;
|
||||||
rand::thread_rng().fill_bytes(&mut random_padding);
|
rand::thread_rng().fill_bytes(&mut random_padding);
|
||||||
plaintext_packet.write_all(&random_padding)?; // plaintext padding
|
plaintext_packet.write_all(&random_padding)?; // plaintext padding
|
||||||
|
|
||||||
info!("Plaintext packet size: {} bytes", plaintext_packet.len());
|
info!("Plaintext packet size: {} bytes", plaintext_packet.len());
|
||||||
|
|
||||||
@@ -263,10 +271,14 @@ impl EncryptedPacket {
|
|||||||
|
|
||||||
// 然後加密plaintext packet(AES-CTR加密整個packet)
|
// 然後加密plaintext packet(AES-CTR加密整個packet)
|
||||||
let cipher = if is_server_to_client {
|
let cipher = if is_server_to_client {
|
||||||
encryption_ctx.cipher_stoc.as_mut()
|
encryption_ctx
|
||||||
|
.cipher_stoc
|
||||||
|
.as_mut()
|
||||||
.ok_or_else(|| anyhow!("cipher_stoc not initialized"))?
|
.ok_or_else(|| anyhow!("cipher_stoc not initialized"))?
|
||||||
} else {
|
} else {
|
||||||
encryption_ctx.cipher_ctos.as_mut()
|
encryption_ctx
|
||||||
|
.cipher_ctos
|
||||||
|
.as_mut()
|
||||||
.ok_or_else(|| anyhow!("cipher_ctos not initialized"))?
|
.ok_or_else(|| anyhow!("cipher_ctos not initialized"))?
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -292,8 +304,11 @@ impl EncryptedPacket {
|
|||||||
/// 写入加密packet(参考OpenSSH cipher.c)
|
/// 写入加密packet(参考OpenSSH cipher.c)
|
||||||
/// AES-CTR模式:写入完整加密packet + MAC
|
/// AES-CTR模式:写入完整加密packet + MAC
|
||||||
pub fn write<W: std::io::Write>(&self, stream: &mut W) -> Result<()> {
|
pub fn write<W: std::io::Write>(&self, stream: &mut W) -> Result<()> {
|
||||||
info!("Writing AES-CTR encrypted packet: total_encrypted_len={}, mac_len={}",
|
info!(
|
||||||
self.payload.len(), self.mac.len());
|
"Writing AES-CTR encrypted packet: total_encrypted_len={}, mac_len={}",
|
||||||
|
self.payload.len(),
|
||||||
|
self.mac.len()
|
||||||
|
);
|
||||||
|
|
||||||
// AES-CTR: 整个packet已加密(包括packet_length),直接写入
|
// AES-CTR: 整个packet已加密(包括packet_length),直接写入
|
||||||
stream.write_all(&self.payload)?;
|
stream.write_all(&self.payload)?;
|
||||||
@@ -322,18 +337,28 @@ impl EncryptedPacket {
|
|||||||
let mut first_block_encrypted = [0u8; 16];
|
let mut first_block_encrypted = [0u8; 16];
|
||||||
stream.read_exact(&mut first_block_encrypted)?;
|
stream.read_exact(&mut first_block_encrypted)?;
|
||||||
|
|
||||||
info!("Read first encrypted block (16 bytes): {:?}", &first_block_encrypted);
|
info!(
|
||||||
|
"Read first encrypted block (16 bytes): {:?}",
|
||||||
|
&first_block_encrypted
|
||||||
|
);
|
||||||
|
|
||||||
// 2. 获取持久化cipher实例(counter已递增)
|
// 2. 获取持久化cipher实例(counter已递增)
|
||||||
let cipher = if is_client_to_server {
|
let cipher = if is_client_to_server {
|
||||||
encryption_ctx.cipher_ctos.as_mut()
|
encryption_ctx
|
||||||
|
.cipher_ctos
|
||||||
|
.as_mut()
|
||||||
.ok_or_else(|| anyhow!("cipher_ctos not initialized"))?
|
.ok_or_else(|| anyhow!("cipher_ctos not initialized"))?
|
||||||
} else {
|
} else {
|
||||||
encryption_ctx.cipher_stoc.as_mut()
|
encryption_ctx
|
||||||
|
.cipher_stoc
|
||||||
|
.as_mut()
|
||||||
.ok_or_else(|| anyhow!("cipher_stoc not initialized"))?
|
.ok_or_else(|| anyhow!("cipher_stoc not initialized"))?
|
||||||
};
|
};
|
||||||
|
|
||||||
info!("Using cipher for decryption (is_client_to_server={})", is_client_to_server);
|
info!(
|
||||||
|
"Using cipher for decryption (is_client_to_server={})",
|
||||||
|
is_client_to_server
|
||||||
|
);
|
||||||
|
|
||||||
// 3. 解密第一个块(counter自动递增)
|
// 3. 解密第一个块(counter自动递增)
|
||||||
let mut first_block_decrypted = first_block_encrypted;
|
let mut first_block_decrypted = first_block_encrypted;
|
||||||
@@ -350,7 +375,10 @@ impl EncryptedPacket {
|
|||||||
]);
|
]);
|
||||||
let padding_length = first_block_decrypted[4];
|
let padding_length = first_block_decrypted[4];
|
||||||
|
|
||||||
info!("Decrypted packet_length={}, padding_length={}", packet_length, padding_length);
|
info!(
|
||||||
|
"Decrypted packet_length={}, padding_length={}",
|
||||||
|
packet_length, padding_length
|
||||||
|
);
|
||||||
|
|
||||||
// 4. 合理性检查
|
// 4. 合理性检查
|
||||||
if packet_length > 35000 {
|
if packet_length > 35000 {
|
||||||
@@ -362,10 +390,13 @@ impl EncryptedPacket {
|
|||||||
// packet_length = padding_length(1) + payload + padding
|
// packet_length = padding_length(1) + payload + padding
|
||||||
// 总加密数据 = packet_length(4) + packet_length = packet_length + 4
|
// 总加密数据 = packet_length(4) + packet_length = packet_length + 4
|
||||||
// 已读取16字节,剩余 = packet_length + 4 - 16
|
// 已读取16字节,剩余 = packet_length + 4 - 16
|
||||||
let total_encrypted_size = packet_length as usize + 4; // packet_length field + content
|
let total_encrypted_size = packet_length as usize + 4; // packet_length field + content
|
||||||
let remaining_encrypted_size = total_encrypted_size - 16;
|
let remaining_encrypted_size = total_encrypted_size - 16;
|
||||||
|
|
||||||
info!("Total encrypted size: {}, remaining: {}", total_encrypted_size, remaining_encrypted_size);
|
info!(
|
||||||
|
"Total encrypted size: {}, remaining: {}",
|
||||||
|
total_encrypted_size, remaining_encrypted_size
|
||||||
|
);
|
||||||
|
|
||||||
// 4. 读取剩余加密数据
|
// 4. 读取剩余加密数据
|
||||||
let mut remaining_encrypted = vec![0u8; remaining_encrypted_size];
|
let mut remaining_encrypted = vec![0u8; remaining_encrypted_size];
|
||||||
@@ -431,7 +462,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes256_ctr_encryption() {
|
fn test_aes256_ctr_encryption() {
|
||||||
let key = vec![0u8; 16]; // AES-128 key (16 bytes)
|
let key = vec![0u8; 16]; // AES-128 key (16 bytes)
|
||||||
let iv = vec![0u8; 16];
|
let iv = vec![0u8; 16];
|
||||||
let plaintext = b"Hello World";
|
let plaintext = b"Hello World";
|
||||||
|
|
||||||
@@ -467,7 +498,7 @@ mod tests {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let mac = ctx.compute_mac(1, data, &key).unwrap();
|
let mac = ctx.compute_mac(1, data, &key).unwrap();
|
||||||
assert_eq!(mac.len(), 32); // HMAC-SHA256 = 32字节
|
assert_eq!(mac.len(), 32); // HMAC-SHA256 = 32字节
|
||||||
|
|
||||||
// 验证MAC
|
// 验证MAC
|
||||||
assert!(ctx.verify_mac(1, data, &mac, &key).unwrap());
|
assert!(ctx.verify_mac(1, data, &mac, &key).unwrap());
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
// SSH加密模块(Phase 3:密钥交换)
|
// SSH加密模块(Phase 3:密钥交换)
|
||||||
// 参考OpenSSH curve25519.c, kex.c
|
// 参考OpenSSH curve25519.c, kex.c
|
||||||
|
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{anyhow, Result};
|
||||||
use x25519_dalek::{EphemeralSecret, PublicKey, SharedSecret};
|
use ed25519_dalek::{Signer, SigningKey};
|
||||||
use ed25519_dalek::{SigningKey, VerifyingKey, Signature, Signer};
|
use log::info;
|
||||||
use sha2::{Sha256, Digest};
|
|
||||||
use log::{info, debug};
|
|
||||||
use rand::rngs::OsRng;
|
use rand::rngs::OsRng;
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||||
|
|
||||||
/// Curve25519密钥交换处理器(参考OpenSSH curve25519.c)
|
/// Curve25519密钥交换处理器(参考OpenSSH curve25519.c)
|
||||||
pub struct Curve25519Kex {
|
pub struct Curve25519Kex {
|
||||||
secret: Option<EphemeralSecret>, // 使用Option包装(一次性使用类型)
|
secret: Option<EphemeralSecret>, // 使用Option包装(一次性使用类型)
|
||||||
public: PublicKey,
|
public: PublicKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,7 +22,10 @@ impl Curve25519Kex {
|
|||||||
let secret = EphemeralSecret::random_from_rng(OsRng);
|
let secret = EphemeralSecret::random_from_rng(OsRng);
|
||||||
let public = PublicKey::from(&secret);
|
let public = PublicKey::from(&secret);
|
||||||
|
|
||||||
Self { secret: Some(secret), public } // Some包装
|
Self {
|
||||||
|
secret: Some(secret),
|
||||||
|
public,
|
||||||
|
} // Some包装
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 获取公钥(用于SSH_MSG_KEX_ECDH_INIT)
|
/// 获取公钥(用于SSH_MSG_KEX_ECDH_INIT)
|
||||||
@@ -48,7 +51,7 @@ impl Curve25519Kex {
|
|||||||
if let Some(secret) = self.secret.take() {
|
if let Some(secret) = self.secret.take() {
|
||||||
let shared_secret = secret.diffie_hellman(&client_public_key);
|
let shared_secret = secret.diffie_hellman(&client_public_key);
|
||||||
info!("Computed shared secret: {:?}", shared_secret.as_bytes());
|
info!("Computed shared secret: {:?}", shared_secret.as_bytes());
|
||||||
Ok(shared_secret.as_bytes().clone())
|
Ok(*shared_secret.as_bytes())
|
||||||
} else {
|
} else {
|
||||||
Err(anyhow!("Secret already used"))
|
Err(anyhow!("Secret already used"))
|
||||||
}
|
}
|
||||||
@@ -71,10 +74,10 @@ impl SessionKeys {
|
|||||||
/// RFC 4253 Section 7.2: Key = HASH(K || H || X || session_id)
|
/// RFC 4253 Section 7.2: Key = HASH(K || H || X || session_id)
|
||||||
pub fn derive(
|
pub fn derive(
|
||||||
shared_secret: &[u8],
|
shared_secret: &[u8],
|
||||||
exchange_hash: &[u8], // H参数(exchange hash)
|
exchange_hash: &[u8], // H参数(exchange hash)
|
||||||
server_public_key: &[u8],
|
_server_public_key: &[u8],
|
||||||
client_public_key: &[u8],
|
_client_public_key: &[u8],
|
||||||
server_host_key: &[u8],
|
_server_host_key: &[u8],
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
// RFC 4253: session_id = H (第一次exchange hash)
|
// RFC 4253: session_id = H (第一次exchange hash)
|
||||||
let session_id = exchange_hash.to_vec();
|
let session_id = exchange_hash.to_vec();
|
||||||
@@ -86,7 +89,11 @@ impl SessionKeys {
|
|||||||
// OpenSSH sshbuf_put_bignum2_bytes() uses bytes DIRECTLY (no reversal)
|
// OpenSSH sshbuf_put_bignum2_bytes() uses bytes DIRECTLY (no reversal)
|
||||||
// Treats little-endian bytes as big-endian mpint (logical reinterpret)
|
// Treats little-endian bytes as big-endian mpint (logical reinterpret)
|
||||||
info!(" Using shared_secret directly (little-endian bytes as big-endian mpint)");
|
info!(" Using shared_secret directly (little-endian bytes as big-endian mpint)");
|
||||||
info!(" shared_secret[0] = {} (>=0x80? {})", shared_secret[0], shared_secret[0] >= 0x80);
|
info!(
|
||||||
|
" shared_secret[0] = {} (>=0x80? {})",
|
||||||
|
shared_secret[0],
|
||||||
|
shared_secret[0] >= 0x80
|
||||||
|
);
|
||||||
info!(" exchange_hash full (32 bytes): {:?}", exchange_hash);
|
info!(" exchange_hash full (32 bytes): {:?}", exchange_hash);
|
||||||
info!(" session_id full (32 bytes): {:?}", session_id);
|
info!(" session_id full (32 bytes): {:?}", session_id);
|
||||||
|
|
||||||
@@ -94,23 +101,57 @@ impl SessionKeys {
|
|||||||
// K is shared_secret encoded as mpint (using little-endian bytes directly)
|
// K is shared_secret encoded as mpint (using little-endian bytes directly)
|
||||||
let shared_secret_mpint = Self::encode_mpint(shared_secret);
|
let shared_secret_mpint = Self::encode_mpint(shared_secret);
|
||||||
|
|
||||||
info!(" shared_secret_mpint ({} bytes): {:?}", shared_secret_mpint.len(), &shared_secret_mpint[..std::cmp::min(12, shared_secret_mpint.len())]);
|
info!(
|
||||||
|
" shared_secret_mpint ({} bytes): {:?}",
|
||||||
|
shared_secret_mpint.len(),
|
||||||
|
&shared_secret_mpint[..std::cmp::min(12, shared_secret_mpint.len())]
|
||||||
|
);
|
||||||
|
|
||||||
let encryption_key_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'C', &session_id)?;
|
let encryption_key_ctos =
|
||||||
let encryption_key_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'D', &session_id)?;
|
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'C', &session_id)?;
|
||||||
let mac_key_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'E', &session_id)?;
|
let encryption_key_stoc =
|
||||||
let mac_key_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'F', &session_id)?;
|
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'D', &session_id)?;
|
||||||
|
let mac_key_ctos =
|
||||||
|
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'E', &session_id)?;
|
||||||
|
let mac_key_stoc =
|
||||||
|
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'F', &session_id)?;
|
||||||
|
|
||||||
let iv_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'A', &session_id)?;
|
let iv_ctos =
|
||||||
let iv_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'B', &session_id)?;
|
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'A', &session_id)?;
|
||||||
|
let iv_stoc =
|
||||||
|
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'B', &session_id)?;
|
||||||
|
|
||||||
info!("Derived keys summary:");
|
info!("Derived keys summary:");
|
||||||
info!(" encryption_key_ctos ({} bytes): {:?}", encryption_key_ctos.len(), &encryption_key_ctos[..std::cmp::min(16, encryption_key_ctos.len())]);
|
info!(
|
||||||
info!(" encryption_key_stoc ({} bytes): {:?}", encryption_key_stoc.len(), &encryption_key_stoc[..std::cmp::min(16, encryption_key_stoc.len())]);
|
" encryption_key_ctos ({} bytes): {:?}",
|
||||||
info!(" iv_ctos ({} bytes): {:?}", iv_ctos.len(), &iv_ctos[..std::cmp::min(16, iv_ctos.len())]);
|
encryption_key_ctos.len(),
|
||||||
info!(" iv_stoc ({} bytes): {:?}", iv_stoc.len(), &iv_stoc[..std::cmp::min(16, iv_stoc.len())]);
|
&encryption_key_ctos[..std::cmp::min(16, encryption_key_ctos.len())]
|
||||||
info!(" mac_key_ctos ({} bytes): {:?}", mac_key_ctos.len(), &mac_key_ctos[..std::cmp::min(16, mac_key_ctos.len())]);
|
);
|
||||||
info!(" mac_key_stoc ({} bytes): {:?}", mac_key_stoc.len(), &mac_key_stoc[..std::cmp::min(16, mac_key_stoc.len())]);
|
info!(
|
||||||
|
" encryption_key_stoc ({} bytes): {:?}",
|
||||||
|
encryption_key_stoc.len(),
|
||||||
|
&encryption_key_stoc[..std::cmp::min(16, encryption_key_stoc.len())]
|
||||||
|
);
|
||||||
|
info!(
|
||||||
|
" iv_ctos ({} bytes): {:?}",
|
||||||
|
iv_ctos.len(),
|
||||||
|
&iv_ctos[..std::cmp::min(16, iv_ctos.len())]
|
||||||
|
);
|
||||||
|
info!(
|
||||||
|
" iv_stoc ({} bytes): {:?}",
|
||||||
|
iv_stoc.len(),
|
||||||
|
&iv_stoc[..std::cmp::min(16, iv_stoc.len())]
|
||||||
|
);
|
||||||
|
info!(
|
||||||
|
" mac_key_ctos ({} bytes): {:?}",
|
||||||
|
mac_key_ctos.len(),
|
||||||
|
&mac_key_ctos[..std::cmp::min(16, mac_key_ctos.len())]
|
||||||
|
);
|
||||||
|
info!(
|
||||||
|
" mac_key_stoc ({} bytes): {:?}",
|
||||||
|
mac_key_stoc.len(),
|
||||||
|
&mac_key_stoc[..std::cmp::min(16, mac_key_stoc.len())]
|
||||||
|
);
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
session_id,
|
session_id,
|
||||||
@@ -129,14 +170,22 @@ impl SessionKeys {
|
|||||||
let mut hasher = Sha256::new();
|
let mut hasher = Sha256::new();
|
||||||
|
|
||||||
info!("Deriving key for X='{}'", X);
|
info!("Deriving key for X='{}'", X);
|
||||||
info!(" K_mpint ({} bytes): {:?}", K_mpint.len(), &K_mpint[..std::cmp::min(8, K_mpint.len())]);
|
info!(
|
||||||
|
" K_mpint ({} bytes): {:?}",
|
||||||
|
K_mpint.len(),
|
||||||
|
&K_mpint[..std::cmp::min(8, K_mpint.len())]
|
||||||
|
);
|
||||||
info!(" H ({} bytes): {:?}", H.len(), &H[..8]);
|
info!(" H ({} bytes): {:?}", H.len(), &H[..8]);
|
||||||
info!(" session_id ({} bytes): {:?}", session_id.len(), &session_id[..8]);
|
info!(
|
||||||
|
" session_id ({} bytes): {:?}",
|
||||||
|
session_id.len(),
|
||||||
|
&session_id[..8]
|
||||||
|
);
|
||||||
|
|
||||||
// RFC 4253: HASH(K || H || X || session_id)
|
// RFC 4253: HASH(K || H || X || session_id)
|
||||||
hasher.update(K_mpint); // K (shared secret in mpint format)
|
hasher.update(K_mpint); // K (shared secret in mpint format)
|
||||||
hasher.update(H); // H (exchange hash)
|
hasher.update(H); // H (exchange hash)
|
||||||
hasher.update(&[X as u8]); // X (single character)
|
hasher.update([X as u8]); // X (single character)
|
||||||
hasher.update(session_id); // session_id
|
hasher.update(session_id); // session_id
|
||||||
|
|
||||||
let full_hash = hasher.finalize();
|
let full_hash = hasher.finalize();
|
||||||
@@ -147,7 +196,7 @@ impl SessionKeys {
|
|||||||
// AES-128-CTR key/IV: 16 bytes
|
// AES-128-CTR key/IV: 16 bytes
|
||||||
// HMAC-SHA256 key: 32 bytes
|
// HMAC-SHA256 key: 32 bytes
|
||||||
match X {
|
match X {
|
||||||
'A' | 'B' | 'C' | 'D' => Ok(full_hash[..16].to_vec()), // IV or encryption key
|
'A' | 'B' | 'C' | 'D' => Ok(full_hash[..16].to_vec()), // IV or encryption key
|
||||||
'E' | 'F' => Ok(full_hash.to_vec()), // MAC key (full 32 bytes)
|
'E' | 'F' => Ok(full_hash.to_vec()), // MAC key (full 32 bytes)
|
||||||
_ => Ok(full_hash[..16].to_vec()), // default
|
_ => Ok(full_hash[..16].to_vec()), // default
|
||||||
}
|
}
|
||||||
@@ -192,7 +241,7 @@ pub struct Ed25519HostKey {
|
|||||||
|
|
||||||
impl Ed25519HostKey {
|
impl Ed25519HostKey {
|
||||||
/// 加载或生成主机密钥(参考OpenSSH hostfile.c)
|
/// 加载或生成主机密钥(参考OpenSSH hostfile.c)
|
||||||
pub fn load_or_generate(key_path: &str) -> Result<Self> {
|
pub fn load_or_generate(_key_path: &str) -> Result<Self> {
|
||||||
// 简化实现:生成临时密钥(实际应从文件加载)
|
// 简化实现:生成临时密钥(实际应从文件加载)
|
||||||
// 参考OpenSSH ssh-keygen
|
// 参考OpenSSH ssh-keygen
|
||||||
|
|
||||||
@@ -228,7 +277,7 @@ impl Ed25519HostKey {
|
|||||||
// SSH公钥格式:ssh-ed25519 <base64-encoded-public-key>
|
// SSH公钥格式:ssh-ed25519 <base64-encoded-public-key>
|
||||||
// 参考OpenSSH ssh-keygen -y
|
// 参考OpenSSH ssh-keygen -y
|
||||||
|
|
||||||
use base64::{Engine as _, engine::general_purpose};
|
use base64::{engine::general_purpose, Engine as _};
|
||||||
let encoded = general_purpose::STANDARD.encode(&public_bytes);
|
let encoded = general_purpose::STANDARD.encode(&public_bytes);
|
||||||
|
|
||||||
format!("ssh-ed25519 {}", encoded)
|
format!("ssh-ed25519 {}", encoded)
|
||||||
@@ -251,10 +300,14 @@ mod tests {
|
|||||||
let mut server_kex = Curve25519Kex::new();
|
let mut server_kex = Curve25519Kex::new();
|
||||||
|
|
||||||
// 客户端计算共享密钥
|
// 客户端计算共享密钥
|
||||||
let client_secret = client_kex.compute_shared_secret(server_kex.public_key()).unwrap();
|
let client_secret = client_kex
|
||||||
|
.compute_shared_secret(server_kex.public_key())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// 服务器计算共享密钥
|
// 服务器计算共享密钥
|
||||||
let server_secret = server_kex.compute_shared_secret(client_kex.public_key()).unwrap();
|
let server_secret = server_kex
|
||||||
|
.compute_shared_secret(client_kex.public_key())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// 应该相同(Curve25519特性)
|
// 应该相同(Curve25519特性)
|
||||||
assert_eq!(client_secret, server_secret);
|
assert_eq!(client_secret, server_secret);
|
||||||
@@ -272,6 +325,6 @@ mod tests {
|
|||||||
let data = b"test data";
|
let data = b"test data";
|
||||||
|
|
||||||
let signature = host_key.sign(data).unwrap();
|
let signature = host_key.sign(data).unwrap();
|
||||||
assert_eq!(signature.len(), 64); // Ed25519签名64字节
|
assert_eq!(signature.len(), 64); // Ed25519签名64字节
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
// SSH端口转发数据传输(Phase 13.5)
|
// SSH端口转发数据传输(Phase 13.5)
|
||||||
// 参考OpenSSH channels.c: channel_handle_data()
|
// 参考OpenSSH channels.c: channel_handle_data()
|
||||||
|
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{anyhow, Result};
|
||||||
use log::{info, warn, debug};
|
|
||||||
use std::net::{TcpStream};
|
|
||||||
use std::io::{Read, Write};
|
|
||||||
use std::thread;
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
use byteorder::{BigEndian, WriteBytesExt};
|
use byteorder::{BigEndian, WriteBytesExt};
|
||||||
|
use log::{debug, info, warn};
|
||||||
|
use std::io::{Read, Write};
|
||||||
|
use std::net::TcpStream;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
/// 数据转发器(Phase 13.5:双向数据传输)
|
/// 数据转发器(Phase 13.5:双向数据传输)
|
||||||
pub struct DataForwarder {
|
pub struct DataForwarder {
|
||||||
@@ -29,22 +29,33 @@ impl DataForwarder {
|
|||||||
/// 启动双向数据转发(Phase 13.5:SSH channel ↔ TCP socket)
|
/// 启动双向数据转发(Phase 13.5:SSH channel ↔ TCP socket)
|
||||||
pub fn start_bidirectional_forwarding(
|
pub fn start_bidirectional_forwarding(
|
||||||
&mut self,
|
&mut self,
|
||||||
ssh_stream: TcpStream, // SSH client连接(加密通道)
|
ssh_stream: TcpStream, // SSH client连接(加密通道)
|
||||||
target_stream: TcpStream, // 目标服务连接(TCP socket)
|
target_stream: TcpStream, // 目标服务连接(TCP socket)
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
info!("Starting bidirectional data forwarding for channel {}", self.channel_id);
|
info!(
|
||||||
|
"Starting bidirectional data forwarding for channel {}",
|
||||||
|
self.channel_id
|
||||||
|
);
|
||||||
|
|
||||||
// Phase 13.5: SSH channel → Target socket(SSH client数据 → 本地服务)
|
// Phase 13.5: SSH channel → Target socket(SSH client数据 → 本地服务)
|
||||||
let ssh_to_target = self.start_ssh_to_target_forwarding(ssh_stream.try_clone()?, target_stream.try_clone()?);
|
let ssh_to_target = self
|
||||||
|
.start_ssh_to_target_forwarding(ssh_stream.try_clone()?, target_stream.try_clone()?);
|
||||||
|
|
||||||
// Phase 13.5: Target socket → SSH channel(本地服务数据 → SSH client)
|
// Phase 13.5: Target socket → SSH channel(本地服务数据 → SSH client)
|
||||||
let target_to_ssh = self.start_target_to_ssh_forwarding(target_stream, ssh_stream);
|
let target_to_ssh = self.start_target_to_ssh_forwarding(target_stream, ssh_stream);
|
||||||
|
|
||||||
// Phase 13.5: 等待两个转发线程完成
|
// Phase 13.5: 等待两个转发线程完成
|
||||||
ssh_to_target.join().map_err(|e| anyhow!("SSH to target thread error: {:?}", e))?;
|
ssh_to_target
|
||||||
target_to_ssh.join().map_err(|e| anyhow!("Target to SSH thread error: {:?}", e))?;
|
.join()
|
||||||
|
.map_err(|e| anyhow!("SSH to target thread error: {:?}", e))?;
|
||||||
|
target_to_ssh
|
||||||
|
.join()
|
||||||
|
.map_err(|e| anyhow!("Target to SSH thread error: {:?}", e))?;
|
||||||
|
|
||||||
info!("Bidirectional data forwarding completed for channel {}", self.channel_id);
|
info!(
|
||||||
|
"Bidirectional data forwarding completed for channel {}",
|
||||||
|
self.channel_id
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,7 +70,10 @@ impl DataForwarder {
|
|||||||
let max_packet_size = self.max_packet_size;
|
let max_packet_size = self.max_packet_size;
|
||||||
|
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
info!("SSH to target forwarding thread started for channel {}", channel_id);
|
info!(
|
||||||
|
"SSH to target forwarding thread started for channel {}",
|
||||||
|
channel_id
|
||||||
|
);
|
||||||
|
|
||||||
let mut buffer = vec![0u8; max_packet_size as usize];
|
let mut buffer = vec![0u8; max_packet_size as usize];
|
||||||
|
|
||||||
@@ -68,7 +82,7 @@ impl DataForwarder {
|
|||||||
let n = match ssh_stream.read(&mut buffer) {
|
let n = match ssh_stream.read(&mut buffer) {
|
||||||
Ok(0) => {
|
Ok(0) => {
|
||||||
info!("SSH channel EOF for channel {}", channel_id);
|
info!("SSH channel EOF for channel {}", channel_id);
|
||||||
break; // EOF
|
break; // EOF
|
||||||
}
|
}
|
||||||
Ok(n) => n,
|
Ok(n) => n,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -81,8 +95,10 @@ impl DataForwarder {
|
|||||||
{
|
{
|
||||||
let window = window_size.lock().unwrap();
|
let window = window_size.lock().unwrap();
|
||||||
if *window < n as u32 {
|
if *window < n as u32 {
|
||||||
warn!("Window size insufficient for channel {}: need {}, have {}",
|
warn!(
|
||||||
channel_id, n, *window);
|
"Window size insufficient for channel {}: need {}, have {}",
|
||||||
|
channel_id, n, *window
|
||||||
|
);
|
||||||
// Phase 13.5: 理论上应该等待SSH_MSG_CHANNEL_WINDOW_ADJUST
|
// Phase 13.5: 理论上应该等待SSH_MSG_CHANNEL_WINDOW_ADJUST
|
||||||
// 简化实现:继续发送(可能会违反RFC 4254)
|
// 简化实现:继续发送(可能会违反RFC 4254)
|
||||||
}
|
}
|
||||||
@@ -90,13 +106,19 @@ impl DataForwarder {
|
|||||||
|
|
||||||
// Phase 13.5: 写入目标socket
|
// Phase 13.5: 写入目标socket
|
||||||
if let Err(e) = target_stream.write_all(&buffer[..n]) {
|
if let Err(e) = target_stream.write_all(&buffer[..n]) {
|
||||||
warn!("Target socket write error for channel {}: {}", channel_id, e);
|
warn!(
|
||||||
|
"Target socket write error for channel {}: {}",
|
||||||
|
channel_id, e
|
||||||
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 13.5: Flush确保数据发送
|
// Phase 13.5: Flush确保数据发送
|
||||||
if let Err(e) = target_stream.flush() {
|
if let Err(e) = target_stream.flush() {
|
||||||
warn!("Target socket flush error for channel {}: {}", channel_id, e);
|
warn!(
|
||||||
|
"Target socket flush error for channel {}: {}",
|
||||||
|
channel_id, e
|
||||||
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,14 +126,22 @@ impl DataForwarder {
|
|||||||
{
|
{
|
||||||
let mut window = window_size.lock().unwrap();
|
let mut window = window_size.lock().unwrap();
|
||||||
*window -= n as u32;
|
*window -= n as u32;
|
||||||
debug!("Window size consumed for channel {}: {} bytes, remaining {}",
|
debug!(
|
||||||
channel_id, n, *window);
|
"Window size consumed for channel {}: {} bytes, remaining {}",
|
||||||
|
channel_id, n, *window
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Forwarded {} bytes from SSH to target for channel {}", n, channel_id);
|
info!(
|
||||||
|
"Forwarded {} bytes from SSH to target for channel {}",
|
||||||
|
n, channel_id
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("SSH to target forwarding thread stopped for channel {}", channel_id);
|
info!(
|
||||||
|
"SSH to target forwarding thread stopped for channel {}",
|
||||||
|
channel_id
|
||||||
|
);
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,16 +154,19 @@ impl DataForwarder {
|
|||||||
let channel_id = self.channel_id;
|
let channel_id = self.channel_id;
|
||||||
|
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
info!("Target to SSH forwarding thread started for channel {}", channel_id);
|
info!(
|
||||||
|
"Target to SSH forwarding thread started for channel {}",
|
||||||
|
channel_id
|
||||||
|
);
|
||||||
|
|
||||||
let mut buffer = vec![0u8; 8192]; // 8KB buffer
|
let mut buffer = vec![0u8; 8192]; // 8KB buffer
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
// Phase 13.5: 从目标socket读取数据
|
// Phase 13.5: 从目标socket读取数据
|
||||||
let n = match target_stream.read(&mut buffer) {
|
let n = match target_stream.read(&mut buffer) {
|
||||||
Ok(0) => {
|
Ok(0) => {
|
||||||
info!("Target socket EOF for channel {}", channel_id);
|
info!("Target socket EOF for channel {}", channel_id);
|
||||||
break; // EOF
|
break; // EOF
|
||||||
}
|
}
|
||||||
Ok(n) => n,
|
Ok(n) => n,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -158,10 +191,16 @@ impl DataForwarder {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Forwarded {} bytes from target to SSH for channel {}", n, channel_id);
|
info!(
|
||||||
|
"Forwarded {} bytes from target to SSH for channel {}",
|
||||||
|
n, channel_id
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Target to SSH forwarding thread stopped for channel {}", channel_id);
|
info!(
|
||||||
|
"Target to SSH forwarding thread stopped for channel {}",
|
||||||
|
channel_id
|
||||||
|
);
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,8 +213,10 @@ impl DataForwarder {
|
|||||||
pub fn adjust_window_size(&self, bytes_to_add: u32) {
|
pub fn adjust_window_size(&self, bytes_to_add: u32) {
|
||||||
let mut window = self.window_size.lock().unwrap();
|
let mut window = self.window_size.lock().unwrap();
|
||||||
*window += bytes_to_add;
|
*window += bytes_to_add;
|
||||||
info!("Window size adjusted for channel {}: added {} bytes, total {}",
|
info!(
|
||||||
self.channel_id, bytes_to_add, *window);
|
"Window size adjusted for channel {}: added {} bytes, total {}",
|
||||||
|
self.channel_id, bytes_to_add, *window
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 检查window size是否足够(Phase 13.5)
|
/// 检查window size是否足够(Phase 13.5)
|
||||||
@@ -245,7 +286,7 @@ mod tests {
|
|||||||
let data = b"Hello, SSH!";
|
let data = b"Hello, SSH!";
|
||||||
let packet = build_channel_data_packet(1, data).unwrap();
|
let packet = build_channel_data_packet(1, data).unwrap();
|
||||||
|
|
||||||
assert_eq!(packet[0], 94); // SSH_MSG_CHANNEL_DATA
|
assert_eq!(packet[0], 94); // SSH_MSG_CHANNEL_DATA
|
||||||
// 验证packet结构
|
// 验证packet结构
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,42 +1,42 @@
|
|||||||
// SSH密钥交换算法协商实现(Phase 2)
|
// SSH密钥交换算法协商实现(Phase 2)
|
||||||
// 参考OpenSSH kex.c: kex_send_kexinit(), kex_choose_conf()
|
// 参考OpenSSH kex.c: kex_send_kexinit(), kex_choose_conf()
|
||||||
|
|
||||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{anyhow, Result};
|
||||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||||
use log::{info, debug};
|
use log::{debug, info};
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Write};
|
||||||
|
|
||||||
/// SSH算法类型(参考OpenSSH PROTOCOL定义)
|
/// SSH算法类型(参考OpenSSH PROTOCOL定义)
|
||||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
pub enum AlgorithmType {
|
pub enum AlgorithmType {
|
||||||
KEX_ALGS = 0, // 密钥交换算法
|
KEX_ALGS = 0, // 密钥交换算法
|
||||||
SERVER_HOST_KEY_ALGS = 1, // 服务器主机密钥算法
|
SERVER_HOST_KEY_ALGS = 1, // 服务器主机密钥算法
|
||||||
ENC_ALGS_CTOS = 2, // 客户端到服务器加密算法
|
ENC_ALGS_CTOS = 2, // 客户端到服务器加密算法
|
||||||
ENC_ALGS_STOC = 3, // 服务器到客户端加密算法
|
ENC_ALGS_STOC = 3, // 服务器到客户端加密算法
|
||||||
MAC_ALGS_CTOS = 4, // 客户端到服务器MAC算法
|
MAC_ALGS_CTOS = 4, // 客户端到服务器MAC算法
|
||||||
MAC_ALGS_STOC = 5, // 服务器到客户端MAC算法
|
MAC_ALGS_STOC = 5, // 服务器到客户端MAC算法
|
||||||
COMP_ALGS_CTOS = 6, // 客户端到服务器压缩算法
|
COMP_ALGS_CTOS = 6, // 客户端到服务器压缩算法
|
||||||
COMP_ALGS_STOC = 7, // 服务器到客户端压缩算法
|
COMP_ALGS_STOC = 7, // 服务器到客户端压缩算法
|
||||||
LANGS_CTOS = 8, // 客户端到服务器语言
|
LANGS_CTOS = 8, // 客户端到服务器语言
|
||||||
LANGS_STOC = 9, // 服务器到客户端语言
|
LANGS_STOC = 9, // 服务器到客户端语言
|
||||||
}
|
}
|
||||||
|
|
||||||
/// SSH算法提议(参考OpenSSH kex.h: struct kex)
|
/// SSH算法提议(参考OpenSSH kex.h: struct kex)
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct KexProposal {
|
pub struct KexProposal {
|
||||||
pub kex_algorithms: String, // 密钥交换算法列表
|
pub kex_algorithms: String, // 密钥交换算法列表
|
||||||
pub server_host_key_algorithms: String, // 主机密钥算法列表
|
pub server_host_key_algorithms: String, // 主机密钥算法列表
|
||||||
pub encryption_algorithms_ctos: String, // 加密算法(客户端→服务器)
|
pub encryption_algorithms_ctos: String, // 加密算法(客户端→服务器)
|
||||||
pub encryption_algorithms_stoc: String, // 加密算法(服务器→客户端)
|
pub encryption_algorithms_stoc: String, // 加密算法(服务器→客户端)
|
||||||
pub mac_algorithms_ctos: String, // MAC算法(客户端→服务器)
|
pub mac_algorithms_ctos: String, // MAC算法(客户端→服务器)
|
||||||
pub mac_algorithms_stoc: String, // MAC算法(服务器→客户端)
|
pub mac_algorithms_stoc: String, // MAC算法(服务器→客户端)
|
||||||
pub compression_algorithms_ctos: String, // 压缩算法(客户端→服务器)
|
pub compression_algorithms_ctos: String, // 压缩算法(客户端→服务器)
|
||||||
pub compression_algorithms_stoc: String, // 压缩算法(服务器→客户端)
|
pub compression_algorithms_stoc: String, // 压缩算法(服务器→客户端)
|
||||||
pub languages_ctos: String, // 语言(客户端→服务器)
|
pub languages_ctos: String, // 语言(客户端→服务器)
|
||||||
pub languages_stoc: String, // 语言(服务器→客户端)
|
pub languages_stoc: String, // 语言(服务器→客户端)
|
||||||
pub first_kex_packet_follows: bool, // 是否立即发送第一个KEX packet
|
pub first_kex_packet_follows: bool, // 是否立即发送第一个KEX packet
|
||||||
pub reserved: u32, // 保留字段(0)
|
pub reserved: u32, // 保留字段(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl KexProposal {
|
impl KexProposal {
|
||||||
@@ -125,7 +125,7 @@ impl KexProposal {
|
|||||||
|
|
||||||
/// 从SSH_MSG_KEXINIT packet解析(参考OpenSSH kex_input_kexinit())
|
/// 从SSH_MSG_KEXINIT packet解析(参考OpenSSH kex_input_kexinit())
|
||||||
pub fn from_kexinit_packet(packet: &SshPacket) -> Result<Self> {
|
pub fn from_kexinit_packet(packet: &SshPacket) -> Result<Self> {
|
||||||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准)
|
let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准)
|
||||||
|
|
||||||
// Packet type
|
// Packet type
|
||||||
let packet_type = cursor.read_u8()?;
|
let packet_type = cursor.read_u8()?;
|
||||||
@@ -174,14 +174,14 @@ impl KexProposal {
|
|||||||
/// SSH算法协商结果(参考OpenSSH struct kex)
|
/// SSH算法协商结果(参考OpenSSH struct kex)
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct KexResult {
|
pub struct KexResult {
|
||||||
pub kex_algorithm: String, // 选定的密钥交换算法
|
pub kex_algorithm: String, // 选定的密钥交换算法
|
||||||
pub host_key_algorithm: String, // 选定的主机密钥算法
|
pub host_key_algorithm: String, // 选定的主机密钥算法
|
||||||
pub encryption_ctos: String, // 选定的加密算法(客户端→服务器)
|
pub encryption_ctos: String, // 选定的加密算法(客户端→服务器)
|
||||||
pub encryption_stoc: String, // 选定的加密算法(服务器→客户端)
|
pub encryption_stoc: String, // 选定的加密算法(服务器→客户端)
|
||||||
pub mac_ctos: String, // 选定的MAC算法(客户端→服务器)
|
pub mac_ctos: String, // 选定的MAC算法(客户端→服务器)
|
||||||
pub mac_stoc: String, // 选定的MAC算法(服务器→客户端)
|
pub mac_stoc: String, // 选定的MAC算法(服务器→客户端)
|
||||||
pub compression_ctos: String, // 选定的压缩算法(客户端→服务器)
|
pub compression_ctos: String, // 选定的压缩算法(客户端→服务器)
|
||||||
pub compression_stoc: String, // 选定的压缩算法(服务器→客户端)
|
pub compression_stoc: String, // 选定的压缩算法(服务器→客户端)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 算法匹配逻辑(参考OpenSSH kex_choose_conf())
|
/// 算法匹配逻辑(参考OpenSSH kex_choose_conf())
|
||||||
@@ -197,19 +197,34 @@ impl KexResult {
|
|||||||
let kex_algorithm = match_algorithm(&client.kex_algorithms, &server.kex_algorithms)?;
|
let kex_algorithm = match_algorithm(&client.kex_algorithms, &server.kex_algorithms)?;
|
||||||
|
|
||||||
// 主机密钥算法匹配
|
// 主机密钥算法匹配
|
||||||
let host_key_algorithm = match_algorithm(&client.server_host_key_algorithms, &server.server_host_key_algorithms)?;
|
let host_key_algorithm = match_algorithm(
|
||||||
|
&client.server_host_key_algorithms,
|
||||||
|
&server.server_host_key_algorithms,
|
||||||
|
)?;
|
||||||
|
|
||||||
// 加密算法匹配
|
// 加密算法匹配
|
||||||
let encryption_ctos = match_algorithm(&client.encryption_algorithms_ctos, &server.encryption_algorithms_ctos)?;
|
let encryption_ctos = match_algorithm(
|
||||||
let encryption_stoc = match_algorithm(&client.encryption_algorithms_stoc, &server.encryption_algorithms_stoc)?;
|
&client.encryption_algorithms_ctos,
|
||||||
|
&server.encryption_algorithms_ctos,
|
||||||
|
)?;
|
||||||
|
let encryption_stoc = match_algorithm(
|
||||||
|
&client.encryption_algorithms_stoc,
|
||||||
|
&server.encryption_algorithms_stoc,
|
||||||
|
)?;
|
||||||
|
|
||||||
// MAC算法匹配
|
// MAC算法匹配
|
||||||
let mac_ctos = match_algorithm(&client.mac_algorithms_ctos, &server.mac_algorithms_ctos)?;
|
let mac_ctos = match_algorithm(&client.mac_algorithms_ctos, &server.mac_algorithms_ctos)?;
|
||||||
let mac_stoc = match_algorithm(&client.mac_algorithms_stoc, &server.mac_algorithms_stoc)?;
|
let mac_stoc = match_algorithm(&client.mac_algorithms_stoc, &server.mac_algorithms_stoc)?;
|
||||||
|
|
||||||
// 压缩算法匹配
|
// 压缩算法匹配
|
||||||
let compression_ctos = match_algorithm(&client.compression_algorithms_ctos, &server.compression_algorithms_ctos)?;
|
let compression_ctos = match_algorithm(
|
||||||
let compression_stoc = match_algorithm(&client.compression_algorithms_stoc, &server.compression_algorithms_stoc)?;
|
&client.compression_algorithms_ctos,
|
||||||
|
&server.compression_algorithms_ctos,
|
||||||
|
)?;
|
||||||
|
let compression_stoc = match_algorithm(
|
||||||
|
&client.compression_algorithms_stoc,
|
||||||
|
&server.compression_algorithms_stoc,
|
||||||
|
)?;
|
||||||
|
|
||||||
info!("Algorithm negotiation completed:");
|
info!("Algorithm negotiation completed:");
|
||||||
debug!(" KEX: {}", kex_algorithm);
|
debug!(" KEX: {}", kex_algorithm);
|
||||||
@@ -245,7 +260,11 @@ fn match_algorithm(client_algs: &str, server_algs: &str) -> Result<String> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(anyhow!("No matching algorithm found: client={}, server={}", client_algs, server_algs))
|
Err(anyhow!(
|
||||||
|
"No matching algorithm found: client={}, server={}",
|
||||||
|
client_algs,
|
||||||
|
server_algs
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// SSH string写入辅助函数(length + data)
|
/// SSH string写入辅助函数(length + data)
|
||||||
@@ -286,7 +305,7 @@ mod tests {
|
|||||||
let server = "aes256-ctr,diffie-hellman-group14-sha256";
|
let server = "aes256-ctr,diffie-hellman-group14-sha256";
|
||||||
|
|
||||||
let matched = match_algorithm(client, server).unwrap();
|
let matched = match_algorithm(client, server).unwrap();
|
||||||
assert_eq!(matched, "aes256-ctr"); // 按客户端顺序匹配
|
assert_eq!(matched, "aes256-ctr"); // 按客户端顺序匹配
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -295,7 +314,7 @@ mod tests {
|
|||||||
let client = KexProposal::client_default();
|
let client = KexProposal::client_default();
|
||||||
|
|
||||||
let result = KexResult::choose_algorithms(&server, &client).unwrap();
|
let result = KexResult::choose_algorithms(&server, &client).unwrap();
|
||||||
assert_eq!(result.kex_algorithm, "curve25519-sha256"); // 优先Curve25519
|
assert_eq!(result.kex_algorithm, "curve25519-sha256"); // 优先Curve25519
|
||||||
assert_eq!(result.encryption_ctos, "aes256-ctr"); // AES-256-CTR
|
assert_eq!(result.encryption_ctos, "aes256-ctr"); // AES-256-CTR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
// SSH密钥交换完整流程(Phase 3剩余)
|
// SSH密钥交换完整流程(Phase 3剩余)
|
||||||
// 参考OpenSSH kex.c: complete implementation
|
// 参考OpenSSH kex.c: complete implementation
|
||||||
|
|
||||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
use crate::ssh_server::crypto::SessionKeys;
|
||||||
use crate::ssh_server::kex::{KexProposal, KexResult};
|
use crate::ssh_server::kex::{KexProposal, KexResult};
|
||||||
use crate::ssh_server::crypto::{SessionKeys};
|
|
||||||
use crate::ssh_server::kex_exchange::KexExchangeHandler;
|
use crate::ssh_server::kex_exchange::KexExchangeHandler;
|
||||||
use anyhow::{Result, anyhow};
|
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||||
use sha2::{Sha256, Digest};
|
use anyhow::{anyhow, Result};
|
||||||
use byteorder::{BigEndian, WriteBytesExt};
|
use log::info;
|
||||||
use log::{info, debug};
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
/// SSH密钥交换完整状态管理(参考OpenSSH struct kex)
|
/// SSH密钥交换完整状态管理(参考OpenSSH struct kex)
|
||||||
pub struct KexState {
|
pub struct KexState {
|
||||||
@@ -65,8 +64,14 @@ impl KexState {
|
|||||||
self.server_kexinit_payload = server_kexinit.payload.clone();
|
self.server_kexinit_payload = server_kexinit.payload.clone();
|
||||||
|
|
||||||
info!("Saved KEXINIT payloads (payload only, no padding)");
|
info!("Saved KEXINIT payloads (payload only, no padding)");
|
||||||
info!(" client payload: {} bytes", self.client_kexinit_payload.len());
|
info!(
|
||||||
info!(" server payload: {} bytes", self.server_kexinit_payload.len());
|
" client payload: {} bytes",
|
||||||
|
self.client_kexinit_payload.len()
|
||||||
|
);
|
||||||
|
info!(
|
||||||
|
" server payload: {} bytes",
|
||||||
|
self.server_kexinit_payload.len()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 计算Exchange Hash(参考OpenSSH kex.c: kex_hash())
|
/// 计算Exchange Hash(参考OpenSSH kex.c: kex_hash())
|
||||||
@@ -93,12 +98,12 @@ impl KexState {
|
|||||||
let client_kexinit_without_type = &self.client_kexinit_payload[1..];
|
let client_kexinit_without_type = &self.client_kexinit_payload[1..];
|
||||||
let server_kexinit_without_type = &self.server_kexinit_payload[1..];
|
let server_kexinit_without_type = &self.server_kexinit_payload[1..];
|
||||||
|
|
||||||
hasher.update(&((client_kexinit_without_type.len() + 1) as u32).to_be_bytes());
|
hasher.update(((client_kexinit_without_type.len() + 1) as u32).to_be_bytes());
|
||||||
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
|
hasher.update([20]); // SSH_MSG_KEXINIT type byte
|
||||||
hasher.update(client_kexinit_without_type);
|
hasher.update(client_kexinit_without_type);
|
||||||
|
|
||||||
hasher.update(&((server_kexinit_without_type.len() + 1) as u32).to_be_bytes());
|
hasher.update(((server_kexinit_without_type.len() + 1) as u32).to_be_bytes());
|
||||||
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
|
hasher.update([20]); // SSH_MSG_KEXINIT type byte
|
||||||
hasher.update(server_kexinit_without_type);
|
hasher.update(server_kexinit_without_type);
|
||||||
|
|
||||||
// K_S: 服务器主机密钥blob(SSH string格式)
|
// K_S: 服务器主机密钥blob(SSH string格式)
|
||||||
@@ -122,7 +127,7 @@ impl KexState {
|
|||||||
info!("Processing SSH_MSG_NEWKEYS");
|
info!("Processing SSH_MSG_NEWKEYS");
|
||||||
|
|
||||||
// 验证packet类型
|
// 验证packet类型
|
||||||
if packet.payload.len() < 1 {
|
if packet.payload.is_empty() {
|
||||||
return Err(anyhow!("Invalid NEWKEYS packet"));
|
return Err(anyhow!("Invalid NEWKEYS packet"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,14 +161,14 @@ impl KexState {
|
|||||||
|
|
||||||
/// SSH string写入到hash(辅助函数)
|
/// SSH string写入到hash(辅助函数)
|
||||||
fn write_ssh_string_to_hash(hasher: &mut Sha256, s: &str) -> Result<()> {
|
fn write_ssh_string_to_hash(hasher: &mut Sha256, s: &str) -> Result<()> {
|
||||||
hasher.update(&(s.len() as u32).to_be_bytes());
|
hasher.update((s.len() as u32).to_be_bytes());
|
||||||
hasher.update(s.as_bytes());
|
hasher.update(s.as_bytes());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// SSH bytes写入到hash(辅助函数)
|
/// SSH bytes写入到hash(辅助函数)
|
||||||
fn write_ssh_bytes_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
|
fn write_ssh_bytes_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
|
||||||
hasher.update(&(bytes.len() as u32).to_be_bytes());
|
hasher.update((bytes.len() as u32).to_be_bytes());
|
||||||
hasher.update(bytes);
|
hasher.update(bytes);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -171,7 +176,7 @@ fn write_ssh_bytes_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
|
|||||||
/// SSH mpint写入到hash(参考OpenSSH sshbuf_put_mpint())
|
/// SSH mpint写入到hash(参考OpenSSH sshbuf_put_mpint())
|
||||||
fn write_ssh_mpint_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
|
fn write_ssh_mpint_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
|
||||||
// OpenSSH要求:去掉前导零(如果最高位为1)
|
// OpenSSH要求:去掉前导零(如果最高位为1)
|
||||||
let mpint_bytes = if bytes.len() > 0 && bytes[0] >= 0x80 {
|
let mpint_bytes = if !bytes.is_empty() && bytes[0] >= 0x80 {
|
||||||
// 需要添加前导零(避免负数)
|
// 需要添加前导零(避免负数)
|
||||||
let mut mpint = vec![0u8];
|
let mut mpint = vec![0u8];
|
||||||
mpint.extend_from_slice(bytes);
|
mpint.extend_from_slice(bytes);
|
||||||
@@ -180,7 +185,7 @@ fn write_ssh_mpint_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
|
|||||||
bytes.to_vec()
|
bytes.to_vec()
|
||||||
};
|
};
|
||||||
|
|
||||||
hasher.update(&(mpint_bytes.len() as u32).to_be_bytes());
|
hasher.update((mpint_bytes.len() as u32).to_be_bytes());
|
||||||
hasher.update(&mpint_bytes);
|
hasher.update(&mpint_bytes);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -195,26 +200,30 @@ mod tests {
|
|||||||
let kex_result = KexResult::choose_algorithms(
|
let kex_result = KexResult::choose_algorithms(
|
||||||
&KexProposal::server_default(),
|
&KexProposal::server_default(),
|
||||||
&KexProposal::client_default(),
|
&KexProposal::client_default(),
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let mut state = KexState::new(
|
let mut state = KexState::new(
|
||||||
"SSH-2.0-OpenSSH_10.2".to_string(),
|
"SSH-2.0-OpenSSH_10.2".to_string(),
|
||||||
"SSH-2.0-MarkBaseSSH_1.0".to_string(),
|
"SSH-2.0-MarkBaseSSH_1.0".to_string(),
|
||||||
kex_result,
|
kex_result,
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Set minimal KEXINIT payloads (need at least 1 byte for packet type)
|
// Set minimal KEXINIT payloads (need at least 1 byte for packet type)
|
||||||
state.client_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
|
state.client_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
|
||||||
state.server_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
|
state.server_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
|
||||||
|
|
||||||
let shared_secret = vec![0u8; 32];
|
let shared_secret = vec![0u8; 32];
|
||||||
let host_key = vec![0u8; 32];
|
let host_key = vec![0u8; 32];
|
||||||
let client_pub = vec![0u8; 32];
|
let client_pub = vec![0u8; 32];
|
||||||
let server_pub = vec![0u8; 32];
|
let server_pub = vec![0u8; 32];
|
||||||
|
|
||||||
let hash = state.compute_exchange_hash(&shared_secret, &host_key, &client_pub, &server_pub).unwrap();
|
let hash = state
|
||||||
|
.compute_exchange_hash(&shared_secret, &host_key, &client_pub, &server_pub)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(hash.len(), 32); // SHA256输出32字节
|
assert_eq!(hash.len(), 32); // SHA256输出32字节
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -222,13 +231,15 @@ mod tests {
|
|||||||
let kex_result = KexResult::choose_algorithms(
|
let kex_result = KexResult::choose_algorithms(
|
||||||
&KexProposal::server_default(),
|
&KexProposal::server_default(),
|
||||||
&KexProposal::client_default(),
|
&KexProposal::client_default(),
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let mut state = KexState::new(
|
let mut state = KexState::new(
|
||||||
"SSH-2.0-OpenSSH_10.2".to_string(),
|
"SSH-2.0-OpenSSH_10.2".to_string(),
|
||||||
"SSH-2.0-MarkBaseSSH_1.0".to_string(),
|
"SSH-2.0-MarkBaseSSH_1.0".to_string(),
|
||||||
kex_result,
|
kex_result,
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let newkeys_packet = SshPacket::new(vec![PacketType::SSH_MSG_NEWKEYS as u8]);
|
let newkeys_packet = SshPacket::new(vec![PacketType::SSH_MSG_NEWKEYS as u8]);
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
// SSH密钥交换流程实现(Phase 3)
|
// SSH密钥交换流程实现(Phase 3)
|
||||||
// 参考OpenSSH kex.c: kex_input_kex_init(), kex_send_kex_reply()
|
// 参考OpenSSH kex.c: kex_input_kex_init(), kex_send_kex_reply()
|
||||||
|
|
||||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
use crate::ssh_server::crypto::{Curve25519Kex, Ed25519HostKey, SessionKeys};
|
||||||
use crate::ssh_server::kex::{KexResult};
|
use crate::ssh_server::kex::KexResult;
|
||||||
use crate::ssh_server::crypto::{Curve25519Kex, SessionKeys, Ed25519HostKey};
|
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{anyhow, Result};
|
||||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||||
use log::{info, debug};
|
use log::info;
|
||||||
|
use sha2::Digest;
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Write};
|
||||||
use sha2::{Sha256, Digest};
|
|
||||||
|
|
||||||
/// SSH密钥交换流程处理器(参考OpenSSH kex.c)
|
/// SSH密钥交换流程处理器(参考OpenSSH kex.c)
|
||||||
pub struct KexExchangeHandler {
|
pub struct KexExchangeHandler {
|
||||||
@@ -18,7 +18,7 @@ pub struct KexExchangeHandler {
|
|||||||
shared_secret: Option<Vec<u8>>,
|
shared_secret: Option<Vec<u8>>,
|
||||||
client_public_key: Option<Vec<u8>>,
|
client_public_key: Option<Vec<u8>>,
|
||||||
server_public_key: Option<Vec<u8>>,
|
server_public_key: Option<Vec<u8>>,
|
||||||
exchange_hash: Option<Vec<u8>>, // 保存exchange hash(H参数)
|
exchange_hash: Option<Vec<u8>>, // 保存exchange hash(H参数)
|
||||||
client_version: Option<String>,
|
client_version: Option<String>,
|
||||||
server_version: Option<String>,
|
server_version: Option<String>,
|
||||||
client_kexinit_payload: Option<Vec<u8>>,
|
client_kexinit_payload: Option<Vec<u8>>,
|
||||||
@@ -46,7 +46,7 @@ impl KexExchangeHandler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 处理SSH_MSG_KEXDH_INIT(Curve25519密钥交换)(参考OpenSSH kex.c: kex_input_kex_init())
|
/// 处理SSH_MSG_KEXDH_INIT(Curve25519密钥交换)(参考OpenSSH kex.c: kex_input_kex_init())
|
||||||
pub fn handle_kexdh_init(
|
pub fn handle_kexdh_init(
|
||||||
&mut self,
|
&mut self,
|
||||||
packet: &SshPacket,
|
packet: &SshPacket,
|
||||||
@@ -66,7 +66,10 @@ impl KexExchangeHandler {
|
|||||||
|
|
||||||
let key_length = cursor.read_u32::<BigEndian>()?;
|
let key_length = cursor.read_u32::<BigEndian>()?;
|
||||||
if key_length != 32 {
|
if key_length != 32 {
|
||||||
return Err(anyhow!("Invalid Curve25519 public key length: {}", key_length));
|
return Err(anyhow!(
|
||||||
|
"Invalid Curve25519 public key length: {}",
|
||||||
|
key_length
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut client_public_key = vec![0u8; 32];
|
let mut client_public_key = vec![0u8; 32];
|
||||||
@@ -105,17 +108,17 @@ impl KexExchangeHandler {
|
|||||||
)?;
|
)?;
|
||||||
|
|
||||||
info!("Exchange hash computed:");
|
info!("Exchange hash computed:");
|
||||||
info!(" shared_secret[0] = {} (>=0x80? {})", shared_secret[0], shared_secret[0] >= 0x80);
|
info!(
|
||||||
|
" shared_secret[0] = {} (>=0x80? {})",
|
||||||
|
shared_secret[0],
|
||||||
|
shared_secret[0] >= 0x80
|
||||||
|
);
|
||||||
info!(" exchange_hash full (32 bytes): {:?}", exchange_hash);
|
info!(" exchange_hash full (32 bytes): {:?}", exchange_hash);
|
||||||
|
|
||||||
self.exchange_hash = Some(exchange_hash.clone());
|
self.exchange_hash = Some(exchange_hash.clone());
|
||||||
info!("Exchange hash saved for key derivation");
|
info!("Exchange hash saved for key derivation");
|
||||||
|
|
||||||
self.build_kexdh_reply(
|
self.build_kexdh_reply(&exchange_hash, &host_key_blob, &server_public_key)
|
||||||
&exchange_hash,
|
|
||||||
&host_key_blob,
|
|
||||||
&server_public_key,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 构建SSH_MSG_KEXDH_REPLY packet(参考OpenSSH kex.c)
|
/// 构建SSH_MSG_KEXDH_REPLY packet(参考OpenSSH kex.c)
|
||||||
@@ -155,7 +158,7 @@ impl KexExchangeHandler {
|
|||||||
// 参考OpenSSH sshkey.c
|
// 参考OpenSSH sshkey.c
|
||||||
|
|
||||||
// Key type: ssh-ed25519
|
// Key type: ssh-ed25519
|
||||||
blob.write_u32::<BigEndian>(11)?; // "ssh-ed25519".len()
|
blob.write_u32::<BigEndian>(11)?; // "ssh-ed25519".len()
|
||||||
blob.write_all("ssh-ed25519".as_bytes())?;
|
blob.write_all("ssh-ed25519".as_bytes())?;
|
||||||
|
|
||||||
// Ed25519公钥(32字节)
|
// Ed25519公钥(32字节)
|
||||||
@@ -178,7 +181,7 @@ impl KexExchangeHandler {
|
|||||||
client_kexinit_payload: &[u8],
|
client_kexinit_payload: &[u8],
|
||||||
server_kexinit_payload: &[u8],
|
server_kexinit_payload: &[u8],
|
||||||
) -> Result<Vec<u8>> {
|
) -> Result<Vec<u8>> {
|
||||||
use sha2::{Sha256, Digest};
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
info!("=== EXCHANGE HASH COMPUTATION ===");
|
info!("=== EXCHANGE HASH COMPUTATION ===");
|
||||||
info!("V_C (client version): {:?}", client_version.as_bytes());
|
info!("V_C (client version): {:?}", client_version.as_bytes());
|
||||||
@@ -187,22 +190,43 @@ impl KexExchangeHandler {
|
|||||||
info!("V_S (server version): {:?}", server_version.as_bytes());
|
info!("V_S (server version): {:?}", server_version.as_bytes());
|
||||||
info!("V_S length: {}", server_version.len());
|
info!("V_S length: {}", server_version.len());
|
||||||
|
|
||||||
info!("I_C (client KEXINIT payload): {:?}", &client_kexinit_payload[..std::cmp::min(50, client_kexinit_payload.len())]);
|
info!(
|
||||||
|
"I_C (client KEXINIT payload): {:?}",
|
||||||
|
&client_kexinit_payload[..std::cmp::min(50, client_kexinit_payload.len())]
|
||||||
|
);
|
||||||
info!("I_C length: {}", client_kexinit_payload.len());
|
info!("I_C length: {}", client_kexinit_payload.len());
|
||||||
info!("I_C[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)", client_kexinit_payload[0]);
|
info!(
|
||||||
|
"I_C[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)",
|
||||||
|
client_kexinit_payload[0]
|
||||||
|
);
|
||||||
|
|
||||||
info!("I_S (server KEXINIT payload): {:?}", &server_kexinit_payload[..std::cmp::min(50, server_kexinit_payload.len())]);
|
info!(
|
||||||
|
"I_S (server KEXINIT payload): {:?}",
|
||||||
|
&server_kexinit_payload[..std::cmp::min(50, server_kexinit_payload.len())]
|
||||||
|
);
|
||||||
info!("I_S length: {}", server_kexinit_payload.len());
|
info!("I_S length: {}", server_kexinit_payload.len());
|
||||||
info!("I_S[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)", server_kexinit_payload[0]);
|
info!(
|
||||||
|
"I_S[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)",
|
||||||
|
server_kexinit_payload[0]
|
||||||
|
);
|
||||||
|
|
||||||
info!("K_S (host key blob): {:?}", &host_key_blob[..std::cmp::min(30, host_key_blob.len())]);
|
info!(
|
||||||
|
"K_S (host key blob): {:?}",
|
||||||
|
&host_key_blob[..std::cmp::min(30, host_key_blob.len())]
|
||||||
|
);
|
||||||
info!("K_S length: {}", host_key_blob.len());
|
info!("K_S length: {}", host_key_blob.len());
|
||||||
|
|
||||||
info!("Q_C (client ECDH public key): {:?}", &client_public_key[..std::cmp::min(16, client_public_key.len())]);
|
info!(
|
||||||
|
"Q_C (client ECDH public key): {:?}",
|
||||||
|
&client_public_key[..std::cmp::min(16, client_public_key.len())]
|
||||||
|
);
|
||||||
info!("Q_C full (32 bytes): {:?}", client_public_key);
|
info!("Q_C full (32 bytes): {:?}", client_public_key);
|
||||||
info!("Q_C length: {}", client_public_key.len());
|
info!("Q_C length: {}", client_public_key.len());
|
||||||
|
|
||||||
info!("Q_S (server ECDH public key): {:?}", &server_public_key[..std::cmp::min(16, server_public_key.len())]);
|
info!(
|
||||||
|
"Q_S (server ECDH public key): {:?}",
|
||||||
|
&server_public_key[..std::cmp::min(16, server_public_key.len())]
|
||||||
|
);
|
||||||
info!("Q_S full (32 bytes): {:?}", server_public_key);
|
info!("Q_S full (32 bytes): {:?}", server_public_key);
|
||||||
info!("Q_S length: {}", server_public_key.len());
|
info!("Q_S length: {}", server_public_key.len());
|
||||||
|
|
||||||
@@ -212,12 +236,22 @@ impl KexExchangeHandler {
|
|||||||
let vc_ssh_string = &(client_version.len() as u32).to_be_bytes();
|
let vc_ssh_string = &(client_version.len() as u32).to_be_bytes();
|
||||||
hasher.update(vc_ssh_string);
|
hasher.update(vc_ssh_string);
|
||||||
hasher.update(client_version.as_bytes());
|
hasher.update(client_version.as_bytes());
|
||||||
info!(" Exchange hash component V_C: len={} bytes=[{:?}] data=[{:?}]", 4+client_version.len(), vc_ssh_string, client_version.as_bytes());
|
info!(
|
||||||
|
" Exchange hash component V_C: len={} bytes=[{:?}] data=[{:?}]",
|
||||||
|
4 + client_version.len(),
|
||||||
|
vc_ssh_string,
|
||||||
|
client_version.as_bytes()
|
||||||
|
);
|
||||||
|
|
||||||
let vs_ssh_string = &(server_version.len() as u32).to_be_bytes();
|
let vs_ssh_string = &(server_version.len() as u32).to_be_bytes();
|
||||||
hasher.update(vs_ssh_string);
|
hasher.update(vs_ssh_string);
|
||||||
hasher.update(server_version.as_bytes());
|
hasher.update(server_version.as_bytes());
|
||||||
info!(" Exchange hash component V_S: len={} bytes=[{:?}] data=[{:?}]", 4+server_version.len(), vs_ssh_string, server_version.as_bytes());
|
info!(
|
||||||
|
" Exchange hash component V_S: len={} bytes=[{:?}] data=[{:?}]",
|
||||||
|
4 + server_version.len(),
|
||||||
|
vs_ssh_string,
|
||||||
|
server_version.as_bytes()
|
||||||
|
);
|
||||||
|
|
||||||
// OpenSSH kexgex.c: "kexinit messages: fake header: len+SSH2_MSG_KEXINIT"
|
// OpenSSH kexgex.c: "kexinit messages: fake header: len+SSH2_MSG_KEXINIT"
|
||||||
// KEXINIT payload should NOT include SSH_MSG_KEXINIT type byte
|
// KEXINIT payload should NOT include SSH_MSG_KEXINIT type byte
|
||||||
@@ -227,36 +261,58 @@ impl KexExchangeHandler {
|
|||||||
let client_kexinit_without_type = &client_kexinit_payload[1..];
|
let client_kexinit_without_type = &client_kexinit_payload[1..];
|
||||||
let server_kexinit_without_type = &server_kexinit_payload[1..];
|
let server_kexinit_without_type = &server_kexinit_payload[1..];
|
||||||
|
|
||||||
info!("I_C (client KEXINIT without type byte): {} bytes (first byte should be cookie)", client_kexinit_without_type.len());
|
info!(
|
||||||
info!("I_S (server KEXINIT without type byte): {} bytes", server_kexinit_without_type.len());
|
"I_C (client KEXINIT without type byte): {} bytes (first byte should be cookie)",
|
||||||
|
client_kexinit_without_type.len()
|
||||||
|
);
|
||||||
|
info!(
|
||||||
|
"I_S (server KEXINIT without type byte): {} bytes",
|
||||||
|
server_kexinit_without_type.len()
|
||||||
|
);
|
||||||
|
|
||||||
// Exchange hash: uint32(len+1) + uint8(SSH_MSG_KEXINIT) + payload_without_type
|
// Exchange hash: uint32(len+1) + uint8(SSH_MSG_KEXINIT) + payload_without_type
|
||||||
let ic_len_bytes = &((client_kexinit_without_type.len() + 1) as u32).to_be_bytes();
|
let ic_len_bytes = &((client_kexinit_without_type.len() + 1) as u32).to_be_bytes();
|
||||||
hasher.update(ic_len_bytes);
|
hasher.update(ic_len_bytes);
|
||||||
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
|
hasher.update([20]); // SSH_MSG_KEXINIT type byte
|
||||||
hasher.update(client_kexinit_without_type);
|
hasher.update(client_kexinit_without_type);
|
||||||
info!(" Exchange hash component I_C: len={} bytes=[{:?}] type=[20] payload_len={} (first 8 bytes=[{:?}])", 4+1+client_kexinit_without_type.len(), ic_len_bytes, client_kexinit_without_type.len(), &client_kexinit_without_type[..std::cmp::min(8, client_kexinit_without_type.len())]);
|
info!(" Exchange hash component I_C: len={} bytes=[{:?}] type=[20] payload_len={} (first 8 bytes=[{:?}])", 4+1+client_kexinit_without_type.len(), ic_len_bytes, client_kexinit_without_type.len(), &client_kexinit_without_type[..std::cmp::min(8, client_kexinit_without_type.len())]);
|
||||||
|
|
||||||
let is_len_bytes = &((server_kexinit_without_type.len() + 1) as u32).to_be_bytes();
|
let is_len_bytes = &((server_kexinit_without_type.len() + 1) as u32).to_be_bytes();
|
||||||
hasher.update(is_len_bytes);
|
hasher.update(is_len_bytes);
|
||||||
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
|
hasher.update([20]); // SSH_MSG_KEXINIT type byte
|
||||||
hasher.update(server_kexinit_without_type);
|
hasher.update(server_kexinit_without_type);
|
||||||
info!(" Exchange hash component I_S: len={} bytes=[{:?}] type=[20] payload_len={} (first 8 bytes=[{:?}])", 4+1+server_kexinit_without_type.len(), is_len_bytes, server_kexinit_without_type.len(), &server_kexinit_without_type[..std::cmp::min(8, server_kexinit_without_type.len())]);
|
info!(" Exchange hash component I_S: len={} bytes=[{:?}] type=[20] payload_len={} (first 8 bytes=[{:?}])", 4+1+server_kexinit_without_type.len(), is_len_bytes, server_kexinit_without_type.len(), &server_kexinit_without_type[..std::cmp::min(8, server_kexinit_without_type.len())]);
|
||||||
|
|
||||||
let ks_len_bytes = &(host_key_blob.len() as u32).to_be_bytes();
|
let ks_len_bytes = &(host_key_blob.len() as u32).to_be_bytes();
|
||||||
hasher.update(ks_len_bytes);
|
hasher.update(ks_len_bytes);
|
||||||
hasher.update(host_key_blob);
|
hasher.update(host_key_blob);
|
||||||
info!(" Exchange hash component K_S: len={} bytes=[{:?}] blob_len={} (full=[{:?}])", 4+host_key_blob.len(), ks_len_bytes, host_key_blob.len(), host_key_blob);
|
info!(
|
||||||
|
" Exchange hash component K_S: len={} bytes=[{:?}] blob_len={} (full=[{:?}])",
|
||||||
|
4 + host_key_blob.len(),
|
||||||
|
ks_len_bytes,
|
||||||
|
host_key_blob.len(),
|
||||||
|
host_key_blob
|
||||||
|
);
|
||||||
|
|
||||||
let qc_len_bytes = &(client_public_key.len() as u32).to_be_bytes();
|
let qc_len_bytes = &(client_public_key.len() as u32).to_be_bytes();
|
||||||
hasher.update(qc_len_bytes);
|
hasher.update(qc_len_bytes);
|
||||||
hasher.update(client_public_key);
|
hasher.update(client_public_key);
|
||||||
info!(" Exchange hash component Q_C: len={} bytes=[{:?}] key=[{:?}]", 4+client_public_key.len(), qc_len_bytes, client_public_key);
|
info!(
|
||||||
|
" Exchange hash component Q_C: len={} bytes=[{:?}] key=[{:?}]",
|
||||||
|
4 + client_public_key.len(),
|
||||||
|
qc_len_bytes,
|
||||||
|
client_public_key
|
||||||
|
);
|
||||||
|
|
||||||
let qs_len_bytes = &(server_public_key.len() as u32).to_be_bytes();
|
let qs_len_bytes = &(server_public_key.len() as u32).to_be_bytes();
|
||||||
hasher.update(qs_len_bytes);
|
hasher.update(qs_len_bytes);
|
||||||
hasher.update(server_public_key);
|
hasher.update(server_public_key);
|
||||||
info!(" Exchange hash component Q_S: len={} bytes=[{:?}] key=[{:?}]", 4+server_public_key.len(), qs_len_bytes, server_public_key);
|
info!(
|
||||||
|
" Exchange hash component Q_S: len={} bytes=[{:?}] key=[{:?}]",
|
||||||
|
4 + server_public_key.len(),
|
||||||
|
qs_len_bytes,
|
||||||
|
server_public_key
|
||||||
|
);
|
||||||
|
|
||||||
info!("Exchange hash components:");
|
info!("Exchange hash components:");
|
||||||
info!(" shared_secret raw full (32 bytes): {:?}", shared_secret);
|
info!(" shared_secret raw full (32 bytes): {:?}", shared_secret);
|
||||||
@@ -274,18 +330,27 @@ impl KexExchangeHandler {
|
|||||||
}
|
}
|
||||||
let trimmed_shared_secret = &shared_secret[start..];
|
let trimmed_shared_secret = &shared_secret[start..];
|
||||||
|
|
||||||
info!(" shared_secret after removing leading zeros ({} bytes): {:?}", trimmed_shared_secret.len(), trimmed_shared_secret);
|
info!(
|
||||||
|
" shared_secret after removing leading zeros ({} bytes): {:?}",
|
||||||
|
trimmed_shared_secret.len(),
|
||||||
|
trimmed_shared_secret
|
||||||
|
);
|
||||||
|
|
||||||
let mpint_shared_secret_data = if trimmed_shared_secret.len() > 0 && trimmed_shared_secret[0] >= 0x80 {
|
let mpint_shared_secret_data =
|
||||||
let mut mpint = vec![0u8];
|
if !trimmed_shared_secret.is_empty() && trimmed_shared_secret[0] >= 0x80 {
|
||||||
mpint.extend_from_slice(trimmed_shared_secret);
|
let mut mpint = vec![0u8];
|
||||||
info!(" trimmed_shared_secret[0] >= 0x80, prepending 0 byte");
|
mpint.extend_from_slice(trimmed_shared_secret);
|
||||||
mpint
|
info!(" trimmed_shared_secret[0] >= 0x80, prepending 0 byte");
|
||||||
} else {
|
mpint
|
||||||
trimmed_shared_secret.to_vec()
|
} else {
|
||||||
};
|
trimmed_shared_secret.to_vec()
|
||||||
|
};
|
||||||
|
|
||||||
info!(" mpint_shared_secret_data ({} bytes): {:?}", mpint_shared_secret_data.len(), &mpint_shared_secret_data[..std::cmp::min(8, mpint_shared_secret_data.len())]);
|
info!(
|
||||||
|
" mpint_shared_secret_data ({} bytes): {:?}",
|
||||||
|
mpint_shared_secret_data.len(),
|
||||||
|
&mpint_shared_secret_data[..std::cmp::min(8, mpint_shared_secret_data.len())]
|
||||||
|
);
|
||||||
|
|
||||||
// mpint格式 = uint32(length) + mpint_data
|
// mpint格式 = uint32(length) + mpint_data
|
||||||
let mpint_len_bytes = &(mpint_shared_secret_data.len() as u32).to_be_bytes();
|
let mpint_len_bytes = &(mpint_shared_secret_data.len() as u32).to_be_bytes();
|
||||||
@@ -330,7 +395,7 @@ impl KexExchangeHandler {
|
|||||||
|
|
||||||
SessionKeys::derive(
|
SessionKeys::derive(
|
||||||
shared_secret,
|
shared_secret,
|
||||||
exchange_hash, // 使用保存的exchange hash(H参数)
|
exchange_hash, // 使用保存的exchange hash(H参数)
|
||||||
server_public_key,
|
server_public_key,
|
||||||
client_public_key,
|
client_public_key,
|
||||||
&host_key_blob,
|
&host_key_blob,
|
||||||
|
|||||||
@@ -1,28 +1,28 @@
|
|||||||
// SSH服务器模块(手动实现SSH协议)
|
// SSH服务器模块(手动实现SSH协议)
|
||||||
// 参考OpenSSH源码实现完整的SSH/SFTP/SCP/rsync协议
|
// 参考OpenSSH源码实现完整的SSH/SFTP/SCP/rsync协议
|
||||||
|
|
||||||
pub mod server;
|
|
||||||
pub mod packet;
|
|
||||||
pub mod version;
|
|
||||||
pub mod crypto;
|
|
||||||
pub mod kex;
|
|
||||||
pub mod kex_exchange;
|
|
||||||
pub mod kex_complete;
|
|
||||||
pub mod cipher;
|
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
pub mod channel;
|
pub mod channel;
|
||||||
pub mod sftp_handler;
|
pub mod cipher;
|
||||||
pub mod scp_handler;
|
pub mod crypto;
|
||||||
|
pub mod data_forwarder; // Phase 13.5: 数据传输模块
|
||||||
|
pub mod kex;
|
||||||
|
pub mod kex_complete;
|
||||||
|
pub mod kex_exchange;
|
||||||
|
pub mod packet;
|
||||||
|
pub mod port_forward; // Phase 13: 端口转发模块
|
||||||
|
pub mod port_forward_listener; // Phase 13.4: 监听线程模块
|
||||||
pub mod rsync_handler;
|
pub mod rsync_handler;
|
||||||
pub mod sshbuf; // Phase 15: SSH Buffer 零拷贝管理(参考OpenSSH sshbuf.c)
|
pub mod scp_handler;
|
||||||
pub mod port_forward; // Phase 13: 端口转发模块
|
pub mod server;
|
||||||
pub mod ssh_security_config; // Phase 13.1: 企业级安全配置
|
pub mod sftp_handler;
|
||||||
pub mod port_forward_listener; // Phase 13.4: 监听线程模块
|
pub mod ssh_security_config; // Phase 13.1: 企业级安全配置
|
||||||
pub mod data_forwarder; // Phase 13.5: 数据传输模块
|
pub mod sshbuf; // Phase 15: SSH Buffer 零拷贝管理(参考OpenSSH sshbuf.c)
|
||||||
pub mod window_manager; // Phase 13.6-13.7: Window size + Channel生命周期
|
pub mod version;
|
||||||
|
pub mod window_manager; // Phase 13.6-13.7: Window size + Channel生命周期
|
||||||
|
|
||||||
|
pub use packet::{PacketType, SshPacket};
|
||||||
pub use server::SshServer;
|
pub use server::SshServer;
|
||||||
pub use packet::{SshPacket, PacketType};
|
pub use ssh_security_config::SshSecurityConfig; // Phase 13.1: 导出安全配置
|
||||||
pub use version::VersionExchange;
|
pub use sshbuf::SshBuf;
|
||||||
pub use ssh_security_config::SshSecurityConfig; // Phase 13.1: 导出安全配置
|
pub use version::VersionExchange; // Phase 15: 导出 SSH Buffer
|
||||||
pub use sshbuf::SshBuf; // Phase 15: 导出 SSH Buffer
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
// SSH Packet基础结构定义
|
// SSH Packet基础结构定义
|
||||||
// 参考OpenSSH packet.c: ssh_packet_read(), ssh_packet_write()
|
// 参考OpenSSH packet.c: ssh_packet_read(), ssh_packet_write()
|
||||||
|
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{anyhow, Result};
|
||||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Write};
|
||||||
|
|
||||||
@@ -70,10 +70,10 @@ impl SshPacket {
|
|||||||
pub fn new(payload: Vec<u8>) -> Self {
|
pub fn new(payload: Vec<u8>) -> Self {
|
||||||
// 计算padding(SSH协议RFC 4253规范)
|
// 计算padding(SSH协议RFC 4253规范)
|
||||||
// 参考OpenSSH packet.c: construct_packet()
|
// 参考OpenSSH packet.c: construct_packet()
|
||||||
let block_size = 8; // 未加密阶段block_size=8
|
let block_size = 8; // 未加密阶段block_size=8
|
||||||
|
|
||||||
let payload_length = payload.len();
|
let payload_length = payload.len();
|
||||||
let min_padding = 4; // OpenSSH要求最少4字节padding
|
let min_padding = 4; // OpenSSH要求最少4字节padding
|
||||||
|
|
||||||
// SSH协议约束:
|
// SSH协议约束:
|
||||||
// packet_length = padding_length + payload_length + 1
|
// packet_length = padding_length + payload_length + 1
|
||||||
@@ -88,10 +88,10 @@ impl SshPacket {
|
|||||||
|
|
||||||
// 计算packet总长度(包括4字节的packet_length字段)
|
// 计算packet总长度(包括4字节的packet_length字段)
|
||||||
let packet_length = 1 + payload_length + padding_length as usize;
|
let packet_length = 1 + payload_length + padding_length as usize;
|
||||||
let total_length = packet_length + 4; // 加上packet_length字段本身的4字节
|
let total_length = packet_length + 4; // 加上packet_length字段本身的4字节
|
||||||
|
|
||||||
// 如果总长度不是block_size的倍数,增加padding
|
// 如果总长度不是block_size的倍数,增加padding
|
||||||
if total_length % block_size != 0 {
|
if !total_length.is_multiple_of(block_size) {
|
||||||
let remainder = total_length % block_size;
|
let remainder = total_length % block_size;
|
||||||
padding_length += (block_size - remainder) as u8;
|
padding_length += (block_size - remainder) as u8;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
// SSH端口转发协议实现(Phase 13)
|
// SSH端口转发协议实现(Phase 13)
|
||||||
// 参考OpenSSH channels.c和RFC 4254
|
// 参考OpenSSH channels.c和RFC 4254
|
||||||
|
|
||||||
use anyhow::{Result, anyhow};
|
use crate::ssh_server::ssh_security_config::SshSecurityConfig;
|
||||||
use log::{info, warn, debug};
|
use anyhow::Result;
|
||||||
use std::net::{TcpListener, TcpStream, SocketAddr};
|
|
||||||
use std::io::{Read, Write};
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
use std::thread;
|
|
||||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||||
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.2: 安全配置
|
use log::{info, warn};
|
||||||
|
use std::io::Read;
|
||||||
|
use std::net::{TcpListener, TcpStream};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
// Phase 13.2: 安全配置
|
||||||
|
|
||||||
/// 端口转发类型(参考RFC 4254)
|
/// 端口转发类型(参考RFC 4254)
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub enum PortForwardType {
|
pub enum PortForwardType {
|
||||||
Local, // Local port forwarding (-L)
|
Local, // Local port forwarding (-L)
|
||||||
Remote, // Remote port forwarding (-R)
|
Remote, // Remote port forwarding (-R)
|
||||||
Dynamic, // Dynamic port forwarding (-D, SOCKS)
|
Dynamic, // Dynamic port forwarding (-D, SOCKS)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 端口转发请求(参考RFC 4254 Section 7)
|
/// 端口转发请求(参考RFC 4254 Section 7)
|
||||||
@@ -36,6 +36,12 @@ pub struct PortForwardManager {
|
|||||||
active_forwards: Arc<Mutex<Vec<(u32, PortForwardType)>>>,
|
active_forwards: Arc<Mutex<Vec<(u32, PortForwardType)>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for PortForwardManager {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl PortForwardManager {
|
impl PortForwardManager {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -46,11 +52,15 @@ impl PortForwardManager {
|
|||||||
/// 处理SSH_MSG_GLOBAL_REQUEST(端口转发请求)
|
/// 处理SSH_MSG_GLOBAL_REQUEST(端口转发请求)
|
||||||
/// 参考RFC 4254 Section 4
|
/// 参考RFC 4254 Section 4
|
||||||
/// Phase 13.2: 添加安全配置验证
|
/// Phase 13.2: 添加安全配置验证
|
||||||
pub fn handle_global_request(&mut self, data: &[u8], security_config: &SshSecurityConfig) -> Result<(bool, Option<Vec<u8>>)> {
|
pub fn handle_global_request(
|
||||||
|
&mut self,
|
||||||
|
data: &[u8],
|
||||||
|
security_config: &SshSecurityConfig,
|
||||||
|
) -> Result<(bool, Option<Vec<u8>>)> {
|
||||||
info!("Processing SSH_MSG_GLOBAL_REQUEST for port forwarding");
|
info!("Processing SSH_MSG_GLOBAL_REQUEST for port forwarding");
|
||||||
|
|
||||||
let mut cursor = std::io::Cursor::new(data);
|
let mut cursor = std::io::Cursor::new(data);
|
||||||
cursor.set_position(1); // Skip packet type
|
cursor.set_position(1); // Skip packet type
|
||||||
|
|
||||||
// 读取请求名称(SSH string)
|
// 读取请求名称(SSH string)
|
||||||
let request_name = read_ssh_string(&mut cursor)?;
|
let request_name = read_ssh_string(&mut cursor)?;
|
||||||
@@ -63,7 +73,8 @@ impl PortForwardManager {
|
|||||||
match request_name.as_str() {
|
match request_name.as_str() {
|
||||||
"tcpip-forward" => {
|
"tcpip-forward" => {
|
||||||
// Local port forwarding (-L)
|
// Local port forwarding (-L)
|
||||||
self.handle_tcpip_forward(&mut cursor, want_reply, security_config) // Phase 13.2
|
self.handle_tcpip_forward(&mut cursor, want_reply, security_config)
|
||||||
|
// Phase 13.2
|
||||||
}
|
}
|
||||||
"cancel-tcpip-forward" => {
|
"cancel-tcpip-forward" => {
|
||||||
// Cancel port forwarding
|
// Cancel port forwarding
|
||||||
@@ -84,19 +95,27 @@ impl PortForwardManager {
|
|||||||
/// 处理tcpip-forward请求(Local port forwarding)
|
/// 处理tcpip-forward请求(Local port forwarding)
|
||||||
/// 参考RFC 4254 Section 7.1
|
/// 参考RFC 4254 Section 7.1
|
||||||
/// Phase 13.2: 添加安全配置验证
|
/// Phase 13.2: 添加安全配置验证
|
||||||
fn handle_tcpip_forward(&mut self, cursor: &mut std::io::Cursor<&[u8]>, want_reply: bool, security_config: &SshSecurityConfig) -> Result<(bool, Option<Vec<u8>>)> {
|
fn handle_tcpip_forward(
|
||||||
|
&mut self,
|
||||||
|
cursor: &mut std::io::Cursor<&[u8]>,
|
||||||
|
want_reply: bool,
|
||||||
|
security_config: &SshSecurityConfig,
|
||||||
|
) -> Result<(bool, Option<Vec<u8>>)> {
|
||||||
// 读取bind address(SSH string)
|
// 读取bind address(SSH string)
|
||||||
let bind_address = read_ssh_string(cursor)?;
|
let bind_address = read_ssh_string(cursor)?;
|
||||||
|
|
||||||
// 读取bind port
|
// 读取bind port
|
||||||
let bind_port = cursor.read_u32::<BigEndian>()?;
|
let bind_port = cursor.read_u32::<BigEndian>()?;
|
||||||
|
|
||||||
info!("tcpip-forward request: bind_address={}, bind_port={}", bind_address, bind_port);
|
info!(
|
||||||
|
"tcpip-forward request: bind_address={}, bind_port={}",
|
||||||
|
bind_address, bind_port
|
||||||
|
);
|
||||||
|
|
||||||
// Phase 13.2: 安全配置验证
|
// Phase 13.2: 安全配置验证
|
||||||
if let Err(e) = security_config.validate_tcpip_forward_request(&bind_address, bind_port) {
|
if let Err(e) = security_config.validate_tcpip_forward_request(&bind_address, bind_port) {
|
||||||
warn!("tcpip-forward security validation failed: {}", e);
|
warn!("tcpip-forward security validation failed: {}", e);
|
||||||
return Ok((false, None)); // 拒绝请求
|
return Ok((false, None)); // 拒绝请求
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("tcpip-forward security validation passed");
|
info!("tcpip-forward security validation passed");
|
||||||
@@ -117,11 +136,18 @@ impl PortForwardManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// 处理cancel-tcpip-forward请求
|
/// 处理cancel-tcpip-forward请求
|
||||||
fn handle_cancel_tcpip_forward(&mut self, cursor: &mut std::io::Cursor<&[u8]>, want_reply: bool) -> Result<(bool, Option<Vec<u8>>)> {
|
fn handle_cancel_tcpip_forward(
|
||||||
|
&mut self,
|
||||||
|
cursor: &mut std::io::Cursor<&[u8]>,
|
||||||
|
want_reply: bool,
|
||||||
|
) -> Result<(bool, Option<Vec<u8>>)> {
|
||||||
let bind_address = read_ssh_string(cursor)?;
|
let bind_address = read_ssh_string(cursor)?;
|
||||||
let bind_port = cursor.read_u32::<BigEndian>()?;
|
let bind_port = cursor.read_u32::<BigEndian>()?;
|
||||||
|
|
||||||
info!("cancel-tcpip-forward: bind_address={}, bind_port={}", bind_address, bind_port);
|
info!(
|
||||||
|
"cancel-tcpip-forward: bind_address={}, bind_port={}",
|
||||||
|
bind_address, bind_port
|
||||||
|
);
|
||||||
|
|
||||||
// 移除active forward
|
// 移除active forward
|
||||||
let mut forwards = self.active_forwards.lock().unwrap();
|
let mut forwards = self.active_forwards.lock().unwrap();
|
||||||
@@ -136,7 +162,11 @@ impl PortForwardManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// 构建SSH_MSG_REQUEST_SUCCESS/FAILURE响应
|
/// 构建SSH_MSG_REQUEST_SUCCESS/FAILURE响应
|
||||||
fn build_global_request_response(&self, success: bool, bound_port: Option<u32>) -> Result<Vec<u8>> {
|
fn build_global_request_response(
|
||||||
|
&self,
|
||||||
|
success: bool,
|
||||||
|
bound_port: Option<u32>,
|
||||||
|
) -> Result<Vec<u8>> {
|
||||||
use crate::ssh_server::packet::PacketType;
|
use crate::ssh_server::packet::PacketType;
|
||||||
|
|
||||||
let mut response = Vec::new();
|
let mut response = Vec::new();
|
||||||
@@ -161,7 +191,7 @@ impl PortForwardManager {
|
|||||||
info!("Processing direct-tcpip channel open");
|
info!("Processing direct-tcpip channel open");
|
||||||
|
|
||||||
let mut cursor = std::io::Cursor::new(data);
|
let mut cursor = std::io::Cursor::new(data);
|
||||||
cursor.set_position(1); // Skip packet type
|
cursor.set_position(1); // Skip packet type
|
||||||
|
|
||||||
// 读取channel type(已知道是"direct-tcpip",跳过)
|
// 读取channel type(已知道是"direct-tcpip",跳过)
|
||||||
let _channel_type = read_ssh_string(&mut cursor)?;
|
let _channel_type = read_ssh_string(&mut cursor)?;
|
||||||
@@ -187,8 +217,10 @@ impl PortForwardManager {
|
|||||||
// 读取originator port
|
// 读取originator port
|
||||||
let originator_port = cursor.read_u32::<BigEndian>()?;
|
let originator_port = cursor.read_u32::<BigEndian>()?;
|
||||||
|
|
||||||
info!("direct-tcpip: host={}, port={}, originator={}:{}",
|
info!(
|
||||||
host_to_connect, port_to_connect, originator_address, originator_port);
|
"direct-tcpip: host={}, port={}, originator={}:{}",
|
||||||
|
host_to_connect, port_to_connect, originator_address, originator_port
|
||||||
|
);
|
||||||
|
|
||||||
Ok(DirectTcpipChannel {
|
Ok(DirectTcpipChannel {
|
||||||
sender_channel,
|
sender_channel,
|
||||||
@@ -226,8 +258,10 @@ impl PortForwardManager {
|
|||||||
// 读取originator port
|
// 读取originator port
|
||||||
let originator_port = cursor.read_u32::<BigEndian>()?;
|
let originator_port = cursor.read_u32::<BigEndian>()?;
|
||||||
|
|
||||||
info!("forwarded-tcpip: bind={}:{}, originator={}:{}",
|
info!(
|
||||||
bind_address, bind_port, originator_address, originator_port);
|
"forwarded-tcpip: bind={}:{}, originator={}:{}",
|
||||||
|
bind_address, bind_port, originator_address, originator_port
|
||||||
|
);
|
||||||
|
|
||||||
Ok(ForwardedTcpipChannel {
|
Ok(ForwardedTcpipChannel {
|
||||||
sender_channel,
|
sender_channel,
|
||||||
|
|||||||
@@ -1,15 +1,12 @@
|
|||||||
// SSH端口转发监听线程(Phase 13.4)
|
// SSH端口转发监听线程(Phase 13.4)
|
||||||
// 参考OpenSSH channels.c: channel_forward_listener
|
// 参考OpenSSH channels.c: channel_forward_listener
|
||||||
|
|
||||||
use anyhow::{Result, anyhow};
|
|
||||||
use log::{info, warn, debug, error};
|
|
||||||
use std::net::{TcpListener, TcpStream};
|
|
||||||
use std::thread;
|
|
||||||
use std::sync::{Arc, Mutex, mpsc};
|
|
||||||
use std::io::{Read, Write};
|
|
||||||
use byteorder::{BigEndian, WriteBytesExt};
|
|
||||||
use crate::ssh_server::packet::PacketType;
|
|
||||||
use crate::ssh_server::ssh_security_config::SshSecurityConfig;
|
use crate::ssh_server::ssh_security_config::SshSecurityConfig;
|
||||||
|
use anyhow::Result;
|
||||||
|
use log::{error, info, warn};
|
||||||
|
use std::net::{TcpListener, TcpStream};
|
||||||
|
use std::sync::{mpsc, Arc, Mutex};
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
/// 监听器状态(Phase 13.4)
|
/// 监听器状态(Phase 13.4)
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -30,28 +27,18 @@ pub enum ListenerRequest {
|
|||||||
stream: TcpStream,
|
stream: TcpStream,
|
||||||
},
|
},
|
||||||
/// 停止监听
|
/// 停止监听
|
||||||
StopListener {
|
StopListener { bind_port: u32 },
|
||||||
bind_port: u32,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 监听器响应(Phase 13.4:线程通信)
|
/// 监听器响应(Phase 13.4:线程通信)
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum ListenerResponse {
|
pub enum ListenerResponse {
|
||||||
/// Channel创建成功
|
/// Channel创建成功
|
||||||
ChannelCreated {
|
ChannelCreated { bind_port: u32, channel_id: u32 },
|
||||||
bind_port: u32,
|
|
||||||
channel_id: u32,
|
|
||||||
},
|
|
||||||
/// 监听器停止
|
/// 监听器停止
|
||||||
ListenerStopped {
|
ListenerStopped { bind_port: u32 },
|
||||||
bind_port: u32,
|
|
||||||
},
|
|
||||||
/// 错误
|
/// 错误
|
||||||
Error {
|
Error { bind_port: u32, message: String },
|
||||||
bind_port: u32,
|
|
||||||
message: String,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 端口转发监听器(Phase 13.4)
|
/// 端口转发监听器(Phase 13.4)
|
||||||
@@ -76,19 +63,22 @@ impl PortForwardListener {
|
|||||||
|
|
||||||
// Phase 13.4: 根据GatewayPorts决定绑定地址
|
// Phase 13.4: 根据GatewayPorts决定绑定地址
|
||||||
let bind_addr = if security_config.gateway_ports {
|
let bind_addr = if security_config.gateway_ports {
|
||||||
format!("0.0.0.0:{}", bind_port) // 允许外部访问
|
format!("0.0.0.0:{}", bind_port) // 允许外部访问
|
||||||
} else {
|
} else {
|
||||||
format!("127.0.0.1:{}", bind_port) // 只允许本地访问
|
format!("127.0.0.1:{}", bind_port) // 只允许本地访问
|
||||||
};
|
};
|
||||||
|
|
||||||
info!("Binding to address: {} (GatewayPorts={})", bind_addr, security_config.gateway_ports);
|
info!(
|
||||||
|
"Binding to address: {} (GatewayPorts={})",
|
||||||
|
bind_addr, security_config.gateway_ports
|
||||||
|
);
|
||||||
|
|
||||||
let listener = TcpListener::bind(&bind_addr)?;
|
let listener = TcpListener::bind(&bind_addr)?;
|
||||||
info!("Listener created successfully on {}", bind_addr);
|
info!("Listener created successfully on {}", bind_addr);
|
||||||
|
|
||||||
// Phase 13.4: 创建线程通信channel
|
// Phase 13.4: 创建线程通信channel
|
||||||
let (request_tx, request_rx) = mpsc::channel();
|
let (request_tx, _request_rx) = mpsc::channel();
|
||||||
let (response_tx, response_rx) = mpsc::channel();
|
let (_response_tx, response_rx) = mpsc::channel();
|
||||||
|
|
||||||
// Phase 13.4: 活动状态标记
|
// Phase 13.4: 活动状态标记
|
||||||
let active = Arc::new(Mutex::new(true));
|
let active = Arc::new(Mutex::new(true));
|
||||||
@@ -126,7 +116,7 @@ impl PortForwardListener {
|
|||||||
let request = ListenerRequest::NewConnection {
|
let request = ListenerRequest::NewConnection {
|
||||||
bind_port,
|
bind_port,
|
||||||
originator_address: addr.ip().to_string(),
|
originator_address: addr.ip().to_string(),
|
||||||
originator_port: addr.port() as u32, // Phase 13.4: u16转u32
|
originator_port: addr.port() as u32, // Phase 13.4: u16转u32
|
||||||
stream,
|
stream,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -182,6 +172,12 @@ pub struct ListenerManager {
|
|||||||
listeners: HashMap<u32, Arc<Mutex<PortForwardListener>>>,
|
listeners: HashMap<u32, Arc<Mutex<PortForwardListener>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for ListenerManager {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ListenerManager {
|
impl ListenerManager {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -225,11 +221,14 @@ impl ListenerManager {
|
|||||||
|
|
||||||
/// 获取活动监听器数量(Phase 13.4)
|
/// 获取活动监听器数量(Phase 13.4)
|
||||||
pub fn active_count(&self) -> usize {
|
pub fn active_count(&self) -> usize {
|
||||||
self.listeners.values().filter(|l| l.lock().unwrap().is_active()).count()
|
self.listeners
|
||||||
|
.values()
|
||||||
|
.filter(|l| l.lock().unwrap().is_active())
|
||||||
|
.count()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
use std::collections::HashMap; // Phase 13.4: HashMap for listener management
|
use std::collections::HashMap; // Phase 13.4: HashMap for listener management
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
@@ -241,6 +240,6 @@ mod tests {
|
|||||||
let listener = PortForwardListener::new(8080, "127.0.0.1".to_string(), security_config);
|
let listener = PortForwardListener::new(8080, "127.0.0.1".to_string(), security_config);
|
||||||
|
|
||||||
// 注意:实际测试需要处理端口占用问题
|
// 注意:实际测试需要处理端口占用问题
|
||||||
assert!(listener.is_ok() || true); // 暂时跳过测试
|
assert!(listener.is_ok() || true); // 暂时跳过测试
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
use std::path::PathBuf;
|
|
||||||
use anyhow::{Result, anyhow};
|
|
||||||
use log::{info, debug, warn};
|
|
||||||
use crate::vfs::{VfsBackend, VfsFile, VfsError};
|
|
||||||
use crate::vfs::open_flags::OpenFlags;
|
use crate::vfs::open_flags::OpenFlags;
|
||||||
|
use crate::vfs::{VfsBackend, VfsFile};
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use log::{debug, info, warn};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
/// MPLEX_BASE from rsync io.h
|
/// MPLEX_BASE from rsync io.h
|
||||||
const MPLEX_BASE: u32 = 7;
|
const MPLEX_BASE: u32 = 7;
|
||||||
@@ -18,7 +18,9 @@ pub(crate) enum RsyncState {
|
|||||||
WaitVersion,
|
WaitVersion,
|
||||||
ReadFileList,
|
ReadFileList,
|
||||||
/// Sum head (4 × write_int = 16 bytes) + checksum seed (4 bytes) = 20 bytes
|
/// Sum head (4 × write_int = 16 bytes) + checksum seed (4 bytes) = 20 bytes
|
||||||
ReadSumHead { need: usize },
|
ReadSumHead {
|
||||||
|
need: usize,
|
||||||
|
},
|
||||||
SendSumCount,
|
SendSumCount,
|
||||||
/// Raw file data from MSG_DATA packets
|
/// Raw file data from MSG_DATA packets
|
||||||
ReadFileData,
|
ReadFileData,
|
||||||
@@ -51,9 +53,16 @@ impl RsyncHandler {
|
|||||||
let mut dest = String::new();
|
let mut dest = String::new();
|
||||||
|
|
||||||
for p in &parts[1..] {
|
for p in &parts[1..] {
|
||||||
if *p == "--server" { is_server = true; continue; }
|
if *p == "--server" {
|
||||||
if *p == "--sender" || p.starts_with('-') { continue; }
|
is_server = true;
|
||||||
if *p == "." { continue; }
|
continue;
|
||||||
|
}
|
||||||
|
if *p == "--sender" || p.starts_with('-') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if *p == "." {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
dest = p.to_string();
|
dest = p.to_string();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,8 +116,10 @@ impl RsyncHandler {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let header = u32::from_le_bytes([
|
let header = u32::from_le_bytes([
|
||||||
self.raw_input[0], self.raw_input[1],
|
self.raw_input[0],
|
||||||
self.raw_input[2], self.raw_input[3],
|
self.raw_input[1],
|
||||||
|
self.raw_input[2],
|
||||||
|
self.raw_input[3],
|
||||||
]);
|
]);
|
||||||
let raw_tag = ((header >> 24) & 0xFF) as u8;
|
let raw_tag = ((header >> 24) & 0xFF) as u8;
|
||||||
let tag = raw_tag.wrapping_sub(MPLEX_BASE as u8);
|
let tag = raw_tag.wrapping_sub(MPLEX_BASE as u8);
|
||||||
@@ -182,12 +193,17 @@ impl RsyncHandler {
|
|||||||
RsyncState::WaitVersion => {
|
RsyncState::WaitVersion => {
|
||||||
if self.rsync_input.len() >= 4 {
|
if self.rsync_input.len() >= 4 {
|
||||||
let version = u32::from_le_bytes([
|
let version = u32::from_le_bytes([
|
||||||
self.rsync_input[0], self.rsync_input[1],
|
self.rsync_input[0],
|
||||||
self.rsync_input[2], self.rsync_input[3],
|
self.rsync_input[1],
|
||||||
|
self.rsync_input[2],
|
||||||
|
self.rsync_input[3],
|
||||||
]);
|
]);
|
||||||
self.rsync_input.drain(..4);
|
self.rsync_input.drain(..4);
|
||||||
self.protocol_version = std::cmp::min(self.protocol_version, version);
|
self.protocol_version = std::cmp::min(self.protocol_version, version);
|
||||||
info!("rsync: negotiated protocol version {}", self.protocol_version);
|
info!(
|
||||||
|
"rsync: negotiated protocol version {}",
|
||||||
|
self.protocol_version
|
||||||
|
);
|
||||||
self.multiplex = self.protocol_version >= 30;
|
self.multiplex = self.protocol_version >= 30;
|
||||||
self.transition(RsyncState::ReadFileList);
|
self.transition(RsyncState::ReadFileList);
|
||||||
} else {
|
} else {
|
||||||
@@ -197,7 +213,9 @@ impl RsyncHandler {
|
|||||||
|
|
||||||
RsyncState::ReadFileList => {
|
RsyncState::ReadFileList => {
|
||||||
loop {
|
loop {
|
||||||
if self.rsync_input.is_empty() { break; }
|
if self.rsync_input.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
let flags = self.rsync_input[0];
|
let flags = self.rsync_input[0];
|
||||||
if flags == 0 {
|
if flags == 0 {
|
||||||
@@ -215,17 +233,25 @@ impl RsyncHandler {
|
|||||||
let mut pos = 1;
|
let mut pos = 1;
|
||||||
|
|
||||||
let _more_flags = if flags & 0x80 != 0 {
|
let _more_flags = if flags & 0x80 != 0 {
|
||||||
if self.rsync_input.len() <= pos { break; }
|
if self.rsync_input.len() <= pos {
|
||||||
|
break;
|
||||||
|
}
|
||||||
let ef = self.rsync_input[pos];
|
let ef = self.rsync_input[pos];
|
||||||
pos += 1;
|
pos += 1;
|
||||||
ef
|
ef
|
||||||
} else { 0 };
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
let has_name = !(flags & 0x02 != 0 && self.current_file > 0);
|
let has_name = !(flags & 0x02 != 0 && self.current_file > 0);
|
||||||
|
|
||||||
if has_name {
|
if has_name {
|
||||||
if let Some(nul_pos) = self.rsync_input[pos..].iter().position(|&b| b == 0) {
|
if let Some(nul_pos) =
|
||||||
let name = String::from_utf8_lossy(&self.rsync_input[pos..pos + nul_pos]).to_string();
|
self.rsync_input[pos..].iter().position(|&b| b == 0)
|
||||||
|
{
|
||||||
|
let name =
|
||||||
|
String::from_utf8_lossy(&self.rsync_input[pos..pos + nul_pos])
|
||||||
|
.to_string();
|
||||||
pos += nul_pos + 1;
|
pos += nul_pos + 1;
|
||||||
self.file_entries.push(name.clone());
|
self.file_entries.push(name.clone());
|
||||||
debug!("rsync: file entry: {}", name);
|
debug!("rsync: file entry: {}", name);
|
||||||
@@ -269,24 +295,34 @@ impl RsyncHandler {
|
|||||||
RsyncState::ReadSumHead { need } => {
|
RsyncState::ReadSumHead { need } => {
|
||||||
if self.rsync_input.len() >= need {
|
if self.rsync_input.len() >= need {
|
||||||
let sum_count = i32::from_le_bytes([
|
let sum_count = i32::from_le_bytes([
|
||||||
self.rsync_input[0], self.rsync_input[1],
|
self.rsync_input[0],
|
||||||
self.rsync_input[2], self.rsync_input[3],
|
self.rsync_input[1],
|
||||||
|
self.rsync_input[2],
|
||||||
|
self.rsync_input[3],
|
||||||
]);
|
]);
|
||||||
let _sum_blength = i32::from_le_bytes([
|
let _sum_blength = i32::from_le_bytes([
|
||||||
self.rsync_input[4], self.rsync_input[5],
|
self.rsync_input[4],
|
||||||
self.rsync_input[6], self.rsync_input[7],
|
self.rsync_input[5],
|
||||||
|
self.rsync_input[6],
|
||||||
|
self.rsync_input[7],
|
||||||
]);
|
]);
|
||||||
let _sum_s2length = i32::from_le_bytes([
|
let _sum_s2length = i32::from_le_bytes([
|
||||||
self.rsync_input[8], self.rsync_input[9],
|
self.rsync_input[8],
|
||||||
self.rsync_input[10], self.rsync_input[11],
|
self.rsync_input[9],
|
||||||
|
self.rsync_input[10],
|
||||||
|
self.rsync_input[11],
|
||||||
]);
|
]);
|
||||||
let _sum_remainder = i32::from_le_bytes([
|
let _sum_remainder = i32::from_le_bytes([
|
||||||
self.rsync_input[12], self.rsync_input[13],
|
self.rsync_input[12],
|
||||||
self.rsync_input[14], self.rsync_input[15],
|
self.rsync_input[13],
|
||||||
|
self.rsync_input[14],
|
||||||
|
self.rsync_input[15],
|
||||||
]);
|
]);
|
||||||
let checksum_seed = i32::from_le_bytes([
|
let checksum_seed = i32::from_le_bytes([
|
||||||
self.rsync_input[16], self.rsync_input[17],
|
self.rsync_input[16],
|
||||||
self.rsync_input[18], self.rsync_input[19],
|
self.rsync_input[17],
|
||||||
|
self.rsync_input[18],
|
||||||
|
self.rsync_input[19],
|
||||||
]);
|
]);
|
||||||
self.rsync_input.drain(..20);
|
self.rsync_input.drain(..20);
|
||||||
|
|
||||||
@@ -308,7 +344,9 @@ impl RsyncHandler {
|
|||||||
|
|
||||||
RsyncState::ReadFileData => {
|
RsyncState::ReadFileData => {
|
||||||
let done_marker = b"RSYNCDONE";
|
let done_marker = b"RSYNCDONE";
|
||||||
if let Some(pos) = self.rsync_input.windows(done_marker.len())
|
if let Some(pos) = self
|
||||||
|
.rsync_input
|
||||||
|
.windows(done_marker.len())
|
||||||
.position(|w| w == done_marker)
|
.position(|w| w == done_marker)
|
||||||
{
|
{
|
||||||
if pos > 0 {
|
if pos > 0 {
|
||||||
@@ -323,8 +361,11 @@ impl RsyncHandler {
|
|||||||
warn!("rsync flush error: {}", e);
|
warn!("rsync flush error: {}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
info!("rsync: file {} complete ({} bytes written to {})",
|
info!(
|
||||||
self.file_entries.get(self.current_file).unwrap_or(&"?".to_string()),
|
"rsync: file {} complete ({} bytes written to {})",
|
||||||
|
self.file_entries
|
||||||
|
.get(self.current_file)
|
||||||
|
.unwrap_or(&"?".to_string()),
|
||||||
self.total_written,
|
self.total_written,
|
||||||
self.dest_path.display(),
|
self.dest_path.display(),
|
||||||
);
|
);
|
||||||
@@ -332,8 +373,11 @@ impl RsyncHandler {
|
|||||||
self.current_file += 1;
|
self.current_file += 1;
|
||||||
if self.current_file >= self.file_entries.len() {
|
if self.current_file >= self.file_entries.len() {
|
||||||
self.transition(RsyncState::Done);
|
self.transition(RsyncState::Done);
|
||||||
info!("rsync ALL DONE: {} bytes written to {}",
|
info!(
|
||||||
self.total_written, self.dest_path.display());
|
"rsync ALL DONE: {} bytes written to {}",
|
||||||
|
self.total_written,
|
||||||
|
self.dest_path.display()
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
self.transition(RsyncState::ReadSumHead { need: 20 });
|
self.transition(RsyncState::ReadSumHead { need: 20 });
|
||||||
}
|
}
|
||||||
@@ -360,7 +404,9 @@ impl RsyncHandler {
|
|||||||
self.vfs.create_dir_all(parent, 0o755).ok();
|
self.vfs.create_dir_all(parent, 0o755).ok();
|
||||||
}
|
}
|
||||||
let flags = OpenFlags::new().write().create().truncate();
|
let flags = OpenFlags::new().write().create().truncate();
|
||||||
let file = self.vfs.open_file(&self.dest_path, &flags)
|
let file = self
|
||||||
|
.vfs
|
||||||
|
.open_file(&self.dest_path, &flags)
|
||||||
.map_err(|e| anyhow!("open error: {}", e))?;
|
.map_err(|e| anyhow!("open error: {}", e))?;
|
||||||
self.output_file = Some(file);
|
self.output_file = Some(file);
|
||||||
info!("rsync: opened {} for writing", self.dest_path.display());
|
info!("rsync: opened {} for writing", self.dest_path.display());
|
||||||
@@ -379,31 +425,43 @@ impl RsyncHandler {
|
|||||||
|
|
||||||
/// Read rsync varint (LSB-first 7-bit groups, 0xFF prefix for negative)
|
/// Read rsync varint (LSB-first 7-bit groups, 0xFF prefix for negative)
|
||||||
fn read_varint(buf: &[u8]) -> Option<(i32, usize)> {
|
fn read_varint(buf: &[u8]) -> Option<(i32, usize)> {
|
||||||
if buf.is_empty() { return None; }
|
if buf.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
let mut pos = 0;
|
let mut pos = 0;
|
||||||
let mut b = buf[pos];
|
let mut b = buf[pos];
|
||||||
pos += 1;
|
pos += 1;
|
||||||
|
|
||||||
let neg = if b == 0xFF {
|
let neg = if b == 0xFF {
|
||||||
if pos >= buf.len() { return None; }
|
if pos >= buf.len() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
b = buf[pos];
|
b = buf[pos];
|
||||||
pos += 1;
|
pos += 1;
|
||||||
true
|
true
|
||||||
} else { false };
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
|
||||||
let mut x = (b & 0x7F) as i32;
|
let mut x = (b & 0x7F) as i32;
|
||||||
let mut shift = 7;
|
let mut shift = 7;
|
||||||
|
|
||||||
while b & 0x80 != 0 {
|
while b & 0x80 != 0 {
|
||||||
if pos >= buf.len() { return None; }
|
if pos >= buf.len() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
b = buf[pos];
|
b = buf[pos];
|
||||||
pos += 1;
|
pos += 1;
|
||||||
x |= ((b & 0x7F) as i32) << shift;
|
x |= ((b & 0x7F) as i32) << shift;
|
||||||
shift += 7;
|
shift += 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
if neg { Some((-x, pos)) } else { Some((x, pos)) }
|
if neg {
|
||||||
|
Some((-x, pos))
|
||||||
|
} else {
|
||||||
|
Some((x, pos))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -419,8 +477,9 @@ mod tests {
|
|||||||
fn test_parse_command() {
|
fn test_parse_command() {
|
||||||
let h = RsyncHandler::parse_rsync_command(
|
let h = RsyncHandler::parse_rsync_command(
|
||||||
"rsync --server -g -l -o -p -D -r -t -v --dirs . /tmp/upload.bin",
|
"rsync --server -g -l -o -p -D -r -t -v --dirs . /tmp/upload.bin",
|
||||||
make_vfs()
|
make_vfs(),
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
assert_eq!(h.dest_path, PathBuf::from("/tmp/upload.bin"));
|
assert_eq!(h.dest_path, PathBuf::from("/tmp/upload.bin"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -428,14 +487,16 @@ mod tests {
|
|||||||
fn test_parse_command_sender() {
|
fn test_parse_command_sender() {
|
||||||
let h = RsyncHandler::parse_rsync_command(
|
let h = RsyncHandler::parse_rsync_command(
|
||||||
"rsync --server --sender -vlogDtprz . /home/user/file.txt",
|
"rsync --server --sender -vlogDtprz . /home/user/file.txt",
|
||||||
make_vfs()
|
make_vfs(),
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
assert_eq!(h.dest_path, PathBuf::from("/home/user/file.txt"));
|
assert_eq!(h.dest_path, PathBuf::from("/home/user/file.txt"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_version_exchange() {
|
fn test_version_exchange() {
|
||||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap();
|
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs())
|
||||||
|
.unwrap();
|
||||||
let output = h.drain_output();
|
let output = h.drain_output();
|
||||||
assert_eq!(output, b"\x1e\x00\x00\x00");
|
assert_eq!(output, b"\x1e\x00\x00\x00");
|
||||||
assert_eq!(h.state, RsyncState::WaitVersion);
|
assert_eq!(h.state, RsyncState::WaitVersion);
|
||||||
@@ -447,7 +508,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_version_negotiate_down() {
|
fn test_version_negotiate_down() {
|
||||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap();
|
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs())
|
||||||
|
.unwrap();
|
||||||
let _ = h.drain_output();
|
let _ = h.drain_output();
|
||||||
h.feed(b"\x1d\x00\x00\x00").unwrap();
|
h.feed(b"\x1d\x00\x00\x00").unwrap();
|
||||||
assert_eq!(h.protocol_version, 29);
|
assert_eq!(h.protocol_version, 29);
|
||||||
@@ -464,26 +526,33 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_file_list_multiplex() {
|
fn test_file_list_multiplex() {
|
||||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap();
|
let mut h =
|
||||||
|
RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs())
|
||||||
|
.unwrap();
|
||||||
let _ = h.drain_output();
|
let _ = h.drain_output();
|
||||||
h.feed(b"\x1e\x00\x00\x00").unwrap();
|
h.feed(b"\x1e\x00\x00\x00").unwrap();
|
||||||
assert!(h.multiplex);
|
assert!(h.multiplex);
|
||||||
|
|
||||||
let mut flist = Vec::new();
|
let mut flist = Vec::new();
|
||||||
// File list: flags=1 (has name), then name with NUL terminator
|
// File list: flags=1 (has name), then name with NUL terminator
|
||||||
flist.push(1); // flags: has name
|
flist.push(1); // flags: has name
|
||||||
flist.extend_from_slice(b"test.txt");
|
flist.extend_from_slice(b"test.txt");
|
||||||
flist.push(0); // name terminator
|
flist.push(0); // name terminator
|
||||||
|
|
||||||
fn write_varint(buf: &mut Vec<u8>, val: i32) {
|
fn write_varint(buf: &mut Vec<u8>, val: i32) {
|
||||||
if val == 0 { buf.push(0); return; }
|
if val == 0 {
|
||||||
|
buf.push(0);
|
||||||
|
return;
|
||||||
|
}
|
||||||
if val < 0 {
|
if val < 0 {
|
||||||
buf.push(0xFF);
|
buf.push(0xFF);
|
||||||
let mut v = (-val) as u32;
|
let mut v = (-val) as u32;
|
||||||
while v > 0 {
|
while v > 0 {
|
||||||
let mut byte = (v & 0x7F) as u8;
|
let mut byte = (v & 0x7F) as u8;
|
||||||
v >>= 7;
|
v >>= 7;
|
||||||
if v > 0 { byte |= 0x80; }
|
if v > 0 {
|
||||||
|
byte |= 0x80;
|
||||||
|
}
|
||||||
buf.push(byte);
|
buf.push(byte);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -491,7 +560,9 @@ mod tests {
|
|||||||
while v > 0 {
|
while v > 0 {
|
||||||
let mut byte = (v & 0x7F) as u8;
|
let mut byte = (v & 0x7F) as u8;
|
||||||
v >>= 7;
|
v >>= 7;
|
||||||
if v > 0 { byte |= 0x80; }
|
if v > 0 {
|
||||||
|
byte |= 0x80;
|
||||||
|
}
|
||||||
buf.push(byte);
|
buf.push(byte);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -502,7 +573,7 @@ mod tests {
|
|||||||
write_varint(&mut flist, 1700000000);
|
write_varint(&mut flist, 1700000000);
|
||||||
write_varint(&mut flist, 100);
|
write_varint(&mut flist, 100);
|
||||||
write_varint(&mut flist, 0);
|
write_varint(&mut flist, 0);
|
||||||
flist.push(0); // file list end marker
|
flist.push(0); // file list end marker
|
||||||
|
|
||||||
let mut sum_head = Vec::new();
|
let mut sum_head = Vec::new();
|
||||||
sum_head.extend_from_slice(&0i32.to_le_bytes());
|
sum_head.extend_from_slice(&0i32.to_le_bytes());
|
||||||
@@ -527,22 +598,51 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_file_data_multiplex() {
|
fn test_file_data_multiplex() {
|
||||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap();
|
let mut h =
|
||||||
|
RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs())
|
||||||
|
.unwrap();
|
||||||
let _ = h.drain_output();
|
let _ = h.drain_output();
|
||||||
h.feed(b"\x1e\x00\x00\x00").unwrap();
|
h.feed(b"\x1e\x00\x00\x00").unwrap();
|
||||||
|
|
||||||
let mut flist = Vec::new();
|
let mut flist = Vec::new();
|
||||||
flist.push(1); // flags: has name
|
flist.push(1); // flags: has name
|
||||||
flist.extend_from_slice(b"test.bin");
|
flist.extend_from_slice(b"test.bin");
|
||||||
flist.push(0);
|
flist.push(0);
|
||||||
fn wv(buf: &mut Vec<u8>, val: i32) {
|
fn wv(buf: &mut Vec<u8>, val: i32) {
|
||||||
if val == 0 { buf.push(0); return; }
|
if val == 0 {
|
||||||
if val < 0 { buf.push(0xFF); let mut v = (-val) as u32; while v > 0 { let mut byte = (v & 0x7F) as u8; v >>= 7; if v > 0 { byte |= 0x80; } buf.push(byte); } }
|
buf.push(0);
|
||||||
else { let mut v = val as u32; while v > 0 { let mut byte = (v & 0x7F) as u8; v >>= 7; if v > 0 { byte |= 0x80; } buf.push(byte); } }
|
return;
|
||||||
|
}
|
||||||
|
if val < 0 {
|
||||||
|
buf.push(0xFF);
|
||||||
|
let mut v = (-val) as u32;
|
||||||
|
while v > 0 {
|
||||||
|
let mut byte = (v & 0x7F) as u8;
|
||||||
|
v >>= 7;
|
||||||
|
if v > 0 {
|
||||||
|
byte |= 0x80;
|
||||||
|
}
|
||||||
|
buf.push(byte);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let mut v = val as u32;
|
||||||
|
while v > 0 {
|
||||||
|
let mut byte = (v & 0x7F) as u8;
|
||||||
|
v >>= 7;
|
||||||
|
if v > 0 {
|
||||||
|
byte |= 0x80;
|
||||||
|
}
|
||||||
|
buf.push(byte);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
wv(&mut flist, 33188); wv(&mut flist, 501); wv(&mut flist, 20);
|
wv(&mut flist, 33188);
|
||||||
wv(&mut flist, 1700000000); wv(&mut flist, 100); wv(&mut flist, 0);
|
wv(&mut flist, 501);
|
||||||
flist.push(0); // file list end
|
wv(&mut flist, 20);
|
||||||
|
wv(&mut flist, 1700000000);
|
||||||
|
wv(&mut flist, 100);
|
||||||
|
wv(&mut flist, 0);
|
||||||
|
flist.push(0); // file list end
|
||||||
h.feed(&build_multiplex(&flist)).unwrap();
|
h.feed(&build_multiplex(&flist)).unwrap();
|
||||||
|
|
||||||
let mut sh = Vec::new();
|
let mut sh = Vec::new();
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
// SCP协议实现(Phase 8)
|
// SCP协议实现(Phase 8)
|
||||||
// 参考OpenSSH scp.c源码
|
// 参考OpenSSH scp.c源码
|
||||||
|
|
||||||
use crate::vfs::{VfsBackend, VfsFile, VfsError, VfsStat};
|
|
||||||
use crate::vfs::open_flags::OpenFlags;
|
use crate::vfs::open_flags::OpenFlags;
|
||||||
use anyhow::{Result, anyhow};
|
use crate::vfs::{VfsBackend, VfsFile, VfsStat};
|
||||||
use log::{info, warn, debug};
|
use anyhow::{anyhow, Result};
|
||||||
|
use log::{debug, info, warn};
|
||||||
|
use std::io::{BufRead, Read, Write};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::io::{Read, Write, BufRead};
|
|
||||||
use std::time::SystemTime;
|
|
||||||
|
|
||||||
/// SCP Handler(参考OpenSSH scp.c)
|
/// SCP Handler(参考OpenSSH scp.c)
|
||||||
pub struct ScpHandler {
|
pub struct ScpHandler {
|
||||||
@@ -71,10 +70,15 @@ impl ScpHandler {
|
|||||||
|
|
||||||
/// SCP Source Mode(scp -f,发送文件)
|
/// SCP Source Mode(scp -f,发送文件)
|
||||||
fn handle_source_mode(&self, channel: &mut dyn ReadWrite) -> Result<()> {
|
fn handle_source_mode(&self, channel: &mut dyn ReadWrite) -> Result<()> {
|
||||||
info!("SCP source mode: sending files from {}", self.root_dir.display());
|
info!(
|
||||||
|
"SCP source mode: sending files from {}",
|
||||||
|
self.root_dir.display()
|
||||||
|
);
|
||||||
|
|
||||||
let full_path = self.resolve_path(&self.root_dir.to_string_lossy())?;
|
let full_path = self.resolve_path(&self.root_dir.to_string_lossy())?;
|
||||||
let stat = self.vfs.stat(&full_path)
|
let stat = self
|
||||||
|
.vfs
|
||||||
|
.stat(&full_path)
|
||||||
.map_err(|e| anyhow!("stat error: {}", e))?;
|
.map_err(|e| anyhow!("stat error: {}", e))?;
|
||||||
|
|
||||||
if stat.is_dir {
|
if stat.is_dir {
|
||||||
@@ -91,7 +95,10 @@ impl ScpHandler {
|
|||||||
|
|
||||||
/// SCP Destination Mode(scp -t,接收文件)
|
/// SCP Destination Mode(scp -t,接收文件)
|
||||||
fn handle_destination_mode(&mut self, channel: &mut dyn ReadWrite) -> Result<()> {
|
fn handle_destination_mode(&mut self, channel: &mut dyn ReadWrite) -> Result<()> {
|
||||||
info!("SCP destination mode: receiving files to {}", self.root_dir.display());
|
info!(
|
||||||
|
"SCP destination mode: receiving files to {}",
|
||||||
|
self.root_dir.display()
|
||||||
|
);
|
||||||
|
|
||||||
channel.write_all(&[0])?;
|
channel.write_all(&[0])?;
|
||||||
channel.flush()?;
|
channel.flush()?;
|
||||||
@@ -130,7 +137,9 @@ impl ScpHandler {
|
|||||||
|
|
||||||
/// 发送文件(参考OpenSSH scp.c: source())
|
/// 发送文件(参考OpenSSH scp.c: source())
|
||||||
fn send_file(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
|
fn send_file(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
|
||||||
let stat = self.vfs.stat(path)
|
let stat = self
|
||||||
|
.vfs
|
||||||
|
.stat(path)
|
||||||
.map_err(|e| anyhow!("stat error: {}", e))?;
|
.map_err(|e| anyhow!("stat error: {}", e))?;
|
||||||
let size = stat.size;
|
let size = stat.size;
|
||||||
let filename = path.file_name().unwrap().to_string_lossy();
|
let filename = path.file_name().unwrap().to_string_lossy();
|
||||||
@@ -146,13 +155,16 @@ impl ScpHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let flags = OpenFlags::new().read();
|
let flags = OpenFlags::new().read();
|
||||||
let mut file = self.vfs.open_file(path, &flags)
|
let mut file = self
|
||||||
|
.vfs
|
||||||
|
.open_file(path, &flags)
|
||||||
.map_err(|e| anyhow!("open error: {}", e))?;
|
.map_err(|e| anyhow!("open error: {}", e))?;
|
||||||
|
|
||||||
let mut buffer = vec![0u8; 8192];
|
let mut buffer = vec![0u8; 8192];
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let n = file.read(&mut buffer)
|
let n = file
|
||||||
|
.read(&mut buffer)
|
||||||
.map_err(|e| anyhow!("read error: {}", e))?;
|
.map_err(|e| anyhow!("read error: {}", e))?;
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
break;
|
break;
|
||||||
@@ -188,7 +200,9 @@ impl ScpHandler {
|
|||||||
return Err(anyhow!("SCP directory command rejected"));
|
return Err(anyhow!("SCP directory command rejected"));
|
||||||
}
|
}
|
||||||
|
|
||||||
let entries = self.vfs.read_dir(path)
|
let entries = self
|
||||||
|
.vfs
|
||||||
|
.read_dir(path)
|
||||||
.map_err(|e| anyhow!("read_dir error: {}", e))?;
|
.map_err(|e| anyhow!("read_dir error: {}", e))?;
|
||||||
|
|
||||||
for entry in &entries {
|
for entry in &entries {
|
||||||
@@ -227,7 +241,10 @@ impl ScpHandler {
|
|||||||
let size: u64 = parts[1].parse()?;
|
let size: u64 = parts[1].parse()?;
|
||||||
let filename = parts[2];
|
let filename = parts[2];
|
||||||
|
|
||||||
debug!("SCP receive file: mode={}, size={}, name={}", mode_str, size, filename);
|
debug!(
|
||||||
|
"SCP receive file: mode={}, size={}, name={}",
|
||||||
|
mode_str, size, filename
|
||||||
|
);
|
||||||
|
|
||||||
if size > 1024 * 1024 * 1024 {
|
if size > 1024 * 1024 * 1024 {
|
||||||
return self.send_error(channel, "File too large (max 1GB)");
|
return self.send_error(channel, "File too large (max 1GB)");
|
||||||
@@ -236,7 +253,9 @@ impl ScpHandler {
|
|||||||
let full_path = self.resolve_path(filename)?;
|
let full_path = self.resolve_path(filename)?;
|
||||||
|
|
||||||
let flags = OpenFlags::new().write().create().truncate();
|
let flags = OpenFlags::new().write().create().truncate();
|
||||||
let mut file = self.vfs.open_file(&full_path, &flags)
|
let mut file = self
|
||||||
|
.vfs
|
||||||
|
.open_file(&full_path, &flags)
|
||||||
.map_err(|e| anyhow!("open error: {}", e))?;
|
.map_err(|e| anyhow!("open error: {}", e))?;
|
||||||
|
|
||||||
channel.write_all(&[0])?;
|
channel.write_all(&[0])?;
|
||||||
@@ -263,7 +282,8 @@ impl ScpHandler {
|
|||||||
if mode_int != 0 {
|
if mode_int != 0 {
|
||||||
let mut set_stat = VfsStat::new();
|
let mut set_stat = VfsStat::new();
|
||||||
set_stat.mode = mode_int;
|
set_stat.mode = mode_int;
|
||||||
self.vfs.set_stat(&full_path, &set_stat)
|
self.vfs
|
||||||
|
.set_stat(&full_path, &set_stat)
|
||||||
.map_err(|e| anyhow!("set_stat error: {}", e))?;
|
.map_err(|e| anyhow!("set_stat error: {}", e))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -297,7 +317,8 @@ impl ScpHandler {
|
|||||||
let full_path = self.resolve_path(dirname)?;
|
let full_path = self.resolve_path(dirname)?;
|
||||||
|
|
||||||
let mode_int: u32 = mode_str.parse()?;
|
let mode_int: u32 = mode_str.parse()?;
|
||||||
self.vfs.create_dir_all(&full_path, mode_int)
|
self.vfs
|
||||||
|
.create_dir_all(&full_path, mode_int)
|
||||||
.map_err(|e| anyhow!("create_dir_all error: {}", e))?;
|
.map_err(|e| anyhow!("create_dir_all error: {}", e))?;
|
||||||
|
|
||||||
channel.write_all(&[0])?;
|
channel.write_all(&[0])?;
|
||||||
@@ -354,10 +375,14 @@ impl ScpHandler {
|
|||||||
fn resolve_path(&self, path: &str) -> Result<PathBuf> {
|
fn resolve_path(&self, path: &str) -> Result<PathBuf> {
|
||||||
let full_path = self.root_dir.join(path);
|
let full_path = self.root_dir.join(path);
|
||||||
|
|
||||||
let canonical_path = self.vfs.real_path(&full_path)
|
let canonical_path = self
|
||||||
|
.vfs
|
||||||
|
.real_path(&full_path)
|
||||||
.map_err(|e| anyhow!("Path resolution error: {}", e))?;
|
.map_err(|e| anyhow!("Path resolution error: {}", e))?;
|
||||||
|
|
||||||
let root_canonical = self.vfs.real_path(&self.root_dir)
|
let root_canonical = self
|
||||||
|
.vfs
|
||||||
|
.real_path(&self.root_dir)
|
||||||
.map_err(|e| anyhow!("Root path resolution error: {}", e))?;
|
.map_err(|e| anyhow!("Root path resolution error: {}", e))?;
|
||||||
|
|
||||||
if !canonical_path.starts_with(&root_canonical) {
|
if !canonical_path.starts_with(&root_canonical) {
|
||||||
@@ -383,20 +408,23 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_scp_command_parse() {
|
fn test_scp_command_parse() {
|
||||||
let handler = ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap();
|
let handler =
|
||||||
|
ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap();
|
||||||
assert_eq!(handler.mode, ScpMode::Destination);
|
assert_eq!(handler.mode, ScpMode::Destination);
|
||||||
assert_eq!(handler.root_dir, PathBuf::from("/tmp"));
|
assert_eq!(handler.root_dir, PathBuf::from("/tmp"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_scp_recursive_parse() {
|
fn test_scp_recursive_parse() {
|
||||||
let handler = ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap();
|
let handler =
|
||||||
|
ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap();
|
||||||
assert!(handler.recursive);
|
assert!(handler.recursive);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_scp_source_parse() {
|
fn test_scp_source_parse() {
|
||||||
let handler = ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap();
|
let handler =
|
||||||
|
ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap();
|
||||||
assert_eq!(handler.mode, ScpMode::Source);
|
assert_eq!(handler.mode, ScpMode::Source);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,32 +1,32 @@
|
|||||||
// SSH服务器完整实现(Phase 1-7集成版 + Phase 13端口转发)
|
// SSH服务器完整实现(Phase 1-7集成版 + Phase 13端口转发)
|
||||||
// 参考OpenSSH sshd.c: complete SSH/SFTP flow + port forwarding
|
// 参考OpenSSH sshd.c: complete SSH/SFTP flow + port forwarding
|
||||||
|
|
||||||
use crate::ssh_server::version::VersionExchange;
|
|
||||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
|
||||||
use crate::ssh_server::kex::{KexResult, KexProposal};
|
|
||||||
use crate::ssh_server::kex_complete::{KexState};
|
|
||||||
use crate::ssh_server::auth::{AuthHandler, AuthResult};
|
|
||||||
use crate::provider::sqlite::SqliteProvider;
|
|
||||||
use crate::provider::pg::PgProvider;
|
use crate::provider::pg::PgProvider;
|
||||||
|
use crate::provider::sqlite::SqliteProvider;
|
||||||
use crate::provider::DataProvider;
|
use crate::provider::DataProvider;
|
||||||
use crate::ssh_server::channel::{ChannelManager};
|
use crate::ssh_server::auth::{AuthHandler, AuthResult};
|
||||||
use crate::ssh_server::cipher::{EncryptionContext, EncryptedPacket};
|
use crate::ssh_server::channel::ChannelManager;
|
||||||
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1
|
use crate::ssh_server::cipher::{EncryptedPacket, EncryptionContext};
|
||||||
use crate::ssh_server::port_forward::PortForwardManager; // Phase 13
|
use crate::ssh_server::kex::{KexProposal, KexResult};
|
||||||
use anyhow::{Result, anyhow};
|
use crate::ssh_server::kex_complete::KexState;
|
||||||
use log::{info, warn, error, debug};
|
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||||
|
use crate::ssh_server::port_forward::PortForwardManager; // Phase 13
|
||||||
|
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1
|
||||||
|
use crate::ssh_server::version::VersionExchange;
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use log::{error, info, warn};
|
||||||
|
use std::io::{Read, Write};
|
||||||
use std::net::{TcpListener, TcpStream};
|
use std::net::{TcpListener, TcpStream};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::thread;
|
use std::sync::{Arc, Mutex};
|
||||||
use std::io::{Read, Write};
|
use std::thread; // Phase 13: 端口转发线程同步
|
||||||
use std::sync::{Arc, Mutex}; // Phase 13: 端口转发线程同步
|
|
||||||
|
|
||||||
/// SSH服务器配置(Phase 13.1企业级安全配置)
|
/// SSH服务器配置(Phase 13.1企业级安全配置)
|
||||||
pub struct SshServerConfig {
|
pub struct SshServerConfig {
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
pub bind_address: String,
|
pub bind_address: String,
|
||||||
pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置
|
pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置
|
||||||
pub pg_conn: Option<String>, // PostgreSQL连接字符串(SFTPGo兼容认证)
|
pub pg_conn: Option<String>, // PostgreSQL连接字符串(SFTPGo兼容认证)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for SshServerConfig {
|
impl Default for SshServerConfig {
|
||||||
@@ -34,7 +34,7 @@ impl Default for SshServerConfig {
|
|||||||
Self {
|
Self {
|
||||||
port: 2024,
|
port: 2024,
|
||||||
bind_address: "127.0.0.1".to_string(),
|
bind_address: "127.0.0.1".to_string(),
|
||||||
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1
|
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1
|
||||||
pg_conn: None,
|
pg_conn: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -56,15 +56,15 @@ impl SshServerConfig {
|
|||||||
/// SSH服务器主结构(Phase 1-13完整版)
|
/// SSH服务器主结构(Phase 1-13完整版)
|
||||||
pub struct SshServer {
|
pub struct SshServer {
|
||||||
config: SshServerConfig,
|
config: SshServerConfig,
|
||||||
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1: 共享安全配置
|
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1: 共享安全配置
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SshServer {
|
impl SshServer {
|
||||||
pub fn new(config: SshServerConfig) -> Self {
|
pub fn new(config: SshServerConfig) -> Self {
|
||||||
let security_config = Arc::new(Mutex::new(config.security_config.clone())); // Phase 13.1: 先clone
|
let security_config = Arc::new(Mutex::new(config.security_config.clone())); // Phase 13.1: 先clone
|
||||||
Self {
|
Self {
|
||||||
config,
|
config,
|
||||||
security_config, // Phase 13.1
|
security_config, // Phase 13.1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,12 +74,14 @@ impl SshServer {
|
|||||||
|
|
||||||
info!("MarkBaseSSH server listening on {}", bind_addr);
|
info!("MarkBaseSSH server listening on {}", bind_addr);
|
||||||
info!("Implementation: Complete SSH/SFTP + Port Forwarding (Phase 1-13)");
|
info!("Implementation: Complete SSH/SFTP + Port Forwarding (Phase 1-13)");
|
||||||
info!("Security config: GatewayPorts={}, PermitOpen={:?}, MaxSessions={}",
|
info!(
|
||||||
|
"Security config: GatewayPorts={}, PermitOpen={:?}, MaxSessions={}",
|
||||||
self.config.security_config.gateway_ports,
|
self.config.security_config.gateway_ports,
|
||||||
self.config.security_config.permit_open,
|
self.config.security_config.permit_open,
|
||||||
self.config.security_config.max_sessions);
|
self.config.security_config.max_sessions
|
||||||
|
);
|
||||||
|
|
||||||
let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置
|
let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置
|
||||||
let pg_conn = self.config.pg_conn.clone();
|
let pg_conn = self.config.pg_conn.clone();
|
||||||
|
|
||||||
for stream in listener.incoming() {
|
for stream in listener.incoming() {
|
||||||
@@ -88,11 +90,14 @@ impl SshServer {
|
|||||||
let client_addr = stream.peer_addr()?;
|
let client_addr = stream.peer_addr()?;
|
||||||
info!("New SSH connection from {}", client_addr);
|
info!("New SSH connection from {}", client_addr);
|
||||||
|
|
||||||
let security_config_clone = security_config.clone(); // Phase 13.1
|
let security_config_clone = security_config.clone(); // Phase 13.1
|
||||||
let pg_conn_clone = pg_conn.clone();
|
let pg_conn_clone = pg_conn.clone();
|
||||||
|
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
if let Err(e) = handle_connection_complete(stream, security_config_clone, pg_conn_clone) { // Phase 13.1
|
if let Err(e) =
|
||||||
|
handle_connection_complete(stream, security_config_clone, pg_conn_clone)
|
||||||
|
{
|
||||||
|
// Phase 13.1
|
||||||
error!("Connection error: {}", e);
|
error!("Connection error: {}", e);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -108,7 +113,11 @@ impl SshServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// 处理完整SSH连接(Phase 1-13完整流程)
|
/// 处理完整SSH连接(Phase 1-13完整流程)
|
||||||
fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshSecurityConfig>>, pg_conn: Option<String>) -> Result<()> {
|
fn handle_connection_complete(
|
||||||
|
stream: TcpStream,
|
||||||
|
security_config: Arc<Mutex<SshSecurityConfig>>,
|
||||||
|
pg_conn: Option<String>,
|
||||||
|
) -> Result<()> {
|
||||||
info!("Handling client connection (Phase 1-13 complete flow with port forwarding)");
|
info!("Handling client connection (Phase 1-13 complete flow with port forwarding)");
|
||||||
|
|
||||||
// Phase 13.1: 增加活动会话数
|
// Phase 13.1: 增加活动会话数
|
||||||
@@ -121,25 +130,44 @@ fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshS
|
|||||||
|
|
||||||
// Phase 1: 版本交换
|
// Phase 1: 版本交换
|
||||||
let client_version = VersionExchange::exchange(&mut stream)?;
|
let client_version = VersionExchange::exchange(&mut stream)?;
|
||||||
info!("Version exchange: client={}, server=SSH-2.0-MarkBaseSSH_1.0", client_version);
|
info!(
|
||||||
|
"Version exchange: client={}, server=SSH-2.0-MarkBaseSSH_1.0",
|
||||||
|
client_version
|
||||||
|
);
|
||||||
|
|
||||||
// Phase 2: 箋法协商
|
// Phase 2: 箋法协商
|
||||||
let (kex_result, server_kexinit, client_kexinit) = perform_kex_negotiation_complete(&mut stream)?;
|
let (kex_result, server_kexinit, client_kexinit) =
|
||||||
info!("KEX negotiation: KEX={}, Cipher={}", kex_result.kex_algorithm, kex_result.encryption_ctos);
|
perform_kex_negotiation_complete(&mut stream)?;
|
||||||
|
info!(
|
||||||
|
"KEX negotiation: KEX={}, Cipher={}",
|
||||||
|
kex_result.kex_algorithm, kex_result.encryption_ctos
|
||||||
|
);
|
||||||
|
|
||||||
// Phase 3: 密钥交换完整流程
|
// Phase 3: 密钥交换完整流程
|
||||||
let mut encryption_ctx = perform_complete_kex_exchange(&mut stream, client_version.clone(), kex_result, server_kexinit, client_kexinit)?;
|
let mut encryption_ctx = perform_complete_kex_exchange(
|
||||||
|
&mut stream,
|
||||||
|
client_version.clone(),
|
||||||
|
kex_result,
|
||||||
|
server_kexinit,
|
||||||
|
client_kexinit,
|
||||||
|
)?;
|
||||||
info!("Key exchange completed, encryption channel ready");
|
info!("Key exchange completed, encryption channel ready");
|
||||||
|
|
||||||
// Phase 5: SSH认证(SFTPGo兼容 — PostgreSQL或SQLite)
|
// Phase 5: SSH认证(SFTPGo兼容 — PostgreSQL或SQLite)
|
||||||
let provider: Box<dyn DataProvider> = if let Some(ref conn_str) = pg_conn {
|
let provider: Box<dyn DataProvider> = if let Some(ref conn_str) = pg_conn {
|
||||||
info!("Using PostgreSQL auth provider (SFTPGo-compatible): {}", conn_str);
|
info!(
|
||||||
Box::new(PgProvider::new(conn_str)
|
"Using PostgreSQL auth provider (SFTPGo-compatible): {}",
|
||||||
.map_err(|e| anyhow!("Failed to init PgProvider: {}", e))?)
|
conn_str
|
||||||
|
);
|
||||||
|
Box::new(
|
||||||
|
PgProvider::new(conn_str).map_err(|e| anyhow!("Failed to init PgProvider: {}", e))?,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
info!("Using SQLite auth provider");
|
info!("Using SQLite auth provider");
|
||||||
Box::new(SqliteProvider::new("data/auth.sqlite")
|
Box::new(
|
||||||
.map_err(|e| anyhow!("Failed to init SqliteProvider: {}", e))?)
|
SqliteProvider::new("data/auth.sqlite")
|
||||||
|
.map_err(|e| anyhow!("Failed to init SqliteProvider: {}", e))?,
|
||||||
|
)
|
||||||
};
|
};
|
||||||
let mut auth_handler = AuthHandler::new(provider);
|
let mut auth_handler = AuthHandler::new(provider);
|
||||||
let auth_user = perform_ssh_auth(&mut stream, &mut auth_handler, &mut encryption_ctx)?;
|
let auth_user = perform_ssh_auth(&mut stream, &mut auth_handler, &mut encryption_ctx)?;
|
||||||
@@ -152,8 +180,14 @@ fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshS
|
|||||||
let mut port_forward_manager = PortForwardManager::new();
|
let mut port_forward_manager = PortForwardManager::new();
|
||||||
|
|
||||||
// Phase 6-13: SSH服务循环(处理channel请求 + 端口转发)
|
// Phase 6-13: SSH服务循环(处理channel请求 + 端口转发)
|
||||||
let security_config_clone = security_config.clone(); // Phase 13.1: clone for service loop
|
let security_config_clone = security_config.clone(); // Phase 13.1: clone for service loop
|
||||||
handle_ssh_service_loop(&mut stream, &mut channel_manager, &mut encryption_ctx, &mut port_forward_manager, security_config_clone)?;
|
handle_ssh_service_loop(
|
||||||
|
&mut stream,
|
||||||
|
&mut channel_manager,
|
||||||
|
&mut encryption_ctx,
|
||||||
|
&mut port_forward_manager,
|
||||||
|
security_config_clone,
|
||||||
|
)?;
|
||||||
|
|
||||||
info!("SSH session completed successfully");
|
info!("SSH session completed successfully");
|
||||||
|
|
||||||
@@ -167,7 +201,9 @@ fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshS
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// 完整算法协商(返回KEXINIT payloads)
|
/// 完整算法协商(返回KEXINIT payloads)
|
||||||
fn perform_kex_negotiation_complete(stream: &mut TcpStream) -> Result<(KexResult, SshPacket, SshPacket)> {
|
fn perform_kex_negotiation_complete(
|
||||||
|
stream: &mut TcpStream,
|
||||||
|
) -> Result<(KexResult, SshPacket, SshPacket)> {
|
||||||
info!("Starting complete KEX negotiation");
|
info!("Starting complete KEX negotiation");
|
||||||
|
|
||||||
// 1. 发送服务器KEXINIT
|
// 1. 发送服务器KEXINIT
|
||||||
@@ -175,13 +211,19 @@ fn perform_kex_negotiation_complete(stream: &mut TcpStream) -> Result<(KexResult
|
|||||||
let server_kexinit = server_proposal.to_kexinit_packet()?;
|
let server_kexinit = server_proposal.to_kexinit_packet()?;
|
||||||
server_kexinit.write(stream)?;
|
server_kexinit.write(stream)?;
|
||||||
|
|
||||||
info!("Sent server KEXINIT (payload size: {} bytes)", server_kexinit.payload.len());
|
info!(
|
||||||
|
"Sent server KEXINIT (payload size: {} bytes)",
|
||||||
|
server_kexinit.payload.len()
|
||||||
|
);
|
||||||
|
|
||||||
// 2. 接收客户端KEXINIT
|
// 2. 接收客户端KEXINIT
|
||||||
let client_kexinit = SshPacket::read(stream)?;
|
let client_kexinit = SshPacket::read(stream)?;
|
||||||
let client_proposal = KexProposal::from_kexinit_packet(&client_kexinit)?;
|
let client_proposal = KexProposal::from_kexinit_packet(&client_kexinit)?;
|
||||||
|
|
||||||
info!("Received client KEXINIT (payload size: {} bytes)", client_kexinit.payload.len());
|
info!(
|
||||||
|
"Received client KEXINIT (payload size: {} bytes)",
|
||||||
|
client_kexinit.payload.len()
|
||||||
|
);
|
||||||
|
|
||||||
// 3. 算法匹配
|
// 3. 算法匹配
|
||||||
let kex_result = KexResult::choose_algorithms(&server_proposal, &client_proposal)?;
|
let kex_result = KexResult::choose_algorithms(&server_proposal, &client_proposal)?;
|
||||||
@@ -255,7 +297,8 @@ fn perform_ssh_auth(
|
|||||||
encryption_ctx: &mut EncryptionContext,
|
encryption_ctx: &mut EncryptionContext,
|
||||||
) -> Result<AuthUser> {
|
) -> Result<AuthUser> {
|
||||||
info!("Starting SSH authentication");
|
info!("Starting SSH authentication");
|
||||||
info!("Encryption context: key_ctos_len={}, key_stoc_len={}, iv_ctos_len={}, iv_stoc_len={}",
|
info!(
|
||||||
|
"Encryption context: key_ctos_len={}, key_stoc_len={}, iv_ctos_len={}, iv_stoc_len={}",
|
||||||
encryption_ctx.encryption_key_ctos.len(),
|
encryption_ctx.encryption_key_ctos.len(),
|
||||||
encryption_ctx.encryption_key_stoc.len(),
|
encryption_ctx.encryption_key_stoc.len(),
|
||||||
encryption_ctx.iv_ctos.len(),
|
encryption_ctx.iv_ctos.len(),
|
||||||
@@ -275,7 +318,10 @@ fn perform_ssh_auth(
|
|||||||
info!("Received packet type: {}", payload[0]);
|
info!("Received packet type: {}", payload[0]);
|
||||||
|
|
||||||
if payload[0] != PacketType::SSH_MSG_SERVICE_REQUEST as u8 {
|
if payload[0] != PacketType::SSH_MSG_SERVICE_REQUEST as u8 {
|
||||||
return Err(anyhow!("Expected SSH_MSG_SERVICE_REQUEST, got type {}", payload[0]));
|
return Err(anyhow!(
|
||||||
|
"Expected SSH_MSG_SERVICE_REQUEST, got type {}",
|
||||||
|
payload[0]
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||||
@@ -291,21 +337,17 @@ fn perform_ssh_auth(
|
|||||||
|
|
||||||
let mut service_accept_payload = Vec::new();
|
let mut service_accept_payload = Vec::new();
|
||||||
service_accept_payload.write_u8(PacketType::SSH_MSG_SERVICE_ACCEPT as u8)?;
|
service_accept_payload.write_u8(PacketType::SSH_MSG_SERVICE_ACCEPT as u8)?;
|
||||||
service_accept_payload.write_u32::<BigEndian>(12)?; // "ssh-userauth" length is 12, not 14!
|
service_accept_payload.write_u32::<BigEndian>(12)?; // "ssh-userauth" length is 12, not 14!
|
||||||
service_accept_payload.write_all("ssh-userauth".as_bytes())?;
|
service_accept_payload.write_all("ssh-userauth".as_bytes())?;
|
||||||
|
|
||||||
let encrypted_accept = EncryptedPacket::new(
|
let encrypted_accept = EncryptedPacket::new(&service_accept_payload, encryption_ctx, true)?;
|
||||||
&service_accept_payload,
|
|
||||||
encryption_ctx,
|
|
||||||
true,
|
|
||||||
)?;
|
|
||||||
encrypted_accept.write(stream)?;
|
encrypted_accept.write(stream)?;
|
||||||
info!("Sent encrypted SSH_MSG_SERVICE_ACCEPT");
|
info!("Sent encrypted SSH_MSG_SERVICE_ACCEPT");
|
||||||
|
|
||||||
let session_id = encryption_ctx.session_id.clone();
|
let session_id = encryption_ctx.session_id.clone();
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let auth_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; // Reading from client, use cipher_ctos
|
let auth_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; // Reading from client, use cipher_ctos
|
||||||
let auth_payload = auth_packet.payload();
|
let auth_payload = auth_packet.payload();
|
||||||
info!("Received encrypted SSH_MSG_USERAUTH_REQUEST");
|
info!("Received encrypted SSH_MSG_USERAUTH_REQUEST");
|
||||||
|
|
||||||
@@ -314,38 +356,36 @@ fn perform_ssh_auth(
|
|||||||
match auth_handler.handle_userauth_request(&auth_request, &session_id)? {
|
match auth_handler.handle_userauth_request(&auth_request, &session_id)? {
|
||||||
AuthResult::Success => {
|
AuthResult::Success => {
|
||||||
let success_payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8];
|
let success_payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8];
|
||||||
let encrypted_success = EncryptedPacket::new(
|
let encrypted_success =
|
||||||
&success_payload,
|
EncryptedPacket::new(&success_payload, encryption_ctx, true)?;
|
||||||
encryption_ctx,
|
|
||||||
true,
|
|
||||||
)?;
|
|
||||||
encrypted_success.write(stream)?;
|
encrypted_success.write(stream)?;
|
||||||
info!("Sent encrypted SSH_MSG_USERAUTH_SUCCESS");
|
info!("Sent encrypted SSH_MSG_USERAUTH_SUCCESS");
|
||||||
|
|
||||||
// Extract username from auth request
|
// Extract username from auth request
|
||||||
let user = extract_username_from_auth_request(&auth_request)
|
let user = extract_username_from_auth_request(&auth_request)
|
||||||
.unwrap_or_else(|_| "unknown".to_string());
|
.unwrap_or_else(|_| "unknown".to_string());
|
||||||
let home_dir = auth_handler.get_home_dir(&user)
|
let home_dir = auth_handler
|
||||||
|
.get_home_dir(&user)
|
||||||
.ok()
|
.ok()
|
||||||
.flatten()
|
.flatten()
|
||||||
.map(PathBuf::from)
|
.map(PathBuf::from)
|
||||||
.unwrap_or_else(|| PathBuf::from("/Users/accusys/markbase"));
|
.unwrap_or_else(|| PathBuf::from("/Users/accusys/markbase"));
|
||||||
info!("Auth success: user={}, home_dir={:?}", user, home_dir);
|
info!("Auth success: user={}, home_dir={:?}", user, home_dir);
|
||||||
return Ok(AuthUser { username: user, home_dir });
|
return Ok(AuthUser {
|
||||||
|
username: user,
|
||||||
|
home_dir,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
AuthResult::Failure(message) => {
|
AuthResult::Failure(message) => {
|
||||||
// message包含可用的认证方法列表(如"password,publickey")
|
// message包含可用的认证方法列表(如"password,publickey")
|
||||||
let mut failure_payload = Vec::new();
|
let mut failure_payload = Vec::new();
|
||||||
failure_payload.write_u8(PacketType::SSH_MSG_USERAUTH_FAILURE as u8)?;
|
failure_payload.write_u8(PacketType::SSH_MSG_USERAUTH_FAILURE as u8)?;
|
||||||
failure_payload.write_u32::<BigEndian>(message.len() as u32)?;
|
failure_payload.write_u32::<BigEndian>(message.len() as u32)?;
|
||||||
failure_payload.write_all(message.as_bytes())?;
|
failure_payload.write_all(message.as_bytes())?;
|
||||||
failure_payload.write_u8(0)?; // partial_success = false
|
failure_payload.write_u8(0)?; // partial_success = false
|
||||||
|
|
||||||
let encrypted_failure = EncryptedPacket::new(
|
let encrypted_failure =
|
||||||
&failure_payload,
|
EncryptedPacket::new(&failure_payload, encryption_ctx, true)?;
|
||||||
encryption_ctx,
|
|
||||||
true,
|
|
||||||
)?;
|
|
||||||
encrypted_failure.write(stream)?;
|
encrypted_failure.write(stream)?;
|
||||||
warn!("Sent encrypted SSH_MSG_USERAUTH_FAILURE: {}", message);
|
warn!("Sent encrypted SSH_MSG_USERAUTH_FAILURE: {}", message);
|
||||||
}
|
}
|
||||||
@@ -368,15 +408,11 @@ AuthResult::Failure(message) => {
|
|||||||
pk_ok_payload.write_u32::<BigEndian>(public_key_blob.len() as u32)?;
|
pk_ok_payload.write_u32::<BigEndian>(public_key_blob.len() as u32)?;
|
||||||
pk_ok_payload.write_all(&public_key_blob)?;
|
pk_ok_payload.write_all(&public_key_blob)?;
|
||||||
|
|
||||||
let encrypted_pk_ok = EncryptedPacket::new(
|
let encrypted_pk_ok = EncryptedPacket::new(&pk_ok_payload, encryption_ctx, true)?;
|
||||||
&pk_ok_payload,
|
|
||||||
encryption_ctx,
|
|
||||||
true,
|
|
||||||
)?;
|
|
||||||
encrypted_pk_ok.write(stream)?;
|
encrypted_pk_ok.write(stream)?;
|
||||||
info!("Sent SSH_MSG_USERAUTH_PK_OK");
|
info!("Sent SSH_MSG_USERAUTH_PK_OK");
|
||||||
|
|
||||||
continue; // Wait for signed request
|
continue; // Wait for signed request
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -389,15 +425,16 @@ fn handle_ssh_service_loop(
|
|||||||
stream: &mut TcpStream,
|
stream: &mut TcpStream,
|
||||||
channel_manager: &mut ChannelManager,
|
channel_manager: &mut ChannelManager,
|
||||||
encryption_ctx: &mut EncryptionContext,
|
encryption_ctx: &mut EncryptionContext,
|
||||||
port_forward_manager: &mut PortForwardManager, // Phase 13
|
port_forward_manager: &mut PortForwardManager, // Phase 13
|
||||||
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1
|
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
info!("Starting SSH service loop (Phase 14.2: unified poll + child status)");
|
info!("Starting SSH service loop (Phase 14.2: unified poll + child status)");
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
// ⭐⭐⭐⭐⭐ Phase 14.2: 统一poll + child状态检测
|
// ⭐⭐⭐⭐⭐ Phase 14.2: 统一poll + child状态检测
|
||||||
// 返回三元组:(stdout_packets, client_has_data, child_exited)
|
// 返回三元组:(stdout_packets, client_has_data, child_exited)
|
||||||
let (stdout_packets, client_has_data, child_exited) = channel_manager.poll_exec_stdout_and_client(stream)?;
|
let (stdout_packets, client_has_data, child_exited) =
|
||||||
|
channel_manager.poll_exec_stdout_and_client(stream)?;
|
||||||
|
|
||||||
// 1. 发送stdout/stderr数据(如果有)
|
// 1. 发送stdout/stderr数据(如果有)
|
||||||
if let Some(packets) = stdout_packets {
|
if let Some(packets) = stdout_packets {
|
||||||
@@ -442,31 +479,36 @@ fn handle_ssh_service_loop(
|
|||||||
if !security.allow_tcp_forwarding {
|
if !security.allow_tcp_forwarding {
|
||||||
warn!("TCP forwarding disabled by security config");
|
warn!("TCP forwarding disabled by security config");
|
||||||
let failure_packet = vec![PacketType::SSH_MSG_REQUEST_FAILURE as u8];
|
let failure_packet = vec![PacketType::SSH_MSG_REQUEST_FAILURE as u8];
|
||||||
let encrypted_failure = EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
|
let encrypted_failure =
|
||||||
|
EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
|
||||||
encrypted_failure.write(stream)?;
|
encrypted_failure.write(stream)?;
|
||||||
info!("Sent SSH_MSG_REQUEST_FAILURE (TCP forwarding disabled)");
|
info!("Sent SSH_MSG_REQUEST_FAILURE (TCP forwarding disabled)");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 13.2: 调用PortForwardManager处理(传递security_config)
|
// Phase 13.2: 调用PortForwardManager处理(传递security_config)
|
||||||
let (success, response) = port_forward_manager.handle_global_request(&packet.payload, &security)?;
|
let (success, response) =
|
||||||
drop(security); // 释放锁
|
port_forward_manager.handle_global_request(&packet.payload, &security)?;
|
||||||
|
drop(security); // 释放锁
|
||||||
|
|
||||||
if success {
|
if success {
|
||||||
if let Some(response_data) = response {
|
if let Some(response_data) = response {
|
||||||
let encrypted_response = EncryptedPacket::new(&response_data, encryption_ctx, true)?;
|
let encrypted_response =
|
||||||
|
EncryptedPacket::new(&response_data, encryption_ctx, true)?;
|
||||||
encrypted_response.write(stream)?;
|
encrypted_response.write(stream)?;
|
||||||
info!("Sent SSH_MSG_REQUEST_SUCCESS (tcpip-forward accepted)");
|
info!("Sent SSH_MSG_REQUEST_SUCCESS (tcpip-forward accepted)");
|
||||||
} else {
|
} else {
|
||||||
// 无响应数据时,发送简单的SUCCESS
|
// 无响应数据时,发送简单的SUCCESS
|
||||||
let success_packet = vec![PacketType::SSH_MSG_REQUEST_SUCCESS as u8];
|
let success_packet = vec![PacketType::SSH_MSG_REQUEST_SUCCESS as u8];
|
||||||
let encrypted_success = EncryptedPacket::new(&success_packet, encryption_ctx, true)?;
|
let encrypted_success =
|
||||||
|
EncryptedPacket::new(&success_packet, encryption_ctx, true)?;
|
||||||
encrypted_success.write(stream)?;
|
encrypted_success.write(stream)?;
|
||||||
info!("Sent SSH_MSG_REQUEST_SUCCESS");
|
info!("Sent SSH_MSG_REQUEST_SUCCESS");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let failure_packet = vec![PacketType::SSH_MSG_REQUEST_FAILURE as u8];
|
let failure_packet = vec![PacketType::SSH_MSG_REQUEST_FAILURE as u8];
|
||||||
let encrypted_failure = EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
|
let encrypted_failure =
|
||||||
|
EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
|
||||||
encrypted_failure.write(stream)?;
|
encrypted_failure.write(stream)?;
|
||||||
info!("Sent SSH_MSG_REQUEST_FAILURE (tcpip-forward rejected)");
|
info!("Sent SSH_MSG_REQUEST_FAILURE (tcpip-forward rejected)");
|
||||||
}
|
}
|
||||||
@@ -478,16 +520,18 @@ fn handle_ssh_service_loop(
|
|||||||
// Phase 13.3: 获取security_config并传递给handle_channel_open
|
// Phase 13.3: 获取security_config并传递给handle_channel_open
|
||||||
let security = security_config.lock().unwrap();
|
let security = security_config.lock().unwrap();
|
||||||
let response = channel_manager.handle_channel_open(&packet, Some(&security))?;
|
let response = channel_manager.handle_channel_open(&packet, Some(&security))?;
|
||||||
drop(security); // 释放锁
|
drop(security); // 释放锁
|
||||||
|
|
||||||
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
let encrypted_response =
|
||||||
|
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||||
encrypted_response.write(stream)?;
|
encrypted_response.write(stream)?;
|
||||||
info!("Sent SSH_MSG_CHANNEL_OPEN_CONFIRMATION");
|
info!("Sent SSH_MSG_CHANNEL_OPEN_CONFIRMATION");
|
||||||
}
|
}
|
||||||
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_REQUEST as u8 => {
|
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_REQUEST as u8 => {
|
||||||
info!("Received SSH_MSG_CHANNEL_REQUEST");
|
info!("Received SSH_MSG_CHANNEL_REQUEST");
|
||||||
if let Some(response) = channel_manager.handle_channel_request(&packet)? {
|
if let Some(response) = channel_manager.handle_channel_request(&packet)? {
|
||||||
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
let encrypted_response =
|
||||||
|
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||||
encrypted_response.write(stream)?;
|
encrypted_response.write(stream)?;
|
||||||
|
|
||||||
// ⭐⭐⭐⭐⭐ Phase 14.5修复:区分普通命令和交互式进程
|
// ⭐⭐⭐⭐⭐ Phase 14.5修复:区分普通命令和交互式进程
|
||||||
@@ -503,20 +547,34 @@ fn handle_ssh_service_loop(
|
|||||||
if let Some(channel_id) = channel_manager.get_channel_with_output() {
|
if let Some(channel_id) = channel_manager.get_channel_with_output() {
|
||||||
if let Some(output) = channel_manager.get_channel_output(channel_id) {
|
if let Some(output) = channel_manager.get_channel_output(channel_id) {
|
||||||
// 发送命令输出(SSH_MSG_CHANNEL_DATA)
|
// 发送命令输出(SSH_MSG_CHANNEL_DATA)
|
||||||
let data_packet = channel_manager.build_channel_data(channel_id, &output)?;
|
let data_packet =
|
||||||
let encrypted_data = EncryptedPacket::new(&data_packet.payload, encryption_ctx, true)?;
|
channel_manager.build_channel_data(channel_id, &output)?;
|
||||||
|
let encrypted_data = EncryptedPacket::new(
|
||||||
|
&data_packet.payload,
|
||||||
|
encryption_ctx,
|
||||||
|
true,
|
||||||
|
)?;
|
||||||
encrypted_data.write(stream)?;
|
encrypted_data.write(stream)?;
|
||||||
info!("Sent command output ({} bytes)", output.len());
|
info!("Sent command output ({} bytes)", output.len());
|
||||||
|
|
||||||
// 发送SSH_MSG_CHANNEL_EOF
|
// 发送SSH_MSG_CHANNEL_EOF
|
||||||
let eof_packet = channel_manager.build_channel_eof(channel_id)?;
|
let eof_packet = channel_manager.build_channel_eof(channel_id)?;
|
||||||
let encrypted_eof = EncryptedPacket::new(&eof_packet.payload, encryption_ctx, true)?;
|
let encrypted_eof = EncryptedPacket::new(
|
||||||
|
&eof_packet.payload,
|
||||||
|
encryption_ctx,
|
||||||
|
true,
|
||||||
|
)?;
|
||||||
encrypted_eof.write(stream)?;
|
encrypted_eof.write(stream)?;
|
||||||
info!("Sent SSH_MSG_CHANNEL_EOF");
|
info!("Sent SSH_MSG_CHANNEL_EOF");
|
||||||
|
|
||||||
// 发送SSH_MSG_CHANNEL_CLOSE
|
// 发送SSH_MSG_CHANNEL_CLOSE
|
||||||
let close_packet = channel_manager.build_channel_close(channel_id)?;
|
let close_packet =
|
||||||
let encrypted_close = EncryptedPacket::new(&close_packet.payload, encryption_ctx, true)?;
|
channel_manager.build_channel_close(channel_id)?;
|
||||||
|
let encrypted_close = EncryptedPacket::new(
|
||||||
|
&close_packet.payload,
|
||||||
|
encryption_ctx,
|
||||||
|
true,
|
||||||
|
)?;
|
||||||
encrypted_close.write(stream)?;
|
encrypted_close.write(stream)?;
|
||||||
info!("Sent SSH_MSG_CHANNEL_CLOSE");
|
info!("Sent SSH_MSG_CHANNEL_CLOSE");
|
||||||
|
|
||||||
@@ -531,22 +589,28 @@ fn handle_ssh_service_loop(
|
|||||||
info!("Received SSH_MSG_CHANNEL_DATA");
|
info!("Received SSH_MSG_CHANNEL_DATA");
|
||||||
if let Some(response) = channel_manager.handle_channel_data(&packet)? {
|
if let Some(response) = channel_manager.handle_channel_data(&packet)? {
|
||||||
// Phase 7: SFTP响应通过CHANNEL_DATA返回
|
// Phase 7: SFTP响应通过CHANNEL_DATA返回
|
||||||
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
let encrypted_response =
|
||||||
|
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||||
encrypted_response.write(stream)?;
|
encrypted_response.write(stream)?;
|
||||||
info!("Sent SSH_MSG_CHANNEL_DATA (SFTP response)");
|
info!("Sent SSH_MSG_CHANNEL_DATA (SFTP response)");
|
||||||
}
|
}
|
||||||
|
|
||||||
// ⭐⭐⭐⭐⭐ Phase 15.1: Drain pending packets (e.g. WINDOW_ADJUST + delayed SFTP response)
|
// ⭐⭐⭐⭐⭐ Phase 15.1: Drain pending packets (e.g. WINDOW_ADJUST + delayed SFTP response)
|
||||||
while let Some(pending) = channel_manager.pending_packets.pop_front() {
|
while let Some(pending) = channel_manager.pending_packets.pop_front() {
|
||||||
let encrypted_pending = EncryptedPacket::new(&pending.payload, encryption_ctx, true)?;
|
let encrypted_pending =
|
||||||
|
EncryptedPacket::new(&pending.payload, encryption_ctx, true)?;
|
||||||
encrypted_pending.write(stream)?;
|
encrypted_pending.write(stream)?;
|
||||||
info!("Sent pending packet (type {})", pending.payload.first().unwrap_or(&0));
|
info!(
|
||||||
|
"Sent pending packet (type {})",
|
||||||
|
pending.payload.first().unwrap_or(&0)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_CLOSE as u8 => {
|
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_CLOSE as u8 => {
|
||||||
info!("Received SSH_MSG_CHANNEL_CLOSE");
|
info!("Received SSH_MSG_CHANNEL_CLOSE");
|
||||||
if let Some(response) = channel_manager.handle_channel_close(&packet)? {
|
if let Some(response) = channel_manager.handle_channel_close(&packet)? {
|
||||||
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
let encrypted_response =
|
||||||
|
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||||
encrypted_response.write(stream)?;
|
encrypted_response.write(stream)?;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
@@ -565,8 +629,10 @@ fn handle_ssh_service_loop(
|
|||||||
let payload = &packet.payload;
|
let payload = &packet.payload;
|
||||||
if payload.len() >= 9 {
|
if payload.len() >= 9 {
|
||||||
// Format: uint32 recipient_channel || uint32 bytes_to_add
|
// Format: uint32 recipient_channel || uint32 bytes_to_add
|
||||||
let recipient_channel = u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]);
|
let recipient_channel =
|
||||||
let bytes_to_add = u32::from_be_bytes([payload[5], payload[6], payload[7], payload[8]]);
|
u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]);
|
||||||
|
let bytes_to_add =
|
||||||
|
u32::from_be_bytes([payload[5], payload[6], payload[7], payload[8]]);
|
||||||
channel_manager.adjust_remote_window(recipient_channel, bytes_to_add);
|
channel_manager.adjust_remote_window(recipient_channel, bytes_to_add);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -580,7 +646,9 @@ fn handle_ssh_service_loop(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// 从SSH_MSG_USERAUTH_REQUEST payload中提取用户名
|
/// 从SSH_MSG_USERAUTH_REQUEST payload中提取用户名
|
||||||
fn extract_username_from_auth_request(packet: &crate::ssh_server::packet::SshPacket) -> Result<String> {
|
fn extract_username_from_auth_request(
|
||||||
|
packet: &crate::ssh_server::packet::SshPacket,
|
||||||
|
) -> Result<String> {
|
||||||
let payload = &packet.payload;
|
let payload = &packet.payload;
|
||||||
if payload.len() < 5 {
|
if payload.len() < 5 {
|
||||||
return Err(anyhow!("Auth request too short"));
|
return Err(anyhow!("Auth request too short"));
|
||||||
@@ -598,7 +666,7 @@ pub fn run_ssh_server(port: Option<u16>, pg_conn: Option<&str>) -> Result<()> {
|
|||||||
let config = SshServerConfig {
|
let config = SshServerConfig {
|
||||||
port: port.unwrap_or(2024),
|
port: port.unwrap_or(2024),
|
||||||
bind_address: "127.0.0.1".to_string(),
|
bind_address: "127.0.0.1".to_string(),
|
||||||
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置
|
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置
|
||||||
pg_conn: pg_conn.map(|s| s.to_string()),
|
pg_conn: pg_conn.map(|s| s.to_string()),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
|||||||
// SSH企业级安全配置(Phase 13.1)
|
// SSH企业级安全配置(Phase 13.1)
|
||||||
// 参考OpenSSH sshd_config安全配置
|
// 参考OpenSSH sshd_config安全配置
|
||||||
|
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{anyhow, Result};
|
||||||
use log::{info, warn};
|
use log::{info, warn};
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
@@ -42,23 +42,23 @@ impl SshSecurityConfig {
|
|||||||
/// 参考:OpenSSH企业级生产环境配置
|
/// 参考:OpenSSH企业级生产环境配置
|
||||||
pub fn enterprise_default() -> Self {
|
pub fn enterprise_default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
gateway_ports: false, // 安全:只绑定127.0.0.1
|
gateway_ports: false, // 安全:只绑定127.0.0.1
|
||||||
permit_open: vec!["localhost:*".to_string()], // 限制转发目标(白名单)
|
permit_open: vec!["localhost:*".to_string()], // 限制转发目标(白名单)
|
||||||
allow_tcp_forwarding: true, // 允许TCP转发
|
allow_tcp_forwarding: true, // 允许TCP转发
|
||||||
max_sessions: 10, // 最多10个会话
|
max_sessions: 10, // 最多10个会话
|
||||||
connect_timeout: 30, // 30秒超时
|
connect_timeout: 30, // 30秒超时
|
||||||
active_sessions: 0, // 运行时状态
|
active_sessions: 0, // 运行时状态
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 开发环境默认配置(宽松)
|
/// 开发环境默认配置(宽松)
|
||||||
pub fn development_default() -> Self {
|
pub fn development_default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
gateway_ports: true, // 开发:允许0.0.0.0
|
gateway_ports: true, // 开发:允许0.0.0.0
|
||||||
permit_open: vec![], // 开发:允许所有目标
|
permit_open: vec![], // 开发:允许所有目标
|
||||||
allow_tcp_forwarding: true,
|
allow_tcp_forwarding: true,
|
||||||
max_sessions: 20, // 开发:更多会话
|
max_sessions: 20, // 开发:更多会话
|
||||||
connect_timeout: 60, // 开发:更长超时
|
connect_timeout: 60, // 开发:更长超时
|
||||||
active_sessions: 0,
|
active_sessions: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -73,26 +73,36 @@ impl SshSecurityConfig {
|
|||||||
let config_str = fs::read_to_string(path)?;
|
let config_str = fs::read_to_string(path)?;
|
||||||
let config: serde_json::Value = serde_json::from_str(&config_str)?;
|
let config: serde_json::Value = serde_json::from_str(&config_str)?;
|
||||||
|
|
||||||
let security = config.get("ssh_server")
|
let security = config
|
||||||
|
.get("ssh_server")
|
||||||
.and_then(|s| s.get("security"))
|
.and_then(|s| s.get("security"))
|
||||||
.ok_or_else(|| anyhow!("Invalid config structure"))?;
|
.ok_or_else(|| anyhow!("Invalid config structure"))?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
gateway_ports: security.get("gateway_ports")
|
gateway_ports: security
|
||||||
|
.get("gateway_ports")
|
||||||
.and_then(|v| v.as_bool())
|
.and_then(|v| v.as_bool())
|
||||||
.unwrap_or(false),
|
.unwrap_or(false),
|
||||||
permit_open: security.get("permit_open")
|
permit_open: security
|
||||||
|
.get("permit_open")
|
||||||
.and_then(|v| v.as_array())
|
.and_then(|v| v.as_array())
|
||||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
|
.map(|arr| {
|
||||||
|
arr.iter()
|
||||||
|
.filter_map(|v| v.as_str().map(String::from))
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
.unwrap_or_else(|| vec!["localhost:*".to_string()]),
|
.unwrap_or_else(|| vec!["localhost:*".to_string()]),
|
||||||
allow_tcp_forwarding: security.get("allow_tcp_forwarding")
|
allow_tcp_forwarding: security
|
||||||
|
.get("allow_tcp_forwarding")
|
||||||
.and_then(|v| v.as_bool())
|
.and_then(|v| v.as_bool())
|
||||||
.unwrap_or(true),
|
.unwrap_or(true),
|
||||||
max_sessions: security.get("max_sessions")
|
max_sessions: security
|
||||||
|
.get("max_sessions")
|
||||||
.and_then(|v| v.as_u64())
|
.and_then(|v| v.as_u64())
|
||||||
.map(|v| v as u32)
|
.map(|v| v as u32)
|
||||||
.unwrap_or(10),
|
.unwrap_or(10),
|
||||||
connect_timeout: security.get("connect_timeout")
|
connect_timeout: security
|
||||||
|
.get("connect_timeout")
|
||||||
.and_then(|v| v.as_u64())
|
.and_then(|v| v.as_u64())
|
||||||
.unwrap_or(30),
|
.unwrap_or(30),
|
||||||
active_sessions: 0,
|
active_sessions: 0,
|
||||||
@@ -101,12 +111,11 @@ impl SshSecurityConfig {
|
|||||||
|
|
||||||
/// 验证tcpip-forward请求(安全检查)
|
/// 验证tcpip-forward请求(安全检查)
|
||||||
/// 参考OpenSSH auth2.c: ssh_forwarding_check()
|
/// 参考OpenSSH auth2.c: ssh_forwarding_check()
|
||||||
pub fn validate_tcpip_forward_request(
|
pub fn validate_tcpip_forward_request(&self, bind_address: &str, bind_port: u32) -> Result<()> {
|
||||||
&self,
|
info!(
|
||||||
bind_address: &str,
|
"Validating tcpip-forward request: bind_address={}, bind_port={}",
|
||||||
bind_port: u32,
|
bind_address, bind_port
|
||||||
) -> Result<()> {
|
);
|
||||||
info!("Validating tcpip-forward request: bind_address={}, bind_port={}", bind_address, bind_port);
|
|
||||||
|
|
||||||
// 1. AllowTcpForwarding检查
|
// 1. AllowTcpForwarding检查
|
||||||
if !self.allow_tcp_forwarding {
|
if !self.allow_tcp_forwarding {
|
||||||
@@ -117,8 +126,11 @@ impl SshSecurityConfig {
|
|||||||
// 2. GatewayPorts检查
|
// 2. GatewayPorts检查
|
||||||
if !self.gateway_ports {
|
if !self.gateway_ports {
|
||||||
// 只允许绑定到127.0.0.1或localhost
|
// 只允许绑定到127.0.0.1或localhost
|
||||||
if bind_address != "127.0.0.1" && bind_address != "localhost" && bind_address != "" {
|
if bind_address != "127.0.0.1" && bind_address != "localhost" && !bind_address.is_empty() {
|
||||||
warn!("GatewayPorts disabled, bind_address {} not allowed", bind_address);
|
warn!(
|
||||||
|
"GatewayPorts disabled, bind_address {} not allowed",
|
||||||
|
bind_address
|
||||||
|
);
|
||||||
return Err(anyhow!("GatewayPorts=no, only 127.0.0.1 allowed"));
|
return Err(anyhow!("GatewayPorts=no, only 127.0.0.1 allowed"));
|
||||||
}
|
}
|
||||||
info!("GatewayPorts check passed: bind_address={}", bind_address);
|
info!("GatewayPorts check passed: bind_address={}", bind_address);
|
||||||
@@ -126,7 +138,10 @@ impl SshSecurityConfig {
|
|||||||
|
|
||||||
// 3. MaxSessions检查
|
// 3. MaxSessions检查
|
||||||
if self.active_sessions >= self.max_sessions {
|
if self.active_sessions >= self.max_sessions {
|
||||||
warn!("Max sessions limit reached: {} >= {}", self.active_sessions, self.max_sessions);
|
warn!(
|
||||||
|
"Max sessions limit reached: {} >= {}",
|
||||||
|
self.active_sessions, self.max_sessions
|
||||||
|
);
|
||||||
return Err(anyhow!("Max sessions limit reached: {}", self.max_sessions));
|
return Err(anyhow!("Max sessions limit reached: {}", self.max_sessions));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,7 +168,10 @@ impl SshSecurityConfig {
|
|||||||
host_to_connect: &str,
|
host_to_connect: &str,
|
||||||
port_to_connect: u32,
|
port_to_connect: u32,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
info!("Validating direct-tcpip channel: host={}, port={}", host_to_connect, port_to_connect);
|
info!(
|
||||||
|
"Validating direct-tcpip channel: host={}, port={}",
|
||||||
|
host_to_connect, port_to_connect
|
||||||
|
);
|
||||||
|
|
||||||
// 1. AllowTcpForwarding检查
|
// 1. AllowTcpForwarding检查
|
||||||
if !self.allow_tcp_forwarding {
|
if !self.allow_tcp_forwarding {
|
||||||
@@ -175,9 +193,15 @@ impl SshSecurityConfig {
|
|||||||
});
|
});
|
||||||
|
|
||||||
if !allowed {
|
if !allowed {
|
||||||
warn!("Target {}:{} not in PermitOpen whitelist", host_to_connect, port_to_connect);
|
warn!(
|
||||||
return Err(anyhow!("Target {}:{} not in PermitOpen whitelist",
|
"Target {}:{} not in PermitOpen whitelist",
|
||||||
host_to_connect, port_to_connect));
|
host_to_connect, port_to_connect
|
||||||
|
);
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Target {}:{} not in PermitOpen whitelist",
|
||||||
|
host_to_connect,
|
||||||
|
port_to_connect
|
||||||
|
));
|
||||||
}
|
}
|
||||||
info!("PermitOpen check passed: target={}", target);
|
info!("PermitOpen check passed: target={}", target);
|
||||||
} else {
|
} else {
|
||||||
@@ -186,7 +210,7 @@ impl SshSecurityConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 3. 端口范围检查
|
// 3. 端口范围检查
|
||||||
if port_to_connect < 1 || port_to_connect > 65535 {
|
if !(1..=65535).contains(&port_to_connect) {
|
||||||
warn!("Invalid port number: {}", port_to_connect);
|
warn!("Invalid port number: {}", port_to_connect);
|
||||||
return Err(anyhow!("Invalid port number: {}", port_to_connect));
|
return Err(anyhow!("Invalid port number: {}", port_to_connect));
|
||||||
}
|
}
|
||||||
@@ -234,14 +258,22 @@ mod tests {
|
|||||||
let config = SshSecurityConfig::enterprise_default();
|
let config = SshSecurityConfig::enterprise_default();
|
||||||
|
|
||||||
// 正常请求应该通过
|
// 正常请求应该通过
|
||||||
assert!(config.validate_tcpip_forward_request("127.0.0.1", 8080).is_ok());
|
assert!(config
|
||||||
assert!(config.validate_tcpip_forward_request("localhost", 8080).is_ok());
|
.validate_tcpip_forward_request("127.0.0.1", 8080)
|
||||||
|
.is_ok());
|
||||||
|
assert!(config
|
||||||
|
.validate_tcpip_forward_request("localhost", 8080)
|
||||||
|
.is_ok());
|
||||||
|
|
||||||
// GatewayPorts=false时,0.0.0.0应该被拒绝
|
// GatewayPorts=false时,0.0.0.0应该被拒绝
|
||||||
assert!(config.validate_tcpip_forward_request("0.0.0.0", 8080).is_err());
|
assert!(config
|
||||||
|
.validate_tcpip_forward_request("0.0.0.0", 8080)
|
||||||
|
.is_err());
|
||||||
|
|
||||||
// 特权端口应该被拒绝
|
// 特权端口应该被拒绝
|
||||||
assert!(config.validate_tcpip_forward_request("127.0.0.1", 80).is_err());
|
assert!(config
|
||||||
|
.validate_tcpip_forward_request("127.0.0.1", 80)
|
||||||
|
.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -249,12 +281,20 @@ mod tests {
|
|||||||
let config = SshSecurityConfig::enterprise_default();
|
let config = SshSecurityConfig::enterprise_default();
|
||||||
|
|
||||||
// localhost:*应该通过(通配符匹配)
|
// localhost:*应该通过(通配符匹配)
|
||||||
assert!(config.validate_direct_tcpip_channel("localhost", 3000).is_ok());
|
assert!(config
|
||||||
assert!(config.validate_direct_tcpip_channel("localhost", 4000).is_ok());
|
.validate_direct_tcpip_channel("localhost", 3000)
|
||||||
|
.is_ok());
|
||||||
|
assert!(config
|
||||||
|
.validate_direct_tcpip_channel("localhost", 4000)
|
||||||
|
.is_ok());
|
||||||
|
|
||||||
// 其他host应该被拒绝
|
// 其他host应该被拒绝
|
||||||
assert!(config.validate_direct_tcpip_channel("192.168.1.100", 3000).is_err());
|
assert!(config
|
||||||
assert!(config.validate_direct_tcpip_channel("example.com", 80).is_err());
|
.validate_direct_tcpip_channel("192.168.1.100", 3000)
|
||||||
|
.is_err());
|
||||||
|
assert!(config
|
||||||
|
.validate_direct_tcpip_channel("example.com", 80)
|
||||||
|
.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -266,7 +306,11 @@ mod tests {
|
|||||||
assert_eq!(config.max_sessions, 20);
|
assert_eq!(config.max_sessions, 20);
|
||||||
|
|
||||||
// 开发配置应该允许所有请求
|
// 开发配置应该允许所有请求
|
||||||
assert!(config.validate_tcpip_forward_request("0.0.0.0", 8080).is_ok());
|
assert!(config
|
||||||
assert!(config.validate_direct_tcpip_channel("example.com", 80).is_ok());
|
.validate_tcpip_forward_request("0.0.0.0", 8080)
|
||||||
|
.is_ok());
|
||||||
|
assert!(config
|
||||||
|
.validate_direct_tcpip_channel("example.com", 80)
|
||||||
|
.is_ok());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
// SSH Buffer 零拷贝实现(参考 OpenSSH sshbuf.c)
|
// SSH Buffer 零拷贝实现(参考 OpenSSH sshbuf.c)
|
||||||
// 提供高效的 buffer 管理,消除临时 buffer
|
// 提供高效的 buffer 管理,消除临时 buffer
|
||||||
|
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{anyhow, Result};
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Write};
|
||||||
|
|
||||||
/// SSH Buffer(参考 OpenSSH struct sshbuf)
|
/// SSH Buffer(参考 OpenSSH struct sshbuf)
|
||||||
@@ -16,10 +16,10 @@ use std::io::{Read, Write};
|
|||||||
/// };
|
/// };
|
||||||
/// ```
|
/// ```
|
||||||
pub struct SshBuf {
|
pub struct SshBuf {
|
||||||
data: Vec<u8>, // Data buffer (对应 OpenSSH buf->d)
|
data: Vec<u8>, // Data buffer (对应 OpenSSH buf->d)
|
||||||
off: usize, // Offset (对应 OpenSSH buf->off)
|
off: usize, // Offset (对应 OpenSSH buf->off)
|
||||||
size: usize, // Size (对应 OpenSSH buf->size)
|
size: usize, // Size (对应 OpenSSH buf->size)
|
||||||
max_size: usize, // Maximum size (对应 OpenSSH buf->max_size)
|
max_size: usize, // Maximum size (对应 OpenSSH buf->max_size)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SshBuf {
|
impl SshBuf {
|
||||||
@@ -144,7 +144,11 @@ impl SshBuf {
|
|||||||
/// Rust 实现:移动偏移量(零拷贝,不实际删除数据)
|
/// Rust 实现:移动偏移量(零拷贝,不实际删除数据)
|
||||||
pub fn consume(&mut self, len: usize) -> Result<()> {
|
pub fn consume(&mut self, len: usize) -> Result<()> {
|
||||||
if len > self.len() {
|
if len > self.len() {
|
||||||
return Err(anyhow!("message incomplete (len={}, consume={})", self.len(), len));
|
return Err(anyhow!(
|
||||||
|
"message incomplete (len={}, consume={})",
|
||||||
|
self.len(),
|
||||||
|
len
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
self.off += len;
|
self.off += len;
|
||||||
@@ -259,7 +263,11 @@ impl SshBuf {
|
|||||||
pub fn debug_info(&self) -> String {
|
pub fn debug_info(&self) -> String {
|
||||||
format!(
|
format!(
|
||||||
"SshBuf: off={}, size={}, len={}, alloc={}, max_size={}",
|
"SshBuf: off={}, size={}, len={}, alloc={}, max_size={}",
|
||||||
self.off, self.size, self.len(), self.data.len(), self.max_size
|
self.off,
|
||||||
|
self.size,
|
||||||
|
self.len(),
|
||||||
|
self.data.len(),
|
||||||
|
self.max_size
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
// 参考OpenSSH sshd.c: ssh_exchange_identification()
|
// 参考OpenSSH sshd.c: ssh_exchange_identification()
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use log::{debug, info};
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Write};
|
||||||
use log::{info, debug};
|
|
||||||
|
|
||||||
/// SSH版本字符串
|
/// SSH版本字符串
|
||||||
pub const SSH_VERSION: &str = "SSH-2.0-MarkBaseSSH_1.0";
|
pub const SSH_VERSION: &str = "SSH-2.0-MarkBaseSSH_1.0";
|
||||||
@@ -22,7 +22,10 @@ impl VersionExchange {
|
|||||||
// 2. 接收客户端版本
|
// 2. 接收客户端版本
|
||||||
let client_version = Self::receive_version(stream)?;
|
let client_version = Self::receive_version(stream)?;
|
||||||
|
|
||||||
info!("Version exchange completed: server={}, client={}", SSH_VERSION, client_version);
|
info!(
|
||||||
|
"Version exchange completed: server={}, client={}",
|
||||||
|
SSH_VERSION, client_version
|
||||||
|
);
|
||||||
Ok(client_version)
|
Ok(client_version)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,14 +49,14 @@ impl VersionExchange {
|
|||||||
stream.read_exact(&mut byte)?;
|
stream.read_exact(&mut byte)?;
|
||||||
|
|
||||||
// OpenSSH兼容性处理:跳过前导空行和调试信息
|
// OpenSSH兼容性处理:跳过前导空行和调试信息
|
||||||
if buffer.is_empty() && byte[0] == '\n' as u8 {
|
if buffer.is_empty() && byte[0] == b'\n' {
|
||||||
continue; // 跳过空行
|
continue; // 跳过空行
|
||||||
}
|
}
|
||||||
|
|
||||||
// 调试信息行(以'#'开头),跳过
|
// 调试信息行(以'#'开头),跳过
|
||||||
if buffer.is_empty() && byte[0] == '#' as u8 {
|
if buffer.is_empty() && byte[0] == b'#' {
|
||||||
// 读取整行调试信息
|
// 读取整行调试信息
|
||||||
while byte[0] != '\n' as u8 {
|
while byte[0] != b'\n' {
|
||||||
stream.read_exact(&mut byte)?;
|
stream.read_exact(&mut byte)?;
|
||||||
}
|
}
|
||||||
buffer.clear();
|
buffer.clear();
|
||||||
@@ -63,7 +66,7 @@ impl VersionExchange {
|
|||||||
buffer.push(byte[0]);
|
buffer.push(byte[0]);
|
||||||
|
|
||||||
// 遇到'\n'结束
|
// 遇到'\n'结束
|
||||||
if byte[0] == '\n' as u8 {
|
if byte[0] == b'\n' {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
// SSH Window Size管理(Phase 13.6)
|
// SSH Window Size管理(Phase 13.6)
|
||||||
// 参考RFC 4254 Section 5.2: Window Size Adjustment
|
// 参考RFC 4254 Section 5.2: Window Size Adjustment
|
||||||
|
|
||||||
use anyhow::{Result, anyhow};
|
|
||||||
use log::{info, warn, debug};
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
use byteorder::{BigEndian, WriteBytesExt};
|
|
||||||
use crate::ssh_server::packet::PacketType;
|
use crate::ssh_server::packet::PacketType;
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use byteorder::{BigEndian, WriteBytesExt};
|
||||||
|
use log::{info, warn};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
/// Window Size管理器(Phase 13.6)
|
/// Window Size管理器(Phase 13.6)
|
||||||
pub struct WindowManager {
|
pub struct WindowManager {
|
||||||
initial_window_size: u32, // RFC 4254: 2MB默认
|
initial_window_size: u32, // RFC 4254: 2MB默认
|
||||||
current_window_size: Arc<Mutex<u32>>,
|
current_window_size: Arc<Mutex<u32>>,
|
||||||
max_packet_size: u32, // RFC 4254: 32KB默认
|
max_packet_size: u32, // RFC 4254: 32KB默认
|
||||||
consumed_bytes: Arc<Mutex<u32>>, // 已消耗bytes统计
|
consumed_bytes: Arc<Mutex<u32>>, // 已消耗bytes统计
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WindowManager {
|
impl WindowManager {
|
||||||
@@ -28,7 +28,7 @@ impl WindowManager {
|
|||||||
|
|
||||||
/// RFC 4254默认window size(2MB)
|
/// RFC 4254默认window size(2MB)
|
||||||
pub fn rfc_default() -> Self {
|
pub fn rfc_default() -> Self {
|
||||||
Self::new(2097152, 32768) // 2MB window, 32KB packet
|
Self::new(2097152, 32768) // 2MB window, 32KB packet
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 检查window size是否足够(Phase 13.6)
|
/// 检查window size是否足够(Phase 13.6)
|
||||||
@@ -37,7 +37,10 @@ impl WindowManager {
|
|||||||
let available = *window >= data_size;
|
let available = *window >= data_size;
|
||||||
|
|
||||||
if !available {
|
if !available {
|
||||||
warn!("Window size insufficient: need {}, have {}", data_size, *window);
|
warn!(
|
||||||
|
"Window size insufficient: need {}, have {}",
|
||||||
|
data_size, *window
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
available
|
available
|
||||||
@@ -48,7 +51,11 @@ impl WindowManager {
|
|||||||
let mut window = self.current_window_size.lock().unwrap();
|
let mut window = self.current_window_size.lock().unwrap();
|
||||||
|
|
||||||
if *window < data_size {
|
if *window < data_size {
|
||||||
return Err(anyhow!("Window size insufficient: need {}, have {}", data_size, *window));
|
return Err(anyhow!(
|
||||||
|
"Window size insufficient: need {}, have {}",
|
||||||
|
data_size,
|
||||||
|
*window
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
*window -= data_size;
|
*window -= data_size;
|
||||||
@@ -57,8 +64,10 @@ impl WindowManager {
|
|||||||
let mut consumed = self.consumed_bytes.lock().unwrap();
|
let mut consumed = self.consumed_bytes.lock().unwrap();
|
||||||
*consumed += data_size;
|
*consumed += data_size;
|
||||||
|
|
||||||
info!("Window size consumed: {} bytes, remaining {}, total consumed {}",
|
info!(
|
||||||
data_size, *window, *consumed);
|
"Window size consumed: {} bytes, remaining {}, total consumed {}",
|
||||||
|
data_size, *window, *consumed
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -68,7 +77,10 @@ impl WindowManager {
|
|||||||
let mut window = self.current_window_size.lock().unwrap();
|
let mut window = self.current_window_size.lock().unwrap();
|
||||||
*window += bytes_to_add;
|
*window += bytes_to_add;
|
||||||
|
|
||||||
info!("Window size adjusted: added {} bytes, total {}", bytes_to_add, *window);
|
info!(
|
||||||
|
"Window size adjusted: added {} bytes, total {}",
|
||||||
|
bytes_to_add, *window
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 构建SSH_MSG_CHANNEL_WINDOW_ADJUST packet(Phase 13.6)
|
/// 构建SSH_MSG_CHANNEL_WINDOW_ADJUST packet(Phase 13.6)
|
||||||
@@ -84,8 +96,10 @@ impl WindowManager {
|
|||||||
// Bytes to add
|
// Bytes to add
|
||||||
packet.write_u32::<BigEndian>(bytes_to_add)?;
|
packet.write_u32::<BigEndian>(bytes_to_add)?;
|
||||||
|
|
||||||
info!("Built SSH_MSG_CHANNEL_WINDOW_ADJUST for channel {}: +{} bytes",
|
info!(
|
||||||
channel_id, bytes_to_add);
|
"Built SSH_MSG_CHANNEL_WINDOW_ADJUST for channel {}: +{} bytes",
|
||||||
|
channel_id, bytes_to_add
|
||||||
|
);
|
||||||
|
|
||||||
Ok(packet)
|
Ok(packet)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,21 @@
|
|||||||
use super::util;
|
|
||||||
use super::open_flags::OpenFlags;
|
use super::open_flags::OpenFlags;
|
||||||
|
use super::util;
|
||||||
use super::{VfsBackend, VfsDirEntry, VfsError, VfsFile, VfsStat};
|
use super::{VfsBackend, VfsDirEntry, VfsError, VfsFile, VfsStat};
|
||||||
use std::fs::{self, File, OpenOptions};
|
use std::fs::{self, File, OpenOptions};
|
||||||
use std::io::{Read, Seek, SeekFrom, Write};
|
use std::io::{Read, Seek, SeekFrom, Write};
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use std::os::unix::fs::{MetadataExt, PermissionsExt};
|
use std::os::unix::fs::{MetadataExt, PermissionsExt};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
/// 本地文件系统实现(直接包装 std::fs,不做路径解析)
|
/// 本地文件系统实现(直接包装 std::fs,不做路径解析)
|
||||||
/// 路径解析由上层(SftpHandler)负责
|
/// 路径解析由上层(SftpHandler)负责
|
||||||
pub struct LocalFs;
|
pub struct LocalFs;
|
||||||
|
|
||||||
|
impl Default for LocalFs {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl LocalFs {
|
impl LocalFs {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self
|
Self
|
||||||
@@ -26,7 +32,9 @@ impl VfsFile for LocalFile {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn write(&mut self, buf: &[u8]) -> Result<usize, VfsError> {
|
fn write(&mut self, buf: &[u8]) -> Result<usize, VfsError> {
|
||||||
self.file.write(buf).map_err(|e| VfsError::Io(e.to_string()))
|
self.file
|
||||||
|
.write(buf)
|
||||||
|
.map_err(|e| VfsError::Io(e.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn seek(&mut self, pos: SeekFrom) -> Result<u64, VfsError> {
|
fn seek(&mut self, pos: SeekFrom) -> Result<u64, VfsError> {
|
||||||
@@ -38,12 +46,17 @@ impl VfsFile for LocalFile {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn stat(&mut self) -> Result<VfsStat, VfsError> {
|
fn stat(&mut self) -> Result<VfsStat, VfsError> {
|
||||||
let meta = self.file.metadata().map_err(|e| VfsError::Io(e.to_string()))?;
|
let meta = self
|
||||||
|
.file
|
||||||
|
.metadata()
|
||||||
|
.map_err(|e| VfsError::Io(e.to_string()))?;
|
||||||
Ok(util::stat_from_metadata(&meta, false))
|
Ok(util::stat_from_metadata(&meta, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_len(&mut self, size: u64) -> Result<(), VfsError> {
|
fn set_len(&mut self, size: u64) -> Result<(), VfsError> {
|
||||||
self.file.set_len(size).map_err(|e| VfsError::Io(e.to_string()))
|
self.file
|
||||||
|
.set_len(size)
|
||||||
|
.map_err(|e| VfsError::Io(e.to_string()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,8 +99,7 @@ impl VfsBackend for LocalFs {
|
|||||||
if flags.create && !flags.exclusive {
|
if flags.create && !flags.exclusive {
|
||||||
if let Ok(meta) = file.metadata() {
|
if let Ok(meta) = file.metadata() {
|
||||||
if flags.mode != 0 && meta.permissions().mode() != flags.mode {
|
if flags.mode != 0 && meta.permissions().mode() != flags.mode {
|
||||||
fs::set_permissions(path, std::fs::Permissions::from_mode(flags.mode))
|
fs::set_permissions(path, std::fs::Permissions::from_mode(flags.mode)).ok();
|
||||||
.ok();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -157,10 +169,12 @@ impl VfsBackend for LocalFs {
|
|||||||
stat.atime.duration_since(std::time::UNIX_EPOCH).ok(),
|
stat.atime.duration_since(std::time::UNIX_EPOCH).ok(),
|
||||||
stat.mtime.duration_since(std::time::UNIX_EPOCH).ok(),
|
stat.mtime.duration_since(std::time::UNIX_EPOCH).ok(),
|
||||||
) {
|
) {
|
||||||
filetime::set_file_times(path,
|
filetime::set_file_times(
|
||||||
|
path,
|
||||||
filetime::FileTime::from_unix_time(atime.as_secs() as i64, 0),
|
filetime::FileTime::from_unix_time(atime.as_secs() as i64, 0),
|
||||||
filetime::FileTime::from_unix_time(mtime.as_secs() as i64, 0),
|
filetime::FileTime::from_unix_time(mtime.as_secs() as i64, 0),
|
||||||
).map_err(|e| util::map_io_error(path, e))?;
|
)
|
||||||
|
.map_err(|e| util::map_io_error(path, e))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -174,8 +188,7 @@ impl VfsBackend for LocalFs {
|
|||||||
fn create_symlink(&self, target: &Path, link: &Path) -> Result<(), VfsError> {
|
fn create_symlink(&self, target: &Path, link: &Path) -> Result<(), VfsError> {
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
{
|
{
|
||||||
std::os::unix::fs::symlink(target, link)
|
std::os::unix::fs::symlink(target, link).map_err(|e| util::map_io_error(link, e))?;
|
||||||
.map_err(|e| util::map_io_error(link, e))?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
@@ -188,7 +201,9 @@ impl VfsBackend for LocalFs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn real_path(&self, path: &Path) -> Result<PathBuf, VfsError> {
|
fn real_path(&self, path: &Path) -> Result<PathBuf, VfsError> {
|
||||||
let canonical = path.canonicalize().map_err(|e| util::map_io_error(path, e))?;
|
let canonical = path
|
||||||
|
.canonicalize()
|
||||||
|
.map_err(|e| util::map_io_error(path, e))?;
|
||||||
Ok(canonical)
|
Ok(canonical)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,7 +219,9 @@ impl VfsBackend for LocalFs {
|
|||||||
|
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
{
|
{
|
||||||
return Err(VfsError::Unsupported("hard_link not supported on non-Unix systems".to_string()));
|
return Err(VfsError::Unsupported(
|
||||||
|
"hard_link not supported on non-Unix systems".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
pub mod open_flags;
|
|
||||||
pub mod local_fs;
|
pub mod local_fs;
|
||||||
|
pub mod open_flags;
|
||||||
pub mod s3_fs;
|
pub mod s3_fs;
|
||||||
pub mod util;
|
pub mod util;
|
||||||
|
|
||||||
@@ -120,7 +120,11 @@ pub trait VfsBackend: Send {
|
|||||||
fn read_dir(&self, path: &Path) -> Result<Vec<VfsDirEntry>, VfsError>;
|
fn read_dir(&self, path: &Path) -> Result<Vec<VfsDirEntry>, VfsError>;
|
||||||
|
|
||||||
/// 打开文件(读/写)
|
/// 打开文件(读/写)
|
||||||
fn open_file(&self, path: &Path, flags: &open_flags::OpenFlags) -> Result<Box<dyn VfsFile>, VfsError>;
|
fn open_file(
|
||||||
|
&self,
|
||||||
|
path: &Path,
|
||||||
|
flags: &open_flags::OpenFlags,
|
||||||
|
) -> Result<Box<dyn VfsFile>, VfsError>;
|
||||||
|
|
||||||
/// 获取文件/目录元数据
|
/// 获取文件/目录元数据
|
||||||
fn stat(&self, path: &Path) -> Result<VfsStat, VfsError>;
|
fn stat(&self, path: &Path) -> Result<VfsStat, VfsError>;
|
||||||
|
|||||||
@@ -56,7 +56,10 @@ impl S3Vfs {
|
|||||||
|
|
||||||
let credentials = Credentials::new(access_key, secret_key);
|
let credentials = Credentials::new(access_key, secret_key);
|
||||||
|
|
||||||
Ok(Self { bucket, credentials })
|
Ok(Self {
|
||||||
|
bucket,
|
||||||
|
credentials,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn path_to_key(path: &Path) -> String {
|
fn path_to_key(path: &Path) -> String {
|
||||||
@@ -118,7 +121,10 @@ impl S3Vfs {
|
|||||||
.map_err(|e| VfsError::Io(format!("S3 PUT failed: {}", e)))?;
|
.map_err(|e| VfsError::Io(format!("S3 PUT failed: {}", e)))?;
|
||||||
|
|
||||||
if resp.status() != 200 {
|
if resp.status() != 200 {
|
||||||
return Err(VfsError::Io(format!("PutObject returned {}", resp.status())));
|
return Err(VfsError::Io(format!(
|
||||||
|
"PutObject returned {}",
|
||||||
|
resp.status()
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -149,15 +155,15 @@ impl S3Vfs {
|
|||||||
.map_err(|e| VfsError::Io(format!("S3 CopyObject failed: {}", e)))?;
|
.map_err(|e| VfsError::Io(format!("S3 CopyObject failed: {}", e)))?;
|
||||||
|
|
||||||
if resp.status() != 200 {
|
if resp.status() != 200 {
|
||||||
return Err(VfsError::Io(format!("CopyObject returned {}", resp.status())));
|
return Err(VfsError::Io(format!(
|
||||||
|
"CopyObject returned {}",
|
||||||
|
resp.status()
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn list_objects(
|
fn list_objects(&self, prefix: &str) -> Result<actions::ListObjectsV2Response, VfsError> {
|
||||||
&self,
|
|
||||||
prefix: &str,
|
|
||||||
) -> Result<actions::ListObjectsV2Response, VfsError> {
|
|
||||||
let mut action = actions::ListObjectsV2::new(&self.bucket, Some(&self.credentials));
|
let mut action = actions::ListObjectsV2::new(&self.bucket, Some(&self.credentials));
|
||||||
if !prefix.is_empty() {
|
if !prefix.is_empty() {
|
||||||
action.with_prefix(prefix);
|
action.with_prefix(prefix);
|
||||||
@@ -181,9 +187,8 @@ impl S3Vfs {
|
|||||||
.read_to_string(&mut body)
|
.read_to_string(&mut body)
|
||||||
.map_err(|e| VfsError::Io(format!("Failed to read S3 list response: {}", e)))?;
|
.map_err(|e| VfsError::Io(format!("Failed to read S3 list response: {}", e)))?;
|
||||||
|
|
||||||
actions::ListObjectsV2::parse_response(&body).map_err(|e| {
|
actions::ListObjectsV2::parse_response(&body)
|
||||||
VfsError::Io(format!("Failed to parse S3 list response XML: {}", e))
|
.map_err(|e| VfsError::Io(format!("Failed to parse S3 list response XML: {}", e)))
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -409,7 +414,9 @@ impl VfsBackend for S3Vfs {
|
|||||||
|
|
||||||
impl VfsFile for S3VfsFile {
|
impl VfsFile for S3VfsFile {
|
||||||
fn read(&mut self, buf: &mut [u8]) -> Result<usize, VfsError> {
|
fn read(&mut self, buf: &mut [u8]) -> Result<usize, VfsError> {
|
||||||
let to_read = buf.len().min((self.size.saturating_sub(self.position)) as usize);
|
let to_read = buf
|
||||||
|
.len()
|
||||||
|
.min((self.size.saturating_sub(self.position)) as usize);
|
||||||
if to_read == 0 {
|
if to_read == 0 {
|
||||||
return Ok(0);
|
return Ok(0);
|
||||||
}
|
}
|
||||||
@@ -443,7 +450,7 @@ impl VfsFile for S3VfsFile {
|
|||||||
self.position = sz.saturating_add(offset as u64);
|
self.position = sz.saturating_add(offset as u64);
|
||||||
} else {
|
} else {
|
||||||
let abs = offset.unsigned_abs();
|
let abs = offset.unsigned_abs();
|
||||||
self.position = if abs <= sz { sz - abs } else { 0 };
|
self.position = sz.saturating_sub(abs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::io::SeekFrom::Current(offset) => {
|
std::io::SeekFrom::Current(offset) => {
|
||||||
@@ -451,11 +458,7 @@ impl VfsFile for S3VfsFile {
|
|||||||
self.position = self.position.saturating_add(offset as u64);
|
self.position = self.position.saturating_add(offset as u64);
|
||||||
} else {
|
} else {
|
||||||
let abs = offset.unsigned_abs();
|
let abs = offset.unsigned_abs();
|
||||||
self.position = if abs <= self.position {
|
self.position = self.position.saturating_sub(abs);
|
||||||
self.position - abs
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -549,7 +552,10 @@ impl S3VfsLike {
|
|||||||
.map_err(|e| VfsError::Io(format!("S3 PUT failed: {}", e)))?;
|
.map_err(|e| VfsError::Io(format!("S3 PUT failed: {}", e)))?;
|
||||||
|
|
||||||
if resp.status() != 200 {
|
if resp.status() != 200 {
|
||||||
return Err(VfsError::Io(format!("PutObject returned {}", resp.status())));
|
return Err(VfsError::Io(format!(
|
||||||
|
"PutObject returned {}",
|
||||||
|
resp.status()
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -612,10 +618,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_path_to_key() {
|
fn test_path_to_key() {
|
||||||
assert_eq!(
|
assert_eq!(S3Vfs::path_to_key(Path::new("/foo/bar.txt")), "foo/bar.txt");
|
||||||
S3Vfs::path_to_key(Path::new("/foo/bar.txt")),
|
|
||||||
"foo/bar.txt"
|
|
||||||
);
|
|
||||||
assert_eq!(S3Vfs::path_to_key(Path::new("/")), "");
|
assert_eq!(S3Vfs::path_to_key(Path::new("/")), "");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
S3Vfs::path_to_key(Path::new("relative/path")),
|
S3Vfs::path_to_key(Path::new("relative/path")),
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ use std::path::Path;
|
|||||||
pub fn map_io_error(path: &Path, e: std::io::Error) -> VfsError {
|
pub fn map_io_error(path: &Path, e: std::io::Error) -> VfsError {
|
||||||
match e.kind() {
|
match e.kind() {
|
||||||
std::io::ErrorKind::NotFound => VfsError::NotFound(path.display().to_string()),
|
std::io::ErrorKind::NotFound => VfsError::NotFound(path.display().to_string()),
|
||||||
std::io::ErrorKind::PermissionDenied => VfsError::PermissionDenied(path.display().to_string()),
|
std::io::ErrorKind::PermissionDenied => {
|
||||||
|
VfsError::PermissionDenied(path.display().to_string())
|
||||||
|
}
|
||||||
std::io::ErrorKind::AlreadyExists => VfsError::AlreadyExists(path.display().to_string()),
|
std::io::ErrorKind::AlreadyExists => VfsError::AlreadyExists(path.display().to_string()),
|
||||||
std::io::ErrorKind::DirectoryNotEmpty => VfsError::NotEmpty(path.display().to_string()),
|
std::io::ErrorKind::DirectoryNotEmpty => VfsError::NotEmpty(path.display().to_string()),
|
||||||
std::io::ErrorKind::NotADirectory => VfsError::NotADirectory(path.display().to_string()),
|
std::io::ErrorKind::NotADirectory => VfsError::NotADirectory(path.display().to_string()),
|
||||||
@@ -65,13 +67,7 @@ pub fn build_long_name(stat: &VfsStat, name: &str) -> String {
|
|||||||
|
|
||||||
format!(
|
format!(
|
||||||
"{}{} {} {} {} {} {} {}",
|
"{}{} {} {} {} {} {} {}",
|
||||||
file_type, perms,
|
file_type, perms, link_count, stat.uid, stat.gid, size, mtime, name
|
||||||
link_count,
|
|
||||||
stat.uid,
|
|
||||||
stat.gid,
|
|
||||||
size,
|
|
||||||
mtime,
|
|
||||||
name
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -48,10 +48,10 @@ impl FrameAlignment {
|
|||||||
|
|
||||||
pub fn is_aligned(&self, offset: usize, size: usize) -> bool {
|
pub fn is_aligned(&self, offset: usize, size: usize) -> bool {
|
||||||
if self.frame_size == 0 {
|
if self.frame_size == 0 {
|
||||||
return offset % self.frame_boundary == 0 && size % self.frame_boundary == 0;
|
return offset.is_multiple_of(self.frame_boundary) && size.is_multiple_of(self.frame_boundary);
|
||||||
}
|
}
|
||||||
|
|
||||||
offset % self.frame_size == 0 && size % self.frame_size == 0
|
offset.is_multiple_of(self.frame_size) && size.is_multiple_of(self.frame_size)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn optimal_chunk_size(&self) -> usize {
|
pub fn optimal_chunk_size(&self) -> usize {
|
||||||
@@ -74,9 +74,9 @@ impl FrameAlignment {
|
|||||||
pub fn align_size(&self, size: usize) -> usize {
|
pub fn align_size(&self, size: usize) -> usize {
|
||||||
if self.frame_size == 0 {
|
if self.frame_size == 0 {
|
||||||
let boundary = self.frame_boundary;
|
let boundary = self.frame_boundary;
|
||||||
((size + boundary - 1) / boundary) * boundary
|
size.div_ceil(boundary) * boundary
|
||||||
} else {
|
} else {
|
||||||
((size + self.frame_size - 1) / self.frame_size) * self.frame_size
|
size.div_ceil(self.frame_size) * self.frame_size
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,12 @@ pub struct ThreadSafeCache {
|
|||||||
path_cache: Mutex<HashMap<String, String>>, // path -> node_id
|
path_cache: Mutex<HashMap<String, String>>, // path -> node_id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for ThreadSafeCache {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ThreadSafeCache {
|
impl ThreadSafeCache {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|||||||
@@ -44,7 +44,9 @@ impl DbManager {
|
|||||||
let mut stmt = conn.prepare(sql)?;
|
let mut stmt = conn.prepare(sql)?;
|
||||||
|
|
||||||
let result = if level == 0 {
|
let result = if level == 0 {
|
||||||
stmt.query_row(params![component, &self.tree_type], |row| row.get::<_, String>(0))
|
stmt.query_row(params![component, &self.tree_type], |row| {
|
||||||
|
row.get::<_, String>(0)
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
stmt.query_row(
|
stmt.query_row(
|
||||||
params![component, &self.tree_type, current_parent.as_ref().unwrap()],
|
params![component, &self.tree_type, current_parent.as_ref().unwrap()],
|
||||||
@@ -79,7 +81,8 @@ impl DbManager {
|
|||||||
pub fn get_node_info(&self, node_id: &str) -> Result<Option<(String, u64)>> {
|
pub fn get_node_info(&self, node_id: &str) -> Result<Option<(String, u64)>> {
|
||||||
let conn = self.conn.lock().unwrap();
|
let conn = self.conn.lock().unwrap();
|
||||||
|
|
||||||
let sql = "SELECT node_type, file_size FROM file_nodes WHERE node_id = ?1 AND tree_type = ?2";
|
let sql =
|
||||||
|
"SELECT node_type, file_size FROM file_nodes WHERE node_id = ?1 AND tree_type = ?2";
|
||||||
let mut stmt = conn.prepare(sql)?;
|
let mut stmt = conn.prepare(sql)?;
|
||||||
|
|
||||||
let result = stmt.query_row(params![node_id, &self.tree_type], |row| {
|
let result = stmt.query_row(params![node_id, &self.tree_type], |row| {
|
||||||
@@ -100,7 +103,9 @@ impl DbManager {
|
|||||||
let mut stmt = conn.prepare(sql)?;
|
let mut stmt = conn.prepare(sql)?;
|
||||||
|
|
||||||
let labels = stmt
|
let labels = stmt
|
||||||
.query_map(params![parent_id, &self.tree_type], |row| row.get::<_, String>(0))?
|
.query_map(params![parent_id, &self.tree_type], |row| {
|
||||||
|
row.get::<_, String>(0)
|
||||||
|
})?
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
Ok(labels)
|
Ok(labels)
|
||||||
@@ -112,7 +117,9 @@ impl DbManager {
|
|||||||
let sql = "SELECT node_id FROM file_nodes WHERE parent_id = ?1 AND label = ?2 AND tree_type = ?3 LIMIT 1";
|
let sql = "SELECT node_id FROM file_nodes WHERE parent_id = ?1 AND label = ?2 AND tree_type = ?3 LIMIT 1";
|
||||||
let mut stmt = conn.prepare(sql)?;
|
let mut stmt = conn.prepare(sql)?;
|
||||||
|
|
||||||
let result = stmt.query_row(params![parent_id, name, &self.tree_type], |row| row.get::<_, String>(0));
|
let result = stmt.query_row(params![parent_id, name, &self.tree_type], |row| {
|
||||||
|
row.get::<_, String>(0)
|
||||||
|
});
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(node_id) => Ok(Some(node_id)),
|
Ok(node_id) => Ok(Some(node_id)),
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use fuse_backend_rs::api::filesystem::{Context, DirEntry, Entry, FileSystem, ZeroCopyWriter};
|
|
||||||
use fuse_backend_rs::abi::fuse_abi::{stat64, statvfs64, FsOptions, OpenOptions};
|
use fuse_backend_rs::abi::fuse_abi::{stat64, statvfs64, FsOptions, OpenOptions};
|
||||||
|
use fuse_backend_rs::api::filesystem::{Context, DirEntry, Entry, FileSystem, ZeroCopyWriter};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::CStr;
|
use std::ffi::CStr;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{Read, Seek, SeekFrom};
|
use std::io::Read;
|
||||||
use std::os::unix::io::{AsRawFd, FromRawFd};
|
use std::os::unix::io::{AsRawFd, FromRawFd};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
use std::time::{Duration, SystemTime};
|
use std::time::{Duration, SystemTime};
|
||||||
@@ -65,11 +65,14 @@ impl MarkBaseFs {
|
|||||||
st.st_uid = 501;
|
st.st_uid = 501;
|
||||||
st.st_gid = 20;
|
st.st_gid = 20;
|
||||||
st.st_size = file_size as i64;
|
st.st_size = file_size as i64;
|
||||||
st.st_blocks = ((file_size + 511) / 512) as i64;
|
st.st_blocks = file_size.div_ceil(512) as i64;
|
||||||
st.st_blksize = 4096;
|
st.st_blksize = 4096;
|
||||||
|
|
||||||
let now = SystemTime::now();
|
let now = SystemTime::now();
|
||||||
st.st_atime = now.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() as i64;
|
st.st_atime = now
|
||||||
|
.duration_since(SystemTime::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs() as i64;
|
||||||
st.st_mtime = st.st_atime;
|
st.st_mtime = st.st_atime;
|
||||||
st.st_ctime = st.st_atime;
|
st.st_ctime = st.st_atime;
|
||||||
|
|
||||||
@@ -85,20 +88,22 @@ impl FileSystem for MarkBaseFs {
|
|||||||
Ok(capable)
|
Ok(capable)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lookup(
|
fn lookup(&self, _ctx: &Context, parent: Self::Inode, name: &CStr) -> std::io::Result<Entry> {
|
||||||
&self,
|
|
||||||
ctx: &Context,
|
|
||||||
parent: Self::Inode,
|
|
||||||
name: &CStr,
|
|
||||||
) -> std::io::Result<Entry> {
|
|
||||||
let name_str = name.to_string_lossy();
|
let name_str = name.to_string_lossy();
|
||||||
|
|
||||||
let node_id = if parent == 1 {
|
let node_id = if parent == 1 {
|
||||||
self.db.find_node_id(&format!("/{}", name_str)).ok().flatten()
|
self.db
|
||||||
|
.find_node_id(&format!("/{}", name_str))
|
||||||
|
.ok()
|
||||||
|
.flatten()
|
||||||
} else {
|
} else {
|
||||||
let parent_id = self.find_node_id_by_inode(parent);
|
let parent_id = self.find_node_id_by_inode(parent);
|
||||||
match parent_id {
|
match parent_id {
|
||||||
Some(pid) => self.db.find_node_id_by_parent(&pid, &name_str).ok().flatten(),
|
Some(pid) => self
|
||||||
|
.db
|
||||||
|
.find_node_id_by_parent(&pid, &name_str)
|
||||||
|
.ok()
|
||||||
|
.flatten(),
|
||||||
None => None,
|
None => None,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -127,16 +132,16 @@ impl FileSystem for MarkBaseFs {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forget(&self, ctx: &Context, inode: Self::Inode, count: u64) {
|
fn forget(&self, _ctx: &Context, inode: Self::Inode, _count: u64) {
|
||||||
let mut map = self.inode_map.write().unwrap();
|
let mut map = self.inode_map.write().unwrap();
|
||||||
map.remove(&inode);
|
map.remove(&inode);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn getattr(
|
fn getattr(
|
||||||
&self,
|
&self,
|
||||||
ctx: &Context,
|
_ctx: &Context,
|
||||||
inode: Self::Inode,
|
inode: Self::Inode,
|
||||||
handle: Option<Self::Handle>,
|
_handle: Option<Self::Handle>,
|
||||||
) -> std::io::Result<(stat64, Duration)> {
|
) -> std::io::Result<(stat64, Duration)> {
|
||||||
if inode == 1 {
|
if inode == 1 {
|
||||||
let attr = MarkBaseFs::make_stat64("folder", 0);
|
let attr = MarkBaseFs::make_stat64("folder", 0);
|
||||||
@@ -161,24 +166,24 @@ impl FileSystem for MarkBaseFs {
|
|||||||
|
|
||||||
fn open(
|
fn open(
|
||||||
&self,
|
&self,
|
||||||
ctx: &Context,
|
_ctx: &Context,
|
||||||
inode: Self::Inode,
|
inode: Self::Inode,
|
||||||
flags: u32,
|
_flags: u32,
|
||||||
fuse_flags: u32,
|
_fuse_flags: u32,
|
||||||
) -> std::io::Result<(Option<Self::Handle>, OpenOptions, Option<u32>)> {
|
) -> std::io::Result<(Option<Self::Handle>, OpenOptions, Option<u32>)> {
|
||||||
Ok((Some(inode), OpenOptions::empty(), None))
|
Ok((Some(inode), OpenOptions::empty(), None))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read(
|
fn read(
|
||||||
&self,
|
&self,
|
||||||
ctx: &Context,
|
_ctx: &Context,
|
||||||
inode: Self::Inode,
|
inode: Self::Inode,
|
||||||
handle: Self::Handle,
|
_handle: Self::Handle,
|
||||||
w: &mut dyn ZeroCopyWriter,
|
w: &mut dyn ZeroCopyWriter,
|
||||||
size: u32,
|
size: u32,
|
||||||
offset: u64,
|
offset: u64,
|
||||||
lock_owner: Option<u64>,
|
_lock_owner: Option<u64>,
|
||||||
flags: u32,
|
_flags: u32,
|
||||||
) -> std::io::Result<usize> {
|
) -> std::io::Result<usize> {
|
||||||
let node_id = self.find_node_id_by_inode(inode);
|
let node_id = self.find_node_id_by_inode(inode);
|
||||||
match node_id {
|
match node_id {
|
||||||
@@ -202,32 +207,32 @@ impl FileSystem for MarkBaseFs {
|
|||||||
|
|
||||||
fn release(
|
fn release(
|
||||||
&self,
|
&self,
|
||||||
ctx: &Context,
|
_ctx: &Context,
|
||||||
inode: Self::Inode,
|
_inode: Self::Inode,
|
||||||
flags: u32,
|
_flags: u32,
|
||||||
handle: Self::Handle,
|
_handle: Self::Handle,
|
||||||
flush: bool,
|
_flush: bool,
|
||||||
flock_release: bool,
|
_flock_release: bool,
|
||||||
lock_owner: Option<u64>,
|
_lock_owner: Option<u64>,
|
||||||
) -> std::io::Result<()> {
|
) -> std::io::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn opendir(
|
fn opendir(
|
||||||
&self,
|
&self,
|
||||||
ctx: &Context,
|
_ctx: &Context,
|
||||||
inode: Self::Inode,
|
inode: Self::Inode,
|
||||||
flags: u32,
|
_flags: u32,
|
||||||
) -> std::io::Result<(Option<Self::Handle>, OpenOptions)> {
|
) -> std::io::Result<(Option<Self::Handle>, OpenOptions)> {
|
||||||
Ok((Some(inode), OpenOptions::empty()))
|
Ok((Some(inode), OpenOptions::empty()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn readdir(
|
fn readdir(
|
||||||
&self,
|
&self,
|
||||||
ctx: &Context,
|
_ctx: &Context,
|
||||||
inode: Self::Inode,
|
inode: Self::Inode,
|
||||||
handle: Self::Handle,
|
_handle: Self::Handle,
|
||||||
size: u32,
|
_size: u32,
|
||||||
offset: u64,
|
offset: u64,
|
||||||
add_entry: &mut dyn FnMut(DirEntry) -> std::io::Result<usize>,
|
add_entry: &mut dyn FnMut(DirEntry) -> std::io::Result<usize>,
|
||||||
) -> std::io::Result<()> {
|
) -> std::io::Result<()> {
|
||||||
@@ -252,19 +257,15 @@ impl FileSystem for MarkBaseFs {
|
|||||||
|
|
||||||
fn releasedir(
|
fn releasedir(
|
||||||
&self,
|
&self,
|
||||||
ctx: &Context,
|
_ctx: &Context,
|
||||||
inode: Self::Inode,
|
_inode: Self::Inode,
|
||||||
flags: u32,
|
_flags: u32,
|
||||||
handle: Self::Handle,
|
_handle: Self::Handle,
|
||||||
) -> std::io::Result<()> {
|
) -> std::io::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn statfs(
|
fn statfs(&self, _ctx: &Context, _inode: Self::Inode) -> std::io::Result<statvfs64> {
|
||||||
&self,
|
|
||||||
ctx: &Context,
|
|
||||||
inode: Self::Inode,
|
|
||||||
) -> std::io::Result<statvfs64> {
|
|
||||||
let mut st = unsafe { std::mem::zeroed::<statvfs64>() };
|
let mut st = unsafe { std::mem::zeroed::<statvfs64>() };
|
||||||
st.f_bsize = 4096;
|
st.f_bsize = 4096;
|
||||||
st.f_blocks = 1000000;
|
st.f_blocks = 1000000;
|
||||||
|
|||||||
@@ -30,7 +30,11 @@ fn main() -> Result<()> {
|
|||||||
let cli = Cli::parse();
|
let cli = Cli::parse();
|
||||||
|
|
||||||
match cli.command {
|
match cli.command {
|
||||||
Commands::Mount { user, dir, tree_type } => {
|
Commands::Mount {
|
||||||
|
user,
|
||||||
|
dir,
|
||||||
|
tree_type,
|
||||||
|
} => {
|
||||||
mount_user(user, tree_type, dir)?;
|
mount_user(user, tree_type, dir)?;
|
||||||
}
|
}
|
||||||
Commands::Unmount { dir } => {
|
Commands::Unmount { dir } => {
|
||||||
@@ -84,7 +88,9 @@ fn mount_user(user: String, tree_type: String, dir: PathBuf) -> Result<()> {
|
|||||||
if let Some((reader, writer)) = channel.get_request()? {
|
if let Some((reader, writer)) = channel.get_request()? {
|
||||||
if let Err(e) = server.handle_message(reader, writer.into(), None, None) {
|
if let Err(e) = server.handle_message(reader, writer.into(), None, None) {
|
||||||
match e {
|
match e {
|
||||||
fuse_backend_rs::Error::EncodeMessage(e) if e.kind() == std::io::ErrorKind::Other => {
|
fuse_backend_rs::Error::EncodeMessage(e)
|
||||||
|
if e.kind() == std::io::ErrorKind::Other =>
|
||||||
|
{
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ fn get_block_device_size(device: &str) -> Result<u64> {
|
|||||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||||
for line in stdout.lines() {
|
for line in stdout.lines() {
|
||||||
if let Some(size_str) = line.strip_prefix(" Disk Size:") {
|
if let Some(size_str) = line.strip_prefix(" Disk Size:") {
|
||||||
if let Some(bytes) = size_str.trim().split_whitespace().next() {
|
if let Some(bytes) = size_str.split_whitespace().next() {
|
||||||
if let Ok(size) = bytes.replace(',', "").parse::<u64>() {
|
if let Ok(size) = bytes.replace(',', "").parse::<u64>() {
|
||||||
return Ok(size);
|
return Ok(size);
|
||||||
}
|
}
|
||||||
@@ -120,7 +120,7 @@ pub fn generate_config(
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let (storage_path, lun_size_bytes) = if let Some(dev) = device {
|
let (storage_path, _lun_size_bytes) = if let Some(dev) = device {
|
||||||
let dev_path = Path::new(dev);
|
let dev_path = Path::new(dev);
|
||||||
if !dev_path.exists() {
|
if !dev_path.exists() {
|
||||||
anyhow::bail!("Block device not found: {}", dev);
|
anyhow::bail!("Block device not found: {}", dev);
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
pub mod nfs;
|
pub mod nfs;
|
||||||
|
|
||||||
pub use nfs::markbase_fs::MarkBaseFS;
|
|
||||||
pub use nfs::backend::MarkBaseNFSBackend;
|
pub use nfs::backend::MarkBaseNFSBackend;
|
||||||
|
pub use nfs::markbase_fs::MarkBaseFS;
|
||||||
|
|||||||
@@ -5,13 +5,12 @@ use std::time::SystemTime;
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use nfsserve::nfs::*;
|
use nfsserve::nfs::*;
|
||||||
use nfsserve::vfs::{DirEntry, NFSFileSystem, ReadDirResult, VFSCapabilities};
|
use nfsserve::vfs::{DirEntry, NFSFileSystem, ReadDirResult, VFSCapabilities};
|
||||||
use rusqlite::Connection;
|
|
||||||
|
|
||||||
use crate::nfs::markbase_fs::MarkBaseFS;
|
use crate::nfs::markbase_fs::MarkBaseFS;
|
||||||
|
|
||||||
pub struct MarkBaseNFSBackend {
|
pub struct MarkBaseNFSBackend {
|
||||||
fs: MarkBaseFS,
|
fs: MarkBaseFS,
|
||||||
id_map: Mutex<HashMap<u64, String>>, // fileid -> node_id
|
id_map: Mutex<HashMap<u64, String>>, // fileid -> node_id
|
||||||
reverse_map: Mutex<HashMap<String, u64>>, // node_id -> fileid
|
reverse_map: Mutex<HashMap<String, u64>>, // node_id -> fileid
|
||||||
next_id: Mutex<u64>,
|
next_id: Mutex<u64>,
|
||||||
}
|
}
|
||||||
@@ -50,9 +49,12 @@ impl MarkBaseNFSBackend {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn get_fileid_from_node(&self, node_id: &str) -> u64 {
|
fn get_fileid_from_node(&self, node_id: &str) -> u64 {
|
||||||
self.reverse_map.lock().unwrap().get(node_id).copied().unwrap_or_else(|| {
|
self.reverse_map
|
||||||
self.allocate_id(node_id)
|
.lock()
|
||||||
})
|
.unwrap()
|
||||||
|
.get(node_id)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or_else(|| self.allocate_id(node_id))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,13 +72,16 @@ impl NFSFileSystem for MarkBaseNFSBackend {
|
|||||||
let dir_node_id = if dirid == 1 {
|
let dir_node_id = if dirid == 1 {
|
||||||
"root".to_string()
|
"root".to_string()
|
||||||
} else {
|
} else {
|
||||||
self.get_node_id(dirid)
|
self.get_node_id(dirid).ok_or(nfsstat3::NFS3ERR_STALE)?
|
||||||
.ok_or(nfsstat3::NFS3ERR_STALE)?
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let filename_str = String::from_utf8_lossy(filename).to_string();
|
let filename_str = String::from_utf8_lossy(filename).to_string();
|
||||||
|
|
||||||
let conn = self.fs.conn.lock().map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?;
|
let conn = self
|
||||||
|
.fs
|
||||||
|
.conn
|
||||||
|
.lock()
|
||||||
|
.map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?;
|
||||||
|
|
||||||
let query = if dir_node_id == "root" {
|
let query = if dir_node_id == "root" {
|
||||||
"SELECT node_id FROM file_nodes WHERE parent_id IS NULL AND label = ?1"
|
"SELECT node_id FROM file_nodes WHERE parent_id IS NULL AND label = ?1"
|
||||||
@@ -85,10 +90,10 @@ impl NFSFileSystem for MarkBaseNFSBackend {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let node_id: String = if dir_node_id == "root" {
|
let node_id: String = if dir_node_id == "root" {
|
||||||
conn.query_row(&query, [&filename_str], |row| row.get(0))
|
conn.query_row(query, [&filename_str], |row| row.get(0))
|
||||||
.map_err(|_| nfsstat3::NFS3ERR_NOENT)?
|
.map_err(|_| nfsstat3::NFS3ERR_NOENT)?
|
||||||
} else {
|
} else {
|
||||||
conn.query_row(&query, [dir_node_id, filename_str], |row| row.get(0))
|
conn.query_row(query, [dir_node_id, filename_str], |row| row.get(0))
|
||||||
.map_err(|_| nfsstat3::NFS3ERR_NOENT)?
|
.map_err(|_| nfsstat3::NFS3ERR_NOENT)?
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -105,24 +110,40 @@ impl NFSFileSystem for MarkBaseNFSBackend {
|
|||||||
gid: 0,
|
gid: 0,
|
||||||
size: 0,
|
size: 0,
|
||||||
used: 0,
|
used: 0,
|
||||||
rdev: specdata3 { specdata1: 0, specdata2: 0 },
|
rdev: specdata3 {
|
||||||
|
specdata1: 0,
|
||||||
|
specdata2: 0,
|
||||||
|
},
|
||||||
fsid: 0,
|
fsid: 0,
|
||||||
fileid: 1,
|
fileid: 1,
|
||||||
atime: nfstime3 { seconds: 0, nseconds: 0 },
|
atime: nfstime3 {
|
||||||
mtime: nfstime3 { seconds: 0, nseconds: 0 },
|
seconds: 0,
|
||||||
ctime: nfstime3 { seconds: 0, nseconds: 0 },
|
nseconds: 0,
|
||||||
|
},
|
||||||
|
mtime: nfstime3 {
|
||||||
|
seconds: 0,
|
||||||
|
nseconds: 0,
|
||||||
|
},
|
||||||
|
ctime: nfstime3 {
|
||||||
|
seconds: 0,
|
||||||
|
nseconds: 0,
|
||||||
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let node_id = self.get_node_id(id).ok_or(nfsstat3::NFS3ERR_STALE)?;
|
let node_id = self.get_node_id(id).ok_or(nfsstat3::NFS3ERR_STALE)?;
|
||||||
|
|
||||||
let conn = self.fs.conn.lock().map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?;
|
let conn = self
|
||||||
|
.fs
|
||||||
|
.conn
|
||||||
|
.lock()
|
||||||
|
.map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?;
|
||||||
|
|
||||||
let (node_type, file_size): (String, i64) = conn
|
let (node_type, file_size): (String, i64) = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
"SELECT node_type, file_size FROM file_nodes WHERE node_id = ?1",
|
"SELECT node_type, file_size FROM file_nodes WHERE node_id = ?1",
|
||||||
[&node_id],
|
[&node_id],
|
||||||
|row| Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?))
|
|row| Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?)),
|
||||||
)
|
)
|
||||||
.map_err(|_| nfsstat3::NFS3ERR_NOENT)?;
|
.map_err(|_| nfsstat3::NFS3ERR_NOENT)?;
|
||||||
|
|
||||||
@@ -145,12 +166,24 @@ impl NFSFileSystem for MarkBaseNFSBackend {
|
|||||||
gid: 0,
|
gid: 0,
|
||||||
size: file_size as u64,
|
size: file_size as u64,
|
||||||
used: file_size as u64,
|
used: file_size as u64,
|
||||||
rdev: specdata3 { specdata1: 0, specdata2: 0 },
|
rdev: specdata3 {
|
||||||
|
specdata1: 0,
|
||||||
|
specdata2: 0,
|
||||||
|
},
|
||||||
fsid: 0,
|
fsid: 0,
|
||||||
fileid: id,
|
fileid: id,
|
||||||
atime: nfstime3 { seconds: now as u32, nseconds: 0 },
|
atime: nfstime3 {
|
||||||
mtime: nfstime3 { seconds: now as u32, nseconds: 0 },
|
seconds: now as u32,
|
||||||
ctime: nfstime3 { seconds: now as u32, nseconds: 0 },
|
nseconds: 0,
|
||||||
|
},
|
||||||
|
mtime: nfstime3 {
|
||||||
|
seconds: now as u32,
|
||||||
|
nseconds: 0,
|
||||||
|
},
|
||||||
|
ctime: nfstime3 {
|
||||||
|
seconds: now as u32,
|
||||||
|
nseconds: 0,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,21 +191,30 @@ impl NFSFileSystem for MarkBaseNFSBackend {
|
|||||||
Err(nfsstat3::NFS3ERR_ROFS)
|
Err(nfsstat3::NFS3ERR_ROFS)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn read(&self, id: fileid3, offset: u64, count: u32) -> Result<(Vec<u8>, bool), nfsstat3> {
|
async fn read(
|
||||||
|
&self,
|
||||||
|
id: fileid3,
|
||||||
|
offset: u64,
|
||||||
|
count: u32,
|
||||||
|
) -> Result<(Vec<u8>, bool), nfsstat3> {
|
||||||
let node_id = self.get_node_id(id).ok_or(nfsstat3::NFS3ERR_STALE)?;
|
let node_id = self.get_node_id(id).ok_or(nfsstat3::NFS3ERR_STALE)?;
|
||||||
|
|
||||||
let conn = self.fs.conn.lock().map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?;
|
let conn = self
|
||||||
|
.fs
|
||||||
|
.conn
|
||||||
|
.lock()
|
||||||
|
.map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?;
|
||||||
|
|
||||||
let aliases_json: String = conn
|
let aliases_json: String = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
"SELECT aliases_json FROM file_nodes WHERE node_id = ?1",
|
"SELECT aliases_json FROM file_nodes WHERE node_id = ?1",
|
||||||
[&node_id],
|
[&node_id],
|
||||||
|row| row.get(0)
|
|row| row.get(0),
|
||||||
)
|
)
|
||||||
.map_err(|_| nfsstat3::NFS3ERR_NOENT)?;
|
.map_err(|_| nfsstat3::NFS3ERR_NOENT)?;
|
||||||
|
|
||||||
let aliases: serde_json::Value = serde_json::from_str(&aliases_json)
|
let aliases: serde_json::Value =
|
||||||
.map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?;
|
serde_json::from_str(&aliases_json).map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?;
|
||||||
|
|
||||||
let file_path = aliases["path"].as_str().ok_or(nfsstat3::NFS3ERR_NOENT)?;
|
let file_path = aliases["path"].as_str().ok_or(nfsstat3::NFS3ERR_NOENT)?;
|
||||||
|
|
||||||
@@ -192,15 +234,28 @@ impl NFSFileSystem for MarkBaseNFSBackend {
|
|||||||
Err(nfsstat3::NFS3ERR_ROFS)
|
Err(nfsstat3::NFS3ERR_ROFS)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create(&self, _dirid: fileid3, _filename: &filename3, _attr: sattr3) -> Result<(fileid3, fattr3), nfsstat3> {
|
async fn create(
|
||||||
|
&self,
|
||||||
|
_dirid: fileid3,
|
||||||
|
_filename: &filename3,
|
||||||
|
_attr: sattr3,
|
||||||
|
) -> Result<(fileid3, fattr3), nfsstat3> {
|
||||||
Err(nfsstat3::NFS3ERR_ROFS)
|
Err(nfsstat3::NFS3ERR_ROFS)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_exclusive(&self, _dirid: fileid3, _filename: &filename3) -> Result<fileid3, nfsstat3> {
|
async fn create_exclusive(
|
||||||
|
&self,
|
||||||
|
_dirid: fileid3,
|
||||||
|
_filename: &filename3,
|
||||||
|
) -> Result<fileid3, nfsstat3> {
|
||||||
Err(nfsstat3::NFS3ERR_ROFS)
|
Err(nfsstat3::NFS3ERR_ROFS)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn mkdir(&self, _dirid: fileid3, _dirname: &filename3) -> Result<(fileid3, fattr3), nfsstat3> {
|
async fn mkdir(
|
||||||
|
&self,
|
||||||
|
_dirid: fileid3,
|
||||||
|
_dirname: &filename3,
|
||||||
|
) -> Result<(fileid3, fattr3), nfsstat3> {
|
||||||
Err(nfsstat3::NFS3ERR_ROFS)
|
Err(nfsstat3::NFS3ERR_ROFS)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,11 +282,14 @@ impl NFSFileSystem for MarkBaseNFSBackend {
|
|||||||
let dir_node_id = if dirid == 1 {
|
let dir_node_id = if dirid == 1 {
|
||||||
"root".to_string()
|
"root".to_string()
|
||||||
} else {
|
} else {
|
||||||
self.get_node_id(dirid)
|
self.get_node_id(dirid).ok_or(nfsstat3::NFS3ERR_STALE)?
|
||||||
.ok_or(nfsstat3::NFS3ERR_STALE)?
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let conn = self.fs.conn.lock().map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?;
|
let conn = self
|
||||||
|
.fs
|
||||||
|
.conn
|
||||||
|
.lock()
|
||||||
|
.map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?;
|
||||||
|
|
||||||
let query = if dir_node_id == "root" {
|
let query = if dir_node_id == "root" {
|
||||||
"SELECT node_id, label, node_type, file_size FROM file_nodes WHERE parent_id IS NULL"
|
"SELECT node_id, label, node_type, file_size FROM file_nodes WHERE parent_id IS NULL"
|
||||||
@@ -239,38 +297,34 @@ impl NFSFileSystem for MarkBaseNFSBackend {
|
|||||||
"SELECT node_id, label, node_type, file_size FROM file_nodes WHERE parent_id = ?1"
|
"SELECT node_id, label, node_type, file_size FROM file_nodes WHERE parent_id = ?1"
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut stmt = conn.prepare(&query).map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?;
|
let mut stmt = conn
|
||||||
|
.prepare(query)
|
||||||
|
.map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?;
|
||||||
|
|
||||||
let rows: Vec<(String, String, String, Option<i64>)> = if dir_node_id == "root" {
|
let rows: Vec<(String, String, String, Option<i64>)> = if dir_node_id == "root" {
|
||||||
stmt.query_map([], |row| {
|
stmt.query_map([], |row| {
|
||||||
row.get::<_, String>(0)
|
row.get::<_, String>(0).and_then(|node_id| {
|
||||||
.and_then(|node_id| {
|
row.get::<_, String>(1).and_then(|label| {
|
||||||
row.get::<_, String>(1)
|
row.get::<_, String>(2).and_then(|node_type| {
|
||||||
.and_then(|label| {
|
row.get::<_, Option<i64>>(3)
|
||||||
row.get::<_, String>(2)
|
.map(|file_size| (node_id, label, node_type, file_size))
|
||||||
.and_then(|node_type| {
|
})
|
||||||
row.get::<_, Option<i64>>(3)
|
|
||||||
.map(|file_size| (node_id, label, node_type, file_size))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
.map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?
|
.map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
.map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?
|
.map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?
|
||||||
} else {
|
} else {
|
||||||
stmt.query_map([&dir_node_id.as_str()], |row| {
|
stmt.query_map([&dir_node_id.as_str()], |row| {
|
||||||
row.get::<_, String>(0)
|
row.get::<_, String>(0).and_then(|node_id| {
|
||||||
.and_then(|node_id| {
|
row.get::<_, String>(1).and_then(|label| {
|
||||||
row.get::<_, String>(1)
|
row.get::<_, String>(2).and_then(|node_type| {
|
||||||
.and_then(|label| {
|
row.get::<_, Option<i64>>(3)
|
||||||
row.get::<_, String>(2)
|
.map(|file_size| (node_id, label, node_type, file_size))
|
||||||
.and_then(|node_type| {
|
})
|
||||||
row.get::<_, Option<i64>>(3)
|
|
||||||
.map(|file_size| (node_id, label, node_type, file_size))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
.map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?
|
.map_err(|_| nfsstat3::NFS3ERR_SERVERFAULT)?
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
@@ -285,12 +339,11 @@ impl NFSFileSystem for MarkBaseNFSBackend {
|
|||||||
let file_size = file_size_opt.unwrap_or(0);
|
let file_size = file_size_opt.unwrap_or(0);
|
||||||
let fileid = self.get_fileid_from_node(&node_id);
|
let fileid = self.get_fileid_from_node(&node_id);
|
||||||
|
|
||||||
if !started {
|
if !started
|
||||||
if fileid == start_after {
|
&& fileid == start_after {
|
||||||
started = true;
|
started = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if started && entries.len() < max_entries {
|
if started && entries.len() < max_entries {
|
||||||
let attr = fattr3 {
|
let attr = fattr3 {
|
||||||
@@ -305,12 +358,24 @@ impl NFSFileSystem for MarkBaseNFSBackend {
|
|||||||
gid: 0,
|
gid: 0,
|
||||||
size: file_size as u64,
|
size: file_size as u64,
|
||||||
used: file_size as u64,
|
used: file_size as u64,
|
||||||
rdev: specdata3 { specdata1: 0, specdata2: 0 },
|
rdev: specdata3 {
|
||||||
|
specdata1: 0,
|
||||||
|
specdata2: 0,
|
||||||
|
},
|
||||||
fsid: 0,
|
fsid: 0,
|
||||||
fileid,
|
fileid,
|
||||||
atime: nfstime3 { seconds: 0, nseconds: 0 },
|
atime: nfstime3 {
|
||||||
mtime: nfstime3 { seconds: 0, nseconds: 0 },
|
seconds: 0,
|
||||||
ctime: nfstime3 { seconds: 0, nseconds: 0 },
|
nseconds: 0,
|
||||||
|
},
|
||||||
|
mtime: nfstime3 {
|
||||||
|
seconds: 0,
|
||||||
|
nseconds: 0,
|
||||||
|
},
|
||||||
|
ctime: nfstime3 {
|
||||||
|
seconds: 0,
|
||||||
|
nseconds: 0,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
entries.push(DirEntry {
|
entries.push(DirEntry {
|
||||||
@@ -321,10 +386,7 @@ impl NFSFileSystem for MarkBaseNFSBackend {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(ReadDirResult {
|
Ok(ReadDirResult { entries, end: true })
|
||||||
entries,
|
|
||||||
end: true,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn symlink(
|
async fn symlink(
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user