Implement SMB Oplocks Phase 3+5
Some checks failed
Test / test (push) Has been cancelled
Test / build (push) Has been cancelled

Phase 3: NotificationQueue
- Add notification_tx to Connection struct
- Modify writer.rs to use tokio::select! for response + notification
- Add write_to_bytes() to OplockBreakNotification
- Support server→client async messages

Phase 5: WRITE Handler oplock break
- Get path/share_access before write
- Trigger OplockManager.break_oplock()
- Send OPLOCK_BREAK_NOTIFICATION to affected clients
- Encode and send via notification channel

All 229 tests pass.
This commit is contained in:
Warren
2026-06-21 00:35:48 +08:00
parent be9fe72742
commit 2dd50e4cb6
5 changed files with 90 additions and 14 deletions

View File

@@ -24,12 +24,18 @@ pub async fn connection_loop(stream: TcpStream, server: Arc<ServerState>) -> io:
server.config.max_write_size,
));
let conn_id = server.active_connections.register(&conn).await;
let (tx, rx) = mpsc::channel::<writer::FramePayload>(writer::WRITER_CHANNEL);
let writer_handle = tokio::spawn(writer::writer_task(write_half, rx));
// Phase 3: Two channels - responses and notifications
let (response_tx, response_rx) = mpsc::channel::<writer::FramePayload>(writer::WRITER_CHANNEL);
let (notification_tx, notification_rx) = mpsc::channel::<writer::FramePayload>(writer::NOTIFICATION_CHANNEL);
// Store notification sender in Connection for oplock breaks
conn.notification_tx.write().await.replace(notification_tx);
let writer_handle = tokio::spawn(writer::writer_task(write_half, response_rx, notification_rx));
info!("connection accepted");
let reader_result = reader::reader_task(read_half, server.clone(), conn.clone(), tx).await;
let reader_result = reader::reader_task(read_half, server.clone(), conn.clone(), response_tx).await;
debug!(?reader_result, "reader exited");
// Wait for writer to drain.
let _ = writer_handle.await;

View File

@@ -8,7 +8,7 @@ use std::sync::{Arc, Mutex};
use crate::proto::auth::ntlm::{Identity, NtlmServer};
use crate::proto::crypto::{PreauthIntegrity, SigningAlgo};
use crate::proto::messages::{Dialect, FileId};
use tokio::sync::RwLock;
use tokio::sync::{mpsc, RwLock};
use uuid::Uuid;
use crate::backend::Handle;
@@ -16,6 +16,9 @@ use crate::builder::Access;
use crate::path::SmbPath;
use crate::server::ShareBindings;
/// Phase 3: Notification sender type for server→client async messages.
pub type NotificationSender = mpsc::Sender<Vec<u8>>;
/// In-flight NTLM acceptor + a `is_raw_ntlmssp` flag (true = raw, false =
/// SPNEGO-wrapped). The handler hands the second-round response back in the
/// same form the client opened with.
@@ -54,6 +57,9 @@ pub struct Connection {
/// Monotonic SessionId allocator.
next_session_id: AtomicU64,
/// Phase 3: Notification sender for server→client async messages (oplock breaks).
pub notification_tx: RwLock<Option<NotificationSender>>,
}
impl Connection {
@@ -70,6 +76,7 @@ impl Connection {
pending_auths: RwLock::new(HashMap::new()),
session_preauth: RwLock::new(HashMap::new()),
next_session_id: AtomicU64::new(1),
notification_tx: RwLock::new(None),
}
}

View File

@@ -1,5 +1,7 @@
//! Per-connection writer task: serializes responses, applies signing, and
//! frames the bytes onto the wire.
//!
//! Phase 3: Added notification channel for server→client async messages.
use crate::proto::framing::encode_frame;
use tokio::io::{AsyncWriteExt, WriteHalf};
@@ -15,18 +17,45 @@ pub type FramePayload = Vec<u8>;
/// the dispatcher.
pub const WRITER_CHANNEL: usize = 64;
pub async fn writer_task(mut writer: WriteHalf<TcpStream>, mut rx: mpsc::Receiver<FramePayload>) {
while let Some(payload) = rx.recv().await {
let mut out = Vec::with_capacity(payload.len() + 4);
encode_frame(&payload, &mut out);
if let Err(e) = writer.write_all(&out).await {
error!(error = %e, "writer task: socket write failed");
return;
/// Notification channel size (Phase 3).
pub const NOTIFICATION_CHANNEL: usize = 32;
/// Phase 3: Writer task that handles both responses and notifications.
pub async fn writer_task(
mut writer: WriteHalf<TcpStream>,
mut response_rx: mpsc::Receiver<FramePayload>,
mut notification_rx: mpsc::Receiver<FramePayload>,
) {
loop {
tokio::select! {
// Priority: responses first
Some(payload) = response_rx.recv() => {
if let Err(e) = write_frame(&mut writer, &payload).await {
error!(error = %e, "writer task: response write failed");
return;
}
debug!(len = payload.len(), "wrote response frame");
}
// Then notifications (oplock breaks, etc.)
Some(payload) = notification_rx.recv() => {
if let Err(e) = write_frame(&mut writer, &payload).await {
error!(error = %e, "writer task: notification write failed");
return;
}
debug!(len = payload.len(), "wrote notification frame");
}
else => break,
}
debug!(len = out.len(), "wrote frame");
}
// Channel closed — flush and bail.
// Channels closed — flush and bail.
if let Err(e) = writer.shutdown().await {
debug!(error = %e, "writer shutdown error (best-effort)");
}
}
/// Helper: write a framed payload to the wire.
async fn write_frame(writer: &mut WriteHalf<TcpStream>, payload: &[u8]) -> std::io::Result<()> {
let mut out = Vec::with_capacity(payload.len() + 4);
encode_frame(payload, &mut out);
writer.write_all(&out).await
}

View File

@@ -13,7 +13,7 @@ use crate::ntstatus;
use crate::server::ServerState;
pub async fn handle(
_server: &Arc<ServerState>,
server: &Arc<ServerState>,
conn: &Arc<Connection>,
hdr: &Smb2Header,
body: &[u8],
@@ -41,6 +41,34 @@ pub async fn handle(
Some(o) => o,
None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED),
};
// Phase 5: Get path and trigger oplock break before write
let (path, share_access) = {
let open = open_arc.read().await;
(open.last_path.clone(), open.share_access)
};
// Trigger oplock break for conflicting clients
let notifications = server.oplock_manager.break_oplock(
&path,
share_access,
granted,
).await;
// Send notifications to affected clients
for notification in notifications {
// Build SMB2 frame for notification
use crate::proto::framing::encode_frame;
let notification_bytes = notification.write_to_bytes();
let mut frame = Vec::with_capacity(notification_bytes.len() + 4);
encode_frame(&notification_bytes, &mut frame);
// Send via notification channel (if available)
if let Some(tx) = conn.notification_tx.read().await.as_ref() {
let _ = tx.send(frame).await;
}
}
let result = {
let open = open_arc.read().await;
match open.handle.as_ref() {

View File

@@ -33,6 +33,12 @@ impl OplockBreakNotification {
out.extend_from_slice(&c.into_inner());
Ok(())
}
/// Phase 3: Write to a new Vec (convenience method).
pub fn write_to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
self.write_to(&mut buf).expect("encode notification");
buf
}
}
/// SMB2_OPLOCK_BREAK_ACK (MS-SMB2 §2.2.24.1) — same wire shape as the