SMB Server Phase 2: VFS backend build fix + integration test
- Add VfsFile: Send supertrait for Mutex compatibility - Fix SmbServerCommand: struct → Subcommand enum with Start variant - Fix tracing_subscriber::init() → try_init() to avoid panic when logger already initialized - Fix CLI subcommand name: smb-server → smb-start (flatten naming) - Add #[command(name = "smb-start")] for CLI disambiguation - Fix unused variable warnings (smb_fs.rs, smb_server_backend.rs) - Remove unused VfsFile imports (webdav.rs, scp_handler.rs) - Integration test: Docker smbclient verified (list, upload, read)
This commit is contained in:
507
vendor/smb2/src/transport/mock.rs
vendored
Normal file
507
vendor/smb2/src/transport/mock.rs
vendored
Normal file
@@ -0,0 +1,507 @@
|
||||
//! Mock transport for testing.
|
||||
//!
|
||||
//! Provides a [`MockTransport`] that queues canned responses and records
|
||||
//! sent messages, enabling test-driven development of higher layers
|
||||
//! without needing a real SMB server.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Mutex;
|
||||
|
||||
use tokio::sync::Notify;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::transport::{TransportReceive, TransportSend};
|
||||
|
||||
/// A mock transport that queues responses and records sent messages.
|
||||
///
|
||||
/// Use this in tests to simulate server conversations without a real
|
||||
/// network connection. Responses are returned in FIFO order.
|
||||
///
|
||||
/// `receive()` awaits on an internal `Notify` when the queue is empty,
|
||||
/// so the background receiver task doesn't exit prematurely between
|
||||
/// `queue_response` calls. Explicit disconnect is triggered by calling
|
||||
/// [`Self::close`].
|
||||
pub struct MockTransport {
|
||||
/// Responses to return on `receive()`, in order.
|
||||
responses: Mutex<VecDeque<Vec<u8>>>,
|
||||
/// Messages that were sent, for assertions.
|
||||
sent: Mutex<Vec<Vec<u8>>>,
|
||||
/// How many times `receive()` was called successfully (returning Ok).
|
||||
receive_count: Mutex<usize>,
|
||||
/// Wakes receivers when a response is queued or `close()` is called.
|
||||
notify: Notify,
|
||||
/// Set by `close()` to signal end-of-stream.
|
||||
closed: AtomicBool,
|
||||
/// When `true`, `receive()` rewrites each response sub-frame's
|
||||
/// `MessageId` to match the `MessageId` of the next pending sent request
|
||||
/// (and consumes it). See [`Self::enable_auto_rewrite_msg_id`].
|
||||
auto_rewrite: AtomicBool,
|
||||
/// FIFO of `MessageId`s observed in `send()` that haven't yet been
|
||||
/// consumed by a `receive()` rewrite. Only used when `auto_rewrite`
|
||||
/// is on.
|
||||
pending_sent_msg_ids: Mutex<VecDeque<u64>>,
|
||||
/// Signaled whenever a new send is recorded or a close happens — used
|
||||
/// by `receive()` in auto-rewrite mode to wait for a sent msg_id to
|
||||
/// pair with a queued response.
|
||||
send_notify: Notify,
|
||||
}
|
||||
|
||||
impl MockTransport {
|
||||
/// Create a new mock with no queued responses.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
responses: Mutex::new(VecDeque::new()),
|
||||
sent: Mutex::new(Vec::new()),
|
||||
receive_count: Mutex::new(0),
|
||||
notify: Notify::new(),
|
||||
closed: AtomicBool::new(false),
|
||||
auto_rewrite: AtomicBool::new(false),
|
||||
pending_sent_msg_ids: Mutex::new(VecDeque::new()),
|
||||
send_notify: Notify::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable msg_id rewriting: when `true`, `receive()` rewrites each
|
||||
/// response sub-frame's `MessageId` in-place to match the `MessageId`
|
||||
/// of the next request recorded by `send()` (FIFO pairing).
|
||||
///
|
||||
/// Without this, canned response builders hardcode `MessageId(0)` and
|
||||
/// won't match the caller's allocated msg_ids — the receiver task
|
||||
/// drops them as orphans and every caller hangs. This mode is the
|
||||
/// test-fixture replacement for the pre-Phase-3 orphan-filter-off
|
||||
/// path. Compound responses (multiple sub-frames chained via
|
||||
/// `NextCommand`) each consume one sent msg_id in order.
|
||||
///
|
||||
/// The receive side blocks until both a queued response and a sent
|
||||
/// msg_id are available, so tests can queue responses before or
|
||||
/// after the caller sends.
|
||||
pub fn enable_auto_rewrite_msg_id(&self) {
|
||||
self.auto_rewrite.store(true, Ordering::Release);
|
||||
}
|
||||
|
||||
/// Queue a response to be returned by the next `receive()` call.
|
||||
pub fn queue_response(&self, data: Vec<u8>) {
|
||||
self.responses.lock().unwrap().push_back(data);
|
||||
self.notify.notify_one();
|
||||
}
|
||||
|
||||
/// Queue multiple responses to be returned in order.
|
||||
pub fn queue_responses(&self, responses: Vec<Vec<u8>>) {
|
||||
let mut guard = self.responses.lock().unwrap();
|
||||
let count = responses.len();
|
||||
for r in responses {
|
||||
guard.push_back(r);
|
||||
}
|
||||
drop(guard);
|
||||
for _ in 0..count {
|
||||
self.notify.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
/// Signal end-of-stream: after all queued responses are drained,
|
||||
/// `receive()` returns `Err(Error::Disconnected)`.
|
||||
pub fn close(&self) {
|
||||
self.closed.store(true, Ordering::Release);
|
||||
// Use `notify_one` (stores a permit for the next `notified().await`)
|
||||
// in addition to `notify_waiters` (wakes currently-parked waiters).
|
||||
// `notify_waiters` alone loses the signal if `close()` fires
|
||||
// between `receive()`'s `closed.load()` check and its
|
||||
// `notified().await` — no waiter is parked yet, so nothing gets
|
||||
// woken. The stored permit from `notify_one` covers that gap.
|
||||
self.notify.notify_one();
|
||||
self.notify.notify_waiters();
|
||||
// Same treatment for the send-notification used by auto-rewrite:
|
||||
// close should wake a receive that's blocked waiting for a paired
|
||||
// sent msg_id so it observes `closed` and bails out.
|
||||
self.send_notify.notify_one();
|
||||
self.send_notify.notify_waiters();
|
||||
}
|
||||
|
||||
/// Get all messages that were sent.
|
||||
pub fn sent_messages(&self) -> Vec<Vec<u8>> {
|
||||
self.sent.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
/// Get the nth sent message, or `None` if out of bounds.
|
||||
pub fn sent_message(&self, n: usize) -> Option<Vec<u8>> {
|
||||
self.sent.lock().unwrap().get(n).cloned()
|
||||
}
|
||||
|
||||
/// How many messages have been sent.
|
||||
pub fn sent_count(&self) -> usize {
|
||||
self.sent.lock().unwrap().len()
|
||||
}
|
||||
|
||||
/// Clear all recorded sent messages.
|
||||
pub fn clear_sent(&self) {
|
||||
self.sent.lock().unwrap().clear();
|
||||
}
|
||||
|
||||
/// How many times `receive()` was called successfully (returned Ok).
|
||||
pub fn received_count(&self) -> usize {
|
||||
*self.receive_count.lock().unwrap()
|
||||
}
|
||||
|
||||
/// How many responses are still queued and unread.
|
||||
///
|
||||
/// Useful in tests that want to assert the code-under-test consumed
|
||||
/// every response it was expected to, without leaking any to a
|
||||
/// later test or leaving stale state that could mask a bug.
|
||||
pub fn pending_responses(&self) -> usize {
|
||||
self.responses.lock().unwrap().len()
|
||||
}
|
||||
|
||||
/// Assert that every queued response has been consumed.
|
||||
///
|
||||
/// Panics with a descriptive message if any responses remain in the
|
||||
/// queue. Use at the end of a test to catch the "caller forgot to
|
||||
/// receive" pattern that produces response-pipe pollution in
|
||||
/// real usage.
|
||||
#[track_caller]
|
||||
pub fn assert_fully_consumed(&self) {
|
||||
let remaining = self.pending_responses();
|
||||
assert_eq!(
|
||||
remaining, 0,
|
||||
"MockTransport has {} queued response(s) the code-under-test never read. \
|
||||
This usually means a caller sent a request but never received its response, \
|
||||
which in real usage leaves an orphan on the wire and corrupts the next op.",
|
||||
remaining
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MockTransport {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportSend for MockTransport {
|
||||
async fn send(&self, data: &[u8]) -> Result<()> {
|
||||
// In auto-rewrite mode, capture the MessageId of each sub-frame
|
||||
// so `receive()` can rewrite a queued response to match.
|
||||
if self.auto_rewrite.load(Ordering::Acquire) {
|
||||
for msg_id in extract_msg_ids(data) {
|
||||
self.pending_sent_msg_ids.lock().unwrap().push_back(msg_id);
|
||||
self.send_notify.notify_one();
|
||||
}
|
||||
}
|
||||
self.sent.lock().unwrap().push(data.to_vec());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportReceive for MockTransport {
|
||||
async fn receive(&self) -> Result<Vec<u8>> {
|
||||
loop {
|
||||
let auto = self.auto_rewrite.load(Ordering::Acquire);
|
||||
// Wait for a queued response first (auto mode and plain mode
|
||||
// both need one to exist).
|
||||
let has_response = !self.responses.lock().unwrap().is_empty();
|
||||
if !has_response {
|
||||
if self.closed.load(Ordering::Acquire) {
|
||||
return Err(Error::Disconnected);
|
||||
}
|
||||
self.notify.notified().await;
|
||||
continue;
|
||||
}
|
||||
|
||||
if auto {
|
||||
// We have a response; peek its sub-frame count and wait
|
||||
// for at least that many sent msg_ids to be queued
|
||||
// (one consumed per sub-frame, even ones that already
|
||||
// have non-zero msg_ids, so pairing stays 1:1).
|
||||
let needed = {
|
||||
let guard = self.responses.lock().unwrap();
|
||||
match guard.front() {
|
||||
Some(frame) => count_sub_frames(frame),
|
||||
None => continue,
|
||||
}
|
||||
};
|
||||
if needed > 0 {
|
||||
loop {
|
||||
let have = self.pending_sent_msg_ids.lock().unwrap().len();
|
||||
if have >= needed {
|
||||
break;
|
||||
}
|
||||
if self.closed.load(Ordering::Acquire) {
|
||||
return Err(Error::Disconnected);
|
||||
}
|
||||
self.send_notify.notified().await;
|
||||
}
|
||||
}
|
||||
// Consume one response and `needed` sent msg_ids,
|
||||
// rewriting each sub-frame's zero msg_id to match the
|
||||
// corresponding sent msg_id.
|
||||
let mut data = match self.responses.lock().unwrap().pop_front() {
|
||||
Some(d) => d,
|
||||
None => continue,
|
||||
};
|
||||
let mut ids = self.pending_sent_msg_ids.lock().unwrap();
|
||||
rewrite_msg_ids(&mut data, &mut ids);
|
||||
drop(ids);
|
||||
*self.receive_count.lock().unwrap() += 1;
|
||||
return Ok(data);
|
||||
}
|
||||
|
||||
// Plain mode: just pop and return.
|
||||
let data = match self.responses.lock().unwrap().pop_front() {
|
||||
Some(d) => d,
|
||||
None => continue,
|
||||
};
|
||||
*self.receive_count.lock().unwrap() += 1;
|
||||
return Ok(data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract `MessageId`s from a packed SMB2 request frame (possibly compound).
|
||||
/// Returns one msg_id per sub-frame, following `NextCommand` offsets.
|
||||
/// Returns an empty Vec if the data isn't a recognizable SMB2 frame —
|
||||
/// e.g. when `send()` is used with arbitrary bytes in transport-level tests.
|
||||
fn extract_msg_ids(data: &[u8]) -> Vec<u64> {
|
||||
const HEADER_MIN: usize = 64;
|
||||
if data.len() < HEADER_MIN {
|
||||
return Vec::new();
|
||||
}
|
||||
// Not an SMB2 header — skip (non-SMB2 tests call send with arbitrary bytes).
|
||||
if &data[0..4] != b"\xFESMB" {
|
||||
return Vec::new();
|
||||
}
|
||||
let mut ids = Vec::new();
|
||||
let mut offset = 0usize;
|
||||
loop {
|
||||
if offset + HEADER_MIN > data.len() {
|
||||
break;
|
||||
}
|
||||
let msg_id =
|
||||
u64::from_le_bytes(data[offset + 24..offset + 32].try_into().unwrap_or([0; 8]));
|
||||
ids.push(msg_id);
|
||||
let next = u32::from_le_bytes(data[offset + 20..offset + 24].try_into().unwrap_or([0; 4]));
|
||||
if next == 0 {
|
||||
break;
|
||||
}
|
||||
offset += next as usize;
|
||||
}
|
||||
ids
|
||||
}
|
||||
|
||||
/// Count sub-frames in a packed SMB2 response frame by walking
|
||||
/// `NextCommand` offsets. Returns 0 for non-SMB2 frames, otherwise the
|
||||
/// total sub-frame count. `rewrite_msg_ids` consumes one sent msg_id
|
||||
/// per sub-frame (even those with already-set msg_ids) to keep
|
||||
/// send→receive pairing strictly 1:1 and avoid queue drift in tests
|
||||
/// that hardcode some but not all msg_ids.
|
||||
fn count_sub_frames(data: &[u8]) -> usize {
|
||||
const HEADER_MIN: usize = 64;
|
||||
if data.len() < HEADER_MIN || &data[0..4] != b"\xFESMB" {
|
||||
return 0;
|
||||
}
|
||||
let mut count = 0usize;
|
||||
let mut offset = 0usize;
|
||||
loop {
|
||||
if offset + HEADER_MIN > data.len() {
|
||||
break;
|
||||
}
|
||||
count += 1;
|
||||
let next = u32::from_le_bytes(data[offset + 20..offset + 24].try_into().unwrap_or([0; 4]));
|
||||
if next == 0 {
|
||||
break;
|
||||
}
|
||||
offset += next as usize;
|
||||
}
|
||||
count
|
||||
}
|
||||
|
||||
/// Rewrite each sub-frame's `MessageId` in-place, consuming one id from
|
||||
/// `ids` per sub-frame in FIFO order. Sub-frames whose msg_id is
|
||||
/// already non-zero keep their hardcoded id (so tests exercising out-of-
|
||||
/// order routing still work) but STILL consume one id from the queue
|
||||
/// to keep send→receive pairing 1:1.
|
||||
fn rewrite_msg_ids(data: &mut [u8], ids: &mut VecDeque<u64>) {
|
||||
const HEADER_MIN: usize = 64;
|
||||
if data.len() < HEADER_MIN || &data[0..4] != b"\xFESMB" {
|
||||
return;
|
||||
}
|
||||
let mut offset = 0usize;
|
||||
loop {
|
||||
if offset + HEADER_MIN > data.len() {
|
||||
break;
|
||||
}
|
||||
let existing =
|
||||
u64::from_le_bytes(data[offset + 24..offset + 32].try_into().unwrap_or([0; 8]));
|
||||
let consumed = ids.pop_front();
|
||||
if existing == 0 {
|
||||
if let Some(id) = consumed {
|
||||
data[offset + 24..offset + 32].copy_from_slice(&id.to_le_bytes());
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let next = u32::from_le_bytes(data[offset + 20..offset + 24].try_into().unwrap_or([0; 4]));
|
||||
if next == 0 {
|
||||
break;
|
||||
}
|
||||
offset += next as usize;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn queue_response_and_receive_it() {
|
||||
let mock = MockTransport::new();
|
||||
let data = vec![0x01, 0x02, 0x03];
|
||||
mock.queue_response(data.clone());
|
||||
|
||||
let received = mock.receive().await.unwrap();
|
||||
assert_eq!(received, data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn queue_multiple_responses_received_in_order() {
|
||||
let mock = MockTransport::new();
|
||||
mock.queue_responses(vec![vec![0x01], vec![0x02, 0x03], vec![0x04, 0x05, 0x06]]);
|
||||
|
||||
assert_eq!(mock.receive().await.unwrap(), vec![0x01]);
|
||||
assert_eq!(mock.receive().await.unwrap(), vec![0x02, 0x03]);
|
||||
assert_eq!(mock.receive().await.unwrap(), vec![0x04, 0x05, 0x06]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn close_causes_receive_to_return_disconnected() {
|
||||
let mock = MockTransport::new();
|
||||
mock.close();
|
||||
|
||||
let result = mock.receive().await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(
|
||||
matches!(err, Error::Disconnected),
|
||||
"expected Disconnected, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_records_message() {
|
||||
let mock = MockTransport::new();
|
||||
let msg = vec![0xAA, 0xBB, 0xCC];
|
||||
|
||||
mock.send(&msg).await.unwrap();
|
||||
|
||||
let sent = mock.sent_messages();
|
||||
assert_eq!(sent.len(), 1);
|
||||
assert_eq!(sent[0], msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sent_count_tracks_correctly() {
|
||||
let mock = MockTransport::new();
|
||||
assert_eq!(mock.sent_count(), 0);
|
||||
|
||||
mock.send(&[0x01]).await.unwrap();
|
||||
assert_eq!(mock.sent_count(), 1);
|
||||
|
||||
mock.send(&[0x02]).await.unwrap();
|
||||
assert_eq!(mock.sent_count(), 2);
|
||||
|
||||
mock.send(&[0x03]).await.unwrap();
|
||||
assert_eq!(mock.sent_count(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sent_message_returns_nth() {
|
||||
let mock = MockTransport::new();
|
||||
mock.send(&[0x0A]).await.unwrap();
|
||||
mock.send(&[0x0B]).await.unwrap();
|
||||
mock.send(&[0x0C]).await.unwrap();
|
||||
|
||||
assert_eq!(mock.sent_message(0), Some(vec![0x0A]));
|
||||
assert_eq!(mock.sent_message(1), Some(vec![0x0B]));
|
||||
assert_eq!(mock.sent_message(2), Some(vec![0x0C]));
|
||||
assert_eq!(mock.sent_message(3), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn clear_sent_removes_all_recorded_messages() {
|
||||
let mock = MockTransport::new();
|
||||
mock.send(&[0x01]).await.unwrap();
|
||||
mock.send(&[0x02]).await.unwrap();
|
||||
assert_eq!(mock.sent_count(), 2);
|
||||
|
||||
mock.clear_sent();
|
||||
assert_eq!(mock.sent_count(), 0);
|
||||
assert!(mock.sent_messages().is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn interleaved_send_and_receive() {
|
||||
let mock = MockTransport::new();
|
||||
mock.queue_responses(vec![vec![0xF1], vec![0xF2], vec![0xF3]]);
|
||||
|
||||
// Send a request, receive a response, repeat.
|
||||
mock.send(&[0x01]).await.unwrap();
|
||||
assert_eq!(mock.receive().await.unwrap(), vec![0xF1]);
|
||||
|
||||
mock.send(&[0x02]).await.unwrap();
|
||||
assert_eq!(mock.receive().await.unwrap(), vec![0xF2]);
|
||||
|
||||
mock.send(&[0x03]).await.unwrap();
|
||||
assert_eq!(mock.receive().await.unwrap(), vec![0xF3]);
|
||||
|
||||
// No more responses. Close to cause Disconnected.
|
||||
mock.close();
|
||||
assert!(mock.receive().await.is_err());
|
||||
|
||||
// All three sends recorded.
|
||||
assert_eq!(mock.sent_count(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent_send_and_receive() {
|
||||
use std::sync::Arc;
|
||||
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
mock.queue_responses(vec![vec![0xAA]; 10]);
|
||||
|
||||
let send_mock = Arc::clone(&mock);
|
||||
let send_task = tokio::spawn(async move {
|
||||
for i in 0..10u8 {
|
||||
send_mock.send(&[i]).await.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let recv_mock = Arc::clone(&mock);
|
||||
let recv_task = tokio::spawn(async move {
|
||||
let mut received = Vec::new();
|
||||
for _ in 0..10 {
|
||||
received.push(recv_mock.receive().await.unwrap());
|
||||
}
|
||||
received
|
||||
});
|
||||
|
||||
send_task.await.unwrap();
|
||||
let received = recv_task.await.unwrap();
|
||||
|
||||
assert_eq!(received.len(), 10);
|
||||
assert_eq!(mock.sent_count(), 10);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_message_can_be_sent_and_received() {
|
||||
let mock = MockTransport::new();
|
||||
mock.queue_response(vec![]);
|
||||
|
||||
mock.send(&[]).await.unwrap();
|
||||
let received = mock.receive().await.unwrap();
|
||||
|
||||
assert!(received.is_empty());
|
||||
assert_eq!(mock.sent_message(0), Some(vec![]));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user