Files
markbase/vendor/smb-server/src/dispatch.rs
Warren 57fd6a475f macOS Time Machine AFP monitoring: backup_time update on file modification
- Added afp_monitor.rs module to track AFP_AfpInfo backup_time
- Open struct now has 'modified' flag to track file modifications
- write.rs sets modified=true on successful write
- close.rs calls AfpMonitor::update_backup_time() on modified files
- create.rs calls AfpMonitor::init_afp_info() on new file creation
- AFP_AfpInfo stored as xattr com.apple.aapl.AfpInfo
- backup_time updated to current epoch time on modification

Also includes:
- LZ4 compression using lz4_flex crate
- Case sensitivity conditional on backend capabilities
- LDAP cfg feature gate fix
- RAID rebuild reconstruction implementation
- DOS attributes xattr persistence
- Snapshot disk persistence

Tests: 201 smb-server, 452 markbase-core (653 total)
2026-06-24 00:46:33 +08:00

1081 lines
40 KiB
Rust

//! Per-frame dispatch: parse header, route to handler, sign response, encode.
use std::sync::Arc;
use crate::proto::auth::ntlm::Identity;
use crate::proto::crypto::{PreauthIntegrity, sign};
use crate::proto::crypto::encryption::{Smb3Encryption, CipherAlgorithm, TransformHeader};
use crate::proto::header::{
Command, HeaderTail, SMB2_FLAGS_ASYNC_COMMAND, SMB2_FLAGS_RELATED_OPERATIONS,
SMB2_FLAGS_SERVER_TO_REDIR, SMB2_FLAGS_SIGNED, SMB2_HEADER_LEN, Smb2Header,
};
use crate::proto::messages::ErrorResponse;
use tracing::{Instrument, debug, debug_span, error, warn};
use crate::conn::state::Connection;
use crate::handlers;
use crate::ntstatus;
use crate::server::ServerState;
/// Result of a handler: a complete (unsigned) response payload + the NTSTATUS
/// to set in the header. The dispatcher patches the header, applies signing
/// (if required), and ships the bytes.
pub struct HandlerResponse {
/// Bytes after the SMB2 header — the body. The handler owns body
/// construction.
pub body: Vec<u8>,
/// NTSTATUS for the response header.
pub status: u32,
/// Optional override for `tree_id` on the response header (e.g.
/// TREE_CONNECT returns the freshly minted tree id).
pub override_tree_id: Option<u32>,
/// Optional override for `session_id` on the response header (e.g.
/// SESSION_SETUP returns the freshly minted session id).
pub override_session_id: Option<u64>,
/// If true, the dispatcher will not sign the response. Used for
/// pre-session-setup messages where no key exists yet.
pub skip_signing: bool,
/// If set, take the per-session 3.1.1 preauth snapshot after hashing the
/// SESSION_SETUP request but before hashing the response. Set by
/// SESSION_SETUP on the round that produces STATUS_SUCCESS, so the
/// session's KDF context can use the snapshot.
pub take_preauth_snapshot_for_session: Option<u64>,
}
impl HandlerResponse {
pub fn ok(body: Vec<u8>) -> Self {
Self {
body,
status: ntstatus::STATUS_SUCCESS,
override_tree_id: None,
override_session_id: None,
skip_signing: false,
take_preauth_snapshot_for_session: None,
}
}
pub fn err(status: u32) -> Self {
let er = ErrorResponse::status(status);
let mut buf = Vec::new();
er.write_to(&mut buf).expect("error response encodes");
Self {
body: buf,
status,
override_tree_id: None,
override_session_id: None,
skip_signing: false,
take_preauth_snapshot_for_session: None,
}
}
}
/// Top-level frame dispatch. Returns the bytes to push into the writer
/// channel, or `None` if the request elicits no response (CANCEL).
pub async fn dispatch_frame(
server: &Arc<ServerState>,
conn: &Arc<Connection>,
frame: &[u8],
) -> Option<Vec<u8>> {
// SMB1 multi-protocol bootstrap (MS-SMB2 §3.3.5.3.1). The only SMB1 we
// accept: a NEGOTIATE_REQUEST listing "SMB 2.???" or "SMB 2.002".
// Reply with an SMB2 NEGOTIATE response and the client follows up with
// a real SMB2 NEGOTIATE.
if let Some(bytes) = handle_smb1_multi_protocol(server, conn, frame).await {
return Some(bytes);
}
// SMB3 encryption check: TRANSFORM_HEADER magic (0x534D4272 = "SMBr")
if frame.len() >= 4 {
let magic = u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]);
if magic == 0x534D4272 {
// Encrypted packet - decrypt and process
return handle_encrypted_frame(server, conn, frame).await;
}
}
if frame.len() < SMB2_HEADER_LEN {
warn!(len = frame.len(), "frame too short for SMB2 header");
return None;
}
let mut sub_offset = 0;
let mut responses = Vec::new();
let mut prev_session_id = 0;
let mut prev_tree_id = 0;
let mut prev_create_file_id = None;
while sub_offset < frame.len() {
let available = &frame[sub_offset..];
if available.len() < SMB2_HEADER_LEN {
warn!(remaining = available.len(), "compound tail too short");
return None;
}
let (mut req_hdr, _) = match Smb2Header::parse(available) {
Ok(p) => p,
Err(e) => {
warn!(error = %e, "failed to parse compound sub-header");
return None;
}
};
let next = req_hdr.next_command as usize;
let sub_len = if next == 0 {
available.len()
} else if next < SMB2_HEADER_LEN || next > available.len() {
warn!(
next,
remaining = available.len(),
"invalid compound NextCommand"
);
return None;
} else {
next
};
let mut sub_frame = available[..sub_len].to_vec();
if req_hdr.flags & SMB2_FLAGS_RELATED_OPERATIONS != 0 {
inherit_related_context(
&mut sub_frame,
&mut req_hdr,
prev_session_id,
prev_tree_id,
prev_create_file_id,
);
}
prev_session_id = req_hdr.session_id;
prev_tree_id = req_hdr.tree_id().unwrap_or(0);
if let Some(response) = dispatch_one(server, conn, &sub_frame).await {
if req_hdr.command == Command::Create {
prev_create_file_id = capture_create_file_id(&response);
}
responses.push(response);
}
if next == 0 {
break;
}
sub_offset += next;
}
if responses.is_empty() {
return None;
}
Some(stitch_responses(conn, responses).await)
}
/// Handle SMB3 encrypted frame (TRANSFORM_HEADER)
async fn handle_encrypted_frame(
server: &Arc<ServerState>,
conn: &Arc<Connection>,
encrypted_frame: &[u8],
) -> Option<Vec<u8>> {
// Parse TRANSFORM_HEADER
let header = match TransformHeader::read_from_bytes(encrypted_frame) {
Ok(h) => h,
Err(e) => {
warn!(error = %e, "failed to parse TRANSFORM_HEADER");
return None;
}
};
// Get session encryption key
let sessions = conn.sessions.read().await;
let session_arc = match sessions.get(&header.session_id).cloned() {
Some(s) => s,
None => {
warn!(session_id = header.session_id, "session not found for encrypted packet");
return None;
}
};
let session = session_arc.read().await;
let encryption_enabled = session.encryption_enabled;
let encryption_key = session.encryption_key;
let encryption_cipher = session.encryption_cipher.unwrap_or(CipherAlgorithm::Aes128Gcm);
if !encryption_enabled {
warn!("session does not have encryption enabled");
return None;
}
let encryption_key = match encryption_key {
Some(k) => k,
None => {
warn!("session has no encryption key");
return None;
}
};
// Decrypt packet using the session's negotiated cipher
let encryption = match Smb3Encryption::new(&encryption_key, encryption_cipher) {
Ok(e) => e,
Err(e) => {
warn!(error = %e, "failed to create encryption context");
return None;
}
};
let decrypted = match encryption.decrypt_packet(encrypted_frame) {
Ok(d) => d,
Err(e) => {
warn!(error = %e, "failed to decrypt packet");
return None;
}
};
debug!(session_id = header.session_id, "decrypted SMB3 packet");
// Process decrypted frame (non-recursive: call dispatch_one directly)
if decrypted.len() < SMB2_HEADER_LEN {
warn!("decrypted frame too short");
return None;
}
let response = dispatch_one(server, conn, &decrypted).await;
// Encrypt response if needed
if let Some(resp_bytes) = response {
if encryption_enabled {
let encrypted_response = match encryption.encrypt_packet(&resp_bytes, header.session_id) {
Ok(e) => e,
Err(e) => {
warn!(error = %e, "failed to encrypt response");
return Some(resp_bytes);
}
};
debug!("encrypted response packet");
return Some(encrypted_response);
}
return Some(resp_bytes);
}
None
}
fn inherit_related_context(
sub_frame: &mut [u8],
req_hdr: &mut Smb2Header,
prev_session_id: u64,
prev_tree_id: u32,
prev_create_file_id: Option<[u8; 16]>,
) {
if read_u64(sub_frame, 0x28) == u64::MAX {
sub_frame[0x28..0x30].copy_from_slice(&prev_session_id.to_le_bytes());
req_hdr.session_id = prev_session_id;
}
if read_u32(sub_frame, 0x24) == u32::MAX {
sub_frame[0x24..0x28].copy_from_slice(&prev_tree_id.to_le_bytes());
if let HeaderTail::Sync { reserved, .. } = req_hdr.tail {
req_hdr.tail = HeaderTail::Sync {
reserved,
tree_id: prev_tree_id,
};
}
}
let Some(file_id) = prev_create_file_id else {
return;
};
let Some(body_offset) = file_id_body_offset(req_hdr.command) else {
return;
};
let offset = SMB2_HEADER_LEN + body_offset;
if offset + 16 <= sub_frame.len()
&& read_u64(sub_frame, offset) == u64::MAX
&& read_u64(sub_frame, offset + 8) == u64::MAX
{
sub_frame[offset..offset + 16].copy_from_slice(&file_id);
}
}
fn file_id_body_offset(command: Command) -> Option<usize> {
match command {
Command::Close
| Command::Flush
| Command::Lock
| Command::Ioctl
| Command::QueryDirectory
| Command::ChangeNotify
| Command::OplockBreak => Some(8),
Command::Read | Command::Write => Some(16),
Command::QueryInfo => Some(24),
Command::SetInfo => Some(16),
_ => None,
}
}
fn capture_create_file_id(response: &[u8]) -> Option<[u8; 16]> {
if response.len() < SMB2_HEADER_LEN + 80 || read_u32(response, 0x08) != ntstatus::STATUS_SUCCESS
{
return None;
}
let mut file_id = [0u8; 16];
let offset = SMB2_HEADER_LEN + 64;
file_id.copy_from_slice(&response[offset..offset + 16]);
Some(file_id)
}
async fn stitch_responses(conn: &Arc<Connection>, responses: Vec<Vec<u8>>) -> Vec<u8> {
let mut out = Vec::new();
let mut ranges = Vec::with_capacity(responses.len());
let response_count = responses.len();
for (index, mut response) in responses.into_iter().enumerate() {
let start = out.len();
let actual_len = response.len();
if index + 1 < response_count {
let next = align_8(actual_len);
response[0x14..0x18].copy_from_slice(&(next as u32).to_le_bytes());
}
out.extend_from_slice(&response);
ranges.push((start, actual_len));
if index + 1 < response_count {
out.resize(start + align_8(actual_len), 0);
}
}
let algo = *conn.signing_algo.read().await;
for (start, len) in ranges {
let flags = read_u32(&out, start + 0x10);
if flags & SMB2_FLAGS_SIGNED == 0 {
continue;
}
let session_id = read_u64(&out, start + 0x28);
let key = {
let sessions = conn.sessions.read().await;
sessions.get(&session_id).cloned()
};
let Some(session) = key else {
continue;
};
let session = session.read().await;
if matches!(session.identity, Identity::Anonymous) {
continue;
}
let signing_key = session.signing_key;
drop(session);
if let Err(e) = sign(&mut out[start..start + len], &signing_key, algo) {
error!(error = %e, "failed to sign compound response");
}
}
out
}
const fn align_8(n: usize) -> usize {
(n + 7) & !7
}
fn read_u32(buf: &[u8], offset: usize) -> u32 {
let mut bytes = [0u8; 4];
bytes.copy_from_slice(&buf[offset..offset + 4]);
u32::from_le_bytes(bytes)
}
fn read_u64(buf: &[u8], offset: usize) -> u64 {
let mut bytes = [0u8; 8];
bytes.copy_from_slice(&buf[offset..offset + 8]);
u64::from_le_bytes(bytes)
}
async fn dispatch_one(
server: &Arc<ServerState>,
conn: &Arc<Connection>,
frame: &[u8],
) -> Option<Vec<u8>> {
let (req_hdr, body_bytes) = match Smb2Header::parse(frame) {
Ok(p) => p,
Err(e) => {
warn!(error = %e, "failed to parse header");
return None;
}
};
let cmd = req_hdr.command;
let mid = req_hdr.message_id;
let sid = req_hdr.session_id;
let tid = req_hdr.tree_id().unwrap_or(0);
let span = debug_span!("dispatch", cmd = ?cmd, mid, sid, tid);
async move {
debug!("dispatch start");
// Verify signature on incoming request (when applicable).
if let Err(status) = verify_request_signature(server, conn, &req_hdr, frame).await {
return Some(build_response_bytes(conn, &req_hdr, HandlerResponse::err(status)).await);
}
// CANCEL is fire-and-forget — no response.
if cmd == Command::Cancel {
debug!("CANCEL received; no response");
return None;
}
let dialect = *conn.dialect.read().await;
let mut session_preauth = None;
// 3.1.1 preauth is connection-scoped for NEGOTIATE, then per
// SESSION_SETUP authentication exchange.
if cmd == Command::Negotiate {
let mut p = conn
.preauth
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
p.update(frame);
} else if cmd == Command::SessionSetup
&& dialect == Some(crate::proto::messages::Dialect::Smb311)
{
let mut p = take_session_preauth(conn, req_hdr.session_id).await;
p.update(frame);
session_preauth = Some(p);
}
let resp = handlers::dispatch_command(server, conn, &req_hdr, body_bytes).await;
debug!(
command = ?req_hdr.command,
status = resp.status,
body_len = resp.body.len(),
"SMB2 handler response"
);
// If the handler asked for a preauth snapshot (3.1.1), take it now.
if let Some(sid) = resp.take_preauth_snapshot_for_session {
let snap = session_preauth
.as_ref()
.expect("SMB 3.1.1 SessionSetup snapshot requires per-session preauth")
.snapshot();
// Stash on the session — the handler already created it.
let sessions = conn.sessions.read().await;
if let Some(sess_arc) = sessions.get(&sid) {
let mut sess = sess_arc.write().await;
sess.preauth_snapshot = Some(snap);
// For 3.1.1, recompute signing key now that we have the snapshot.
let dialect = *conn.dialect.read().await;
if dialect == Some(crate::proto::messages::Dialect::Smb311) {
sess.signing_key =
crate::proto::crypto::signing_key_311(&sess.session_base_key, &snap);
}
}
}
let bytes = build_response_bytes(conn, &req_hdr, resp).await;
if cmd == Command::Negotiate {
let mut p = conn
.preauth
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
p.update(&bytes);
} else if cmd == Command::SessionSetup
&& dialect == Some(crate::proto::messages::Dialect::Smb311)
{
if read_u32(&bytes, 0x08) == ntstatus::STATUS_MORE_PROCESSING_REQUIRED {
if let Some(mut p) = session_preauth {
p.update(&bytes);
let sid = read_u64(&bytes, 0x28);
conn.session_preauth.write().await.insert(sid, p);
}
} else {
conn.session_preauth
.write()
.await
.remove(&req_hdr.session_id);
}
}
Some(bytes)
}
.instrument(span)
.await
}
async fn take_session_preauth(conn: &Arc<Connection>, session_id: u64) -> PreauthIntegrity {
if session_id != 0
&& let Some(preauth) = conn.session_preauth.write().await.remove(&session_id)
{
return preauth;
}
conn.preauth
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.clone()
}
async fn verify_request_signature(
_server: &Arc<ServerState>,
conn: &Arc<Connection>,
hdr: &Smb2Header,
frame: &[u8],
) -> Result<(), u32> {
if hdr.command == Command::Negotiate {
return Ok(());
}
if hdr.session_id == 0 {
return Ok(());
}
let sessions = conn.sessions.read().await;
let sess_arc = match sessions.get(&hdr.session_id) {
Some(s) => s.clone(),
None => {
// Unknown session.
if hdr.flags & SMB2_FLAGS_SIGNED == 0 {
return Ok(());
}
return Err(ntstatus::STATUS_USER_SESSION_DELETED);
}
};
drop(sessions);
if hdr.flags & SMB2_FLAGS_SIGNED != 0 {
let sess = sess_arc.read().await;
if matches!(sess.identity, Identity::Anonymous) {
return Ok(());
}
let key = sess.signing_key;
drop(sess);
let algo = *conn.signing_algo.read().await;
if let Err(e) = crate::proto::crypto::verify(frame, &key, algo) {
warn!(error = %e, "request signature verification failed");
return Err(ntstatus::STATUS_ACCESS_DENIED);
}
} else if hdr.command != Command::SessionSetup {
let sess = sess_arc.read().await;
let need = sess.signing_required && !matches!(sess.identity, Identity::Anonymous);
drop(sess);
if need {
warn!(?hdr.command, "missing required signature on request");
return Err(ntstatus::STATUS_ACCESS_DENIED);
}
}
Ok(())
}
/// Build the final on-the-wire bytes: header + body, with signing applied
/// when the session has a key.
async fn build_response_bytes(
conn: &Arc<Connection>,
req_hdr: &Smb2Header,
handler_resp: HandlerResponse,
) -> Vec<u8> {
let mut hdr = *req_hdr;
hdr.flags |= SMB2_FLAGS_SERVER_TO_REDIR;
hdr.flags &= !SMB2_FLAGS_ASYNC_COMMAND;
hdr.next_command = 0;
hdr.channel_sequence_status = handler_resp.status;
hdr.tail = HeaderTail::sync(
handler_resp
.override_tree_id
.unwrap_or_else(|| req_hdr.tree_id().unwrap_or(0)),
);
if let Some(sid) = handler_resp.override_session_id {
hdr.session_id = sid;
}
hdr.signature = [0u8; 16];
// Grant at least 1 credit so clients (e.g. Samba smbclient) can proceed.
hdr.credit_request_response = hdr.credit_request_response.max(1);
let request_was_signed = req_hdr.flags & SMB2_FLAGS_SIGNED != 0;
// MS-SMB2 §3.3.5.5.3 step 12: SessionSetup SUCCESS must be signed for
// non-anon/non-guest sessions even though the request cannot be signed yet.
let is_session_setup_success =
req_hdr.command == Command::SessionSetup && handler_resp.status == ntstatus::STATUS_SUCCESS;
let mut should_sign = false;
let mut key = [0u8; 16];
let algo = *conn.signing_algo.read().await;
if !handler_resp.skip_signing
&& hdr.session_id != 0
&& (request_was_signed || is_session_setup_success)
{
let sessions = conn.sessions.read().await;
if let Some(sess_arc) = sessions.get(&hdr.session_id) {
let sess = sess_arc.read().await;
let is_anon = matches!(sess.identity, Identity::Anonymous);
let is_guest_response = is_session_setup_success
&& handler_resp.body.len() >= 4
&& (handler_resp.body[2] & 0x01) != 0;
if !is_anon && !is_guest_response && sess.signing_key != [0u8; 16] {
key = sess.signing_key;
should_sign = true;
}
}
}
if should_sign {
hdr.flags |= SMB2_FLAGS_SIGNED;
} else {
hdr.flags &= !SMB2_FLAGS_SIGNED;
}
let mut out = Vec::with_capacity(SMB2_HEADER_LEN + handler_resp.body.len());
if let Err(e) = hdr.write(&mut out) {
error!(error = %e, "failed to encode response header");
return Vec::new();
}
out.extend_from_slice(&handler_resp.body);
if should_sign && let Err(e) = sign(&mut out, &key, algo) {
error!(error = %e, "failed to sign response");
}
out
}
/// Detect and answer an SMB1 multi-protocol NEGOTIATE_REQUEST.
///
/// SMB1 frame layout for the request we accept:
/// * `[0..4]` — magic `0xFF 'S' 'M' 'B'`
/// * `[4]` — command (0x72 = SMB_COM_NEGOTIATE)
/// * `[5..32]` — rest of SMB1 header (status, flags, pid, tid, mid …)
/// * `[32]` — `WordCount` (0 for NEGOTIATE)
/// * `[33..35]`— `ByteCount` (u16 LE)
/// * `[35..]` — dialect strings, each `0x02 <ASCII> 0x00`.
///
/// Returns `Some(reply_bytes)` only for a SMB1 NEGOTIATE that lists at least
/// one SMB2 dialect we recognise; otherwise `None` so the caller can fall
/// through to the normal SMB2 path.
async fn handle_smb1_multi_protocol(
server: &Arc<ServerState>,
conn: &Arc<Connection>,
frame: &[u8],
) -> Option<Vec<u8>> {
if frame.len() < 35 || frame[0..4] != [0xFF, b'S', b'M', b'B'] || frame[4] != 0x72 {
return None;
}
let body_start = 33; // 32-byte header + 1-byte WordCount(=0)
let byte_count = u16::from_le_bytes([frame[body_start], frame[body_start + 1]]) as usize;
let blob_start = body_start + 2;
let blob_end = (blob_start + byte_count).min(frame.len());
let blob = &frame[blob_start..blob_end];
let mut wants_wildcard = false;
let mut wants_smb202 = false;
let mut i = 0;
while i < blob.len() {
if blob[i] != 0x02 {
break;
}
i += 1;
let nul = match blob[i..].iter().position(|&b| b == 0) {
Some(p) => p,
None => break,
};
let s = std::str::from_utf8(&blob[i..i + nul]).unwrap_or("");
match s {
"SMB 2.???" => wants_wildcard = true,
"SMB 2.002" => wants_smb202 = true,
_ => {}
}
i += nul + 1;
}
let chosen = if wants_wildcard {
crate::proto::messages::Dialect::Smb2Wildcard.as_u16()
} else if wants_smb202 {
crate::proto::messages::Dialect::Smb202.as_u16()
} else {
return None;
};
debug!(
chosen = %format_args!("0x{chosen:04X}"),
"SMB1 multi-protocol negotiate"
);
// Synthesize a request header so build_response_bytes can mint the
// SERVER_TO_REDIR response. Per MS-SMB2 §3.3.5.3.1 the response uses
// message_id=0, tree_id=0xFFFF, session_id=0.
let req_hdr = Smb2Header {
command: Command::Negotiate,
message_id: 0,
session_id: 0,
tail: HeaderTail::Sync {
reserved: 0,
tree_id: 0xFFFF,
},
..Default::default()
};
let resp = handlers::negotiate::multi_protocol_response(server, conn, chosen).await;
Some(build_response_bytes(conn, &req_hdr, resp).await)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::conn::state::{Session, TreeConnect};
use crate::proto::messages::create::{CreateResponse, FileId};
use crate::proto::header::SMB2_MAGIC;
use crate::Share;
use uuid::Uuid;
fn test_conn() -> Arc<Connection> {
Arc::new(Connection::new(Uuid::nil(), 1024 * 1024, 1024 * 1024))
}
fn negotiated_preauth() -> PreauthIntegrity {
let mut preauth = PreauthIntegrity::new();
preauth.update(b"negotiate request");
preauth.update(b"negotiate response");
preauth
}
#[tokio::test]
async fn new_session_setup_preauth_starts_from_negotiate_base() {
let conn = test_conn();
let base = negotiated_preauth();
*conn.preauth.lock().expect("preauth lock") = base.clone();
let mut first_session = take_session_preauth(&conn, 0).await;
first_session.update(b"session one request");
first_session.update(b"session one response");
conn.session_preauth.write().await.insert(1, first_session);
let mut second_session = take_session_preauth(&conn, 0).await;
second_session.update(b"session two request");
let mut expected = base.clone();
expected.update(b"session two request");
let mut polluted = base;
polluted.update(b"session one request");
polluted.update(b"session one response");
polluted.update(b"session two request");
assert_eq!(second_session.snapshot(), expected.snapshot());
assert_ne!(second_session.snapshot(), polluted.snapshot());
}
#[tokio::test]
async fn followup_session_setup_consumes_stored_session_preauth() {
let conn = test_conn();
let mut stored = negotiated_preauth();
stored.update(b"session setup request");
stored.update(b"session setup more-processing response");
let expected = stored.snapshot();
conn.session_preauth.write().await.insert(7, stored);
let got = take_session_preauth(&conn, 7).await;
assert_eq!(got.snapshot(), expected);
assert!(!conn.session_preauth.read().await.contains_key(&7));
}
// ── Compound request response stitching ─────────────────────────────────
/// Build a minimal response frame of `body_len` bytes (64 header + body).
fn make_response(body_len: usize) -> Vec<u8> {
let mut buf = vec![0u8; SMB2_HEADER_LEN + body_len];
buf[..4].copy_from_slice(&SMB2_MAGIC);
buf
}
#[tokio::test]
async fn test_stitch_responses_single() {
let conn = test_conn();
let r1 = make_response(100);
let responses = vec![r1.clone()];
let stitched = stitch_responses(&conn, responses).await;
// Single response: no padding, NextCommand=0
assert_eq!(stitched.len(), 100 + SMB2_HEADER_LEN);
assert_eq!(&stitched[..4], &SMB2_MAGIC);
// NextCommand at offset 20 should be 0
let next = u32::from_le_bytes(stitched[20..24].try_into().unwrap());
assert_eq!(next, 0);
}
#[tokio::test]
async fn test_stitch_responses_aligned() {
let conn = test_conn();
// Two responses: 100 bytes + 80 bytes body
let r1 = make_response(100);
let r2 = make_response(80);
let responses = vec![r1, r2];
let stitched = stitch_responses(&conn, responses).await;
// First response: 64+100 = 164, aligned to 168 (next multiple of 8)
let total1 = SMB2_HEADER_LEN + 100;
let aligned1 = (total1 + 7) & !7;
assert_eq!(aligned1, 168);
// First response's NextCommand should point to second
let next1 = u32::from_le_bytes(stitched[20..24].try_into().unwrap());
assert_eq!(next1 as usize, aligned1);
// Second response body starts at offset `aligned1`
let next2 = u32::from_le_bytes(stitched[aligned1 + 20..aligned1 + 24].try_into().unwrap());
assert_eq!(next2, 0); // Last response, no NextCommand
// Total length = aligned1 + 64(body header) + 80(body)
assert_eq!(stitched.len(), aligned1 + SMB2_HEADER_LEN + 80);
}
#[tokio::test]
async fn test_stitch_responses_three_responses() {
let conn = test_conn();
let r1 = make_response(100);
let r2 = make_response(80);
let r3 = make_response(60);
let responses = vec![r1, r2, r3];
let stitched = stitch_responses(&conn, responses).await;
// First header at 0
let total1 = SMB2_HEADER_LEN + 100;
let aligned1 = (total1 + 7) & !7;
let next1 = u32::from_le_bytes(stitched[20..24].try_into().unwrap());
assert_eq!(next1 as usize, aligned1);
// Second header at aligned1
let total2 = SMB2_HEADER_LEN + 80;
let aligned2 = aligned1 + ((total2 + 7) & !7);
let next2 = u32::from_le_bytes(stitched[aligned1 + 20..aligned1 + 24].try_into().unwrap());
assert_eq!(next2 as usize, aligned2 - aligned1);
// Third header at aligned1 + aligned2_inner
let next3 = u32::from_le_bytes(stitched[aligned2 + 20..aligned2 + 24].try_into().unwrap());
assert_eq!(next3, 0); // Last
assert_eq!(stitched.len(), aligned2 + SMB2_HEADER_LEN + 60);
}
#[tokio::test]
async fn test_stitch_responses_empty_returns_empty() {
let conn = test_conn();
let stitched = stitch_responses(&conn, vec![]).await;
assert!(stitched.is_empty());
}
// ── FileId capture from CREATE response ─────────────────────────────────
/// Build CREATE response bytes with known FileId.
fn make_create_response(persistent: u64, volatile: u64) -> Vec<u8> {
let mut resp = Vec::new();
let cr = CreateResponse {
structure_size: 89,
oplock_level: 0,
flags: 0,
create_action: 1,
creation_time: 0,
last_access_time: 0,
last_write_time: 0,
change_time: 0,
allocation_size: 100,
end_of_file: 100,
file_attributes: 0,
reserved2: 0,
file_id: FileId::new(persistent, volatile),
create_contexts_offset: 0,
create_contexts_length: 0,
create_contexts: vec![],
};
cr.write_to(&mut resp).unwrap();
resp
}
#[test]
fn test_capture_create_file_id_found() {
// Build a full response: 64-byte header + CREATE body
let body = make_create_response(0xAAAABBBBCCCCDDDD, 0x1111222233334444);
let mut response = vec![0u8; SMB2_HEADER_LEN];
response[..4].copy_from_slice(&SMB2_MAGIC);
response[12..14].copy_from_slice(&(Command::Create as u16).to_le_bytes());
response.extend_from_slice(&body);
let file_id = capture_create_file_id(&response);
// FileId::new(0xAAAABBBBCCCCDDDD, 0x1111222233334444)
// LE: persistent=[0xDD,0xDD,0xCC,0xCC,0xBB,0xBB,0xAA,0xAA]
// volatile=[0x44,0x44,0x33,0x33,0x22,0x22,0x11,0x11]
assert_eq!(file_id, Some([0xDD, 0xDD, 0xCC, 0xCC, 0xBB, 0xBB, 0xAA, 0xAA,
0x44, 0x44, 0x33, 0x33, 0x22, 0x22, 0x11, 0x11]));
}
#[test]
fn test_capture_create_file_id_not_create_command() {
let response = vec![0u8; SMB2_HEADER_LEN];
// Response has no file_id because it's not a CREATE
assert!(capture_create_file_id(&response).is_none());
}
#[test]
fn test_capture_create_file_id_too_short() {
let mut response = vec![0u8; SMB2_HEADER_LEN + 10];
response[12..14].copy_from_slice(&(Command::Create as u16).to_le_bytes());
assert!(capture_create_file_id(&response).is_none());
}
// ── Related context inheritance ─────────────────────────────────────────
fn build_frame(command: Command, flags: u32, next: u32, session: u64, tree: u32, file_id_bytes: &[u8; 16]) -> Vec<u8> {
let mut f = vec![0u8; SMB2_HEADER_LEN];
f[..4].copy_from_slice(&SMB2_MAGIC);
f[4..6].copy_from_slice(&64u16.to_le_bytes()); // structure_size must be 64
f[12..14].copy_from_slice(&(command as u16).to_le_bytes());
f[16..20].copy_from_slice(&flags.to_le_bytes());
f[20..24].copy_from_slice(&next.to_le_bytes());
f[24..32].copy_from_slice(&0u64.to_le_bytes()); // message_id
// tail: Sync { reserved: 0, tree_id }
let tail_data = &tree.to_le_bytes();
f[36..40].copy_from_slice(tail_data);
f[40..48].copy_from_slice(&session.to_le_bytes());
// Append file_id at body_offset for CLOSE (offset 8)
let mut body = vec![0u8; 24]; // CLOSE body size
body[8..24].copy_from_slice(file_id_bytes);
f.extend_from_slice(&body);
f
}
#[tokio::test]
async fn test_inherit_related_context_close() {
// Use MAX values so inherit_related_context overrides them
let mut frame = build_frame(Command::Close, SMB2_FLAGS_RELATED_OPERATIONS, 0, u64::MAX, u32::MAX, &[0xFF; 16]);
let mut hdr = Smb2Header::parse(&frame).unwrap().0;
let prev_file_id = Some([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10]);
inherit_related_context(&mut frame, &mut hdr, 7, 3, prev_file_id);
// After inheritance: session_id should be 7
let session = u64::from_le_bytes(frame[40..48].try_into().unwrap());
assert_eq!(session, 7, "session_id should be inherited from 0xFFFF to 7");
// tree_id should be 3
let tree = u32::from_le_bytes(frame[36..40].try_into().unwrap());
assert_eq!(tree, 3, "tree_id should be inherited from 0xFFFFFFFF to 3");
// FileId should be inherited
let file_id_field = &frame[SMB2_HEADER_LEN + 8..SMB2_HEADER_LEN + 24];
assert_eq!(file_id_field, &prev_file_id.unwrap(), "FileId should be inherited from 0xFFFF to prev_file_id");
}
// ── Full compound chain dispatch test ───────────────────────────────────
#[tokio::test]
async fn test_compound_chain_create_then_close() {
use crate::server::SmbServer;
use crate::tests::memfs::MemFsBackend;
use crate::Access;
let server = SmbServer::builder()
.listen("127.0.0.1:0".parse().unwrap())
.user("alice", "password")
.share(
Share::new("home", MemFsBackend::new().with_file("test.txt", b"hello world"))
.user("alice", Access::ReadWrite),
)
.build()
.expect("build");
let state = server.state();
let conn = Arc::new(Connection::new(
state.config.server_guid,
state.config.max_read_size,
state.config.max_write_size,
));
state.active_connections.register(&conn).await;
let identity = Identity::User {
user: "alice".to_string(),
domain: String::new(),
};
let session = Session::new(1, identity, [0; 16], [0; 16], None, false, false, None, None);
let session = Arc::new(tokio::sync::RwLock::new(session));
let share = state.find_share("home").await.expect("share");
let tree = Arc::new(tokio::sync::RwLock::new(TreeConnect::new(
1,
share,
Access::ReadWrite,
)));
{
let sess = session.read().await;
sess.trees.write().await.insert(1, tree);
}
conn.sessions.write().await.insert(1, session);
// Build a compound CREATE+CLOSE frame
// Sub-frame 1: CREATE "test.txt"
let name_utf16: Vec<u8> = "test.txt".encode_utf16().flat_map(|c| c.to_le_bytes()).collect();
let create_body_len = 56 + name_utf16.len();
let create_frame_len = SMB2_HEADER_LEN + create_body_len;
let mut frame = vec![0u8; create_frame_len + SMB2_HEADER_LEN + 24]; // CREATE + CLOSE
// ── CREATE sub-frame ──
// Header
frame[..4].copy_from_slice(&SMB2_MAGIC);
frame[4..6].copy_from_slice(&64u16.to_le_bytes()); // structure_size
frame[12..14].copy_from_slice(&(Command::Create as u16).to_le_bytes());
frame[16..20].copy_from_slice(&SMB2_FLAGS_RELATED_OPERATIONS.to_le_bytes());
frame[20..24].copy_from_slice(&(create_frame_len as u32).to_le_bytes()); // NextCommand
frame[24..32].copy_from_slice(&0u64.to_le_bytes()); // message_id
frame[36..40].copy_from_slice(&1u32.to_le_bytes()); // tree_id
frame[40..48].copy_from_slice(&1u64.to_le_bytes()); // session_id
// CREATE body
let body_start = SMB2_HEADER_LEN;
frame[body_start..body_start+2].copy_from_slice(&57u16.to_le_bytes()); // structure_size
frame[body_start+3] = 0; // security_flags
frame[body_start+4..body_start+8].copy_from_slice(&2u32.to_le_bytes()); // impersonation_level
// desired_access = 0x00120089 (READ_CONTROL|SYNCHRONIZE|READ_ATTR|READ_DATA)
frame[body_start+24..body_start+28].copy_from_slice(&0x00120089u32.to_le_bytes());
frame[body_start+32..body_start+36].copy_from_slice(&7u32.to_le_bytes()); // share_access
frame[body_start+36..body_start+40].copy_from_slice(&1u32.to_le_bytes()); // create_disposition: OPEN
frame[body_start+40..body_start+44].copy_from_slice(&0x00000040u32.to_le_bytes()); // create_options: FILE_NON_DIRECTORY_FILE
// name_offset relative to header start: 64 + 56 = 120
frame[body_start+44..body_start+46].copy_from_slice(&120u16.to_le_bytes());
frame[body_start+46..body_start+48].copy_from_slice(&(name_utf16.len() as u16).to_le_bytes());
frame[body_start+56..body_start+56+name_utf16.len()].copy_from_slice(&name_utf16);
// ── CLOSE sub-frame ──
let close_start = create_frame_len;
frame[close_start..close_start+4].copy_from_slice(&SMB2_MAGIC);
frame[close_start+4..close_start+6].copy_from_slice(&64u16.to_le_bytes()); // structure_size
frame[close_start+12..close_start+14].copy_from_slice(&(Command::Close as u16).to_le_bytes());
frame[close_start+16..close_start+20].copy_from_slice(&(SMB2_FLAGS_RELATED_OPERATIONS).to_le_bytes());
frame[close_start+20..close_start+24].copy_from_slice(&0u32.to_le_bytes()); // NextCommand = last
frame[close_start+24..close_start+32].copy_from_slice(&0u64.to_le_bytes()); // message_id
frame[close_start+36..close_start+40].copy_from_slice(&0u32.to_le_bytes()); // tree_id = 0 (inherited)
frame[close_start+40..close_start+48].copy_from_slice(&0u64.to_le_bytes()); // session_id = 0 (inherited)
// CLOSE body: FileId = 0xFFFF...FFFF (auto-inherit from CREATE)
let close_body_start = close_start + SMB2_HEADER_LEN;
frame[close_body_start..close_body_start+2].copy_from_slice(&24u16.to_le_bytes()); // structure_size
fill_all_ones(&mut frame[close_body_start+8..close_body_start+24]); // FileId = FFFF...
let result = dispatch_frame(&state, &conn, &frame).await;
assert!(result.is_some(), "dispatch_frame returned None");
let response = result.unwrap();
// Response should contain two sub-responses
assert!(response.len() >= SMB2_HEADER_LEN * 2 + 64,
"response too short for compound chain: {} bytes", response.len());
// First response should be CREATE
let next1 = u32::from_le_bytes(response[20..24].try_into().unwrap());
assert!(next1 > 0, "first response should have NextCommand for compound chain");
// First response header should have SERVER_TO_REDIR
let flags1 = u32::from_le_bytes(response[16..20].try_into().unwrap());
assert_ne!(flags1 & SMB2_FLAGS_SERVER_TO_REDIR, 0);
// Second response should be CLOSE (at offset next1)
let hdr2_start = next1 as usize;
let flags2 = u32::from_le_bytes(response[hdr2_start+16..hdr2_start+20].try_into().unwrap());
assert_ne!(flags2 & SMB2_FLAGS_SERVER_TO_REDIR, 0);
// Second response should be last (NextCommand = 0)
let next2 = u32::from_le_bytes(response[hdr2_start+20..hdr2_start+24].try_into().unwrap());
assert_eq!(next2, 0, "second (last) response should have NextCommand = 0");
}
fn fill_all_ones(buf: &mut [u8]) {
for b in buf.iter_mut() {
*b = 0xFF;
}
}
}