use dav_server::davpath::DavPath; use dav_server::ls::{DavLock, DavLockSystem, LsFuture}; use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use uuid::Uuid; use xmltree::Element; fn recover_mutex(result: std::sync::LockResult) -> T { match result { Ok(guard) => guard, Err(e) => { log::warn!("Mutex poisoned in webdav_locks, recovering"); e.into_inner() } } } /// Serializable lock representation for JSON persistence #[derive(Debug, Clone, Serialize, Deserialize)] struct PersistedLock { token: String, path: String, principal: Option, owner_xml: Option, timeout_at_epoch: Option, timeout_secs: Option, shared: bool, deep: bool, } impl PersistedLock { fn into_lock(self) -> DavLock { let path = DavPath::from_uri( &self.path.parse::().unwrap_or_else(|_| "/unknown".parse().unwrap()), ) .unwrap_or_else(|_| { DavPath::from_uri(&"/unknown".parse().unwrap()).unwrap() }); DavLock { token: self.token, path: Box::new(path), principal: self.principal, owner: None, timeout_at: self .timeout_at_epoch .map(|secs| UNIX_EPOCH + Duration::from_secs(secs)), timeout: self.timeout_secs.map(Duration::from_secs), shared: self.shared, deep: self.deep, } } } impl From<&DavLock> for PersistedLock { fn from(l: &DavLock) -> Self { Self { token: l.token.clone(), path: l.path.to_string(), principal: l.principal.clone(), owner_xml: l.owner.as_ref().and_then(|e| { let mut buf = Vec::new(); e.write(&mut buf).ok().map(|_| String::from_utf8_lossy(&buf).to_string()) }), timeout_at_epoch: l .timeout_at .and_then(|t| t.duration_since(UNIX_EPOCH).ok()) .map(|d| d.as_secs()), timeout_secs: l.timeout.map(|d| d.as_secs()), shared: l.shared, deep: l.deep, } } } /// Check if two paths overlap for locking purposes. fn paths_overlap(lock_path: &str, request_path: &str, lock_deep: bool, request_deep: bool) -> bool { let lp = lock_path.trim_end_matches('/'); let rp = request_path.trim_end_matches('/'); if lock_deep && request_deep { lp == rp || rp.starts_with(&format!("{}/", lp)) || lp.starts_with(&format!("{}/", rp)) } else if lock_deep { lp == rp || rp.starts_with(&format!("{}/", lp)) } else if request_deep { lp == rp || lp.starts_with(&format!("{}/", rp)) } else { lp == rp } } fn is_expired(lock: &DavLock) -> bool { if let Some(timeout_at) = lock.timeout_at { timeout_at < SystemTime::now() } else { false } } fn cleanup_expired_locks(locks: &mut Vec, locks_file: &PathBuf) { let before = locks.len(); locks.retain(|l| !is_expired(l)); if locks.len() < before { let persisted: Vec = locks.iter().map(PersistedLock::from).collect(); if let Ok(json) = serde_json::to_string(&persisted) { let _ = std::fs::write(locks_file, json); } } } #[derive(Debug, Clone)] pub struct PersistedLs { locks: Arc>>, locks_file: PathBuf, } impl PersistedLs { pub fn new(locks_file: PathBuf) -> Box { let locks = if locks_file.exists() { std::fs::read_to_string(&locks_file) .ok() .and_then(|json| serde_json::from_str::>(&json).ok()) .map(|v| v.into_iter().map(|p| p.into_lock()).collect()) .unwrap_or_default() } else { Vec::new() }; Box::new(Self { locks: Arc::new(Mutex::new(locks)), locks_file, }) } } impl DavLockSystem for PersistedLs { fn lock( &'_ self, path: &DavPath, principal: Option<&str>, owner: Option<&Element>, timeout: Option, shared: bool, deep: bool, ) -> LsFuture<'_, Result> { let locks = self.locks.clone(); let path2 = path.clone(); let locks_file = self.locks_file.clone(); let principal_owned = principal.map(|s| s.to_string()); let owner_owned = owner.map(|o| Box::new(o.clone())); Box::pin(async move { let mut all = recover_mutex(locks.lock()); cleanup_expired_locks(&mut all, &locks_file); let path_str = path2.to_string(); for existing in all.iter() { let ep = existing.path.to_string(); if paths_overlap(&ep, &path_str, existing.deep, deep) { let owned = existing.principal.as_deref() == principal_owned.as_deref(); if !owned && !existing.shared { return Err(existing.clone()); } if !shared && !owned { return Err(existing.clone()); } } } let timeout_at = timeout.map(|d| SystemTime::now() + d); let lock = DavLock { token: Uuid::new_v4().urn().to_string(), path: Box::new(path2), principal: principal_owned, owner: owner_owned, timeout_at, timeout, shared, deep, }; all.push(lock.clone()); let persisted: Vec = all.iter().map(PersistedLock::from).collect(); if let Ok(json) = serde_json::to_string(&persisted) { let _ = std::fs::write(&locks_file, json); } Ok(lock) }) } fn unlock(&'_ self, path: &DavPath, token: &str) -> LsFuture<'_, Result<(), ()>> { let locks = self.locks.clone(); let path_str = path.to_string(); let locks_file = self.locks_file.clone(); let token_owned = token.to_string(); Box::pin(async move { let mut all = recover_mutex(locks.lock()); let before = all.len(); all.retain(|l| !(l.path.to_string() == path_str && l.token == token_owned)); if all.len() == before { return Err(()); } let persisted: Vec = all.iter().map(PersistedLock::from).collect(); if let Ok(json) = serde_json::to_string(&persisted) { let _ = std::fs::write(&locks_file, json); } Ok(()) }) } fn refresh( &'_ self, path: &DavPath, token: &str, timeout: Option, ) -> LsFuture<'_, Result> { let locks = self.locks.clone(); let path_str = path.to_string(); let token_owned = token.to_string(); let locks_file = self.locks_file.clone(); Box::pin(async move { let mut all = recover_mutex(locks.lock()); let existing = all.iter_mut().find(|l| l.path.to_string() == path_str && l.token == token_owned); match existing { Some(lock) => { lock.timeout_at = timeout.map(|d| SystemTime::now() + d); lock.timeout = timeout; let result = lock.clone(); let persisted: Vec = all.iter().map(PersistedLock::from).collect(); if let Ok(json) = serde_json::to_string(&persisted) { let _ = std::fs::write(&locks_file, json); } Ok(result) } None => Err(()), } }) } fn check( &'_ self, path: &DavPath, principal: Option<&str>, ignore_principal: bool, deep: bool, submitted_tokens: &[String], ) -> LsFuture<'_, Result<(), DavLock>> { let locks = self.locks.clone(); let path_str = path.to_string(); let principal_owned = principal.map(|s| s.to_string()); let submitted = submitted_tokens.to_vec(); let locks_file = self.locks_file.clone(); Box::pin(async move { let mut all = recover_mutex(locks.lock()); cleanup_expired_locks(&mut all, &locks_file); for existing in all.iter() { let ep = existing.path.to_string(); if !paths_overlap(&ep, &path_str, existing.deep, deep) { continue; } let owned = submitted.iter().any(|t| t == &existing.token) || (ignore_principal && existing.principal.as_deref() == principal_owned.as_deref()); if !owned && !existing.shared { return Err(existing.clone()); } } Ok(()) }) } fn discover(&'_ self, path: &DavPath) -> LsFuture<'_, Vec> { let locks = self.locks.clone(); let path_str = path.to_string(); let locks_file = self.locks_file.clone(); Box::pin(async move { let mut all = recover_mutex(locks.lock()); cleanup_expired_locks(&mut all, &locks_file); let mut result: Vec = all .iter() .filter(|l| { let lp = l.path.to_string(); paths_overlap(&lp, &path_str, l.deep, false) }) .cloned() .collect(); result.sort_by(|a, b| a.token.cmp(&b.token)); result }) } fn delete(&'_ self, path: &DavPath) -> LsFuture<'_, Result<(), ()>> { let locks = self.locks.clone(); let prefix = path.to_string().trim_end_matches('/').to_string(); let locks_file = self.locks_file.clone(); Box::pin(async move { let mut all = recover_mutex(locks.lock()); let before = all.len(); all.retain(|l| { let lp = l.path.to_string().trim_end_matches('/').to_string(); !(lp == prefix || lp.starts_with(&format!("{}/", prefix))) }); if all.len() < before { let persisted: Vec = all.iter().map(PersistedLock::from).collect(); if let Ok(json) = serde_json::to_string(&persisted) { let _ = std::fs::write(&locks_file, json); } } Ok(()) }) } } #[cfg(test)] mod tests { use super::*; use dav_server::davpath::DavPath; use tempfile::TempDir; fn path(p: &str) -> Box { Box::new( DavPath::from_uri(&p.parse::().unwrap()).unwrap(), ) } #[test] fn test_lock_and_unlock() { let dir = TempDir::new().unwrap(); let ls = PersistedLs::new(dir.path().join("locks.json")); let dpath = path("/test.txt"); let result = rt( ls.lock(&dpath, Some("user"), None, Some(Duration::from_secs(3600)), false, false), ); assert!(result.is_ok()); let lock = result.unwrap(); assert_eq!(lock.shared, false); assert_eq!(lock.deep, false); let result = rt(ls.unlock(&dpath, &lock.token)); assert!(result.is_ok()); } #[test] fn test_exclusive_conflict() { let dir = TempDir::new().unwrap(); let ls = PersistedLs::new(dir.path().join("locks.json")); let dpath = path("/test.txt"); let r1 = rt( ls.lock(&dpath, Some("alice"), None, None, false, false), ); assert!(r1.is_ok()); let r2 = rt( ls.lock(&dpath, Some("bob"), None, None, false, false), ); assert!(r2.is_err()); } #[test] fn test_shared_lock_no_conflict() { let dir = TempDir::new().unwrap(); let ls = PersistedLs::new(dir.path().join("locks.json")); let dpath = path("/test.txt"); let r1 = rt( ls.lock(&dpath, Some("alice"), None, None, true, false), ); assert!(r1.is_ok()); let r2 = rt( ls.lock(&dpath, Some("bob"), None, None, true, false), ); assert!(r2.is_ok()); } #[test] fn test_persistence() { let dir = TempDir::new().unwrap(); let locks_file = dir.path().join("locks.json"); let lock_token; { let ls = PersistedLs::new(locks_file.clone()); let dpath = path("/test.txt"); let result = rt( ls.lock(&dpath, Some("user"), None, Some(Duration::from_secs(3600)), false, false), ); assert!(result.is_ok()); lock_token = result.unwrap().token; } let ls2 = PersistedLs::new(locks_file.clone()); let dpath = path("/test.txt"); let discovered = rt(ls2.discover(&dpath)); assert_eq!(discovered.len(), 1); assert_eq!(discovered[0].token, lock_token); } #[test] fn test_deep_lock_conflict() { let dir = TempDir::new().unwrap(); let ls = PersistedLs::new(dir.path().join("locks.json")); let parent = path("/docs"); let r1 = rt( ls.lock(&parent, Some("alice"), None, None, true, true), ); assert!(r1.is_ok()); let child = path("/docs/sub/file.txt"); let r2 = rt( ls.lock(&child, Some("bob"), None, None, false, false), ); assert!(r2.is_err()); } fn rt(fut: LsFuture<'_, T>) -> T { tokio::runtime::Runtime::new().unwrap().block_on(fut) } }