From 5238a84972fe132caf9efe9071fd7ed594b835d2 Mon Sep 17 00:00:00 2001 From: Warren Date: Sun, 21 Jun 2026 05:11:39 +0800 Subject: [PATCH] Implement SMB Durable Handles (Phase 1): Persistent FileId + reconnect + expiration + cleanup --- vendor/smb-server/src/backend.rs | 41 +++ vendor/smb-server/src/durable_handle.rs | 402 ++++++++++++++++++++++++ vendor/smb-server/src/lib.rs | 3 +- 3 files changed, 445 insertions(+), 1 deletion(-) create mode 100644 vendor/smb-server/src/durable_handle.rs diff --git a/vendor/smb-server/src/backend.rs b/vendor/smb-server/src/backend.rs index 7c2bdcf..a99352a 100644 --- a/vendor/smb-server/src/backend.rs +++ b/vendor/smb-server/src/backend.rs @@ -236,3 +236,44 @@ impl ShareBackend for NotSupportedBackend { } } } + +/// Null handle for testing purposes. +pub struct NullHandle; + +#[async_trait] +impl Handle for NullHandle { + async fn read(&self, _offset: u64, _len: u32) -> SmbResult { + Err(SmbError::NotSupported) + } + async fn write(&self, _offset: u64, _data: &[u8]) -> SmbResult { + Err(SmbError::NotSupported) + } + async fn flush(&self) -> SmbResult<()> { + Err(SmbError::NotSupported) + } + async fn stat(&self) -> SmbResult { + Ok(FileInfo { + name: String::new(), + end_of_file: 0, + allocation_size: 0, + creation_time: 0, + last_access_time: 0, + last_write_time: 0, + change_time: 0, + is_directory: false, + file_index: 0, + }) + } + async fn set_times(&self, _times: FileTimes) -> SmbResult<()> { + Err(SmbError::NotSupported) + } + async fn truncate(&self, _len: u64) -> SmbResult<()> { + Err(SmbError::NotSupported) + } + async fn list_dir(&self, _pattern: Option<&str>) -> SmbResult> { + Err(SmbError::NotSupported) + } + async fn close(self: Box) -> SmbResult<()> { + Ok(()) + } +} diff --git a/vendor/smb-server/src/durable_handle.rs b/vendor/smb-server/src/durable_handle.rs new file mode 100644 index 0000000..c003bec --- /dev/null +++ b/vendor/smb-server/src/durable_handle.rs @@ -0,0 +1,402 @@ +use crate::conn::state::Open; +use crate::path::SmbPath; +use crate::proto::messages::FileId; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; + +#[derive(Debug, Clone)] +pub struct DurableHandleConfig { + pub max_durable_handles: usize, + pub handle_timeout: Duration, + pub cleanup_interval: Duration, + pub enable_persistent_ids: bool, +} + +impl Default for DurableHandleConfig { + fn default() -> Self { + Self { + max_durable_handles: 1000, + handle_timeout: Duration::from_secs(300), + cleanup_interval: Duration::from_secs(60), + enable_persistent_ids: true, + } + } +} + +#[derive(Debug, Clone)] +pub struct DurableHandleEntry { + pub persistent_id: u64, + pub volatile_id: u64, + pub session_id: u64, + pub tree_id: u32, + pub path: SmbPath, + pub granted_access: u32, + pub share_access: u32, + pub oplock_level: u8, + pub lease_key: Option<[u8; 16]>, + pub lease_state: Option, + pub created_at: Instant, + pub last_access: Instant, + pub is_directory: bool, + pub delete_on_close: bool, + pub create_contexts: Vec, +} + +impl DurableHandleEntry { + pub fn file_id(&self) -> FileId { + FileId::new(self.persistent_id, self.volatile_id) + } + + pub fn is_expired(&self, now: Instant, timeout: Duration) -> bool { + now.duration_since(self.last_access) > timeout + } +} + +pub struct DurableHandleManager { + config: DurableHandleConfig, + handles: RwLock>, + persistent_to_volatile: RwLock>, + next_persistent_id: RwLock, +} + +impl DurableHandleManager { + pub fn new(config: DurableHandleConfig) -> Self { + Self { + config, + handles: RwLock::new(HashMap::new()), + persistent_to_volatile: RwLock::new(HashMap::new()), + next_persistent_id: RwLock::new(1), + } + } + + pub fn default() -> Self { + Self::new(DurableHandleConfig::default()) + } + + pub async fn alloc_persistent_id(&self) -> u64 { + let mut next_id = self.next_persistent_id.write().await; + let id = *next_id; + *next_id += 1; + id + } + + pub async fn register_durable_handle( + &self, + open: &Open, + session_id: u64, + tree_id: u32, + create_contexts: Vec, + ) -> Result { + let handles = self.handles.read().await; + if handles.len() >= self.config.max_durable_handles { + return Err(DurableHandleError::MaxHandlesReached); + } + drop(handles); + + let persistent_id = self.alloc_persistent_id().await; + let volatile_id = open.file_id.volatile; + + let entry = DurableHandleEntry { + persistent_id, + volatile_id, + session_id, + tree_id, + path: open.last_path.clone(), + granted_access: if open.granted_access.allows_write() { 1 } else { 0 }, + share_access: open.share_access, + oplock_level: open.oplock_level, + lease_key: open.lease_key, + lease_state: open.lease_state, + created_at: Instant::now(), + last_access: Instant::now(), + is_directory: open.is_directory, + delete_on_close: open.delete_on_close, + create_contexts, + }; + + let mut handles = self.handles.write().await; + handles.insert(persistent_id, entry); + + let mut p2v = self.persistent_to_volatile.write().await; + p2v.insert(persistent_id, volatile_id); + + Ok(FileId::new(persistent_id, volatile_id)) + } + + pub async fn lookup_durable_handle( + &self, + persistent_id: u64, + ) -> Option { + let handles = self.handles.read().await; + handles.get(&persistent_id).cloned() + } + + pub async fn lookup_by_volatile(&self, volatile_id: u64) -> Option { + let handles = self.handles.read().await; + handles + .values() + .find(|e| e.volatile_id == volatile_id) + .cloned() + } + + pub async fn update_access_time(&self, persistent_id: u64) { + let mut handles = self.handles.write().await; + if let Some(entry) = handles.get_mut(&persistent_id) { + entry.last_access = Instant::now(); + } + } + + pub async fn remove_durable_handle(&self, persistent_id: u64) { + let mut handles = self.handles.write().await; + handles.remove(&persistent_id); + + let mut p2v = self.persistent_to_volatile.write().await; + p2v.remove(&persistent_id); + } + + pub async fn reconnect_handle( + &self, + persistent_id: u64, + new_session_id: u64, + new_tree_id: u32, + ) -> Result { + let mut handles = self.handles.write().await; + + let entry = handles + .get(&persistent_id) + .cloned() + .ok_or(DurableHandleError::HandleNotFound)?; + + if entry.is_expired(Instant::now(), self.config.handle_timeout) { + handles.remove(&persistent_id); + return Err(DurableHandleError::HandleExpired); + } + + let mut_entry = handles.get_mut(&persistent_id).unwrap(); + mut_entry.session_id = new_session_id; + mut_entry.tree_id = new_tree_id; + mut_entry.last_access = Instant::now(); + + Ok(mut_entry.clone()) + } + + pub async fn cleanup_expired_handles(&self) -> usize { + let now = Instant::now(); + let mut handles = self.handles.write().await; + let mut p2v = self.persistent_to_volatile.write().await; + + let expired_count = handles.len(); + handles.retain(|_, entry| !entry.is_expired(now, self.config.handle_timeout)); + let retained_count = handles.len(); + + p2v.retain(|persistent_id, _| handles.contains_key(persistent_id)); + + expired_count - retained_count + } + + pub async fn get_stats(&self) -> DurableHandleStats { + let handles = self.handles.read().await; + + let total = handles.len(); + let expired = handles + .values() + .filter(|e| e.is_expired(Instant::now(), self.config.handle_timeout)) + .count(); + + let by_session: HashMap = handles + .values() + .fold(HashMap::new(), |mut acc, e| { + *acc.entry(e.session_id).or_insert(0) += 1; + acc + }); + + DurableHandleStats { + total_handles: total, + expired_handles: expired, + max_handles: self.config.max_durable_handles, + handles_by_session: by_session, + } + } + + pub async fn get_file_id_for_reconnect( + &self, + persistent_id: u64, + ) -> Option { + let handles = self.handles.read().await; + handles.get(&persistent_id).map(|e| e.file_id()) + } +} + +#[derive(Debug, Clone)] +pub struct DurableHandleStats { + pub total_handles: usize, + pub expired_handles: usize, + pub max_handles: usize, + pub handles_by_session: HashMap, +} + +#[derive(Debug, Clone)] +pub enum DurableHandleError { + MaxHandlesReached, + HandleNotFound, + HandleExpired, + InvalidPersistentId, + SessionMismatch, +} + +impl std::fmt::Display for DurableHandleError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DurableHandleError::MaxHandlesReached => write!(f, "Maximum durable handles reached"), + DurableHandleError::HandleNotFound => write!(f, "Durable handle not found"), + DurableHandleError::HandleExpired => write!(f, "Durable handle expired"), + DurableHandleError::InvalidPersistentId => write!(f, "Invalid persistent ID"), + DurableHandleError::SessionMismatch => write!(f, "Session mismatch during reconnect"), + } + } +} + +impl std::error::Error for DurableHandleError {} + +#[cfg(test)] +mod tests { + use super::*; + use crate::conn::state::Open; + use crate::proto::messages::FileId; + use std::time::Duration; + + fn make_test_open() -> Open { + Open::new( + FileId::new(0, 1), + Box::new(crate::backend::NullHandle), + crate::builder::Access::Read, + SmbPath::root(), + false, + false, + 0, + 0, + ) + } + + #[tokio::test] + async fn test_register_durable_handle() { + let manager = DurableHandleManager::default(); + let open = make_test_open(); + + let file_id = manager + .register_durable_handle(&open, 1, 1, vec![]) + .await + .unwrap(); + + assert_ne!(file_id.persistent, 0); + assert_eq!(file_id.volatile, 1); + } + + #[tokio::test] + async fn test_lookup_durable_handle() { + let manager = DurableHandleManager::default(); + let open = make_test_open(); + + let file_id = manager + .register_durable_handle(&open, 1, 1, vec![]) + .await + .unwrap(); + + let entry = manager.lookup_durable_handle(file_id.persistent).await; + assert!(entry.is_some()); + assert_eq!(entry.unwrap().session_id, 1); + } + + #[tokio::test] + async fn test_reconnect_handle() { + let manager = DurableHandleManager::default(); + let open = make_test_open(); + + let file_id = manager + .register_durable_handle(&open, 1, 1, vec![]) + .await + .unwrap(); + + let entry = manager + .reconnect_handle(file_id.persistent, 2, 2) + .await + .unwrap(); + + assert_eq!(entry.session_id, 2); + assert_eq!(entry.tree_id, 2); + } + + #[tokio::test] + async fn test_expired_handle() { + let config = DurableHandleConfig { + handle_timeout: Duration::from_millis(100), + ..Default::default() + }; + let manager = DurableHandleManager::new(config); + let open = make_test_open(); + + let file_id = manager + .register_durable_handle(&open, 1, 1, vec![]) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(150)).await; + + let result = manager.reconnect_handle(file_id.persistent, 2, 2).await; + assert!(matches!(result, Err(DurableHandleError::HandleExpired))); + } + + #[tokio::test] + async fn test_cleanup_expired_handles() { + let config = DurableHandleConfig { + handle_timeout: Duration::from_millis(100), + ..Default::default() + }; + let manager = DurableHandleManager::new(config); + let open = make_test_open(); + + manager.register_durable_handle(&open, 1, 1, vec![]).await.unwrap(); + + tokio::time::sleep(Duration::from_millis(150)).await; + + let cleaned = manager.cleanup_expired_handles().await; + assert_eq!(cleaned, 1); + + let stats = manager.get_stats().await; + assert_eq!(stats.total_handles, 0); + } + + #[tokio::test] + async fn test_max_handles_limit() { + let config = DurableHandleConfig { + max_durable_handles: 2, + ..Default::default() + }; + let manager = DurableHandleManager::new(config); + let open = make_test_open(); + + manager.register_durable_handle(&open, 1, 1, vec![]).await.unwrap(); + manager.register_durable_handle(&open, 2, 1, vec![]).await.unwrap(); + + let result = manager.register_durable_handle(&open, 3, 1, vec![]).await; + assert!(matches!(result, Err(DurableHandleError::MaxHandlesReached))); + } + + #[tokio::test] + async fn test_remove_durable_handle() { + let manager = DurableHandleManager::default(); + let open = make_test_open(); + + let file_id = manager + .register_durable_handle(&open, 1, 1, vec![]) + .await + .unwrap(); + + manager.remove_durable_handle(file_id.persistent).await; + + let entry = manager.lookup_durable_handle(file_id.persistent).await; + assert!(entry.is_none()); + } +} \ No newline at end of file diff --git a/vendor/smb-server/src/lib.rs b/vendor/smb-server/src/lib.rs index 0aa39d9..368c531 100644 --- a/vendor/smb-server/src/lib.rs +++ b/vendor/smb-server/src/lib.rs @@ -20,6 +20,7 @@ mod backend; mod builder; pub(crate) mod conn; mod dispatch; +mod durable_handle; mod error; #[cfg(feature = "localfs")] mod fs; @@ -32,7 +33,7 @@ mod proto; mod server; mod utils; -pub use backend::{BackendCapabilities, DirEntry, FileInfo, FileTimes, Handle, OpenIntent, OpenOptions, ShareBackend}; +pub use backend::{BackendCapabilities, DirEntry, FileInfo, FileTimes, Handle, NullHandle, OpenIntent, OpenOptions, ShareBackend}; pub use error::SmbError; pub use path::SmbPath; pub use builder::{Access, Share};