diff --git a/vendor/smb-server/src/conn/mod.rs b/vendor/smb-server/src/conn/mod.rs index b4550aa..4e7872a 100644 --- a/vendor/smb-server/src/conn/mod.rs +++ b/vendor/smb-server/src/conn/mod.rs @@ -24,12 +24,18 @@ pub async fn connection_loop(stream: TcpStream, server: Arc) -> io: server.config.max_write_size, )); let conn_id = server.active_connections.register(&conn).await; - let (tx, rx) = mpsc::channel::(writer::WRITER_CHANNEL); + + // Phase 3: Two channels - responses and notifications + let (response_tx, response_rx) = mpsc::channel::(writer::WRITER_CHANNEL); + let (notification_tx, notification_rx) = mpsc::channel::(writer::NOTIFICATION_CHANNEL); + + // Store notification sender in Connection for oplock breaks + conn.notification_tx.write().await.replace(notification_tx); - let writer_handle = tokio::spawn(writer::writer_task(write_half, rx)); + let writer_handle = tokio::spawn(writer::writer_task(write_half, response_rx, notification_rx)); info!("connection accepted"); - let reader_result = reader::reader_task(read_half, server.clone(), conn.clone(), tx).await; + let reader_result = reader::reader_task(read_half, server.clone(), conn.clone(), response_tx).await; debug!(?reader_result, "reader exited"); // Wait for writer to drain. let _ = writer_handle.await; diff --git a/vendor/smb-server/src/conn/state.rs b/vendor/smb-server/src/conn/state.rs index 37558ae..b4cc4c2 100644 --- a/vendor/smb-server/src/conn/state.rs +++ b/vendor/smb-server/src/conn/state.rs @@ -8,7 +8,7 @@ use std::sync::{Arc, Mutex}; use crate::proto::auth::ntlm::{Identity, NtlmServer}; use crate::proto::crypto::{PreauthIntegrity, SigningAlgo}; use crate::proto::messages::{Dialect, FileId}; -use tokio::sync::RwLock; +use tokio::sync::{mpsc, RwLock}; use uuid::Uuid; use crate::backend::Handle; @@ -16,6 +16,9 @@ use crate::builder::Access; use crate::path::SmbPath; use crate::server::ShareBindings; +/// Phase 3: Notification sender type for server→client async messages. +pub type NotificationSender = mpsc::Sender>; + /// In-flight NTLM acceptor + a `is_raw_ntlmssp` flag (true = raw, false = /// SPNEGO-wrapped). The handler hands the second-round response back in the /// same form the client opened with. @@ -54,6 +57,9 @@ pub struct Connection { /// Monotonic SessionId allocator. next_session_id: AtomicU64, + + /// Phase 3: Notification sender for server→client async messages (oplock breaks). + pub notification_tx: RwLock>, } impl Connection { @@ -70,6 +76,7 @@ impl Connection { pending_auths: RwLock::new(HashMap::new()), session_preauth: RwLock::new(HashMap::new()), next_session_id: AtomicU64::new(1), + notification_tx: RwLock::new(None), } } diff --git a/vendor/smb-server/src/conn/writer.rs b/vendor/smb-server/src/conn/writer.rs index 7eae534..2ca1499 100644 --- a/vendor/smb-server/src/conn/writer.rs +++ b/vendor/smb-server/src/conn/writer.rs @@ -1,5 +1,7 @@ //! Per-connection writer task: serializes responses, applies signing, and //! frames the bytes onto the wire. +//! +//! Phase 3: Added notification channel for server→client async messages. use crate::proto::framing::encode_frame; use tokio::io::{AsyncWriteExt, WriteHalf}; @@ -15,18 +17,45 @@ pub type FramePayload = Vec; /// the dispatcher. pub const WRITER_CHANNEL: usize = 64; -pub async fn writer_task(mut writer: WriteHalf, mut rx: mpsc::Receiver) { - while let Some(payload) = rx.recv().await { - let mut out = Vec::with_capacity(payload.len() + 4); - encode_frame(&payload, &mut out); - if let Err(e) = writer.write_all(&out).await { - error!(error = %e, "writer task: socket write failed"); - return; +/// Notification channel size (Phase 3). +pub const NOTIFICATION_CHANNEL: usize = 32; + +/// Phase 3: Writer task that handles both responses and notifications. +pub async fn writer_task( + mut writer: WriteHalf, + mut response_rx: mpsc::Receiver, + mut notification_rx: mpsc::Receiver, +) { + loop { + tokio::select! { + // Priority: responses first + Some(payload) = response_rx.recv() => { + if let Err(e) = write_frame(&mut writer, &payload).await { + error!(error = %e, "writer task: response write failed"); + return; + } + debug!(len = payload.len(), "wrote response frame"); + } + // Then notifications (oplock breaks, etc.) + Some(payload) = notification_rx.recv() => { + if let Err(e) = write_frame(&mut writer, &payload).await { + error!(error = %e, "writer task: notification write failed"); + return; + } + debug!(len = payload.len(), "wrote notification frame"); + } + else => break, } - debug!(len = out.len(), "wrote frame"); } - // Channel closed — flush and bail. + // Channels closed — flush and bail. if let Err(e) = writer.shutdown().await { debug!(error = %e, "writer shutdown error (best-effort)"); } } + +/// Helper: write a framed payload to the wire. +async fn write_frame(writer: &mut WriteHalf, payload: &[u8]) -> std::io::Result<()> { + let mut out = Vec::with_capacity(payload.len() + 4); + encode_frame(payload, &mut out); + writer.write_all(&out).await +} diff --git a/vendor/smb-server/src/handlers/write.rs b/vendor/smb-server/src/handlers/write.rs index 16735b1..92afbbe 100644 --- a/vendor/smb-server/src/handlers/write.rs +++ b/vendor/smb-server/src/handlers/write.rs @@ -13,7 +13,7 @@ use crate::ntstatus; use crate::server::ServerState; pub async fn handle( - _server: &Arc, + server: &Arc, conn: &Arc, hdr: &Smb2Header, body: &[u8], @@ -41,6 +41,34 @@ pub async fn handle( Some(o) => o, None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), }; + + // Phase 5: Get path and trigger oplock break before write + let (path, share_access) = { + let open = open_arc.read().await; + (open.last_path.clone(), open.share_access) + }; + + // Trigger oplock break for conflicting clients + let notifications = server.oplock_manager.break_oplock( + &path, + share_access, + granted, + ).await; + + // Send notifications to affected clients + for notification in notifications { + // Build SMB2 frame for notification + use crate::proto::framing::encode_frame; + let notification_bytes = notification.write_to_bytes(); + let mut frame = Vec::with_capacity(notification_bytes.len() + 4); + encode_frame(¬ification_bytes, &mut frame); + + // Send via notification channel (if available) + if let Some(tx) = conn.notification_tx.read().await.as_ref() { + let _ = tx.send(frame).await; + } + } + let result = { let open = open_arc.read().await; match open.handle.as_ref() { diff --git a/vendor/smb-server/src/proto/messages/oplock_break.rs b/vendor/smb-server/src/proto/messages/oplock_break.rs index 5aaa139..fa3a332 100644 --- a/vendor/smb-server/src/proto/messages/oplock_break.rs +++ b/vendor/smb-server/src/proto/messages/oplock_break.rs @@ -33,6 +33,12 @@ impl OplockBreakNotification { out.extend_from_slice(&c.into_inner()); Ok(()) } + /// Phase 3: Write to a new Vec (convenience method). + pub fn write_to_bytes(&self) -> Vec { + let mut buf = Vec::new(); + self.write_to(&mut buf).expect("encode notification"); + buf + } } /// SMB2_OPLOCK_BREAK_ACK (MS-SMB2 §2.2.24.1) — same wire shape as the