Implement SMB Oplocks Phase 3+5
Phase 3: NotificationQueue - Add notification_tx to Connection struct - Modify writer.rs to use tokio::select! for response + notification - Add write_to_bytes() to OplockBreakNotification - Support server→client async messages Phase 5: WRITE Handler oplock break - Get path/share_access before write - Trigger OplockManager.break_oplock() - Send OPLOCK_BREAK_NOTIFICATION to affected clients - Encode and send via notification channel All 229 tests pass.
This commit is contained in:
12
vendor/smb-server/src/conn/mod.rs
vendored
12
vendor/smb-server/src/conn/mod.rs
vendored
@@ -24,12 +24,18 @@ pub async fn connection_loop(stream: TcpStream, server: Arc<ServerState>) -> io:
|
|||||||
server.config.max_write_size,
|
server.config.max_write_size,
|
||||||
));
|
));
|
||||||
let conn_id = server.active_connections.register(&conn).await;
|
let conn_id = server.active_connections.register(&conn).await;
|
||||||
let (tx, rx) = mpsc::channel::<writer::FramePayload>(writer::WRITER_CHANNEL);
|
|
||||||
|
|
||||||
let writer_handle = tokio::spawn(writer::writer_task(write_half, rx));
|
// Phase 3: Two channels - responses and notifications
|
||||||
|
let (response_tx, response_rx) = mpsc::channel::<writer::FramePayload>(writer::WRITER_CHANNEL);
|
||||||
|
let (notification_tx, notification_rx) = mpsc::channel::<writer::FramePayload>(writer::NOTIFICATION_CHANNEL);
|
||||||
|
|
||||||
|
// Store notification sender in Connection for oplock breaks
|
||||||
|
conn.notification_tx.write().await.replace(notification_tx);
|
||||||
|
|
||||||
|
let writer_handle = tokio::spawn(writer::writer_task(write_half, response_rx, notification_rx));
|
||||||
|
|
||||||
info!("connection accepted");
|
info!("connection accepted");
|
||||||
let reader_result = reader::reader_task(read_half, server.clone(), conn.clone(), tx).await;
|
let reader_result = reader::reader_task(read_half, server.clone(), conn.clone(), response_tx).await;
|
||||||
debug!(?reader_result, "reader exited");
|
debug!(?reader_result, "reader exited");
|
||||||
// Wait for writer to drain.
|
// Wait for writer to drain.
|
||||||
let _ = writer_handle.await;
|
let _ = writer_handle.await;
|
||||||
|
|||||||
9
vendor/smb-server/src/conn/state.rs
vendored
9
vendor/smb-server/src/conn/state.rs
vendored
@@ -8,7 +8,7 @@ use std::sync::{Arc, Mutex};
|
|||||||
use crate::proto::auth::ntlm::{Identity, NtlmServer};
|
use crate::proto::auth::ntlm::{Identity, NtlmServer};
|
||||||
use crate::proto::crypto::{PreauthIntegrity, SigningAlgo};
|
use crate::proto::crypto::{PreauthIntegrity, SigningAlgo};
|
||||||
use crate::proto::messages::{Dialect, FileId};
|
use crate::proto::messages::{Dialect, FileId};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::{mpsc, RwLock};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::backend::Handle;
|
use crate::backend::Handle;
|
||||||
@@ -16,6 +16,9 @@ use crate::builder::Access;
|
|||||||
use crate::path::SmbPath;
|
use crate::path::SmbPath;
|
||||||
use crate::server::ShareBindings;
|
use crate::server::ShareBindings;
|
||||||
|
|
||||||
|
/// Phase 3: Notification sender type for server→client async messages.
|
||||||
|
pub type NotificationSender = mpsc::Sender<Vec<u8>>;
|
||||||
|
|
||||||
/// In-flight NTLM acceptor + a `is_raw_ntlmssp` flag (true = raw, false =
|
/// In-flight NTLM acceptor + a `is_raw_ntlmssp` flag (true = raw, false =
|
||||||
/// SPNEGO-wrapped). The handler hands the second-round response back in the
|
/// SPNEGO-wrapped). The handler hands the second-round response back in the
|
||||||
/// same form the client opened with.
|
/// same form the client opened with.
|
||||||
@@ -54,6 +57,9 @@ pub struct Connection {
|
|||||||
|
|
||||||
/// Monotonic SessionId allocator.
|
/// Monotonic SessionId allocator.
|
||||||
next_session_id: AtomicU64,
|
next_session_id: AtomicU64,
|
||||||
|
|
||||||
|
/// Phase 3: Notification sender for server→client async messages (oplock breaks).
|
||||||
|
pub notification_tx: RwLock<Option<NotificationSender>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Connection {
|
impl Connection {
|
||||||
@@ -70,6 +76,7 @@ impl Connection {
|
|||||||
pending_auths: RwLock::new(HashMap::new()),
|
pending_auths: RwLock::new(HashMap::new()),
|
||||||
session_preauth: RwLock::new(HashMap::new()),
|
session_preauth: RwLock::new(HashMap::new()),
|
||||||
next_session_id: AtomicU64::new(1),
|
next_session_id: AtomicU64::new(1),
|
||||||
|
notification_tx: RwLock::new(None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
45
vendor/smb-server/src/conn/writer.rs
vendored
45
vendor/smb-server/src/conn/writer.rs
vendored
@@ -1,5 +1,7 @@
|
|||||||
//! Per-connection writer task: serializes responses, applies signing, and
|
//! Per-connection writer task: serializes responses, applies signing, and
|
||||||
//! frames the bytes onto the wire.
|
//! frames the bytes onto the wire.
|
||||||
|
//!
|
||||||
|
//! Phase 3: Added notification channel for server→client async messages.
|
||||||
|
|
||||||
use crate::proto::framing::encode_frame;
|
use crate::proto::framing::encode_frame;
|
||||||
use tokio::io::{AsyncWriteExt, WriteHalf};
|
use tokio::io::{AsyncWriteExt, WriteHalf};
|
||||||
@@ -15,18 +17,45 @@ pub type FramePayload = Vec<u8>;
|
|||||||
/// the dispatcher.
|
/// the dispatcher.
|
||||||
pub const WRITER_CHANNEL: usize = 64;
|
pub const WRITER_CHANNEL: usize = 64;
|
||||||
|
|
||||||
pub async fn writer_task(mut writer: WriteHalf<TcpStream>, mut rx: mpsc::Receiver<FramePayload>) {
|
/// Notification channel size (Phase 3).
|
||||||
while let Some(payload) = rx.recv().await {
|
pub const NOTIFICATION_CHANNEL: usize = 32;
|
||||||
let mut out = Vec::with_capacity(payload.len() + 4);
|
|
||||||
encode_frame(&payload, &mut out);
|
/// Phase 3: Writer task that handles both responses and notifications.
|
||||||
if let Err(e) = writer.write_all(&out).await {
|
pub async fn writer_task(
|
||||||
error!(error = %e, "writer task: socket write failed");
|
mut writer: WriteHalf<TcpStream>,
|
||||||
|
mut response_rx: mpsc::Receiver<FramePayload>,
|
||||||
|
mut notification_rx: mpsc::Receiver<FramePayload>,
|
||||||
|
) {
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
// Priority: responses first
|
||||||
|
Some(payload) = response_rx.recv() => {
|
||||||
|
if let Err(e) = write_frame(&mut writer, &payload).await {
|
||||||
|
error!(error = %e, "writer task: response write failed");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
debug!(len = out.len(), "wrote frame");
|
debug!(len = payload.len(), "wrote response frame");
|
||||||
}
|
}
|
||||||
// Channel closed — flush and bail.
|
// Then notifications (oplock breaks, etc.)
|
||||||
|
Some(payload) = notification_rx.recv() => {
|
||||||
|
if let Err(e) = write_frame(&mut writer, &payload).await {
|
||||||
|
error!(error = %e, "writer task: notification write failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
debug!(len = payload.len(), "wrote notification frame");
|
||||||
|
}
|
||||||
|
else => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Channels closed — flush and bail.
|
||||||
if let Err(e) = writer.shutdown().await {
|
if let Err(e) = writer.shutdown().await {
|
||||||
debug!(error = %e, "writer shutdown error (best-effort)");
|
debug!(error = %e, "writer shutdown error (best-effort)");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Helper: write a framed payload to the wire.
|
||||||
|
async fn write_frame(writer: &mut WriteHalf<TcpStream>, payload: &[u8]) -> std::io::Result<()> {
|
||||||
|
let mut out = Vec::with_capacity(payload.len() + 4);
|
||||||
|
encode_frame(payload, &mut out);
|
||||||
|
writer.write_all(&out).await
|
||||||
|
}
|
||||||
|
|||||||
30
vendor/smb-server/src/handlers/write.rs
vendored
30
vendor/smb-server/src/handlers/write.rs
vendored
@@ -13,7 +13,7 @@ use crate::ntstatus;
|
|||||||
use crate::server::ServerState;
|
use crate::server::ServerState;
|
||||||
|
|
||||||
pub async fn handle(
|
pub async fn handle(
|
||||||
_server: &Arc<ServerState>,
|
server: &Arc<ServerState>,
|
||||||
conn: &Arc<Connection>,
|
conn: &Arc<Connection>,
|
||||||
hdr: &Smb2Header,
|
hdr: &Smb2Header,
|
||||||
body: &[u8],
|
body: &[u8],
|
||||||
@@ -41,6 +41,34 @@ pub async fn handle(
|
|||||||
Some(o) => o,
|
Some(o) => o,
|
||||||
None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED),
|
None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Phase 5: Get path and trigger oplock break before write
|
||||||
|
let (path, share_access) = {
|
||||||
|
let open = open_arc.read().await;
|
||||||
|
(open.last_path.clone(), open.share_access)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Trigger oplock break for conflicting clients
|
||||||
|
let notifications = server.oplock_manager.break_oplock(
|
||||||
|
&path,
|
||||||
|
share_access,
|
||||||
|
granted,
|
||||||
|
).await;
|
||||||
|
|
||||||
|
// Send notifications to affected clients
|
||||||
|
for notification in notifications {
|
||||||
|
// Build SMB2 frame for notification
|
||||||
|
use crate::proto::framing::encode_frame;
|
||||||
|
let notification_bytes = notification.write_to_bytes();
|
||||||
|
let mut frame = Vec::with_capacity(notification_bytes.len() + 4);
|
||||||
|
encode_frame(¬ification_bytes, &mut frame);
|
||||||
|
|
||||||
|
// Send via notification channel (if available)
|
||||||
|
if let Some(tx) = conn.notification_tx.read().await.as_ref() {
|
||||||
|
let _ = tx.send(frame).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let result = {
|
let result = {
|
||||||
let open = open_arc.read().await;
|
let open = open_arc.read().await;
|
||||||
match open.handle.as_ref() {
|
match open.handle.as_ref() {
|
||||||
|
|||||||
@@ -33,6 +33,12 @@ impl OplockBreakNotification {
|
|||||||
out.extend_from_slice(&c.into_inner());
|
out.extend_from_slice(&c.into_inner());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
/// Phase 3: Write to a new Vec (convenience method).
|
||||||
|
pub fn write_to_bytes(&self) -> Vec<u8> {
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
self.write_to(&mut buf).expect("encode notification");
|
||||||
|
buf
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// SMB2_OPLOCK_BREAK_ACK (MS-SMB2 §2.2.24.1) — same wire shape as the
|
/// SMB2_OPLOCK_BREAK_ACK (MS-SMB2 §2.2.24.1) — same wire shape as the
|
||||||
|
|||||||
Reference in New Issue
Block a user