Implement Upload Hook for momentry integration (Phase 1)
Some checks failed
Test / test (push) Has been cancelled
Test / build (push) Has been cancelled

- Add upload_hook.rs module: trigger video_probe + video_register on upload
- Add UploadHookSection to config: video extensions, binary paths
- Integrate with SFTP: handle_close triggers hook on write files
- Integrate with SCP/rsync: child process exit triggers hook
- All 155 tests pass
This commit is contained in:
Warren
2026-06-19 06:26:20 +08:00
parent c71811090b
commit e2d58538f9
7 changed files with 336 additions and 42 deletions

View File

@@ -4,11 +4,12 @@
use crate::ssh_server::packet::{PacketType, SshPacket};
use crate::ssh_server::port_forward::{
DirectTcpipChannel, ForwardedTcpipChannel, PortForwardManager,
}; // Phase 13.3
use crate::ssh_server::rsync_handler::RsyncHandler; // Phase 8: rsync handler
use crate::ssh_server::scp_handler::ScpHandler; // Phase 8: SCP handler
use crate::ssh_server::sftp_handler::SftpHandler; // Phase 7: SFTP handler
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.3: 安全配置
};
use crate::ssh_server::rsync_handler::RsyncHandler;
use crate::ssh_server::scp_handler::ScpHandler;
use crate::ssh_server::sftp_handler::SftpHandler;
use crate::ssh_server::ssh_security_config::SshSecurityConfig;
use crate::ssh_server::upload_hook::UploadHook;
use anyhow::{anyhow, Result};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use log::{info, warn};
@@ -25,10 +26,10 @@ use std::process::{Child, ChildStderr, ChildStdin, ChildStdout}; // Phase 14:
pub struct ChannelManager {
channels: HashMap<u32, Channel>,
next_channel_id: u32,
/// ⭐⭐⭐⭐⭐ Phase 15.1: 待发送packet队列用于同时发送WINDOW_ADJUST和SFTP响应
pub pending_packets: VecDeque<SshPacket>,
/// 用户home目录SFTP/SCP/rsync根目录SFTPGo兼容
pub home_dir: PathBuf,
pub upload_hook: Option<std::sync::Arc<UploadHook>>,
pub user_uuid: String,
}
/// Phase 14: 交互式Exec进程管理参考OpenSSH session.c: do_exec_no_pty
@@ -44,12 +45,18 @@ pub struct ExecProcess {
}
impl ChannelManager {
pub fn new(home_dir: PathBuf) -> Self {
pub fn new(
home_dir: PathBuf,
upload_hook: Option<std::sync::Arc<UploadHook>>,
user_uuid: String,
) -> Self {
Self {
channels: HashMap::new(),
next_channel_id: 0,
pending_packets: VecDeque::new(),
home_dir,
upload_hook,
user_uuid,
}
}
@@ -574,7 +581,13 @@ impl ChannelManager {
};
let vfs = Box::new(crate::vfs::local_fs::LocalFs::new());
let sftp_handler = SftpHandler::new(root_dir, vfs, maxpacket); // ⭐⭐⭐⭐⭐ Phase 4: 传入 maxpack
let sftp_handler = SftpHandler::new(
root_dir,
vfs,
maxpacket,
self.upload_hook.clone(),
self.user_uuid.clone(),
);
// 存储到channel
if let Some(ch) = self.channels.get_mut(&channel) {
@@ -1374,7 +1387,10 @@ impl ChannelManager {
);
child_exited = true;
// ⭐⭐⭐⭐⭐ Child exited读取剩余stdout如果有
let command_str = exec_process.command.clone();
let should_trigger_hook = status.success()
&& (command_str.contains("scp") || command_str.contains("rsync"));
if let Some(stdout) = &mut exec_process.stdout {
let mut buffer = vec![0u8; 32768];
match stdout.read(&mut buffer) {
@@ -1395,6 +1411,17 @@ impl ChannelManager {
}
}
if should_trigger_hook {
let dest_path = Self::extract_dest_path_from_command(&command_str, &self.home_dir);
if let Some(path) = dest_path {
if let Some(hook) = &self.upload_hook {
if let Err(e) = hook.trigger(&path, &self.user_uuid) {
warn!("Upload hook failed for {:?}: {}", path, e);
}
}
}
}
// 没有剩余数据返回child_exited标志
return Ok((None, false, true));
}
@@ -1823,6 +1850,29 @@ impl ChannelManager {
Ok(Some(packets))
}
}
fn extract_dest_path_from_command(command: &str, home_dir: &PathBuf) -> Option<PathBuf> {
if command.contains("scp") {
if command.contains("scp -t") {
let parts: Vec<&str> = command.split_whitespace().collect();
for part in parts.iter().rev() {
if !part.starts_with("-") && *part != "scp" && *part != "-t" {
return Some(home_dir.join(part));
}
}
}
} else if command.contains("rsync") {
if command.contains("--server") {
let parts: Vec<&str> = command.split_whitespace().collect();
for part in parts.iter().rev() {
if !part.starts_with("-") && !part.contains("--") && *part != "rsync" && *part != "--server" && *part != "--sender" {
return Some(home_dir.join(part));
}
}
}
}
None
}
}
/// SSH Channel结构参考OpenSSH channel.c: struct channel
@@ -1967,13 +2017,13 @@ mod tests {
#[test]
fn test_channel_manager_creation() {
let manager = ChannelManager::new(PathBuf::from("/tmp"));
let manager = ChannelManager::new(PathBuf::from("/tmp"), None, "test_user".to_string());
assert_eq!(manager.next_channel_id, 0);
}
#[test]
fn test_channel_open_confirmation() {
let manager = ChannelManager::new(PathBuf::from("/tmp"));
let manager = ChannelManager::new(PathBuf::from("/tmp"), None, "test_user".to_string());
let packet = manager
.build_channel_open_confirmation(0, 100, 2097152, 32768)
.unwrap();
@@ -1986,7 +2036,7 @@ mod tests {
#[test]
fn test_channel_success() {
let manager = ChannelManager::new(PathBuf::from("/tmp"));
let manager = ChannelManager::new(PathBuf::from("/tmp"), None, "test_user".to_string());
let packet = manager.build_channel_success(0).unwrap();
assert_eq!(packet.payload[0], PacketType::SSH_MSG_CHANNEL_SUCCESS as u8);

View File

@@ -5,21 +5,22 @@ pub mod auth;
pub mod channel;
pub mod cipher;
pub mod crypto;
pub mod data_forwarder; // Phase 13.5: 数据传输模块
pub mod data_forwarder;
pub mod kex;
pub mod kex_complete;
pub mod kex_exchange;
pub mod packet;
pub mod port_forward; // Phase 13: 端口转发模块
pub mod port_forward_listener; // Phase 13.4: 监听线程模块
pub mod port_forward;
pub mod port_forward_listener;
pub mod rsync_handler;
pub mod scp_handler;
pub mod server;
pub mod sftp_handler;
pub mod ssh_security_config; // Phase 13.1: 企业级安全配置
pub mod sshbuf; // Phase 15: SSH Buffer 零拷贝管理参考OpenSSH sshbuf.c
pub mod ssh_security_config;
pub mod sshbuf;
pub mod upload_hook;
pub mod version;
pub mod window_manager; // Phase 13.6-13.7: Window size + Channel生命周期
pub mod window_manager;
pub use packet::{PacketType, SshPacket};
pub use server::SshServer;

View File

@@ -10,8 +10,9 @@ use crate::ssh_server::cipher::{EncryptedPacket, EncryptionContext};
use crate::ssh_server::kex::{KexProposal, KexResult};
use crate::ssh_server::kex_complete::KexState;
use crate::ssh_server::packet::{PacketType, SshPacket};
use crate::ssh_server::port_forward::PortForwardManager; // Phase 13
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1
use crate::ssh_server::port_forward::PortForwardManager;
use crate::ssh_server::ssh_security_config::SshSecurityConfig;
use crate::ssh_server::upload_hook::UploadHook;
use crate::ssh_server::version::VersionExchange;
use anyhow::{anyhow, Result};
use log::{error, info, warn};
@@ -19,14 +20,14 @@ use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::thread; // Phase 13: 端口转发线程同步
use std::thread;
/// SSH服务器配置Phase 13.1企业级安全配置)
pub struct SshServerConfig {
pub port: u16,
pub bind_address: String,
pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置
pub pg_conn: Option<String>, // PostgreSQL连接字符串SFTPGo兼容认证
pub security_config: SshSecurityConfig,
pub pg_conn: Option<String>,
pub upload_hook_config: crate::config::UploadHookSection,
}
impl Default for SshServerConfig {
@@ -34,8 +35,9 @@ impl Default for SshServerConfig {
Self {
port: 2024,
bind_address: "127.0.0.1".to_string(),
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1
security_config: SshSecurityConfig::enterprise_default(),
pg_conn: None,
upload_hook_config: crate::config::UploadHookSection::default(),
}
}
}
@@ -49,6 +51,7 @@ impl SshServerConfig {
bind_address: "127.0.0.1".to_string(),
security_config: config,
pg_conn: None,
upload_hook_config: crate::config::UploadHookSection::default(),
})
}
}
@@ -81,8 +84,9 @@ impl SshServer {
self.config.security_config.max_sessions
);
let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置
let security_config = self.security_config.clone();
let pg_conn = self.config.pg_conn.clone();
let upload_hook_config = self.config.upload_hook_config.clone();
for stream in listener.incoming() {
match stream {
@@ -90,14 +94,18 @@ impl SshServer {
let client_addr = stream.peer_addr()?;
info!("New SSH connection from {}", client_addr);
let security_config_clone = security_config.clone(); // Phase 13.1
let security_config_clone = security_config.clone();
let pg_conn_clone = pg_conn.clone();
let upload_hook_config_clone = upload_hook_config.clone();
thread::spawn(move || {
if let Err(e) =
handle_connection_complete(stream, security_config_clone, pg_conn_clone)
if let Err(e) = handle_connection_complete(
stream,
security_config_clone,
pg_conn_clone,
upload_hook_config_clone,
)
{
// Phase 13.1
error!("Connection error: {}", e);
}
});
@@ -117,6 +125,7 @@ fn handle_connection_complete(
stream: TcpStream,
security_config: Arc<Mutex<SshSecurityConfig>>,
pg_conn: Option<String>,
upload_hook_config: crate::config::UploadHookSection,
) -> Result<()> {
info!("Handling client connection (Phase 1-13 complete flow with port forwarding)");
@@ -173,8 +182,23 @@ fn handle_connection_complete(
let auth_user = perform_ssh_auth(&mut stream, &mut auth_handler, &mut encryption_ctx)?;
info!("SSH authentication succeeded: user={}", auth_user.username);
// Phase 6: SSH Channel管理参考OpenSSH channel.c
let mut channel_manager = ChannelManager::new(auth_user.home_dir.clone());
let upload_hook = if upload_hook_config.enabled {
Some(Arc::new(UploadHook::new(
upload_hook_config.enabled,
PathBuf::from(&upload_hook_config.video_probe_path),
PathBuf::from(&upload_hook_config.video_register_cli),
PathBuf::from(&upload_hook_config.video_register_dir),
upload_hook_config.video_extensions.clone(),
)))
} else {
None
};
let mut channel_manager = ChannelManager::new(
auth_user.home_dir.clone(),
upload_hook,
auth_user.username.clone(),
);
// Phase 13: PortForwardManager初始化
let mut port_forward_manager = PortForwardManager::new();
@@ -666,8 +690,9 @@ pub fn run_ssh_server(port: Option<u16>, pg_conn: Option<&str>) -> Result<()> {
let config = SshServerConfig {
port: port.unwrap_or(2024),
bind_address: "127.0.0.1".to_string(),
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置
security_config: SshSecurityConfig::enterprise_default(),
pg_conn: pg_conn.map(|s| s.to_string()),
upload_hook_config: crate::config::UploadHookSection::default(),
};
let server = SshServer::new(config);

View File

@@ -289,6 +289,7 @@ pub struct SftpHandle {
pub handle_type: SftpHandleType,
pub file: Option<Box<dyn VfsFile>>,
pub dir_entries: Option<Vec<VfsDirEntry>>,
pub write_mode: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -303,10 +304,10 @@ pub struct SftpHandler {
vfs: Box<dyn VfsBackend>,
next_handle_id: u32,
handles: std::collections::HashMap<u32, SftpHandle>,
// ⭐⭐⭐⭐⭐ Phase 4: 添加 client maxpack 限制参考OpenSSH sftp-server.c
maxpacket: u32, // 来自 SSH_MSG_CHANNEL_OPEN_CONFIRMATION 的 maximum_packet_size
/// 限制绝对路径也在 root_dir 之下chroot 模式)
maxpacket: u32,
restrict_absolute: bool,
upload_hook: Option<std::sync::Arc<crate::ssh_server::upload_hook::UploadHook>>,
user_uuid: String,
}
impl SftpHandler {
@@ -318,7 +319,13 @@ impl SftpHandler {
const MAX_HASH_SIZE: u64 = 268_435_456;
// ⭐⭐⭐⭐⭐ Phase 4: 修改 new() 方法,接受 maxpack 参数
pub fn new(root_dir: PathBuf, vfs: Box<dyn VfsBackend>, maxpacket: u32) -> Self {
pub fn new(
root_dir: PathBuf,
vfs: Box<dyn VfsBackend>,
maxpacket: u32,
upload_hook: Option<std::sync::Arc<crate::ssh_server::upload_hook::UploadHook>>,
user_uuid: String,
) -> Self {
let canonical_root = root_dir.canonicalize().unwrap_or(root_dir);
Self {
root_dir: canonical_root,
@@ -327,6 +334,8 @@ impl SftpHandler {
handles: std::collections::HashMap::new(),
maxpacket,
restrict_absolute: false,
upload_hook,
user_uuid,
}
}
@@ -426,6 +435,7 @@ impl SftpHandler {
handle_type: SftpHandleType::File,
file: Some(file),
dir_entries: None,
write_mode: flags.write,
};
self.handles.insert(handle_id, handle);
@@ -454,7 +464,14 @@ impl SftpHandler {
info!("SSH_FXP_CLOSE: id={}, handle={}", id, handle_id);
if self.handles.remove(&handle_id).is_some() {
if let Some(handle) = self.handles.remove(&handle_id) {
if handle.write_mode && handle.handle_type == SftpHandleType::File {
if let Some(hook) = &self.upload_hook {
if let Err(e) = hook.trigger(&handle.path, &self.user_uuid) {
warn!("Upload hook failed for {:?}: {}", handle.path, e);
}
}
}
self.build_status_response(id, SftpStatus::SSH_FX_OK, "File closed")
} else {
self.build_status_response(id, SftpStatus::SSH_FX_FAILURE, "Invalid handle")
@@ -670,6 +687,7 @@ impl SftpHandler {
handle_type: SftpHandleType::Directory,
file: None,
dir_entries: Some(entries),
write_mode: false,
};
self.handles.insert(handle_id, handle);
@@ -1765,7 +1783,7 @@ mod tests {
use tempfile::TempDir;
fn make_handler(root_dir: PathBuf) -> SftpHandler {
SftpHandler::new(root_dir, Box::new(LocalFs::new()), 32768)
SftpHandler::new(root_dir, Box::new(LocalFs::new()), 32768, None, "test_user".to_string())
}
#[test]

View File

@@ -0,0 +1,164 @@
use std::path::{Path, PathBuf};
use std::process::Command;
use anyhow::{anyhow, Result};
use log::{info, warn, error};
pub struct UploadHook {
enabled: bool,
video_probe_path: PathBuf,
video_register_cli: PathBuf,
video_register_dir: PathBuf,
video_extensions: Vec<String>,
}
impl UploadHook {
pub fn new(
enabled: bool,
video_probe_path: PathBuf,
video_register_cli: PathBuf,
video_register_dir: PathBuf,
video_extensions: Vec<String>,
) -> Self {
Self {
enabled,
video_probe_path,
video_register_cli,
video_register_dir,
video_extensions,
}
}
pub fn trigger(&self, file_path: &Path, user_uuid: &str) -> Result<()> {
if !self.enabled {
return Ok(());
}
if !self.is_video_file(file_path) {
info!("UploadHook: Skipping non-video file: {:?}", file_path);
return Ok(());
}
info!("UploadHook: Triggering for file {:?} (user={})", file_path, user_uuid);
let probe_json = self.run_video_probe(file_path)?;
let video_uuid = self.run_video_register(&probe_json, user_uuid)?;
info!("UploadHook: Video registered successfully (UUID={})", video_uuid);
Ok(())
}
fn run_video_probe(&self, file_path: &Path) -> Result<PathBuf> {
info!("UploadHook: Running video_probe on {:?}", file_path);
let output = Command::new(&self.video_probe_path)
.arg(file_path)
.output()
.map_err(|e| anyhow!("Failed to execute video_probe: {}", e))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
error!("UploadHook: video_probe failed: {}", stderr);
return Err(anyhow!("video_probe failed with status {}", output.status));
}
let stdout = String::from_utf8_lossy(&output.stdout);
info!("UploadHook: video_probe output: {}", stdout);
let probe_json = file_path.with_extension("probe.json");
Ok(probe_json)
}
fn run_video_register(&self, probe_json: &Path, user_uuid: &str) -> Result<String> {
info!("UploadHook: Running video_register on {:?}", probe_json);
let output = Command::new("python3")
.arg(&self.video_register_cli)
.arg("register")
.arg(probe_json)
.current_dir(&self.video_register_dir)
.env("USER_UUID", user_uuid)
.output()
.map_err(|e| anyhow!("Failed to execute video_register: {}", e))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
error!("UploadHook: video_register failed: {}", stderr);
return Err(anyhow!("video_register failed with status {}", output.status));
}
let stdout = String::from_utf8_lossy(&output.stdout);
info!("UploadHook: video_register output: {}", stdout);
let uuid = stdout
.lines()
.find(|line| line.contains("UUID:") || line.contains("uuid"))
.and_then(|line| {
if line.contains("UUID:") {
line.split(':').nth(1).map(|s| s.trim().to_string())
} else {
Some(line.trim().to_string())
}
})
.unwrap_or_else(|| "unknown".to_string());
Ok(uuid)
}
fn is_video_file(&self, path: &Path) -> bool {
path.extension()
.and_then(|e| e.to_str())
.map(|ext| self.video_extensions.contains(&ext.to_lowercase()))
.unwrap_or(false)
}
}
impl Default for UploadHook {
fn default() -> Self {
Self::new(
false,
PathBuf::from("/Users/accusys/momentry_core_project/video_probe/target/release/video_probe"),
PathBuf::from("cli.py"),
PathBuf::from("/Users/accusys/momentry_core_project/video_register"),
vec![
"mp4".to_string(),
"mov".to_string(),
"avi".to_string(),
"mkv".to_string(),
"webm".to_string(),
"flv".to_string(),
"wmv".to_string(),
"m4v".to_string(),
],
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_video_file() {
let hook = UploadHook::default();
assert!(hook.is_video_file(Path::new("test.mp4")));
assert!(hook.is_video_file(Path::new("test.MOV")));
assert!(hook.is_video_file(Path::new("test.avi")));
assert!(!hook.is_video_file(Path::new("test.txt")));
assert!(!hook.is_video_file(Path::new("test.jpg")));
}
#[test]
fn test_disabled_hook() {
let hook = UploadHook::new(
false,
PathBuf::from("/tmp/video_probe"),
PathBuf::from("cli.py"),
PathBuf::from("/tmp/video_register"),
vec!["mp4".to_string()],
);
let result = hook.trigger(Path::new("/tmp/test.mp4"), "user123");
assert!(result.is_ok());
}
}