SMB Server Phase 2: VFS backend build fix + integration test
Some checks failed
Test / build (push) Has been cancelled
Test / test (push) Has been cancelled

- Add VfsFile: Send supertrait for Mutex compatibility
- Fix SmbServerCommand: struct → Subcommand enum with Start variant
- Fix tracing_subscriber::init() → try_init() to avoid panic when
  logger already initialized
- Fix CLI subcommand name: smb-server → smb-start (flatten naming)
- Add #[command(name = "smb-start")] for CLI disambiguation
- Fix unused variable warnings (smb_fs.rs, smb_server_backend.rs)
- Remove unused VfsFile imports (webdav.rs, scp_handler.rs)
- Integration test: Docker smbclient verified (list, upload, read)
This commit is contained in:
Warren
2026-06-20 19:42:29 +08:00
parent 45d050c0b3
commit 7eb528d35f
167 changed files with 59897 additions and 12 deletions

182
vendor/smb2/src/client/CLAUDE.md vendored Normal file
View File

@@ -0,0 +1,182 @@
# Client -- high-level SMB2 API
Entry point for most users. `SmbClient` wraps `Connection` + `Session` and provides convenience methods for file operations.
## Key files
| File | Purpose |
|---|---|
| `mod.rs` | `SmbClient`, `ClientConfig`, `connect()` shorthand |
| `connection.rs` | `Connection` -- credit tracking, message sequencing, signing, encryption, `execute` / `execute_compound` |
| `session.rs` | `Session::setup()` -- NTLM auth, key derivation, signing/encryption activation |
| `tree.rs` | `Tree` -- share connection, file CRUD, compound and pipelined I/O |
| `stream.rs` | `FileDownload` / `FileUpload` / `FileWriter` (owns `Connection` + `Arc<Tree>`, `'static`) / `open_file_writer` -- streaming I/O with progress |
| `watcher.rs` | `Watcher` -- directory change notifications via CHANGE_NOTIFY long-poll |
| `pipeline.rs` | `Pipeline` / `Op` / `OpResult` -- batched concurrent operations (the core feature) |
| `shares.rs` | Share enumeration via IPC$ + srvsvc RPC |
| `dfs.rs` | DFS referral IOCTL helper, `DfsResolver` with TTL-based referral cache |
## Layering
```
SmbClient (owns Connection + Session, stores credentials for reconnect)
Connection (TCP transport, credits, message IDs, signing, encryption)
Session (NTLM auth, key derivation -- setup mutates Connection)
Tree (share-level ops, borrows &mut Connection for each call)
extra_connections (HashMap<String, ConnectionEntry> for DFS cross-server)
dfs_resolver (DfsResolver with TTL-based referral cache)
```
All `Tree` methods take `&mut Connection` as a parameter. `SmbClient` convenience methods use `connection_for_tree(tree)` to route through the correct connection (primary or DFS extra connection) based on the tree's `server` field.
## Connection and credits
- Connection starts with 1 credit (from negotiate). Requests 256 credits in every message.
- Multi-credit requests (reads/writes > 64 KB) consume `ceil(payload_size / 65536)` credits and use that many consecutive `MessageId` values. Gaps in `MessageId` sequences cause the server to drop the connection.
- Credits flow back from responses via `CreditResponse` header field. The connection tracks available credits and blocks if exhausted.
- `STATUS_PENDING` interim responses carry credits but the request isn't done -- keep waiting.
## Compound requests
`Connection::execute_compound(&[CompoundOp])` packs multiple operations into a single transport frame. Each sub-request is 8-byte aligned, linked via `NextCommand`. Subsequent related operations use `FileId::SENTINEL` (the server substitutes the real handle from the first CREATE).
- **Read compound**: CREATE + READ + CLOSE (3 ops, 1 round-trip). Default for `read_file`.
- **Write compound**: CREATE + WRITE + FLUSH + CLOSE (4 ops, 1 round-trip). Default for `write_file`.
- **Delete compound**: CREATE (DELETE_ON_CLOSE) + CLOSE (2 ops, 1 round-trip). Default for `delete_file` / `delete_directory`.
- **Rename compound**: CREATE + SET_INFO + CLOSE (3 ops, 1 round-trip). Default for `rename`.
- **Stat compound**: CREATE + QUERY_INFO (basic) + QUERY_INFO (standard) + CLOSE (4 ops, 1 round-trip). Default for `stat`.
- **Fs-info compound**: CREATE + QUERY_INFO (FileFsFullSizeInformation) + CLOSE (3 ops, 1 round-trip). Default for `fs_info`.
- If CREATE succeeds but a later op fails, the client issues a standalone CLOSE to avoid leaking the handle.
### Receiving compound responses
`execute_compound` returns `Result<Vec<Result<Frame>>>`. The outer `Result` is "did the compound hit the wire"; the inner one is per-sub-op (waiter-level: session expired, signature verify, connection dropped mid-await). Sub-op protocol status codes (`STATUS_OBJECT_NAME_NOT_FOUND` etc.) ride in the inner frame's `header.status`, not the inner `Result`. Per MS-SMB2 3.3.4.1.3 the server MAY split the compound response across multiple transport frames (Samba, QNAP, Windows Server in some cases); the receiver task routes each sub-response by `MessageId` so the per-waiter `oneshot::Receiver`s resolve independently and `execute_compound` reassembles the result vector in submission order.
Most callers use a small `all_or_first_err` helper (see `tree.rs`) that propagates the first inner `Err` as the outer `Err` (matching the pre-Phase-3 shortcircuit behavior) and hands back a `Vec<Frame>` indexable per sub-op. Tolerating partial failure (for example, CREATE ok, READ fails → issue standalone CLOSE with the create's returned `FileId`) keeps the individual inner `Result`s.
## Batch operations
`delete_files`, `rename_files`, and `stat_files` issue one `execute_compound` per file. Partial failures are independent — if 3 of 50 files fail, the other 47 still succeed. Each method returns `Vec<Result<T>>` in the same order as the input.
Decision/Why — sequential execute vs parallel: pre-Phase-3 these methods did "phase 1 send all compounds, phase 2 receive all" for wire-level pipelining. With the new API a caller can re-create that shape by spawning `tokio::spawn` tasks over `conn.clone()`s, each calling `execute_compound`. For cmdr's "delete 50 files" flows the sequential-compound cost is small (one round-trip per file) so we chose simplicity. If a workload needs the extra parallelism later, the refactor is local to each batch method.
## DFS (Distributed File System) resolution
Reactive DFS resolution with multi-target failover. When a convenience method gets `STATUS_PATH_NOT_COVERED` (mapped to `ErrorKind::DfsReferral`), it:
1. Calls `handle_dfs_redirect()` which resolves the referral via `DfsResolver` (cache or IOCTL)
2. Tries each target in the referral response (multi-target failover)
3. Creates a new connection + session for cross-server targets via `ensure_connection()`
4. Tree-connects to the target share via `ensure_tree()`
5. Updates the caller's `&mut Tree` in-place to point to the new server/share
6. Retries the operation with the resolved remaining path
**Key design decisions:**
- Convenience methods take `&mut Tree` (not `&Tree`) so DFS can update the tree in-place
- `disconnect_share` stays as `&Tree` (no redirect on teardown)
- Streaming methods (`download`, `upload`) keep `&Tree` because they return handles that borrow the tree for their lifetime
- `watch` now returns an *owned* `Watcher` (no lifetime); see the [Watcher pipelining](#watcher-pipelining) section
- Batch methods (`delete_files`, `rename_files`, `stat_files`) don't retry per-file; the caller should trigger one single-file operation first to resolve the redirect
- `dfs_enabled` flag on `ClientConfig` (default `true`) gates all DFS resolution
- Borrow checker requires inlining the connection lookup in `handle_dfs_redirect` to avoid double `&mut self` borrows
## Watcher pipelining
`Watcher` keeps **one CHANGE_NOTIFY request pre-issued on the wire at all times** after the first `next_events()` call. The wire never sits idle between responses. This closes the response→re-arm loss window that strict servers (older Samba builds, NAS firmware) drop events through.
Shape: `Watcher` owns a cloned `Connection` (cheap `Arc::clone`, all clones multiplex over the same SMB session) and a `Tree` clone — no lifetime parameter, no borrow against the caller's `Connection`. `next_events` dispatches the next request via `Connection::dispatch` (a sibling to `execute` that returns once `transport.send().await` completes, handing back the `oneshot::Receiver` for the response) *before* awaiting the previous response. So when control returns to the consumer, the server already has somewhere to put new events.
Decision/Why — eager-send `dispatch` vs `tokio::spawn(conn.execute(...))`: the spawn-based approach defers the send to when the spawned task is polled, which under tokio's `current_thread` scheduler may not happen until the spawning task yields. That left a gap where the simulator-modeled strict server dropped events. `dispatch` awaits transport.send() inline, so the eager-send guarantee is "after `.await` returns, the request is on the wire" — independent of scheduler.
Pinned by `client::watcher::loss_window_tests::watcher_does_not_lose_events_between_consecutive_requests`: a strict-server simulator drops events that arrive with no outstanding request. Pre-fix: 5/5 gap events dropped. Post-fix: 0/5 dropped.
## Pipelined I/O
For large files, `read_file_pipelined` / `write_file_pipelined` issue multiple `execute_with_credits` calls concurrently on cloned connections via `futures_util::stream::FuturesUnordered`. The sliding window stays at 32 in-flight requests, credits are checked per launch via `conn.credits()`. Chunk size is `min(512 KB, max_read_size)`. This is the core performance feature -- without it, throughput is ~10x worse.
`FileWriter` owns its `Connection` (cheap `Arc::clone`) and `Arc<Tree>` — no lifetime parameter, no borrow against the `SmbClient` that built it. It keeps an owned `FuturesUnordered<BoxedWriteFut>` field — `launch_wire_chunk` pushes a boxed `execute_with_credits` future, `drain_one` awaits `in_flight.next()`, and the public `write_chunk` / `finish` / `abort` drive that state machine.
FileWriter provides push-based pipelined writes. The consumer pushes chunks at their own pace via `write_chunk`, with the sliding window handling backpressure. Complement to FileDownload (read streaming). Build one via `open_file_writer(tree, conn, path)` (free function), `Tree::create_file_writer(&Arc<Self>, conn, path)`, or `SmbClient::create_file_writer(&self, tree, path)` — the last clones the client's primary connection internally for convenience.
## Streaming download entry points
Two symmetric ways to start a `FileDownload`:
- `SmbClient::download(&mut self, &Tree, path)` — convenience wrapper that borrows the client's internal `Connection`.
- `Tree::download(&self, &mut Connection, path)` — takes the `Connection` directly. Use this when you hold a
`conn.clone()` and want to drive concurrent downloads on the same SMB session (each clone pairs with one outstanding
download; the receiver task multiplexes responses by `MessageId`). `SmbClient::download` delegates here.
For full control, `Tree::open_file` (returns `(FileId, u64)`) plus `FileDownload::new` let callers build custom chunk
loops with non-default `chunk_size`. Most users shouldn't need this — `read_file_compound` (1 RTT) handles small files
and `Tree::download` / `SmbClient::download` handle the streaming case.
FileWriter has two terminal operations:
- `finish()` — send all buffered data, drain in-flight WRITEs, FLUSH (fsync on the server), CLOSE. Use on normal completion.
- `abort()` — discard unsent data, drain in-flight WRITEs to keep credits/message-ids in sync, skip FLUSH, best-effort CLOSE. Use on cancellation or error paths where the partial remote file is going to be deleted anyway — `abort()` saves the fsync round-trip. The caller is responsible for deleting the partial remote file.
Both consume `self` so write-after-close/abort is a compile error. `Drop` logs a debug warning if neither was called (handle leaks).
## Session setup flow
1. Send NTLM NEGOTIATE in SESSION_SETUP
2. Receive STATUS_MORE_PROCESSING_REQUIRED with challenge, update preauth hash
3. Send NTLM AUTHENTICATE in SESSION_SETUP, update preauth hash with request only
4. Receive STATUS_SUCCESS (do NOT include in preauth hash)
5. Derive signing/encryption keys via SP800-108 KDF
6. Activate signing on the connection
7. If session or share requires encryption, activate encryption (TRANSFORM_HEADER wrapping with AEAD)
## Encryption
Encryption is activated when the session flags include `ENCRYPT_DATA` or a share has `SMB2_SHAREFLAG_ENCRYPT_DATA`. When active:
- Outgoing messages are wrapped in TRANSFORM_HEADER (protocol ID 0xFD) with a monotonic nonce
- Incoming messages with 0xFD are decrypted before processing
- Signing is skipped (AEAD provides authentication)
- Compound chains are encrypted as one unit (pitfall #9)
Tree-level encryption: `connect_share()` checks the share's encrypt flag and activates encryption on the connection if needed, even if the session didn't require it.
## Reconnection
`SmbClient::reconnect()` creates a fresh TCP connection, re-negotiates, and re-authenticates using stored credentials. All previous `Tree` handles and `FileId` values are invalidated. The caller must `connect_share` again.
## Connection internals: receiver task + `oneshot` routing
`Connection::execute` / `execute_compound` is the primary API. A background receiver task (spawned per `Connection` at `from_transport`) owns the transport's read half and routes each sub-frame to a per-request `oneshot::Sender` by `MessageId`.
- `Connection` is `Clone` and holds just `Arc<Inner>`. `Inner` owns `waiters: Mutex<HashMap<MessageId, oneshot::Sender<Result<Frame>>>>`, `credits: AtomicU32`, `next_message_id: AtomicU64`, the transport send half (via `Arc<dyn TransportSend>`), the receiver task's `JoinHandle`, and crypto state. All state is behind atomics or short-critical-section `std::sync::Mutex`.
- `execute(command, body, tree_id)` allocates a `MessageId` (`AtomicU64::fetch_add(credit_charge)`), registers a `oneshot::Sender` in `waiters` atomically under the waiters lock (re-checks `disconnected` there to rule out a TOCTOU where the receiver task has already shut down and drained the map), packs the frame, signs/encrypts/compresses as needed, and writes through `TransportSend::send`. Then it awaits the local `oneshot::Receiver`. Returns `Result<Frame { header, body, raw }>`.
- `execute_compound(&[CompoundOp])` does the same per sub-op, building one compound transport frame with `NextCommand` offsets, then awaits each per-sub-op receiver sequentially. Each receiver resolves independently (the receiver task splits the server's response by `NextCommand` and routes each sub-response by its `MessageId`). The outer `Result` is "did the compound hit the wire"; the inner `Vec<Result<Frame>>` has one entry per sub-op.
- **Cancellation-by-drop is safe by construction.** If a caller's future is aborted (`tokio::spawn` + `JoinHandle::abort()` is the common path in consumers), the locally-owned `oneshot::Receiver` drops; the receiver task's `Sender::send` then fails silently when the late frame arrives; the frame is discarded. Credits are still applied in the receiver task so dropped-caller frames don't starve throughput.
- **Transport drop** fans `Err(Disconnected)` to every pending `oneshot::Sender` and sets `disconnected=true` under the waiters lock. Subsequent `execute` / `execute_compound` sees `disconnected=true` and returns `Err(Disconnected)` without inserting (no leaked waiters).
Gotcha/Why — pre-Phase-3 `send_request` / `receive_response` split API was removed in Phase 3 Stage A.3. The test-mode `set_orphan_filter_enabled(false)` escape hatch is gone too; tests that build mocks without going through `setup_connection` call `mock.enable_auto_rewrite_msg_id()` instead, which rewrites each queued response's zero-msg_id to match the next pending sent msg_id in FIFO order.
Full design in [docs/specs/connection-actor.md](../../docs/specs/connection-actor.md).
## Key decisions
- **Owned `FileWriter`: N concurrent streamed writes over one Connection without external locking**: `FileWriter` owns its `Connection` (cheap `Arc::clone`) and `Arc<Tree>` instead of borrowing `&'a mut Connection` from the `SmbClient`. Built via the free `open_file_writer(tree: Arc<Tree>, conn: Connection, path: &str)` or one of the two convenience wrappers (`Tree::create_file_writer`, `SmbClient::create_file_writer`). Multiple writers built from clones of the same `Connection` pipeline their WRITEs over one SMB session — the receiver task multiplexes responses by `MessageId`. The borrowed variant was the root cause of a production-reproducing deadlock in the cmdr SMB volume's `write_from_stream` (Phase C QNAP test, 200 × 7 MB concurrent overwrites): the consumer had to hold its session mutex for the entire upload because the writer borrowed `&'a mut Connection`. Owning the connection removes the lock from the hot path entirely.
- **`execute` / `execute_compound` take `&self`**: `Connection: Clone` supports concurrent ops per connection — clone freely across tasks, the receiver task multiplexes responses by `MessageId`. `Tree::*` methods still take `&mut Connection` because session-setup mutators (`activate_signing`, `set_session_id`) keep `&mut self`; Tree code calls both, so `&mut` at that layer is the least-churn choice.
- **Sender work stays on the caller thread, only the receiver is a task**: The send path already uses an internal Mutex on the transport write half for ordering; adding a second task just to drive sends would add latency without correctness gain. The receiver bug (orphan/dropped-caller frames corrupting the wire) only existed on the receive side, so only the receive side needed a task.
- **Compound reads as default**: One round-trip for small files. Saves 2 RTTs vs sequential CREATE/READ/CLOSE.
- **512 KB pipeline chunks**: Balances between too many small requests (overhead) and too few large ones (credit starvation). Gives ~20 chunks per 10 MB file.
- **Password stored in `SmbClient`**: Enables reconnect without re-prompting. Not encrypted in memory. Drop when done.
## Gotchas
- **Preauth hash excludes the final success response**: Only STATUS_MORE_PROCESSING_REQUIRED responses are hashed. Including the success response produces wrong keys. (MS-SMB2 3.2.5.3.1)
- **Oplock break notifications arrive with MessageId 0xFFFFFFFFFFFFFFFF**: The receiver task detects these and skips them without invoking a waiter lookup.
- **Register-waiter must be atomic with `disconnected` check**: The waiters lock covers both reading `disconnected` and inserting the `oneshot::Sender`. If the check and insert were racy, a receiver-task failure mid-send could leave an orphan `Sender` in the map that never gets routed — caller would hang on `rx.await` forever. Same goes for `fan_error_to_waiters`: it sets `disconnected=true` UNDER the same waiters lock before draining, so new sends strictly either succeed-and-get-drained or fail at the insert check.
- **Unrecoverable frame errors tear down the connection** (Phase 3 P3.4): decrypt failure, decompress failure, or a malformed sub-frame header that survives `split_compound` all cause the receiver task to call `fan_error_to_waiters(Err(Disconnected))` and exit. The alternative — log-and-continue — would leave the matching waiter hanging forever, because the msg_id isn't recoverable from an unparseable frame. The connection is also out of sync after one bad frame, so reconnect is the right move anyway. Counted via `MetricsSnapshot::{decrypt_failures, decompress_failures, malformed_frames}`.
- **STATUS_PENDING loop**: CHANGE_NOTIFY and other long-poll operations get STATUS_PENDING first. The receiver task keeps the waiter registered on PENDING and does NOT forward the interim response. Credits from PENDING are still applied so the caller's `conn.credits()` reflects them. Counted via `MetricsSnapshot::status_pending_loops`.
- **Signing and encryption are mutually exclusive on the wire**: When encrypting, zero the signature field (AEAD provides integrity). On receive, skip signature verification if decryption succeeded.
- **Compound encryption wraps the entire chain**: One TRANSFORM_HEADER for all sub-requests concatenated, not per sub-request.
- **Share-level encryption**: If a share has `SMB2_SHAREFLAG_ENCRYPT_DATA`, encryption is activated even if the session didn't require it.
- **FileDownload/FileUpload can leak handles on drop**: Rust has no async drop. If not consumed fully, the file handle leaks. The types log a warning.
- **FileWriter can leak handles on drop**: Same as FileDownload/FileUpload. Rust has no async drop. If not consumed via `finish()` or `abort()`, the file handle leaks. The type logs a debug warning.
- **DFS paths must include server\share prefix**: When `SMB2_FLAGS_DFS_OPERATIONS` is set, the server expects the path to start with `server\share\` (MS-SMB2 3.2.4.3). `Tree::format_path()` handles this automatically for DFS shares. Without the prefix, Samba strips the first two path components, leading to wrong file opens.
- **DFS redirect changes the tree in-place**: After a DFS redirect, `tree.server`, `tree.share_name`, and `tree.tree_id` all change. Subsequent operations on the same tree use the target server directly -- they must use target-relative paths, not the original DFS paths.
- **tree.server stores addr:port**: The `server` field on `Tree` stores the full `addr:port` string (not just hostname) so `connection_for_tree` can distinguish servers that share the same hostname but use different ports.
- **Servers MAY split compound responses**: MS-SMB2 section 3.3.4.1.3 says the server SHOULD compound responses but is not required to. Samba (and QNAP firmware built on it) is known to split compound chains into separate frames in some scenarios; Windows Server does too under certain conditions. Compound-using methods (`read_file_compound`, `write_file_compound`, `fs_info`, `stat`, `rename`, `delete_file`, batch `*_files`) call `Connection::receive_compound_expected(n)` instead of `receive_compound()`, which transparently gathers additional frames if the server splits. Logged at DEBUG, not WARN -- it's a spec edge case, not a problem.

3413
vendor/smb2/src/client/connection.rs vendored Normal file

File diff suppressed because it is too large Load Diff

884
vendor/smb2/src/client/dfs.rs vendored Normal file
View File

@@ -0,0 +1,884 @@
//! DFS referral IOCTL helper and path resolver with referral cache.
//!
//! Sends `FSCTL_DFS_GET_REFERRALS` via IOCTL to resolve DFS paths. Connects
//! to IPC$ for the IOCTL exchange, similar to how `shares.rs` does for RPC.
//!
//! The [`DfsResolver`] caches referral responses with TTL and resolves UNC
//! paths using longest-prefix matching. All string comparisons are
//! case-insensitive (DFS paths are case-insensitive per MS-DFSC).
// DFS resolver is used by SmbClient for reactive DFS path resolution.
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use log::debug;
use crate::client::connection::Connection;
use crate::error::Result;
use crate::msg::dfs::{ReqGetDfsReferral, RespGetDfsReferral};
use crate::msg::ioctl::{
IoctlRequest, IoctlResponse, FSCTL_DFS_GET_REFERRALS, SMB2_0_IOCTL_IS_FSCTL,
};
use crate::msg::tree_connect::{TreeConnectRequest, TreeConnectRequestFlags, TreeConnectResponse};
use crate::msg::tree_disconnect::TreeDisconnectRequest;
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
use crate::types::status::NtStatus;
use crate::types::{Command, FileId, TreeId};
use crate::Error;
/// Maximum output buffer size for DFS referral responses (8 KiB).
const DFS_MAX_OUTPUT_RESPONSE: u32 = 8192;
/// Send a DFS referral request and return the parsed response.
///
/// Connects to IPC$ (or reuses an existing tree), sends
/// `FSCTL_DFS_GET_REFERRALS` via IOCTL with `FileId::SENTINEL`, and
/// parses the response.
///
/// The `path` should be a UNC-style path with a single leading backslash
/// (for example, `\server\share\dir`).
pub(crate) async fn get_dfs_referral(
conn: &mut Connection,
path: &str,
) -> Result<RespGetDfsReferral> {
// 1. Tree-connect to IPC$
let tree_id = tree_connect_ipc(conn).await?;
// Send the IOCTL, then clean up regardless of outcome
let result = send_dfs_ioctl(conn, tree_id, path).await;
// Tree-disconnect IPC$ (best-effort -- don't mask the real error)
let _ = tree_disconnect(conn, tree_id).await;
result
}
/// Connect to the IPC$ share, returning the tree ID.
async fn tree_connect_ipc(conn: &mut Connection) -> Result<TreeId> {
let server = conn.server_name().to_string();
let unc_path = format!(r"\\{}\IPC$", server);
let req = TreeConnectRequest {
flags: TreeConnectRequestFlags::default(),
path: unc_path,
};
let frame = conn.execute(Command::TreeConnect, &req, None).await?;
if frame.header.command != Command::TreeConnect {
return Err(Error::invalid_data(format!(
"expected TreeConnect response, got {:?}",
frame.header.command
)));
}
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::TreeConnect,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let _resp = TreeConnectResponse::unpack(&mut cursor)?;
let tree_id = frame
.header
.tree_id
.ok_or_else(|| Error::invalid_data("TreeConnect response missing tree ID"))?;
debug!("dfs: connected to IPC$, tree_id={}", tree_id);
Ok(tree_id)
}
/// Build and send the FSCTL_DFS_GET_REFERRALS IOCTL, parse the response.
async fn send_dfs_ioctl(
conn: &mut Connection,
tree_id: TreeId,
path: &str,
) -> Result<RespGetDfsReferral> {
// Build the referral request payload
let referral_req = ReqGetDfsReferral {
max_referral_level: 4,
request_file_name: path.to_string(),
};
let mut req_cursor = WriteCursor::new();
referral_req.pack(&mut req_cursor);
let input_data = req_cursor.into_inner();
debug!(
"dfs: sending FSCTL_DFS_GET_REFERRALS for {:?} ({} bytes input)",
path,
input_data.len()
);
// Build the IOCTL request
let ioctl_req = IoctlRequest {
ctl_code: FSCTL_DFS_GET_REFERRALS,
file_id: FileId::SENTINEL,
max_input_response: 0,
max_output_response: DFS_MAX_OUTPUT_RESPONSE,
flags: SMB2_0_IOCTL_IS_FSCTL,
input_data,
};
let frame = conn
.execute(Command::Ioctl, &ioctl_req, Some(tree_id))
.await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Ioctl,
});
}
// Parse the IOCTL response envelope
let mut cursor = ReadCursor::new(&frame.body);
let ioctl_resp = IoctlResponse::unpack(&mut cursor)?;
debug!(
"dfs: received IOCTL response ({} bytes output)",
ioctl_resp.output_data.len()
);
// Parse the DFS referral from the output buffer
let mut ref_cursor = ReadCursor::new(&ioctl_resp.output_data);
let referral_resp = RespGetDfsReferral::unpack(&mut ref_cursor)?;
debug!(
"dfs: parsed {} referral entries (path_consumed={})",
referral_resp.entries.len(),
referral_resp.path_consumed
);
Ok(referral_resp)
}
/// Disconnect from a tree.
async fn tree_disconnect(conn: &mut Connection, tree_id: TreeId) -> Result<()> {
let body = TreeDisconnectRequest;
let frame = conn
.execute(Command::TreeDisconnect, &body, Some(tree_id))
.await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::TreeDisconnect,
});
}
debug!("dfs: disconnected from IPC$");
Ok(())
}
// ── DFS resolver types ───────────────────────────────────────────────
/// A resolved DFS path ready for connection.
#[derive(Debug, Clone)]
pub(crate) struct ResolvedPath {
/// Server hostname (or IP) to connect to.
pub server: String,
/// Port to connect on (default 445).
pub port: u16,
/// Share name to tree-connect.
pub share: String,
/// Remaining path within the share (may be empty).
pub remaining_path: String,
}
/// A single DFS target from a referral response.
#[derive(Debug, Clone)]
struct DfsTarget {
/// Server hostname from the network_address field.
server: String,
/// Share name from the network_address field.
share: String,
/// Any remaining path suffix from the network_address.
remaining_prefix: String,
}
/// A cached DFS referral entry with TTL.
#[derive(Debug, Clone)]
struct CachedReferral {
/// The DFS path prefix this referral covers (lowercase for matching).
dfs_path_prefix: String,
/// Available targets (first is preferred).
targets: Vec<DfsTarget>,
/// When this entry expires.
expires_at: Instant,
}
/// DFS referral cache and path resolver.
///
/// Maintains a cache of DFS referral responses keyed by path prefix.
/// Resolves UNC paths by longest-prefix matching against the cache,
/// falling back to an IOCTL referral request on cache miss.
pub(crate) struct DfsResolver {
cache: HashMap<String, CachedReferral>,
/// Counters surfaced through [`SmbClient::diagnostics`].
cache_hits: AtomicU64,
referrals_resolved: AtomicU64,
}
impl DfsResolver {
/// Create a new empty resolver.
pub fn new() -> Self {
Self {
cache: HashMap::new(),
cache_hits: AtomicU64::new(0),
referrals_resolved: AtomicU64::new(0),
}
}
/// `(cache_hits, referrals_resolved)` for diagnostics.
pub(crate) fn counters(&self) -> (u64, u64) {
(
self.cache_hits.load(Ordering::Relaxed),
self.referrals_resolved.load(Ordering::Relaxed),
)
}
/// Iterate the cache entries (including expired ones — eviction is
/// lazy). Used by [`SmbClient::diagnostics`].
pub(crate) fn cache_entries(&self) -> Vec<crate::client::diagnostics::DfsCacheEntry> {
let now = Instant::now();
self.cache
.values()
.map(|e| crate::client::diagnostics::DfsCacheEntry {
path_prefix: e.dfs_path_prefix.clone(),
target_count: e.targets.len(),
expires_in: if e.expires_at > now {
Some(e.expires_at - now)
} else {
None
},
})
.collect()
}
/// Resolve a UNC path by checking the cache first, then querying the server.
///
/// `unc_path` should be like `\\server\share\path\to\file`.
/// `conn` is the connection to the server that returned `STATUS_PATH_NOT_COVERED`.
pub async fn resolve(
&mut self,
conn: &mut Connection,
unc_path: &str,
) -> Result<Vec<ResolvedPath>> {
// 1. Check cache (longest prefix match)
if let Some(resolved) = self.resolve_from_cache(unc_path) {
self.cache_hits.fetch_add(1, Ordering::Relaxed);
debug!("dfs: cache hit for {:?}", unc_path);
return Ok(resolved);
}
// 2. Send referral request.
// Convert \\server\share\path to \server\share\path (single leading
// backslash for the IOCTL).
let referral_path = if unc_path.starts_with("\\\\") {
&unc_path[1..] // strip one leading backslash
} else {
unc_path
};
debug!("dfs: cache miss, sending referral for {:?}", referral_path);
let resp = get_dfs_referral(conn, referral_path).await?;
self.referrals_resolved.fetch_add(1, Ordering::Relaxed);
// 3. Cache the result
self.cache_referral(&resp);
// 4. Resolve from the freshly cached entry
self.resolve_from_cache(unc_path).ok_or_else(|| {
Error::invalid_data("DFS referral response did not match the requested path")
})
}
/// Try to resolve a path from the cache. Returns `None` on cache miss or
/// expiry. Returns a `Vec` of [`ResolvedPath`]s (multiple targets for
/// failover).
pub(crate) fn resolve_from_cache(&self, unc_path: &str) -> Option<Vec<ResolvedPath>> {
let normalized = unc_path.to_lowercase().replace('/', "\\");
// Longest prefix match
let mut best_match: Option<&CachedReferral> = None;
for entry in self.cache.values() {
if normalized.starts_with(&entry.dfs_path_prefix)
&& entry.expires_at > Instant::now()
&& best_match.is_none_or(|b| entry.dfs_path_prefix.len() > b.dfs_path_prefix.len())
{
best_match = Some(entry);
}
}
let entry = best_match?;
// Strip the consumed prefix and build ResolvedPaths
let remaining = &normalized[entry.dfs_path_prefix.len()..];
let remaining = remaining.trim_start_matches('\\');
let resolved: Vec<ResolvedPath> = entry
.targets
.iter()
.map(|target| {
let full_remaining = if target.remaining_prefix.is_empty() {
remaining.to_string()
} else if remaining.is_empty() {
target.remaining_prefix.clone()
} else {
format!("{}\\{}", target.remaining_prefix, remaining)
};
ResolvedPath {
server: target.server.clone(),
port: 445,
share: target.share.clone(),
remaining_path: full_remaining,
}
})
.collect();
Some(resolved)
}
/// Store a referral response in the cache.
fn cache_referral(&mut self, resp: &RespGetDfsReferral) {
if resp.entries.is_empty() {
return;
}
// Use the dfs_path from the first entry as the cache key.
// Normalize to lowercase backslash form with `\\` prefix (UNC canonical).
let mut dfs_path_prefix = resp.entries[0].dfs_path.to_lowercase().replace('/', "\\");
if !dfs_path_prefix.starts_with("\\\\") {
if let Some(stripped) = dfs_path_prefix.strip_prefix('\\') {
dfs_path_prefix = format!("\\\\{stripped}");
}
}
// Parse targets from entries
let targets: Vec<DfsTarget> = resp
.entries
.iter()
.filter_map(|e| parse_unc_target(&e.network_address))
.collect();
if targets.is_empty() {
return;
}
let ttl = resp.entries[0].ttl.max(1); // At least 1 second
debug!(
"dfs: caching {:?} with {} targets, ttl={}s",
dfs_path_prefix,
targets.len(),
ttl
);
self.cache.insert(
dfs_path_prefix.clone(),
CachedReferral {
dfs_path_prefix,
targets,
expires_at: Instant::now() + Duration::from_secs(ttl as u64),
},
);
}
}
/// Parse a UNC network_address into server, share, and remaining path.
///
/// Input: `\\server\share` or `\\server\share\path`.
/// Returns `None` if the format is invalid.
fn parse_unc_target(network_address: &str) -> Option<DfsTarget> {
let path = network_address.trim_start_matches('\\');
let mut parts = path.splitn(3, '\\');
let server = parts.next()?.to_string();
let share = parts.next()?.to_string();
let remaining_prefix = parts.next().unwrap_or("").to_string();
if server.is_empty() || share.is_empty() {
return None;
}
Some(DfsTarget {
server,
share,
remaining_prefix,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::connection::pack_message;
use crate::client::test_helpers::{build_tree_connect_response, setup_connection};
use crate::msg::header::{ErrorResponse, Header};
use crate::msg::ioctl::IoctlResponse as IoctlResp;
use crate::msg::tree_connect::ShareType;
use crate::msg::tree_disconnect::TreeDisconnectResponse;
use crate::transport::MockTransport;
use crate::types::TreeId;
use std::sync::Arc;
/// Build an IOCTL response containing the given output data.
fn build_ioctl_response(output_data: Vec<u8>) -> Vec<u8> {
let mut h = Header::new_request(Command::Ioctl);
h.flags.set_response();
h.credits = 32;
let body = IoctlResp {
ctl_code: FSCTL_DFS_GET_REFERRALS,
file_id: FileId::SENTINEL,
flags: SMB2_0_IOCTL_IS_FSCTL,
output_data,
};
pack_message(&h, &body)
}
/// Build an IOCTL error response with the given status.
fn build_ioctl_error_response(status: NtStatus) -> Vec<u8> {
let mut h = Header::new_request(Command::Ioctl);
h.flags.set_response();
h.credits = 32;
h.status = status;
let body = ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
pack_message(&h, &body)
}
/// Build a TREE_DISCONNECT response.
fn build_tree_disconnect_response() -> Vec<u8> {
let mut h = Header::new_request(Command::TreeDisconnect);
h.flags.set_response();
h.credits = 32;
pack_message(&h, &TreeDisconnectResponse)
}
/// Pack a known DFS referral response into bytes.
///
/// Builds a V3 referral with the given entries.
fn pack_dfs_referral_response(
path_consumed: u16,
header_flags: u32,
entries: &[(&str, &str, &str, u32)], // (dfs_path, alt_path, net_addr, ttl)
) -> Vec<u8> {
// We build a V3 referral response manually.
// Entry fixed size: 4 (version+size) + 2+2+4 (server_type+flags+ttl)
// + 2+2+2 (offsets) + 16 (guid) = 34 bytes
let entry_fixed_size: u16 = 34;
let num_entries = entries.len() as u16;
let total_fixed = entry_fixed_size * num_entries;
// Pre-compute all string bytes
let entry_strings: Vec<(Vec<u8>, Vec<u8>, Vec<u8>)> = entries
.iter()
.map(|(dfs, alt, net, _)| {
(
encode_null_utf16(dfs),
encode_null_utf16(alt),
encode_null_utf16(net),
)
})
.collect();
// Compute cumulative string offsets relative to each entry's start.
// All strings come after all fixed entries. The offset for entry i
// is relative to entry i's start position.
let mut buf = Vec::new();
// Response header (8 bytes)
buf.extend_from_slice(&path_consumed.to_le_bytes());
buf.extend_from_slice(&num_entries.to_le_bytes());
buf.extend_from_slice(&header_flags.to_le_bytes());
// Calculate where strings start (after all fixed entries, but
// offsets are measured from the start of the entry data, not from
// the response header -- since RespGetDfsReferral::unpack reads
// the header first and then works with the remaining bytes).
//
// Actually, offsets in V3 entries are relative to the entry start
// within the entry data buffer.
// Accumulate string buffer contents and compute per-entry offsets.
let mut string_buf = Vec::new();
let mut per_entry_offsets = Vec::new();
for (i, (dfs_bytes, alt_bytes, net_bytes)) in entry_strings.iter().enumerate() {
let entry_start = i as u16 * entry_fixed_size;
let strings_base = total_fixed + string_buf.len() as u16;
let dfs_offset = strings_base - entry_start;
let alt_offset = dfs_offset + dfs_bytes.len() as u16;
let net_offset = alt_offset + alt_bytes.len() as u16;
per_entry_offsets.push((dfs_offset, alt_offset, net_offset));
string_buf.extend_from_slice(dfs_bytes);
string_buf.extend_from_slice(alt_bytes);
string_buf.extend_from_slice(net_bytes);
}
// Write fixed entries
for (i, (_, _, _, ttl)) in entries.iter().enumerate() {
let (dfs_off, alt_off, net_off) = per_entry_offsets[i];
buf.extend_from_slice(&3u16.to_le_bytes()); // version = 3
buf.extend_from_slice(&entry_fixed_size.to_le_bytes()); // size
buf.extend_from_slice(&0u16.to_le_bytes()); // server_type
buf.extend_from_slice(&0u16.to_le_bytes()); // referral_entry_flags
buf.extend_from_slice(&ttl.to_le_bytes()); // ttl
buf.extend_from_slice(&dfs_off.to_le_bytes());
buf.extend_from_slice(&alt_off.to_le_bytes());
buf.extend_from_slice(&net_off.to_le_bytes());
buf.extend_from_slice(&[0u8; 16]); // service_site_guid
}
// Write string buffer
buf.extend_from_slice(&string_buf);
buf
}
/// Encode a string as null-terminated UTF-16LE bytes.
fn encode_null_utf16(s: &str) -> Vec<u8> {
let mut out = Vec::new();
for cu in s.encode_utf16() {
out.extend_from_slice(&cu.to_le_bytes());
}
out.extend_from_slice(&[0x00, 0x00]);
out
}
#[tokio::test]
async fn dfs_referral_ioctl_flow() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
let tree_id = TreeId(99);
// Build the DFS referral payload
let referral_bytes = pack_dfs_referral_response(
48, // path_consumed
0x02, // header_flags (StorageServers)
&[
(
r"\domain\dfs\docs",
r"\domain\dfs\docs",
r"\server1\share",
600,
),
(
r"\domain\dfs\docs",
r"\domain\dfs\docs",
r"\server2\share",
300,
),
],
);
// Queue responses: TreeConnect, IOCTL, TreeDisconnect
mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe));
mock.queue_response(build_ioctl_response(referral_bytes));
mock.queue_response(build_tree_disconnect_response());
let resp = get_dfs_referral(&mut conn, r"\domain\dfs\docs")
.await
.unwrap();
assert_eq!(resp.path_consumed, 48);
assert_eq!(resp.header_flags, 0x02);
assert_eq!(resp.entries.len(), 2);
assert_eq!(resp.entries[0].version, 3);
assert_eq!(resp.entries[0].dfs_path, r"\domain\dfs\docs");
assert_eq!(resp.entries[0].network_address, r"\server1\share");
assert_eq!(resp.entries[0].ttl, 600);
assert_eq!(resp.entries[1].network_address, r"\server2\share");
assert_eq!(resp.entries[1].ttl, 300);
// Should have sent 3 messages: TreeConnect, IOCTL, TreeDisconnect
assert_eq!(mock.sent_count(), 3);
}
#[tokio::test]
async fn dfs_referral_ioctl_error() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
let tree_id = TreeId(99);
// Queue responses: TreeConnect, IOCTL error, TreeDisconnect
mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe));
mock.queue_response(build_ioctl_error_response(NtStatus::NOT_FOUND));
mock.queue_response(build_tree_disconnect_response());
let result = get_dfs_referral(&mut conn, r"\nonexistent\path").await;
assert!(result.is_err());
let err = result.unwrap_err();
match &err {
Error::Protocol { status, command } => {
assert_eq!(*status, NtStatus::NOT_FOUND);
assert_eq!(*command, Command::Ioctl);
}
other => panic!("expected Protocol error, got: {other:?}"),
}
// Should still send TreeDisconnect even after IOCTL error
assert_eq!(mock.sent_count(), 3);
}
// ── parse_unc_target tests ───────────────────────────────────────
#[test]
fn parse_unc_target_basic() {
let t = parse_unc_target(r"\\server\share").unwrap();
assert_eq!(t.server, "server");
assert_eq!(t.share, "share");
assert_eq!(t.remaining_prefix, "");
}
#[test]
fn parse_unc_target_with_path() {
let t = parse_unc_target(r"\\server\share\path\to").unwrap();
assert_eq!(t.server, "server");
assert_eq!(t.share, "share");
assert_eq!(t.remaining_prefix, r"path\to");
}
#[test]
fn parse_unc_target_invalid() {
assert!(parse_unc_target(r"\\").is_none());
assert!(parse_unc_target("").is_none());
assert!(parse_unc_target(r"\\server").is_none());
// Single backslash + server but no share
assert!(parse_unc_target(r"\server").is_none());
}
#[test]
fn parse_unc_target_single_backslash_prefix() {
// Network addresses with single backslash prefix should also work.
let t = parse_unc_target(r"\server\share").unwrap();
assert_eq!(t.server, "server");
assert_eq!(t.share, "share");
assert_eq!(t.remaining_prefix, "");
}
#[test]
fn parse_unc_target_triple_backslash() {
// Extra leading backslashes are stripped.
let t = parse_unc_target(r"\\\server\share\path").unwrap();
assert_eq!(t.server, "server");
assert_eq!(t.share, "share");
assert_eq!(t.remaining_prefix, "path");
}
#[test]
fn parse_unc_target_ip_address() {
// IP addresses as server names.
let t = parse_unc_target(r"\\192.168.1.100\data").unwrap();
assert_eq!(t.server, "192.168.1.100");
assert_eq!(t.share, "data");
assert_eq!(t.remaining_prefix, "");
}
#[test]
fn parse_unc_target_deep_path() {
// The remaining prefix captures everything after server\share.
let t = parse_unc_target(r"\\server\share\a\b\c\d").unwrap();
assert_eq!(t.server, "server");
assert_eq!(t.share, "share");
assert_eq!(t.remaining_prefix, r"a\b\c\d");
}
#[test]
fn parse_unc_target_empty_components() {
// Empty server or share should return None.
assert!(parse_unc_target(r"\\\\share").is_none()); // empty server
assert!(parse_unc_target(r"\\\").is_none()); // server is empty after strip
}
// ── DfsResolver tests ────────────────────────────────────────────
/// Helper: build a RespGetDfsReferral for cache tests.
fn make_referral(
dfs_path: &str,
entries: &[(&str, u32)], // (network_address, ttl)
) -> RespGetDfsReferral {
use crate::msg::dfs::DfsReferralEntry;
let referral_entries: Vec<DfsReferralEntry> = entries
.iter()
.map(|(net_addr, ttl)| DfsReferralEntry {
version: 3,
server_type: 0,
referral_entry_flags: 0,
ttl: *ttl,
dfs_path: dfs_path.to_string(),
dfs_alternate_path: dfs_path.to_string(),
network_address: net_addr.to_string(),
})
.collect();
RespGetDfsReferral {
path_consumed: 0,
header_flags: 0,
entries: referral_entries,
}
}
#[test]
fn resolver_cache_hit() {
let mut resolver = DfsResolver::new();
let resp = make_referral(r"\domain\dfs\docs", &[(r"\\server1\share", 600)]);
resolver.cache_referral(&resp);
let result = resolver.resolve_from_cache(r"\\domain\dfs\docs\file.txt");
assert!(result.is_some());
let paths = result.unwrap();
assert_eq!(paths.len(), 1);
assert_eq!(paths[0].server, "server1");
assert_eq!(paths[0].share, "share");
assert_eq!(paths[0].port, 445);
assert_eq!(paths[0].remaining_path, "file.txt");
}
#[test]
fn resolver_cache_miss() {
let resolver = DfsResolver::new();
let result = resolver.resolve_from_cache(r"\\server\share\file.txt");
assert!(result.is_none());
}
#[test]
fn resolver_cache_expired() {
let mut resolver = DfsResolver::new();
// Insert with TTL=0 -- cache_referral clamps to 1s, so we need to
// manually insert an already-expired entry.
let targets = vec![DfsTarget {
server: "srv".to_string(),
share: "data".to_string(),
remaining_prefix: String::new(),
}];
resolver.cache.insert(
r"\domain\dfs".to_string(),
CachedReferral {
dfs_path_prefix: r"\domain\dfs".to_string(),
targets,
expires_at: Instant::now() - Duration::from_secs(1),
},
);
let result = resolver.resolve_from_cache(r"\\domain\dfs\file.txt");
assert!(result.is_none(), "expired entry should not match");
}
#[test]
fn resolver_cache_longest_prefix() {
let mut resolver = DfsResolver::new();
// Insert a short prefix
let short = make_referral(r"\domain\dfs", &[(r"\\server1\root", 600)]);
resolver.cache_referral(&short);
// Insert a longer prefix
let long = make_referral(r"\domain\dfs\docs", &[(r"\\server2\docs", 600)]);
resolver.cache_referral(&long);
// Should match the longer prefix
let result = resolver
.resolve_from_cache(r"\\domain\dfs\docs\file.txt")
.unwrap();
assert_eq!(result[0].server, "server2");
assert_eq!(result[0].share, "docs");
assert_eq!(result[0].remaining_path, "file.txt");
// A path that only matches the short prefix
let result2 = resolver
.resolve_from_cache(r"\\domain\dfs\other\file.txt")
.unwrap();
assert_eq!(result2[0].server, "server1");
assert_eq!(result2[0].share, "root");
assert_eq!(result2[0].remaining_path, r"other\file.txt");
}
#[test]
fn resolver_multiple_targets() {
let mut resolver = DfsResolver::new();
let resp = make_referral(
r"\domain\dfs\docs",
&[(r"\\server1\share", 600), (r"\\server2\share", 300)],
);
resolver.cache_referral(&resp);
let result = resolver
.resolve_from_cache(r"\\domain\dfs\docs\file.txt")
.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].server, "server1");
assert_eq!(result[1].server, "server2");
// Both should have the same remaining path
assert_eq!(result[0].remaining_path, "file.txt");
assert_eq!(result[1].remaining_path, "file.txt");
}
#[test]
fn resolver_path_normalization() {
let mut resolver = DfsResolver::new();
// Cache with backslash-separated DFS path
let resp = make_referral(r"\domain\dfs\docs", &[(r"\\server\share", 600)]);
resolver.cache_referral(&resp);
// Resolve with double-backslash prefix and mixed case
let result = resolver
.resolve_from_cache(r"\\DOMAIN\DFS\DOCS\Sub\File.txt")
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].server, "server");
assert_eq!(result[0].share, "share");
// remaining_path is lowercased because we normalize the full input
assert_eq!(result[0].remaining_path, r"sub\file.txt");
// Forward slashes should also work
let result2 = resolver
.resolve_from_cache(r"\\domain/dfs/docs/other.txt")
.unwrap();
assert_eq!(result2[0].remaining_path, "other.txt");
}
#[test]
fn resolver_remaining_prefix_from_target() {
let mut resolver = DfsResolver::new();
// Target has a remaining prefix (network_address includes a subpath)
let resp = make_referral(r"\domain\dfs\docs", &[(r"\\server\share\subdir", 600)]);
resolver.cache_referral(&resp);
// With additional path after the DFS prefix
let result = resolver
.resolve_from_cache(r"\\domain\dfs\docs\file.txt")
.unwrap();
assert_eq!(result[0].remaining_path, r"subdir\file.txt");
// Without additional path -- just the target's remaining prefix
let result2 = resolver.resolve_from_cache(r"\\domain\dfs\docs").unwrap();
assert_eq!(result2[0].remaining_path, "subdir");
}
}

1048
vendor/smb2/src/client/diagnostics.rs vendored Normal file

File diff suppressed because it is too large Load Diff

1495
vendor/smb2/src/client/mod.rs vendored Normal file

File diff suppressed because it is too large Load Diff

670
vendor/smb2/src/client/pipeline.rs vendored Normal file
View File

@@ -0,0 +1,670 @@
//! Unified operation pipeline for concurrent SMB2 operations.
//!
//! The [`Pipeline`] sends multiple SMB2 requests without waiting for each
//! response, filling the credit window. Results are collected and returned
//! once all operations complete.
//!
//! This is a first-iteration pipeline that executes a batch of operations.
//! Future iterations will add a channel-based streaming interface, compound
//! request construction, and chunk-level interleaving for large files.
use log::debug;
use crate::client::connection::Connection;
use crate::client::tree::Tree;
/// An operation to execute through the pipeline.
#[derive(Debug, Clone)]
pub enum Op {
/// Read a file, returning its contents.
ReadFile(String),
/// Write data to a file (create or overwrite).
WriteFile(String, Vec<u8>),
/// Delete a file.
Delete(String),
/// List a directory.
ListDirectory(String),
/// Get file metadata.
Stat(String),
}
/// Result of a pipeline operation.
#[derive(Debug)]
pub enum OpResult {
/// File data read successfully.
FileData {
/// The path that was read.
path: String,
/// The file contents.
data: Vec<u8>,
},
/// File written successfully.
Written {
/// The path that was written.
path: String,
/// Number of bytes written.
bytes_written: u64,
},
/// File deleted successfully.
Deleted {
/// The path that was deleted.
path: String,
},
/// Directory listing.
DirEntries {
/// The path that was listed.
path: String,
/// The directory entries.
entries: Vec<crate::client::tree::DirectoryEntry>,
},
/// File metadata.
Stat {
/// The path that was queried.
path: String,
/// The file information.
info: crate::client::tree::FileInfo,
},
/// Operation failed.
Error {
/// The path that failed.
path: String,
/// The error that occurred.
error: crate::Error,
},
}
/// A pipeline for executing multiple SMB operations as a batch.
///
/// The pipeline executes operations sequentially in this first iteration.
/// Each multi-step operation (for example, read = CREATE + READ + CLOSE) runs
/// to completion before the next operation starts. Future iterations will
/// interleave steps from different operations to fill the credit window.
pub struct Pipeline<'a> {
conn: &'a mut Connection,
tree: &'a Tree,
}
impl<'a> Pipeline<'a> {
/// Create a new pipeline bound to a connection and tree.
pub fn new(conn: &'a mut Connection, tree: &'a Tree) -> Self {
Self { conn, tree }
}
/// Execute a batch of operations and return the results.
///
/// Results are returned in the same order as the input operations.
/// Each operation that fails produces an [`OpResult::Error`] rather
/// than aborting the entire batch.
pub async fn execute(&mut self, ops: Vec<Op>) -> Vec<OpResult> {
let mut results = Vec::with_capacity(ops.len());
for op in ops {
let result = self.execute_one(op).await;
results.push(result);
}
results
}
/// Execute a single operation.
async fn execute_one(&mut self, op: Op) -> OpResult {
match op {
Op::ReadFile(path) => {
debug!("pipeline: read_file path={}", path);
match self.tree.read_file(self.conn, &path).await {
Ok(data) => OpResult::FileData { path, data },
Err(e) => OpResult::Error { path, error: e },
}
}
Op::WriteFile(path, data) => {
debug!("pipeline: write_file path={}", path);
match self.tree.write_file(self.conn, &path, &data).await {
Ok(bytes_written) => OpResult::Written {
path,
bytes_written,
},
Err(e) => OpResult::Error { path, error: e },
}
}
Op::Delete(path) => {
debug!("pipeline: delete path={}", path);
match self.tree.delete_file(self.conn, &path).await {
Ok(()) => OpResult::Deleted { path },
Err(e) => OpResult::Error { path, error: e },
}
}
Op::ListDirectory(path) => {
debug!("pipeline: list_directory path={}", path);
match self.tree.list_directory(self.conn, &path).await {
Ok(entries) => OpResult::DirEntries { path, entries },
Err(e) => OpResult::Error { path, error: e },
}
}
Op::Stat(path) => {
debug!("pipeline: stat path={}", path);
match self.tree.stat(self.conn, &path).await {
Ok(info) => OpResult::Stat { path, info },
Err(e) => OpResult::Error { path, error: e },
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::connection::pack_message;
use crate::client::test_helpers::{
build_close_response, build_create_response, setup_connection,
};
use crate::client::tree::Tree;
use crate::msg::create::{CreateAction, CreateResponse};
use crate::msg::header::{ErrorResponse, Header};
use crate::msg::query_directory::QueryDirectoryResponse;
use crate::msg::query_info::QueryInfoResponse;
use crate::msg::read::ReadResponse;
use crate::msg::write::WriteResponse;
use crate::pack::FileTime;
use crate::transport::MockTransport;
use crate::types::status::NtStatus;
use crate::types::{Command, FileId, OplockLevel, TreeId};
use std::sync::Arc;
fn test_tree() -> Tree {
Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
}
}
fn build_create_response_directory(file_id: FileId) -> Vec<u8> {
let mut h = Header::new_request(Command::Create);
h.flags.set_response();
h.credits = 32;
let body = CreateResponse {
oplock_level: OplockLevel::None,
flags: 0,
create_action: CreateAction::FileOpened,
creation_time: FileTime(132_000_000_000_000_000),
last_access_time: FileTime(132_000_000_000_000_000),
last_write_time: FileTime(133_000_000_000_000_000),
change_time: FileTime(133_000_000_000_000_000),
allocation_size: 0,
end_of_file: 0,
file_attributes: 0x10, // DIRECTORY
file_id,
create_contexts: vec![],
};
pack_message(&h, &body)
}
fn build_flush_response() -> Vec<u8> {
let mut h = Header::new_request(Command::Flush);
h.flags.set_response();
h.credits = 32;
let body = crate::msg::flush::FlushResponse;
pack_message(&h, &body)
}
fn build_read_response(data: Vec<u8>) -> Vec<u8> {
let mut h = Header::new_request(Command::Read);
h.flags.set_response();
h.credits = 32;
let body = ReadResponse {
data_offset: 0x50,
data_remaining: 0,
flags: 0,
data,
};
pack_message(&h, &body)
}
fn build_write_response(count: u32) -> Vec<u8> {
let mut h = Header::new_request(Command::Write);
h.flags.set_response();
h.credits = 32;
let body = WriteResponse {
count,
remaining: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
};
pack_message(&h, &body)
}
fn build_query_info_response(output_buffer: Vec<u8>) -> Vec<u8> {
let mut h = Header::new_request(Command::QueryInfo);
h.flags.set_response();
h.credits = 32;
let body = QueryInfoResponse { output_buffer };
pack_message(&h, &body)
}
fn build_query_directory_response(status: NtStatus, entries_data: Vec<u8>) -> Vec<u8> {
let mut h = Header::new_request(Command::QueryDirectory);
h.flags.set_response();
h.credits = 32;
h.status = status;
if status == NtStatus::NO_MORE_FILES {
let body = ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
return pack_message(&h, &body);
}
let body = QueryDirectoryResponse {
output_buffer: entries_data,
};
pack_message(&h, &body)
}
/// Build a FileBasicInformation buffer (40 bytes).
fn build_file_basic_info(
creation_time: u64,
last_access_time: u64,
last_write_time: u64,
change_time: u64,
file_attributes: u32,
) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&creation_time.to_le_bytes());
buf.extend_from_slice(&last_access_time.to_le_bytes());
buf.extend_from_slice(&last_write_time.to_le_bytes());
buf.extend_from_slice(&change_time.to_le_bytes());
buf.extend_from_slice(&file_attributes.to_le_bytes());
// Padding to 40 bytes (Reserved)
buf.extend_from_slice(&0u32.to_le_bytes());
buf
}
/// Build a FileStandardInformation buffer (24 bytes).
fn build_file_standard_info(
allocation_size: u64,
end_of_file: u64,
number_of_links: u32,
delete_pending: bool,
directory: bool,
) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&allocation_size.to_le_bytes());
buf.extend_from_slice(&end_of_file.to_le_bytes());
buf.extend_from_slice(&number_of_links.to_le_bytes());
buf.push(if delete_pending { 1 } else { 0 });
buf.push(if directory { 1 } else { 0 });
buf.extend_from_slice(&0u16.to_le_bytes()); // Reserved
buf
}
/// Build a single FileBothDirectoryInformation entry.
fn build_file_both_dir_info(
name: &str,
size: u64,
is_directory: bool,
next_offset: u32,
) -> Vec<u8> {
let name_u16: Vec<u16> = name.encode_utf16().collect();
let name_bytes_len = name_u16.len() * 2;
let mut buf = Vec::new();
buf.extend_from_slice(&next_offset.to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes()); // FileIndex
buf.extend_from_slice(&132_000_000_000_000_000u64.to_le_bytes()); // CreationTime
buf.extend_from_slice(&132_000_000_000_000_000u64.to_le_bytes()); // LastAccessTime
buf.extend_from_slice(&133_000_000_000_000_000u64.to_le_bytes()); // LastWriteTime
buf.extend_from_slice(&133_000_000_000_000_000u64.to_le_bytes()); // ChangeTime
buf.extend_from_slice(&size.to_le_bytes());
buf.extend_from_slice(&((size + 4095) & !4095).to_le_bytes()); // AllocationSize
let attrs: u32 = if is_directory { 0x10 } else { 0x20 };
buf.extend_from_slice(&attrs.to_le_bytes());
buf.extend_from_slice(&(name_bytes_len as u32).to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes()); // EaSize
buf.push(0); // ShortNameLength
buf.push(0); // Reserved
buf.extend_from_slice(&[0u8; 24]); // ShortName
for &u in &name_u16 {
buf.extend_from_slice(&u.to_le_bytes());
}
buf
}
/// Build a compound response frame with proper NextCommand offsets and padding.
fn build_compound_response_frame(responses: &[Vec<u8>]) -> Vec<u8> {
let mut padded: Vec<Vec<u8>> = Vec::new();
for (i, resp) in responses.iter().enumerate() {
let mut r = resp.clone();
let is_last = i == responses.len() - 1;
if !is_last {
// Pad to 8-byte alignment.
let remainder = r.len() % 8;
if remainder != 0 {
r.resize(r.len() + (8 - remainder), 0);
}
// Set NextCommand.
let next_cmd = r.len() as u32;
r[20..24].copy_from_slice(&next_cmd.to_le_bytes());
}
padded.push(r);
}
let mut frame = Vec::new();
for r in &padded {
frame.extend_from_slice(r);
}
frame
}
/// Build a compound read response frame (CREATE + READ + CLOSE) for pipeline tests.
fn build_compound_read_response(file_id: FileId, data: Vec<u8>) -> Vec<u8> {
let create_resp = build_create_response(file_id, data.len() as u64);
let read_resp = build_read_response(data);
let close_resp = build_close_response();
build_compound_response_frame(&[create_resp, read_resp, close_resp])
}
#[tokio::test]
async fn pipeline_batch_of_three_reads() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 1,
volatile: 2,
};
// Three read operations, each needs a compound CREATE + READ + CLOSE frame.
for i in 0..3 {
let data = format!("content_{}", i);
mock.queue_response(build_compound_read_response(file_id, data.into_bytes()));
}
let mut conn = setup_connection(&mock);
let tree = test_tree();
let mut pipeline = Pipeline::new(&mut conn, &tree);
let results = pipeline
.execute(vec![
Op::ReadFile("file1.txt".to_string()),
Op::ReadFile("file2.txt".to_string()),
Op::ReadFile("file3.txt".to_string()),
])
.await;
assert_eq!(results.len(), 3);
for (i, result) in results.into_iter().enumerate() {
match result {
OpResult::FileData { path, data } => {
assert_eq!(path, format!("file{}.txt", i + 1));
assert_eq!(data, format!("content_{}", i).into_bytes());
}
other => panic!("expected FileData, got {:?}", other),
}
}
}
#[tokio::test]
async fn pipeline_mixed_ops() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 1,
volatile: 2,
};
// Op 1: ReadFile -- compound CREATE + READ + CLOSE
mock.queue_response(build_compound_read_response(file_id, b"hello".to_vec()));
// Op 2: Delete -- compound CREATE(DELETE_ON_CLOSE) + CLOSE
let del_create = build_create_response(file_id, 0);
let del_close = build_close_response();
mock.queue_response(build_compound_response_frame(&[del_create, del_close]));
// Op 3: ListDirectory -- CREATE + QUERY_DIR + QUERY_DIR(NO_MORE) + CLOSE
mock.queue_response(build_create_response_directory(file_id));
let entry = build_file_both_dir_info("test.txt", 100, false, 0);
mock.queue_response(build_query_directory_response(NtStatus::SUCCESS, entry));
mock.queue_response(build_query_directory_response(
NtStatus::NO_MORE_FILES,
vec![],
));
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = test_tree();
let mut pipeline = Pipeline::new(&mut conn, &tree);
let results = pipeline
.execute(vec![
Op::ReadFile("data.bin".to_string()),
Op::Delete("old.txt".to_string()),
Op::ListDirectory("docs".to_string()),
])
.await;
assert_eq!(results.len(), 3);
match &results[0] {
OpResult::FileData { data, .. } => assert_eq!(data, b"hello"),
other => panic!("expected FileData, got {:?}", other),
}
match &results[1] {
OpResult::Deleted { path } => assert_eq!(path, "old.txt"),
other => panic!("expected Deleted, got {:?}", other),
}
match &results[2] {
OpResult::DirEntries { entries, .. } => {
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].name, "test.txt");
}
other => panic!("expected DirEntries, got {:?}", other),
}
}
#[tokio::test]
async fn pipeline_delete_file() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 1,
volatile: 2,
};
// DELETE = compound CREATE(DELETE_ON_CLOSE) + CLOSE
let create_resp = build_create_response(file_id, 0);
let close_resp = build_close_response();
let frame = build_compound_response_frame(&[create_resp, close_resp]);
mock.queue_response(frame);
let mut conn = setup_connection(&mock);
let tree = test_tree();
let mut pipeline = Pipeline::new(&mut conn, &tree);
let results = pipeline
.execute(vec![Op::Delete("remove_me.txt".to_string())])
.await;
assert_eq!(results.len(), 1);
match &results[0] {
OpResult::Deleted { path } => assert_eq!(path, "remove_me.txt"),
other => panic!("expected Deleted, got {:?}", other),
}
// One compound frame sent.
let sent = mock.sent_messages();
assert_eq!(sent.len(), 1);
}
#[tokio::test]
async fn pipeline_write_file() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 1,
volatile: 2,
};
// WRITE uses compound: CREATE+WRITE+FLUSH+CLOSE in one frame.
let create_resp = build_create_response(file_id, 0);
let write_resp = build_write_response(11);
let flush_resp = build_flush_response();
let close_resp = build_close_response();
let frame =
build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]);
mock.queue_response(frame);
let mut conn = setup_connection(&mock);
let tree = test_tree();
let mut pipeline = Pipeline::new(&mut conn, &tree);
let results = pipeline
.execute(vec![Op::WriteFile(
"output.txt".to_string(),
b"hello world".to_vec(),
)])
.await;
assert_eq!(results.len(), 1);
match &results[0] {
OpResult::Written {
path,
bytes_written,
} => {
assert_eq!(path, "output.txt");
assert_eq!(*bytes_written, 11);
}
other => panic!("expected Written, got {:?}", other),
}
}
#[tokio::test]
async fn pipeline_stat() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 1,
volatile: 2,
};
// STAT = compound CREATE + QUERY_INFO(basic) + QUERY_INFO(standard) + CLOSE
let create_resp = build_create_response(file_id, 0);
let basic_info = build_file_basic_info(
132_000_000_000_000_000,
132_100_000_000_000_000,
133_000_000_000_000_000,
133_000_000_000_000_000,
0x20, // ARCHIVE (not a directory)
);
let basic_resp = build_query_info_response(basic_info);
let std_info = build_file_standard_info(
4096, // allocation_size
2048, // end_of_file (actual size)
1, // number_of_links
false, // delete_pending
false, // directory
);
let std_resp = build_query_info_response(std_info);
let close_resp = build_close_response();
let frame = build_compound_response_frame(&[create_resp, basic_resp, std_resp, close_resp]);
mock.queue_response(frame);
let mut conn = setup_connection(&mock);
let tree = test_tree();
let mut pipeline = Pipeline::new(&mut conn, &tree);
let results = pipeline
.execute(vec![Op::Stat("info.txt".to_string())])
.await;
assert_eq!(results.len(), 1);
match &results[0] {
OpResult::Stat { path, info } => {
assert_eq!(path, "info.txt");
assert_eq!(info.size, 2048);
assert!(!info.is_directory);
assert_eq!(info.created, FileTime(132_000_000_000_000_000));
assert_eq!(info.modified, FileTime(133_000_000_000_000_000));
}
other => panic!("expected Stat, got {:?}", other),
}
}
#[tokio::test]
async fn pipeline_error_does_not_abort_batch() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 1,
volatile: 2,
};
// Op 1: ReadFile that fails at CREATE -- compound frame with cascaded errors.
let error_body = ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
let mut h1 = Header::new_request(Command::Create);
h1.flags.set_response();
h1.credits = 32;
h1.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let create_err = pack_message(&h1, &error_body);
let mut h2 = Header::new_request(Command::Read);
h2.flags.set_response();
h2.credits = 32;
h2.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let read_err = pack_message(&h2, &error_body);
let mut h3 = Header::new_request(Command::Close);
h3.flags.set_response();
h3.credits = 32;
h3.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let close_err = pack_message(&h3, &error_body);
mock.queue_response(build_compound_response_frame(&[
create_err, read_err, close_err,
]));
// Op 2: ReadFile that succeeds -- compound frame.
mock.queue_response(build_compound_read_response(file_id, b"abc".to_vec()));
let mut conn = setup_connection(&mock);
let tree = test_tree();
let mut pipeline = Pipeline::new(&mut conn, &tree);
let results = pipeline
.execute(vec![
Op::ReadFile("missing.txt".to_string()),
Op::ReadFile("exists.txt".to_string()),
])
.await;
assert_eq!(results.len(), 2);
match &results[0] {
OpResult::Error { path, .. } => assert_eq!(path, "missing.txt"),
other => panic!("expected Error, got {:?}", other),
}
match &results[1] {
OpResult::FileData { path, data } => {
assert_eq!(path, "exists.txt");
assert_eq!(data, b"abc");
}
other => panic!("expected FileData, got {:?}", other),
}
}
}

769
vendor/smb2/src/client/session.rs vendored Normal file
View File

@@ -0,0 +1,769 @@
//! Authenticated SMB2 session.
//!
//! The [`Session`] type manages the multi-round-trip SESSION_SETUP exchange
//! (NTLM authentication), key derivation, and signing activation.
use log::{debug, info, trace, warn};
use crate::auth::ntlm::{NtlmAuthenticator, NtlmCredentials};
use crate::client::connection::Connection;
use crate::crypto::kdf::derive_session_keys;
use crate::crypto::signing::{algorithm_for_dialect, SigningAlgorithm};
use crate::error::Result;
use crate::msg::session_setup::{SessionSetupRequest, SessionSetupResponse};
use crate::pack::{ReadCursor, Unpack};
use crate::types::flags::{Capabilities, SecurityMode};
use crate::types::status::NtStatus;
use crate::types::{Command, Dialect, SessionId};
use crate::Error;
use crate::msg::session_setup::SessionSetupRequestFlags;
/// An authenticated SMB2 session with derived keys.
#[derive(Debug)]
pub struct Session {
/// The session ID assigned by the server.
pub session_id: SessionId,
/// Key used to sign outgoing messages.
pub signing_key: Vec<u8>,
/// Key used to encrypt outgoing messages (SMB 3.x).
pub encryption_key: Option<Vec<u8>>,
/// Key used to decrypt incoming messages (SMB 3.x).
pub decryption_key: Option<Vec<u8>>,
/// The signing algorithm to use.
pub signing_algorithm: SigningAlgorithm,
/// Whether outgoing messages should be signed.
pub should_sign: bool,
/// Whether outgoing messages should be encrypted.
pub should_encrypt: bool,
}
impl Session {
/// Perform the multi-round-trip SESSION_SETUP exchange.
///
/// Steps:
/// 1. Send NTLM NEGOTIATE_MESSAGE in SESSION_SETUP.
/// 2. Receive STATUS_MORE_PROCESSING_REQUIRED with CHALLENGE_MESSAGE.
/// 3. Update preauth hash with request+response.
/// 4. Send NTLM AUTHENTICATE_MESSAGE in SESSION_SETUP.
/// 5. Receive STATUS_SUCCESS with session flags.
/// 6. Update preauth hash with request+response.
/// 7. Derive signing/encryption keys.
/// 8. Activate signing on the connection.
pub async fn setup(
conn: &mut Connection,
username: &str,
password: &str,
domain: &str,
) -> Result<Session> {
let params = conn
.params()
.ok_or_else(|| Error::invalid_data("negotiate must complete before session setup"))?
.clone();
let mut auth = NtlmAuthenticator::new(NtlmCredentials {
username: username.to_string(),
password: password.to_string(),
domain: domain.to_string(),
});
// Clone the preauth hasher for this session (spec: per-session hash).
let mut session_hasher = conn.preauth_hasher().clone();
// ── Round 1: NEGOTIATE_MESSAGE ──
debug!("session: round 1, sending NTLM negotiate");
let type1_bytes = auth.negotiate();
let req1 = SessionSetupRequest {
flags: SessionSetupRequestFlags(0),
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
capabilities: Capabilities::default(),
channel: 0,
previous_session_id: 0,
security_buffer: type1_bytes,
};
let (frame1, req1_raw) = conn
.execute_capturing_request(Command::SessionSetup, &req1, None)
.await?;
// Update session preauth hash with request.
session_hasher.update(&req1_raw);
let resp1_header = frame1.header;
let resp1_body = frame1.body;
// Update session preauth hash with response.
session_hasher.update(&frame1.raw);
if resp1_header.command != Command::SessionSetup {
return Err(Error::invalid_data(format!(
"expected SessionSetup response, got {:?}",
resp1_header.command
)));
}
if !resp1_header.status.is_more_processing_required() {
if resp1_header.status.is_error() {
return Err(Error::Protocol {
status: resp1_header.status,
command: Command::SessionSetup,
});
}
return Err(Error::invalid_data(
"expected STATUS_MORE_PROCESSING_REQUIRED, got success on first round",
));
}
// The server assigned a session ID -- use it for subsequent requests.
debug!(
"session: round 1 complete, status={:?}, session_id={}",
resp1_header.status, resp1_header.session_id
);
conn.set_session_id(resp1_header.session_id);
// Parse the challenge response.
let mut cursor1 = ReadCursor::new(&resp1_body);
let setup_resp1 = SessionSetupResponse::unpack(&mut cursor1)?;
// ── Round 2: AUTHENTICATE_MESSAGE ──
debug!("session: round 2, sending NTLM authenticate");
let type3_bytes = auth.authenticate(&setup_resp1.security_buffer)?;
let req2 = SessionSetupRequest {
flags: SessionSetupRequestFlags(0),
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
capabilities: Capabilities::default(),
channel: 0,
previous_session_id: 0,
security_buffer: type3_bytes,
};
let (frame2, req2_raw) = conn
.execute_capturing_request(Command::SessionSetup, &req2, None)
.await?;
// Update session preauth hash with the request ONLY.
// The final SESSION_SETUP response (STATUS_SUCCESS) is NOT
// included in the preauth hash (spec section 3.2.5.3.1).
// Only STATUS_MORE_PROCESSING_REQUIRED responses are hashed.
session_hasher.update(&req2_raw);
let resp2_header = frame2.header;
let resp2_body = frame2.body;
// Do NOT hash the success response -- the preauth hash used for
// key derivation contains only messages up to (and including)
// the final authenticate request, not the success response.
if resp2_header.command != Command::SessionSetup {
return Err(Error::invalid_data(format!(
"expected SessionSetup response, got {:?}",
resp2_header.command
)));
}
if resp2_header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: resp2_header.status,
command: Command::SessionSetup,
});
}
// Parse the final response.
let mut cursor2 = ReadCursor::new(&resp2_body);
let setup_resp2 = SessionSetupResponse::unpack(&mut cursor2)?;
let session_id = resp2_header.session_id;
conn.set_session_id(session_id);
// Get the session key from NTLM.
let session_key = auth
.session_key()
.ok_or_else(|| Error::Auth {
message: "NTLM did not produce a session key".to_string(),
})?
.to_vec();
// Determine signing algorithm.
let gmac_negotiated = params.gmac_negotiated;
let signing_algorithm = algorithm_for_dialect(params.dialect, gmac_negotiated);
debug!(
"session: signing_algo={:?}, dialect={}",
signing_algorithm, params.dialect
);
// Derive keys for SMB 3.x, or use session key directly for SMB 2.x.
trace!(
"session: deriving keys, session_key_len={}",
session_key.len()
);
let (signing_key, encryption_key, decryption_key) = match params.dialect {
Dialect::Smb3_0 | Dialect::Smb3_0_2 => {
let keys = derive_session_keys(&session_key, params.dialect, None, 128);
(
keys.signing_key,
Some(keys.encryption_key),
Some(keys.decryption_key),
)
}
Dialect::Smb3_1_1 => {
// Key length: 256 bits only for AES-256 ciphers. GMAC signing
// uses AES-128-GCM internally, so it needs 128-bit (16-byte) keys.
let key_len_bits = match params.cipher {
Some(crate::crypto::encryption::Cipher::Aes256Ccm)
| Some(crate::crypto::encryption::Cipher::Aes256Gcm) => 256,
_ => 128,
};
let keys = derive_session_keys(
&session_key,
Dialect::Smb3_1_1,
Some(session_hasher.value()),
key_len_bits,
);
(
keys.signing_key,
Some(keys.encryption_key),
Some(keys.decryption_key),
)
}
_ => {
// SMB 2.x: use session key directly for signing.
(session_key.clone(), None, None)
}
};
// Determine if we should sign.
let should_sign = params.signing_required
|| !setup_resp2.session_flags.is_guest() && !setup_resp2.session_flags.is_null();
let should_encrypt = setup_resp2.session_flags.encrypt_data();
// Activate signing on the connection.
if should_sign {
conn.activate_signing(signing_key.clone(), signing_algorithm);
}
// Activate encryption on the connection if the session requires it.
// The cipher comes from negotiate contexts (SMB 3.1.1). If the server
// didn't send one (for example, Samba with `smb encrypt = required` sometimes
// omits the encryption context), fall back to AES-128-CCM which is
// universally supported by all SMB 3.x servers.
if should_encrypt {
let cipher = params
.cipher
.unwrap_or(crate::crypto::encryption::Cipher::Aes128Ccm);
if let (Some(ref enc_key), Some(ref dec_key)) = (&encryption_key, &decryption_key) {
conn.activate_encryption(enc_key.clone(), dec_key.clone(), cipher);
} else {
warn!(
"session: encryption requested but missing keys, \
enc_key={}, dec_key={}",
encryption_key.is_some(),
decryption_key.is_some(),
);
}
}
info!(
"session: established, session_id={}, sign={}, encrypt={}",
session_id, should_sign, should_encrypt
);
Ok(Session {
session_id,
signing_key,
encryption_key,
decryption_key,
signing_algorithm,
should_sign,
should_encrypt,
})
}
/// Perform Kerberos-based SESSION_SETUP.
///
/// Authenticates against the KDC first (AS + TGS), then sends the
/// SPNEGO-wrapped AP-REQ in SESSION_SETUP. Handles both single-round
/// (STATUS_SUCCESS) and mutual-auth (STATUS_MORE_PROCESSING_REQUIRED)
/// flows.
///
/// The session key comes from the Kerberos TGS exchange, not from the
/// SMB server response.
/// Perform Kerberos-based SESSION_SETUP using a credential cache.
///
/// Reads cached tickets from the ccache. If a service ticket for
/// `cifs/<server_hostname>` is cached, uses it directly (no KDC needed).
/// If only a TGT is cached, does a TGS exchange for the service ticket.
pub async fn setup_kerberos_from_ccache(
conn: &mut Connection,
credentials: &crate::auth::kerberos::KerberosCredentials,
server_hostname: &str,
ccache: &crate::auth::kerberos::ccache::CCache,
) -> Result<Session> {
let mut auth = crate::auth::kerberos::KerberosAuthenticator::new(credentials.clone());
auth.authenticate_from_ccache(ccache, server_hostname)
.await?;
Self::setup_kerberos_with_auth(conn, &mut auth).await
}
/// Perform Kerberos-based SESSION_SETUP.
///
/// Authenticates against the KDC first (AS + TGS), then sends the
/// SPNEGO-wrapped AP-REQ in SESSION_SETUP. Handles both single-round
/// (STATUS_SUCCESS) and mutual-auth (STATUS_MORE_PROCESSING_REQUIRED)
/// flows.
///
/// The session key comes from the Kerberos TGS exchange, not from the
/// SMB server response.
pub async fn setup_kerberos(
conn: &mut Connection,
credentials: &crate::auth::kerberos::KerberosCredentials,
server_hostname: &str,
) -> Result<Session> {
let mut auth = crate::auth::kerberos::KerberosAuthenticator::new(credentials.clone());
auth.authenticate(server_hostname).await?;
Self::setup_kerberos_with_auth(conn, &mut auth).await
}
/// Shared Kerberos SESSION_SETUP logic used by both password-based
/// and ccache-based authentication paths.
async fn setup_kerberos_with_auth(
conn: &mut Connection,
auth: &mut crate::auth::kerberos::KerberosAuthenticator,
) -> Result<Session> {
let params = conn
.params()
.ok_or_else(|| Error::invalid_data("negotiate must complete before session setup"))?
.clone();
let token = auth
.token()
.ok_or_else(|| Error::Auth {
message: "Kerberos authentication produced no token".to_string(),
})?
.to_vec();
debug!("session: Kerberos auth complete, token_len={}", token.len());
// Clone the preauth hasher for this session.
let mut session_hasher = conn.preauth_hasher().clone();
// Step 2: Send SPNEGO-wrapped AP-REQ in SESSION_SETUP.
let req = SessionSetupRequest {
flags: SessionSetupRequestFlags(0),
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
capabilities: Capabilities::default(),
channel: 0,
previous_session_id: 0,
security_buffer: token,
};
let (frame, req_raw) = conn
.execute_capturing_request(Command::SessionSetup, &req, None)
.await?;
// Hash the request (same as NTLM round 1).
session_hasher.update(&req_raw);
let resp_header = frame.header;
let resp_body = frame.body;
let resp_raw = frame.raw;
if resp_header.command != Command::SessionSetup {
return Err(Error::invalid_data(format!(
"expected SessionSetup response, got {:?}",
resp_header.command
)));
}
if resp_header.status != NtStatus::SUCCESS
&& !resp_header.status.is_more_processing_required()
{
return Err(Error::Protocol {
status: resp_header.status,
command: Command::SessionSetup,
});
}
// The server assigned a session ID.
let session_id = resp_header.session_id;
conn.set_session_id(session_id);
let mut cursor = ReadCursor::new(&resp_body);
let setup_resp = SessionSetupResponse::unpack(&mut cursor)?;
if resp_header.status.is_more_processing_required() {
debug!(
"session: Kerberos got MORE_PROCESSING_REQUIRED, session_id={}",
session_id
);
// Hash the response per MS-SMB2 3.2.5.3.1.
session_hasher.update(&resp_raw);
}
// Process the SPNEGO response token (AP-REP or KRB-ERROR).
// This applies to both STATUS_SUCCESS and MORE_PROCESSING_REQUIRED —
// the server may include an AP-REP with a sub-session key in either.
if !setup_resp.security_buffer.is_empty() {
let spnego_resp =
crate::auth::spnego::parse_neg_token_resp(&setup_resp.security_buffer)?;
debug!(
"session: SPNEGO state={:?}, has_token={}, supported_mech={:02x?}",
spnego_resp.neg_state,
spnego_resp.response_token.is_some(),
spnego_resp.supported_mech.as_deref().unwrap_or(&[]),
);
if let Some(ref token_bytes) = spnego_resp.response_token {
auth.process_mutual_auth_token(token_bytes)?;
}
}
// Get the session key AFTER processing the AP-REP (the server's
// subkey may have overridden ours).
//
// Per MS-SMB2 3.2.5.3: "Session.SessionKey MUST be set to the first
// 16 bytes of the cryptographic key queried from the GSS protocol."
let full_key = auth.session_key().ok_or_else(|| Error::Auth {
message: "Kerberos authentication produced no session key".to_string(),
})?;
let session_key = if full_key.len() > 16 {
full_key[..16].to_vec()
} else {
full_key.to_vec()
};
debug!(
"session: Kerberos session_key_len={} (truncated from {})",
session_key.len(),
full_key.len()
);
// Determine signing algorithm.
let signing_algorithm = algorithm_for_dialect(params.dialect, params.gmac_negotiated);
debug!(
"session: Kerberos signing_algo={:?}, dialect={}",
signing_algorithm, params.dialect
);
// Derive keys for SMB 3.x using the Kerberos session key.
let (signing_key, encryption_key, decryption_key) = match params.dialect {
Dialect::Smb3_0 | Dialect::Smb3_0_2 => {
let keys = derive_session_keys(&session_key, params.dialect, None, 128);
(
keys.signing_key,
Some(keys.encryption_key),
Some(keys.decryption_key),
)
}
Dialect::Smb3_1_1 => {
let key_len_bits = match params.cipher {
Some(crate::crypto::encryption::Cipher::Aes256Ccm)
| Some(crate::crypto::encryption::Cipher::Aes256Gcm) => 256,
_ => 128,
};
let keys = derive_session_keys(
&session_key,
Dialect::Smb3_1_1,
Some(session_hasher.value()),
key_len_bits,
);
(
keys.signing_key,
Some(keys.encryption_key),
Some(keys.decryption_key),
)
}
_ => (session_key.clone(), None, None),
};
let should_sign = params.signing_required
|| !setup_resp.session_flags.is_guest() && !setup_resp.session_flags.is_null();
let should_encrypt = setup_resp.session_flags.encrypt_data();
if should_sign {
conn.activate_signing(signing_key.clone(), signing_algorithm);
}
if should_encrypt {
let cipher = params
.cipher
.unwrap_or(crate::crypto::encryption::Cipher::Aes128Ccm);
if let (Some(ref enc_key), Some(ref dec_key)) = (&encryption_key, &decryption_key) {
conn.activate_encryption(enc_key.clone(), dec_key.clone(), cipher);
}
}
info!(
"session: Kerberos established, session_id={}, sign={}, encrypt={}",
session_id, should_sign, should_encrypt
);
Ok(Session {
session_id,
signing_key,
encryption_key,
decryption_key,
signing_algorithm,
should_sign,
should_encrypt,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::connection::{pack_message, Connection, NegotiatedParams};
use crate::msg::header::Header;
use crate::msg::session_setup::{SessionFlags, SessionSetupResponse};
use crate::pack::Guid;
use crate::transport::MockTransport;
use crate::types::flags::Capabilities;
use crate::types::status::NtStatus;
use crate::types::{Command, Dialect, SessionId};
use std::sync::Arc;
/// Build a session setup response with the given status and session ID.
fn build_session_setup_response(
status: NtStatus,
session_id: SessionId,
security_buffer: Vec<u8>,
session_flags: SessionFlags,
) -> Vec<u8> {
let mut h = Header::new_request(Command::SessionSetup);
h.flags.set_response();
h.credits = 32;
h.status = status;
h.session_id = session_id;
let body = SessionSetupResponse {
session_flags,
security_buffer,
};
pack_message(&h, &body)
}
/// Build a minimal NTLM challenge message (Type 2).
///
/// This is a stripped-down challenge that the NtlmAuthenticator can parse.
fn build_ntlm_challenge() -> Vec<u8> {
let mut buf = Vec::new();
// Signature (8 bytes)
buf.extend_from_slice(b"NTLMSSP\0");
// MessageType = 2 (4 bytes)
buf.extend_from_slice(&2u32.to_le_bytes());
// TargetNameFields: Len=0, MaxLen=0, Offset=56
buf.extend_from_slice(&0u16.to_le_bytes()); // Len
buf.extend_from_slice(&0u16.to_le_bytes()); // MaxLen
buf.extend_from_slice(&56u32.to_le_bytes()); // Offset
// NegotiateFlags
let flags: u32 = 0x0000_0001 // UNICODE
| 0x0000_0200 // NTLM
| 0x0008_0000 // EXTENDED_SESSIONSECURITY
| 0x0080_0000 // TARGET_INFO
| 0x2000_0000 // 128
| 0x4000_0000 // KEY_EXCH
| 0x8000_0000 // 56
| 0x0000_0010 // SIGN
| 0x0000_0020; // SEAL
buf.extend_from_slice(&flags.to_le_bytes());
// ServerChallenge (8 bytes)
buf.extend_from_slice(&[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF]);
// Reserved (8 bytes)
buf.extend_from_slice(&[0u8; 8]);
// TargetInfoFields: Len, MaxLen, Offset (will be at offset 56 + target_name_len)
// Build target info: just MsvAvEOL
let target_info = build_av_eol();
let ti_offset = 56u32; // right after the fixed header
buf.extend_from_slice(&(target_info.len() as u16).to_le_bytes()); // Len
buf.extend_from_slice(&(target_info.len() as u16).to_le_bytes()); // MaxLen
buf.extend_from_slice(&ti_offset.to_le_bytes()); // Offset
// Ensure we're at offset 56 (pad if needed).
while buf.len() < 56 {
buf.push(0);
}
// Target info data
buf.extend_from_slice(&target_info);
buf
}
/// Build an AV_PAIR list with just MsvAvEOL.
fn build_av_eol() -> Vec<u8> {
let mut buf = Vec::new();
// MsvAvEOL: AvId=0, AvLen=0
buf.extend_from_slice(&0u16.to_le_bytes());
buf.extend_from_slice(&0u16.to_le_bytes());
buf
}
#[tokio::test]
async fn session_setup_stores_session_id() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let session_id = SessionId(0xDEAD_BEEF);
// Queue the two session setup responses.
let challenge = build_ntlm_challenge();
mock.queue_response(build_session_setup_response(
NtStatus::MORE_PROCESSING_REQUIRED,
session_id,
challenge,
SessionFlags(0),
));
mock.queue_response(build_session_setup_response(
NtStatus::SUCCESS,
session_id,
vec![],
SessionFlags(0),
));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
// Set up negotiate params (pretend we already negotiated).
// We need to call negotiate or set params manually.
// Let's also queue a negotiate response first.
// Actually, let's set params directly.
set_test_params(&mut conn, Dialect::Smb2_0_2);
let session = Session::setup(&mut conn, "user", "pass", "").await.unwrap();
assert_eq!(session.session_id, session_id);
}
#[tokio::test]
async fn session_setup_derives_signing_key() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let session_id = SessionId(0x1234);
let challenge = build_ntlm_challenge();
mock.queue_response(build_session_setup_response(
NtStatus::MORE_PROCESSING_REQUIRED,
session_id,
challenge,
SessionFlags(0),
));
mock.queue_response(build_session_setup_response(
NtStatus::SUCCESS,
session_id,
vec![],
SessionFlags(0),
));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
set_test_params(&mut conn, Dialect::Smb2_0_2);
let session = Session::setup(&mut conn, "user", "pass", "").await.unwrap();
assert!(!session.signing_key.is_empty());
}
#[tokio::test]
async fn session_setup_activates_signing() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let session_id = SessionId(0x5678);
let challenge = build_ntlm_challenge();
mock.queue_response(build_session_setup_response(
NtStatus::MORE_PROCESSING_REQUIRED,
session_id,
challenge,
SessionFlags(0),
));
mock.queue_response(build_session_setup_response(
NtStatus::SUCCESS,
session_id,
vec![],
SessionFlags(0),
));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
set_test_params(&mut conn, Dialect::Smb2_0_2);
let session = Session::setup(&mut conn, "user", "pass", "").await.unwrap();
assert!(session.should_sign);
assert_eq!(session.signing_algorithm, SigningAlgorithm::HmacSha256);
}
#[tokio::test]
async fn session_setup_error_on_auth_failure() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let session_id = SessionId(0x9999);
let challenge = build_ntlm_challenge();
mock.queue_response(build_session_setup_response(
NtStatus::MORE_PROCESSING_REQUIRED,
session_id,
challenge,
SessionFlags(0),
));
// Auth fails on second round.
mock.queue_response(build_session_setup_response(
NtStatus::LOGON_FAILURE,
session_id,
vec![],
SessionFlags(0),
));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
set_test_params(&mut conn, Dialect::Smb2_0_2);
let result = Session::setup(&mut conn, "user", "badpass", "").await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(
err,
Error::Protocol {
status: NtStatus::LOGON_FAILURE,
..
}
),
"expected LOGON_FAILURE, got: {err}"
);
}
/// Helper: set fake negotiated params on a connection.
fn set_test_params(conn: &mut Connection, dialect: Dialect) {
conn.set_test_params(NegotiatedParams {
dialect,
max_read_size: 65536,
max_write_size: 65536,
max_transact_size: 65536,
server_guid: Guid::ZERO,
signing_required: false,
capabilities: Capabilities::default(),
gmac_negotiated: false,
cipher: None,
compression_supported: false,
});
}
}

764
vendor/smb2/src/client/shares.rs vendored Normal file
View File

@@ -0,0 +1,764 @@
//! Share enumeration via IPC$ + srvsvc RPC.
//!
//! Lists available shares on an SMB server by connecting to the IPC$ share,
//! opening the srvsvc named pipe, and performing the NetShareEnumAll RPC
//! exchange.
use log::{debug, info};
use crate::client::connection::Connection;
use crate::error::Result;
use crate::msg::close::CloseRequest;
use crate::msg::create::{
CreateDisposition, CreateRequest, CreateResponse, ImpersonationLevel, ShareAccess,
};
use crate::msg::read::{ReadRequest, ReadResponse, SMB2_CHANNEL_NONE};
use crate::msg::tree_connect::{TreeConnectRequest, TreeConnectRequestFlags, TreeConnectResponse};
use crate::msg::tree_disconnect::TreeDisconnectRequest;
use crate::msg::write::{WriteRequest, WriteResponse};
use crate::pack::{ReadCursor, Unpack};
use crate::rpc;
use crate::rpc::srvsvc::{self, ShareInfo};
use crate::types::flags::FileAccessMask;
use crate::types::status::NtStatus;
use crate::types::{Command, FileId, OplockLevel, TreeId};
use crate::Error;
/// Read buffer size for pipe reads (64 KiB is plenty for share listings).
const PIPE_READ_BUFFER_SIZE: u32 = 65536;
/// List available shares on the server.
///
/// Connects to the IPC$ share, opens the srvsvc named pipe, performs
/// the RPC exchange, and returns filtered disk shares.
///
/// This is a self-contained operation -- it opens and closes its own
/// tree connection to IPC$.
pub async fn list_shares(conn: &mut Connection) -> Result<Vec<ShareInfo>> {
// 1. Tree connect to IPC$
let tree_id = tree_connect_ipc(conn).await?;
// Run the pipe operations, then clean up regardless of outcome
let result = pipe_rpc_exchange(conn, tree_id).await;
// 8. Tree disconnect (best-effort -- don't mask the real error)
let _ = tree_disconnect(conn, tree_id).await;
let all_shares = result?;
// 9. Filter to disk shares
let filtered = srvsvc::filter_disk_shares(all_shares);
info!("shares: found {} disk shares", filtered.len());
Ok(filtered)
}
/// Connect to the IPC$ share, returning the tree ID.
async fn tree_connect_ipc(conn: &mut Connection) -> Result<TreeId> {
let server = conn.server_name().to_string();
let unc_path = format!(r"\\{}\IPC$", server);
let req = TreeConnectRequest {
flags: TreeConnectRequestFlags::default(),
path: unc_path,
};
let frame = conn.execute(Command::TreeConnect, &req, None).await?;
if frame.header.command != Command::TreeConnect {
return Err(Error::invalid_data(format!(
"expected TreeConnect response, got {:?}",
frame.header.command
)));
}
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::TreeConnect,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let _resp = TreeConnectResponse::unpack(&mut cursor)?;
let tree_id = frame
.header
.tree_id
.ok_or_else(|| Error::invalid_data("TreeConnect response missing tree ID"))?;
info!("shares: connected to IPC$, tree_id={}", tree_id);
Ok(tree_id)
}
/// Open the srvsvc pipe, perform the RPC bind and request, then close.
async fn pipe_rpc_exchange(conn: &mut Connection, tree_id: TreeId) -> Result<Vec<ShareInfo>> {
// 2. Create \pipe\srvsvc
let file_id = open_srvsvc_pipe(conn, tree_id).await?;
// Run RPC exchange, then close regardless of outcome
let result = rpc_bind_and_request(conn, tree_id, file_id).await;
// 7. Close the pipe handle (best-effort)
let _ = close_handle(conn, tree_id, file_id).await;
result
}
/// Perform the RPC bind + NetShareEnumAll request over the pipe.
async fn rpc_bind_and_request(
conn: &mut Connection,
tree_id: TreeId,
file_id: FileId,
) -> Result<Vec<ShareInfo>> {
// 3. Write RPC BIND
let bind_data = rpc::build_srvsvc_bind(1);
write_pipe(conn, tree_id, file_id, &bind_data).await?;
debug!("shares: sent RPC BIND ({} bytes)", bind_data.len());
// 4. Read RPC BIND_ACK
let bind_ack_data = read_pipe_message(conn, tree_id, file_id).await?;
rpc::parse_bind_ack(&bind_ack_data)?;
debug!("shares: received BIND_ACK, context accepted");
// 5. Write RPC REQUEST (NetShareEnumAll)
let server_name = format!(r"\\{}", conn.server_name());
let request_data = srvsvc::build_net_share_enum_all(2, &server_name);
write_pipe(conn, tree_id, file_id, &request_data).await?;
debug!(
"shares: sent NetShareEnumAll request ({} bytes)",
request_data.len()
);
// 6. Read RPC RESPONSE, reassembling DCE/RPC fragments (MS-RPCE 2.2.2.6).
// A large NetShareEnum reply may arrive as several fragment PDUs, each its
// own pipe message, with PFC_LAST_FRAG set only on the last.
let mut stub = Vec::new();
let mut fragments = 0;
loop {
let pdu = read_pipe_message(conn, tree_id, file_id).await?;
let (frag_stub, is_last) = rpc::parse_response_fragment(&pdu)?;
stub.extend_from_slice(frag_stub);
fragments += 1;
if is_last {
break;
}
}
let shares = srvsvc::parse_net_share_enum_all_stub(&stub)?;
debug!(
"shares: received {} shares in response ({} RPC fragment(s))",
shares.len(),
fragments
);
Ok(shares)
}
/// Open the `\pipe\srvsvc` named pipe via CREATE.
async fn open_srvsvc_pipe(conn: &mut Connection, tree_id: TreeId) -> Result<FileId> {
let req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::FILE_READ_DATA | FileAccessMask::FILE_WRITE_DATA,
),
file_attributes: 0,
share_access: ShareAccess(ShareAccess::FILE_SHARE_READ | ShareAccess::FILE_SHARE_WRITE),
create_disposition: CreateDisposition::FileOpen,
create_options: 0,
name: r"srvsvc".to_string(),
create_contexts: vec![],
};
let frame = conn.execute(Command::Create, &req, Some(tree_id)).await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Create,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = CreateResponse::unpack(&mut cursor)?;
debug!("shares: opened srvsvc pipe, file_id={:?}", resp.file_id);
Ok(resp.file_id)
}
/// Write data to the pipe.
async fn write_pipe(
conn: &mut Connection,
tree_id: TreeId,
file_id: FileId,
data: &[u8],
) -> Result<()> {
// DataOffset: header (64) + fixed write body (48) = 112 = 0x70
let req = WriteRequest {
data_offset: 0x70,
offset: 0,
file_id,
channel: 0,
remaining_bytes: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
flags: 0,
data: data.to_vec(),
};
let frame = conn.execute(Command::Write, &req, Some(tree_id)).await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Write,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = WriteResponse::unpack(&mut cursor)?;
debug!("shares: wrote {} bytes to pipe", resp.count);
Ok(())
}
/// Read one complete pipe message, following `STATUS_BUFFER_OVERFLOW`.
///
/// A pipe message larger than our read buffer comes back as one or more
/// `STATUS_BUFFER_OVERFLOW` reads carrying partial data, terminated by a
/// `STATUS_SUCCESS` read with the remainder (MS-SMB2 3.3.5.10). We append each
/// chunk until a `SUCCESS` read completes the message.
async fn read_pipe_message(
conn: &mut Connection,
tree_id: TreeId,
file_id: FileId,
) -> Result<Vec<u8>> {
let mut message = Vec::new();
loop {
let req = ReadRequest {
padding: 0x50,
flags: 0,
length: PIPE_READ_BUFFER_SIZE,
offset: 0,
file_id,
minimum_count: 0,
channel: SMB2_CHANNEL_NONE,
remaining_bytes: 0,
read_channel_info: vec![],
};
let frame = conn.execute(Command::Read, &req, Some(tree_id)).await?;
let status = frame.header.status;
// BUFFER_OVERFLOW is a warning meaning "partial data, read again", not a
// failure -- accept it alongside SUCCESS.
if !status.is_success_or_partial() {
return Err(Error::Protocol {
status,
command: Command::Read,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = ReadResponse::unpack(&mut cursor)?;
let chunk_len = resp.data.len();
message.extend_from_slice(&resp.data);
// SUCCESS completes the message; BUFFER_OVERFLOW means read more.
if status != NtStatus::BUFFER_OVERFLOW {
break;
}
// Guard against a server that signals overflow but sends no data, which
// would otherwise spin forever.
if chunk_len == 0 {
return Err(Error::invalid_data(
"pipe read returned BUFFER_OVERFLOW with no data",
));
}
}
debug!("shares: read {} bytes from pipe", message.len());
Ok(message)
}
/// Close a file handle.
async fn close_handle(conn: &mut Connection, tree_id: TreeId, file_id: FileId) -> Result<()> {
let req = CloseRequest { flags: 0, file_id };
let frame = conn.execute(Command::Close, &req, Some(tree_id)).await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Close,
});
}
Ok(())
}
/// Disconnect from a tree.
async fn tree_disconnect(conn: &mut Connection, tree_id: TreeId) -> Result<()> {
let body = TreeDisconnectRequest;
let frame = conn
.execute(Command::TreeDisconnect, &body, Some(tree_id))
.await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::TreeDisconnect,
});
}
info!("shares: disconnected from IPC$");
Ok(())
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::client::connection::{pack_message, NegotiatedParams};
use crate::client::test_helpers::{
build_close_response, build_create_response, build_tree_connect_response, setup_connection,
};
use crate::msg::header::Header;
use crate::msg::read::ReadResponse as ReadResp;
use crate::msg::tree_connect::ShareType;
use crate::msg::tree_disconnect::TreeDisconnectResponse;
use crate::msg::write::WriteResponse as WriteResp;
use crate::pack::Guid;
use crate::rpc::srvsvc::{STYPE_DISKTREE, STYPE_IPC, STYPE_SPECIAL};
use crate::transport::MockTransport;
use crate::types::flags::Capabilities;
use crate::types::{Dialect, SessionId, TreeId};
use std::sync::Arc;
fn build_write_response(count: u32) -> Vec<u8> {
let mut h = Header::new_request(Command::Write);
h.flags.set_response();
h.credits = 32;
let body = WriteResp {
count,
remaining: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
};
pack_message(&h, &body)
}
fn build_read_response(data: Vec<u8>) -> Vec<u8> {
build_read_response_with_status(data, NtStatus::SUCCESS)
}
/// Build a READ response with an explicit NTSTATUS.
///
/// Pipe reads use `STATUS_BUFFER_OVERFLOW` to mean "this read returned a
/// partial message; read again for the rest."
fn build_read_response_with_status(data: Vec<u8>, status: NtStatus) -> Vec<u8> {
let mut h = Header::new_request(Command::Read);
h.flags.set_response();
h.credits = 32;
h.status = status;
let body = ReadResp {
data_offset: 0x50,
data_remaining: 0,
flags: 0,
data,
};
pack_message(&h, &body)
}
fn build_tree_disconnect_response() -> Vec<u8> {
let mut h = Header::new_request(Command::TreeDisconnect);
h.flags.set_response();
h.credits = 32;
pack_message(&h, &TreeDisconnectResponse)
}
/// Build a canned RPC BIND_ACK response.
fn build_bind_ack() -> Vec<u8> {
use crate::pack::WriteCursor;
let mut w = WriteCursor::with_capacity(64);
// Common header
w.write_u8(5); // version
w.write_u8(0); // version minor
w.write_u8(12); // BIND_ACK type
w.write_u8(0x03); // flags (first + last)
w.write_bytes(&[0x10, 0x00, 0x00, 0x00]); // data rep
let frag_len_pos = w.position();
w.write_u16_le(0); // frag length placeholder
w.write_u16_le(0); // auth length
w.write_u32_le(1); // call id
// BIND_ACK specific
w.write_u16_le(4280); // max xmit frag
w.write_u16_le(4280); // max recv frag
w.write_u32_le(0x12345); // assoc group
// Secondary address (empty)
w.write_u16_le(0);
w.write_bytes(&[0, 0]); // padding
// Result list
w.write_u8(1); // num results
w.write_bytes(&[0, 0, 0]); // reserved
w.write_u16_le(0); // result = accepted
w.write_u16_le(0); // reason
// Transfer syntax UUID + version (20 bytes)
use crate::pack::Pack;
let ndr_uuid = Guid {
data1: 0x8A885D04,
data2: 0x1CEB,
data3: 0x11C9,
data4: [0x9F, 0xE8, 0x08, 0x00, 0x2B, 0x10, 0x48, 0x60],
};
ndr_uuid.pack(&mut w);
w.write_u32_le(2);
let total_len = w.position();
w.set_u16_le_at(frag_len_pos, total_len as u16);
w.into_inner()
}
/// Build the NDR stub for a NetShareEnumAll RESPONSE (no RPC envelope).
fn build_share_enum_stub(shares: &[(&str, u32, &str)]) -> Vec<u8> {
use crate::pack::WriteCursor;
// Build NDR stub
let mut w = WriteCursor::with_capacity(512);
let count = shares.len() as u32;
// Level = 1
w.write_u32_le(1);
// Union discriminant = 1
w.write_u32_le(1);
if count == 0 {
w.write_u32_le(0); // null container
w.write_u32_le(0); // total entries
w.write_u32_le(0); // resume handle
w.write_u32_le(0); // return value
} else {
// Container pointer
w.write_u32_le(0x0002_0000);
// EntriesRead
w.write_u32_le(count);
// Array pointer
w.write_u32_le(0x0002_0004);
// MaxCount
w.write_u32_le(count);
// Fixed entries
for (i, &(_, share_type, _)) in shares.iter().enumerate() {
w.write_u32_le(0x0002_0008 + (i as u32) * 2); // name ref
w.write_u32_le(share_type);
w.write_u32_le(0x0002_0108 + (i as u32) * 2); // comment ref
}
// Deferred strings
for &(name, _, comment) in shares {
write_ndr_string(&mut w, name);
write_ndr_string(&mut w, comment);
}
w.write_u32_le(count); // total entries
w.write_u32_le(0); // resume handle
w.write_u32_le(0); // return value
}
w.into_inner()
}
/// Wrap NDR stub bytes in an RPC RESPONSE PDU with the given PFC flags.
///
/// `pfc_flags` lets a caller emit a fragment (for example, `PFC_FIRST_FRAG`
/// alone for a non-final fragment) instead of the usual `FIRST | LAST`.
fn wrap_rpc_response_pdu(stub_chunk: &[u8], pfc_flags: u8) -> Vec<u8> {
use crate::pack::WriteCursor;
let mut w = WriteCursor::with_capacity(24 + stub_chunk.len());
w.write_u8(5);
w.write_u8(0);
w.write_u8(2); // RESPONSE
w.write_u8(pfc_flags);
w.write_bytes(&[0x10, 0x00, 0x00, 0x00]);
let frag_len_pos = w.position();
w.write_u16_le(0);
w.write_u16_le(0);
w.write_u32_le(2); // call id
w.write_u32_le(stub_chunk.len() as u32); // alloc hint
w.write_u16_le(0); // context id
w.write_u8(0); // cancel count
w.write_u8(0); // reserved
w.write_bytes(stub_chunk);
let total_len = w.position();
w.set_u16_le_at(frag_len_pos, total_len as u16);
w.into_inner()
}
/// Build a canned single-fragment RPC RESPONSE with NetShareEnumAll data.
fn build_share_enum_response(shares: &[(&str, u32, &str)]) -> Vec<u8> {
// 0x03 = PFC_FIRST_FRAG | PFC_LAST_FRAG (a complete, single-fragment PDU).
wrap_rpc_response_pdu(&build_share_enum_stub(shares), 0x03)
}
fn write_ndr_string(w: &mut crate::pack::WriteCursor, s: &str) {
let utf16: Vec<u16> = s.encode_utf16().chain(std::iter::once(0)).collect();
let char_count = utf16.len() as u32;
w.write_u32_le(char_count);
w.write_u32_le(0);
w.write_u32_le(char_count);
for &code_unit in &utf16 {
w.write_u16_le(code_unit);
}
w.align_to(4);
}
/// Queue all the responses needed for a full list_shares flow.
pub(crate) fn queue_share_listing_responses(
mock: &MockTransport,
shares: &[(&str, u32, &str)],
) {
let tree_id = TreeId(42);
let file_id = FileId {
persistent: 0xAAAA,
volatile: 0xBBBB,
};
// 1. TREE_CONNECT response
mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe));
// 2. CREATE response (open srvsvc pipe)
mock.queue_response(build_create_response(file_id, 0));
// 3. WRITE response (RPC BIND)
mock.queue_response(build_write_response(72));
// 4. READ response (BIND_ACK)
mock.queue_response(build_read_response(build_bind_ack()));
// 5. WRITE response (NetShareEnumAll request)
mock.queue_response(build_write_response(100));
// 6. READ response (NetShareEnumAll response)
mock.queue_response(build_read_response(build_share_enum_response(shares)));
// 7. CLOSE response
mock.queue_response(build_close_response());
// 8. TREE_DISCONNECT response
mock.queue_response(build_tree_disconnect_response());
}
/// Like `queue_share_listing_responses`, but the server splits a single
/// RPC RESPONSE PDU across two pipe reads: the first read returns
/// `STATUS_BUFFER_OVERFLOW` with the leading bytes, the second returns
/// `SUCCESS` with the rest. The client must stitch them before parsing.
fn queue_overflow_share_listing_responses(mock: &MockTransport, shares: &[(&str, u32, &str)]) {
let tree_id = TreeId(42);
let file_id = FileId {
persistent: 0xAAAA,
volatile: 0xBBBB,
};
let pdu = build_share_enum_response(shares);
let split = pdu.len() / 2;
let (first, rest) = pdu.split_at(split);
mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe));
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(72));
mock.queue_response(build_read_response(build_bind_ack()));
mock.queue_response(build_write_response(100));
// The response PDU arrives in two chunks: overflow then success.
mock.queue_response(build_read_response_with_status(
first.to_vec(),
NtStatus::BUFFER_OVERFLOW,
));
mock.queue_response(build_read_response_with_status(
rest.to_vec(),
NtStatus::SUCCESS,
));
mock.queue_response(build_close_response());
mock.queue_response(build_tree_disconnect_response());
}
/// Like `queue_share_listing_responses`, but the RPC RESPONSE is split into
/// two DCE/RPC fragments (each its own pipe message): the first carries
/// `PFC_FIRST_FRAG`, the second `PFC_LAST_FRAG`. The client must reassemble
/// the stub across fragments before parsing.
fn queue_fragmented_share_listing_responses(
mock: &MockTransport,
shares: &[(&str, u32, &str)],
) {
let tree_id = TreeId(42);
let file_id = FileId {
persistent: 0xAAAA,
volatile: 0xBBBB,
};
let stub = build_share_enum_stub(shares);
let split = stub.len() / 2;
let (first, rest) = stub.split_at(split);
let frag1 = wrap_rpc_response_pdu(first, 0x01); // PFC_FIRST_FRAG only
let frag2 = wrap_rpc_response_pdu(rest, 0x02); // PFC_LAST_FRAG only
mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe));
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(72));
mock.queue_response(build_read_response(build_bind_ack()));
mock.queue_response(build_write_response(100));
mock.queue_response(build_read_response(frag1));
mock.queue_response(build_read_response(frag2));
mock.queue_response(build_close_response());
mock.queue_response(build_tree_disconnect_response());
}
#[tokio::test]
async fn list_shares_reassembles_buffer_overflow_reads() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
queue_overflow_share_listing_responses(
&mock,
&[
("Documents", STYPE_DISKTREE, "Shared docs"),
("Photos", STYPE_DISKTREE, "Family photos"),
],
);
let shares = list_shares(&mut conn).await.unwrap();
assert_eq!(shares.len(), 2);
assert_eq!(shares[0].name, "Documents");
assert_eq!(shares[1].name, "Photos");
}
#[tokio::test]
async fn list_shares_reassembles_rpc_fragments() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
queue_fragmented_share_listing_responses(
&mock,
&[
("Documents", STYPE_DISKTREE, "Shared docs"),
("Photos", STYPE_DISKTREE, "Family photos"),
],
);
let shares = list_shares(&mut conn).await.unwrap();
assert_eq!(shares.len(), 2);
assert_eq!(shares[0].name, "Documents");
assert_eq!(shares[1].name, "Photos");
}
#[tokio::test]
async fn list_shares_returns_disk_shares() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
queue_share_listing_responses(
&mock,
&[
("Documents", STYPE_DISKTREE, "Shared docs"),
("IPC$", STYPE_IPC | STYPE_SPECIAL, "Remote IPC"),
("C$", STYPE_DISKTREE | STYPE_SPECIAL, "Default share"),
("Photos", STYPE_DISKTREE, "Family photos"),
],
);
let shares = list_shares(&mut conn).await.unwrap();
// Only disk shares without $ suffix and without STYPE_SPECIAL
assert_eq!(shares.len(), 2);
assert_eq!(shares[0].name, "Documents");
assert_eq!(shares[0].comment, "Shared docs");
assert_eq!(shares[1].name, "Photos");
assert_eq!(shares[1].comment, "Family photos");
}
#[tokio::test]
async fn list_shares_sends_correct_number_of_messages() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
queue_share_listing_responses(&mock, &[("TestShare", STYPE_DISKTREE, "A test share")]);
let _shares = list_shares(&mut conn).await.unwrap();
// Should have sent 8 messages:
// TREE_CONNECT, CREATE, WRITE(bind), READ(bind_ack),
// WRITE(request), READ(response), CLOSE, TREE_DISCONNECT
assert_eq!(mock.sent_count(), 8);
}
#[tokio::test]
async fn list_shares_empty_server() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
queue_share_listing_responses(&mock, &[]);
let shares = list_shares(&mut conn).await.unwrap();
assert!(shares.is_empty());
}
#[tokio::test]
async fn list_shares_filters_non_disk_shares() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
// All non-disk or special shares
queue_share_listing_responses(
&mock,
&[
("IPC$", STYPE_IPC | STYPE_SPECIAL, "Remote IPC"),
("ADMIN$", STYPE_DISKTREE | STYPE_SPECIAL, "Remote Admin"),
],
);
let shares = list_shares(&mut conn).await.unwrap();
assert!(shares.is_empty());
}
#[tokio::test]
async fn list_shares_uses_correct_server_name() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut conn =
Connection::from_transport(Box::new(mock.clone()), Box::new(mock.clone()), "my-nas");
conn.set_test_params(NegotiatedParams {
dialect: Dialect::Smb2_0_2,
max_read_size: 65536,
max_write_size: 65536,
max_transact_size: 65536,
server_guid: Guid::ZERO,
signing_required: false,
capabilities: Capabilities::default(),
gmac_negotiated: false,
cipher: None,
compression_supported: false,
});
conn.set_session_id(SessionId(0x1234));
queue_share_listing_responses(&mock, &[("share1", STYPE_DISKTREE, "")]);
let shares = list_shares(&mut conn).await.unwrap();
assert_eq!(shares.len(), 1);
// Verify the TREE_CONNECT request contains \\my-nas\IPC$
let sent = mock.sent_messages();
let tree_connect_bytes = &sent[0];
// The UNC path is UTF-16LE in the request body
let unc_utf8 = String::from_utf8_lossy(tree_connect_bytes);
// Verify the server name appears somewhere in the raw bytes
assert!(
tree_connect_bytes.windows(2).any(|w| w == b"m\0"), // 'm' in UTF-16LE from "my-nas"
"TREE_CONNECT should reference the server name"
);
drop(unc_utf8);
}
}

1499
vendor/smb2/src/client/stream.rs vendored Normal file

File diff suppressed because it is too large Load Diff

182
vendor/smb2/src/client/test_helpers.rs vendored Normal file
View File

@@ -0,0 +1,182 @@
//! Shared test helper functions for `client` module tests.
//!
//! These build mock SMB2 responses used across pipeline, shares, and tree tests.
use std::sync::Arc;
use crate::client::connection::{pack_message, Connection, NegotiatedParams};
use crate::msg::close::CloseResponse;
use crate::msg::create::{CreateAction, CreateResponse};
use crate::msg::header::Header;
use crate::msg::tree_connect::{ShareType, TreeConnectResponse};
use crate::pack::{FileTime, Guid};
use crate::transport::MockTransport;
use crate::types::flags::{Capabilities, ShareCapabilities, ShareFlags};
use crate::types::{Command, Dialect, FileId, OplockLevel, SessionId, TreeId};
/// Create a mock-backed connection with standard negotiated params.
///
/// Enables the mock's auto-msg_id-rewrite so canned `build_*_response`
/// helpers (which hardcode `MessageId(0)` and don't know the caller's
/// allocated msg_ids) still route through the Phase 3 receiver task: on
/// each `receive()` the mock patches sub-frame msg_ids to match the next
/// pending sent msg_id in FIFO order. Replaces the pre-Phase-3
/// `set_orphan_filter_enabled(false)` path.
pub(crate) fn setup_connection(mock: &Arc<MockTransport>) -> Connection {
mock.enable_auto_rewrite_msg_id();
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.set_test_params(NegotiatedParams {
dialect: Dialect::Smb2_0_2,
max_read_size: 65536,
max_write_size: 65536,
max_transact_size: 65536,
server_guid: Guid::ZERO,
signing_required: false,
capabilities: Capabilities::default(),
gmac_negotiated: false,
cipher: None,
compression_supported: false,
});
conn.set_session_id(SessionId(0x1234));
conn
}
/// Build a CREATE response with the given file ID and end-of-file size.
pub(crate) fn build_create_response(file_id: FileId, end_of_file: u64) -> Vec<u8> {
let mut h = Header::new_request(Command::Create);
h.flags.set_response();
h.credits = 32;
let body = CreateResponse {
oplock_level: OplockLevel::None,
flags: 0,
create_action: CreateAction::FileOpened,
creation_time: FileTime::ZERO,
last_access_time: FileTime::ZERO,
last_write_time: FileTime::ZERO,
change_time: FileTime::ZERO,
allocation_size: 0,
end_of_file,
file_attributes: 0,
file_id,
create_contexts: vec![],
};
pack_message(&h, &body)
}
/// Build a CREATE response with a non-success status (for error tests).
pub(crate) fn build_create_error_response(status: crate::types::status::NtStatus) -> Vec<u8> {
use crate::msg::header::ErrorResponse;
let mut h = Header::new_request(Command::Create);
h.flags.set_response();
h.credits = 32;
h.status = status;
let body = ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
pack_message(&h, &body)
}
/// Build a CLOSE response with zeroed fields.
pub(crate) fn build_close_response() -> Vec<u8> {
let mut h = Header::new_request(Command::Close);
h.flags.set_response();
h.credits = 32;
let body = CloseResponse {
flags: 0,
creation_time: FileTime::ZERO,
last_access_time: FileTime::ZERO,
last_write_time: FileTime::ZERO,
change_time: FileTime::ZERO,
allocation_size: 0,
end_of_file: 0,
file_attributes: 0,
};
pack_message(&h, &body)
}
/// Build a WRITE response with the given byte count.
pub(crate) fn build_write_response(count: u32) -> Vec<u8> {
use crate::msg::write::WriteResponse;
let mut h = Header::new_request(Command::Write);
h.flags.set_response();
h.credits = 32;
let body = WriteResponse {
count,
remaining: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
};
pack_message(&h, &body)
}
/// Build a WRITE response with a non-success status (for error tests).
pub(crate) fn build_write_error_response(status: crate::types::status::NtStatus) -> Vec<u8> {
use crate::msg::header::ErrorResponse;
let mut h = Header::new_request(Command::Write);
h.flags.set_response();
h.credits = 32;
h.status = status;
let body = ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
pack_message(&h, &body)
}
/// Build a CLOSE response with a non-success status (for error tests).
pub(crate) fn build_close_error_response(status: crate::types::status::NtStatus) -> Vec<u8> {
use crate::msg::header::ErrorResponse;
let mut h = Header::new_request(Command::Close);
h.flags.set_response();
h.credits = 32;
h.status = status;
let body = ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
pack_message(&h, &body)
}
/// Build a FLUSH response.
pub(crate) fn build_flush_response() -> Vec<u8> {
let mut h = Header::new_request(Command::Flush);
h.flags.set_response();
h.credits = 32;
let body = crate::msg::flush::FlushResponse;
pack_message(&h, &body)
}
/// Build a TREE_CONNECT response with the given tree ID and share type.
pub(crate) fn build_tree_connect_response(tree_id: TreeId, share_type: ShareType) -> Vec<u8> {
let mut h = Header::new_request(Command::TreeConnect);
h.flags.set_response();
h.credits = 32;
h.tree_id = Some(tree_id);
let body = TreeConnectResponse {
share_type,
share_flags: ShareFlags::default(),
capabilities: ShareCapabilities::default(),
maximal_access: 0x001F_01FF,
};
pack_message(&h, &body)
}

6691
vendor/smb2/src/client/tree.rs vendored Normal file

File diff suppressed because it is too large Load Diff

780
vendor/smb2/src/client/watcher.rs vendored Normal file
View File

@@ -0,0 +1,780 @@
//! Directory change notification via SMB2 CHANGE_NOTIFY.
//!
//! The [`Watcher`] type registers for change notifications on a directory
//! and returns [`FileNotifyEvent`] entries describing changes as they happen.
//! The server holds the request until a change occurs, making this a long-poll
//! operation.
use log::debug;
use crate::client::connection::{await_frame, Connection, Frame};
use crate::client::tree::Tree;
use crate::error::Result;
use crate::msg::change_notify::{
ChangeNotifyRequest, ChangeNotifyResponse, FILE_NOTIFY_CHANGE_ATTRIBUTES,
FILE_NOTIFY_CHANGE_CREATION, FILE_NOTIFY_CHANGE_DIR_NAME, FILE_NOTIFY_CHANGE_FILE_NAME,
FILE_NOTIFY_CHANGE_LAST_WRITE, FILE_NOTIFY_CHANGE_SIZE, SMB2_WATCH_TREE,
};
use crate::pack::{ReadCursor, Unpack};
use crate::types::status::NtStatus;
use crate::types::{Command, FileId};
use crate::Error;
use tokio::sync::oneshot;
/// Default completion filter: watch for most common changes.
const DEFAULT_COMPLETION_FILTER: u32 = FILE_NOTIFY_CHANGE_FILE_NAME
| FILE_NOTIFY_CHANGE_DIR_NAME
| FILE_NOTIFY_CHANGE_ATTRIBUTES
| FILE_NOTIFY_CHANGE_SIZE
| FILE_NOTIFY_CHANGE_LAST_WRITE
| FILE_NOTIFY_CHANGE_CREATION;
/// Default output buffer length for CHANGE_NOTIFY responses (64 KB).
const OUTPUT_BUFFER_LENGTH: u32 = 65536;
/// The type of change that occurred on a file or directory.
///
/// These correspond to the `Action` field in `FILE_NOTIFY_INFORMATION`
/// (MS-FSCC section 2.4.42).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FileNotifyAction {
/// A file was added to the directory.
Added,
/// A file was removed from the directory.
Removed,
/// A file was modified.
Modified,
/// A file was renamed (this is the old name).
RenamedOldName,
/// A file was renamed (this is the new name).
RenamedNewName,
}
impl FileNotifyAction {
/// Parse an action value from the wire format.
fn from_u32(value: u32) -> Result<Self> {
match value {
0x0000_0001 => Ok(FileNotifyAction::Added),
0x0000_0002 => Ok(FileNotifyAction::Removed),
0x0000_0003 => Ok(FileNotifyAction::Modified),
0x0000_0004 => Ok(FileNotifyAction::RenamedOldName),
0x0000_0005 => Ok(FileNotifyAction::RenamedNewName),
other => Err(Error::invalid_data(format!(
"unknown FILE_NOTIFY_INFORMATION action: {other:#010X}"
))),
}
}
}
impl std::fmt::Display for FileNotifyAction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FileNotifyAction::Added => write!(f, "added"),
FileNotifyAction::Removed => write!(f, "removed"),
FileNotifyAction::Modified => write!(f, "modified"),
FileNotifyAction::RenamedOldName => write!(f, "renamed (old name)"),
FileNotifyAction::RenamedNewName => write!(f, "renamed (new name)"),
}
}
}
/// A single file change notification.
///
/// Represents one `FILE_NOTIFY_INFORMATION` entry from the server.
#[derive(Debug, Clone)]
pub struct FileNotifyEvent {
/// What kind of change occurred.
pub action: FileNotifyAction,
/// The relative file name within the watched directory.
pub filename: String,
}
/// Watches a directory for changes via SMB2 CHANGE_NOTIFY.
///
/// The server holds the request until something changes, then responds
/// with one or more [`FileNotifyEvent`] entries. Each call to
/// [`next_events()`](Watcher::next_events) blocks until the server
/// reports a change.
///
/// ```no_run
/// # async fn example(client: &mut smb2::SmbClient, share: &smb2::Tree) -> Result<(), smb2::Error> {
/// let mut watcher = client.watch(&share, "_test/", true).await?;
/// loop {
/// let events = watcher.next_events().await?;
/// for event in &events {
/// println!("{}: {}", event.filename, event.action);
/// }
/// }
/// # Ok(())
/// # }
/// ```
///
/// **Pipelining**: `Watcher` keeps one CHANGE_NOTIFY request pre-issued on
/// the wire at all times after the first call to
/// [`next_events`](Self::next_events). The wire never sits idle between
/// consecutive responses, so server-side events that arrive while the
/// consumer is processing the previous batch are still delivered to an
/// outstanding request — they don't fall in a response→re-arm gap where
/// strict servers (older Samba, NAS firmware) drop them silently.
///
/// The watcher owns a cloned [`Connection`] (cheap `Arc::clone`, all
/// clones multiplex over the same SMB session), so the caller doesn't
/// need a second `SmbClient` to perform other operations while watching.
pub struct Watcher {
tree: Tree,
conn: Connection,
file_id: FileId,
recursive: bool,
/// In-flight CHANGE_NOTIFY response receiver. Populated lazily on the
/// first `next_events()` call and re-populated before awaiting each
/// response, so there is always exactly one outstanding request on
/// the wire from that point on.
pending: Option<oneshot::Receiver<Result<Frame>>>,
}
impl Watcher {
/// Create a new watcher (called by `Tree::watch`).
pub(crate) fn new(tree: Tree, conn: Connection, file_id: FileId, recursive: bool) -> Self {
Watcher {
tree,
conn,
file_id,
recursive,
pending: None,
}
}
/// Wait for the next batch of change events.
///
/// Dispatches a CHANGE_NOTIFY request (if one isn't already pre-issued
/// from the previous call), then — before awaiting the response —
/// dispatches the *next* CHANGE_NOTIFY. This keeps the wire
/// continuously armed: from the moment the first call returns until
/// the watcher is dropped, the server always has an outstanding
/// request to deliver events into. Closes the response→re-arm loss
/// window that strict servers (older Samba, NAS firmware) drop events
/// through.
///
/// The server holds each request until changes occur, so this call
/// may block for a long time.
///
/// Returns `Ok(events)` with one or more events when changes are detected.
///
/// # Errors
///
/// Returns `Error::Protocol` with `STATUS_NOTIFY_ENUM_DIR` if too many
/// changes occurred and the server could not fit them in the response
/// buffer. In this case, the caller should re-scan the directory and
/// keep watching — by the time control returns, the pipelined-next
/// request is already on the wire so no events arriving during the
/// re-scan get lost.
pub async fn next_events(&mut self) -> Result<Vec<FileNotifyEvent>> {
// Cold start: no request has been issued yet. Dispatch the first.
if self.pending.is_none() {
let rx = self.dispatch_next().await?;
self.pending = Some(rx);
}
// Take the currently in-flight receiver, then immediately
// pre-issue the next request before awaiting this one. The
// `dispatch` call below `.await`s only the transport.send(), so
// when it returns, the next CHANGE_NOTIFY is on the wire and the
// server has somewhere to put new events even while we process
// the response for the previous one.
let in_flight = self.pending.take().expect("pending populated above");
let next_rx = self.dispatch_next().await?;
self.pending = Some(next_rx);
let frame = await_frame(in_flight).await?;
if frame.header.status == NtStatus::NOTIFY_ENUM_DIR {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::ChangeNotify,
});
}
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::ChangeNotify,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = ChangeNotifyResponse::unpack(&mut cursor)?;
let events = parse_notify_information(&resp.output_data)?;
debug!("watcher: received {} change event(s)", events.len());
Ok(events)
}
/// Build a CHANGE_NOTIFY request and dispatch it on the cloned
/// connection, returning the response receiver. `Connection::dispatch`
/// awaits only up to and including `transport.send()`, so when this
/// returns the request is on the wire — the caller can rely on the
/// "outstanding on the wire" invariant for whatever comes next.
async fn dispatch_next(&self) -> Result<oneshot::Receiver<Result<Frame>>> {
let flags = if self.recursive { SMB2_WATCH_TREE } else { 0 };
let req = ChangeNotifyRequest {
flags,
output_buffer_length: OUTPUT_BUFFER_LENGTH,
file_id: self.file_id,
completion_filter: DEFAULT_COMPLETION_FILTER,
};
self.conn
.dispatch(Command::ChangeNotify, &req, Some(self.tree.tree_id))
.await
}
/// Close the directory handle.
///
/// Drops the pre-issued CHANGE_NOTIFY receiver (the `Connection`
/// receiver task discards the late response silently when it
/// arrives — same contract `Connection::execute` already documents),
/// then issues a CLOSE on the file handle. If `close` is not called
/// explicitly, the `Drop` impl drops the pre-issued receiver but the
/// server-side handle leaks until the session ends (there is no
/// async drop in Rust).
pub async fn close(mut self) -> Result<()> {
self.pending.take();
self.tree.close_handle(&mut self.conn, self.file_id).await
}
}
impl Drop for Watcher {
fn drop(&mut self) {
// The pre-issued response receiver (if any) drops with the
// Watcher. The `Connection` receiver task discards the late
// frame silently when it arrives, matching the contract on
// `Connection::execute`. The directory handle itself leaks
// server-side until the session ends — the docstring on `close`
// already warns about this.
}
}
/// Parse a chain of FILE_NOTIFY_INFORMATION entries from the response buffer.
///
/// Each entry has:
/// - `NextEntryOffset` (u32): offset to next entry, 0 for last
/// - `Action` (u32): the change type
/// - `FileNameLength` (u32): length of filename in bytes (UTF-16LE)
/// - `FileName` (variable): UTF-16LE, NOT null-terminated
///
/// Entries are 4-byte aligned.
fn parse_notify_information(data: &[u8]) -> Result<Vec<FileNotifyEvent>> {
let mut events = Vec::new();
let mut offset = 0usize;
if data.is_empty() {
return Ok(events);
}
loop {
// Need at least 12 bytes for the fixed fields.
if offset + 12 > data.len() {
return Err(Error::invalid_data(
"FILE_NOTIFY_INFORMATION truncated: not enough bytes for fixed fields",
));
}
let next_entry_offset =
u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize;
let action_raw = u32::from_le_bytes(data[offset + 4..offset + 8].try_into().unwrap());
let filename_length =
u32::from_le_bytes(data[offset + 8..offset + 12].try_into().unwrap()) as usize;
// Filename starts right after the 12-byte fixed header.
let filename_start = offset + 12;
let filename_end = filename_start + filename_length;
if filename_end > data.len() {
return Err(Error::invalid_data(format!(
"FILE_NOTIFY_INFORMATION filename extends beyond buffer: \
need {} bytes at offset {}, buffer is {} bytes",
filename_length,
filename_start,
data.len()
)));
}
let filename_bytes = &data[filename_start..filename_end];
// Decode UTF-16LE filename.
let filename = decode_utf16le(filename_bytes)?;
let action = FileNotifyAction::from_u32(action_raw)?;
events.push(FileNotifyEvent { action, filename });
if next_entry_offset == 0 {
break;
}
offset += next_entry_offset;
}
Ok(events)
}
/// Decode a UTF-16LE byte slice into a Rust String.
fn decode_utf16le(bytes: &[u8]) -> Result<String> {
if bytes.len() % 2 != 0 {
return Err(Error::invalid_data("UTF-16LE filename has odd byte count"));
}
let u16s: Vec<u16> = bytes
.chunks_exact(2)
.map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
String::from_utf16(&u16s)
.map_err(|e| Error::invalid_data(format!("invalid UTF-16LE filename: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_single_notify_entry() {
// Build a single FILE_NOTIFY_INFORMATION entry.
let filename = "test.txt";
let utf16: Vec<u16> = filename.encode_utf16().collect();
let filename_bytes: Vec<u8> = utf16.iter().flat_map(|c| c.to_le_bytes()).collect();
let filename_len = filename_bytes.len() as u32;
let mut data = Vec::new();
// NextEntryOffset = 0 (last entry)
data.extend_from_slice(&0u32.to_le_bytes());
// Action = FILE_ACTION_ADDED (0x00000001)
data.extend_from_slice(&1u32.to_le_bytes());
// FileNameLength
data.extend_from_slice(&filename_len.to_le_bytes());
// FileName (UTF-16LE)
data.extend_from_slice(&filename_bytes);
let events = parse_notify_information(&data).unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0].action, FileNotifyAction::Added);
assert_eq!(events[0].filename, "test.txt");
}
#[test]
fn parse_multiple_notify_entries() {
// Build two FILE_NOTIFY_INFORMATION entries.
let build_entry = |name: &str, action: u32, is_last: bool| -> Vec<u8> {
let utf16: Vec<u16> = name.encode_utf16().collect();
let filename_bytes: Vec<u8> = utf16.iter().flat_map(|c| c.to_le_bytes()).collect();
let filename_len = filename_bytes.len() as u32;
let mut entry = Vec::new();
// Fixed header is 12 bytes + filename. Align to 4 bytes.
let entry_size = 12 + filename_bytes.len();
let aligned_size = (entry_size + 3) & !3;
let next_offset = if is_last { 0u32 } else { aligned_size as u32 };
entry.extend_from_slice(&next_offset.to_le_bytes());
entry.extend_from_slice(&action.to_le_bytes());
entry.extend_from_slice(&filename_len.to_le_bytes());
entry.extend_from_slice(&filename_bytes);
// Pad to 4-byte alignment.
while entry.len() < aligned_size {
entry.push(0);
}
entry
};
let mut data = Vec::new();
data.extend_from_slice(&build_entry("added.txt", 1, false));
data.extend_from_slice(&build_entry("removed.txt", 2, true));
let events = parse_notify_information(&data).unwrap();
assert_eq!(events.len(), 2);
assert_eq!(events[0].action, FileNotifyAction::Added);
assert_eq!(events[0].filename, "added.txt");
assert_eq!(events[1].action, FileNotifyAction::Removed);
assert_eq!(events[1].filename, "removed.txt");
}
#[test]
fn parse_empty_buffer_returns_no_events() {
let events = parse_notify_information(&[]).unwrap();
assert!(events.is_empty());
}
#[test]
fn parse_truncated_buffer_returns_error() {
// Only 8 bytes, need at least 12 for fixed fields.
let data = vec![0u8; 8];
let result = parse_notify_information(&data);
assert!(result.is_err());
}
#[test]
fn decode_utf16le_basic() {
let input = "hello";
let utf16: Vec<u16> = input.encode_utf16().collect();
let bytes: Vec<u8> = utf16.iter().flat_map(|c| c.to_le_bytes()).collect();
let result = decode_utf16le(&bytes).unwrap();
assert_eq!(result, "hello");
}
#[test]
fn decode_utf16le_non_ascii() {
let input = "photos/\u{00E9}t\u{00E9}";
let utf16: Vec<u16> = input.encode_utf16().collect();
let bytes: Vec<u8> = utf16.iter().flat_map(|c| c.to_le_bytes()).collect();
let result = decode_utf16le(&bytes).unwrap();
assert_eq!(result, input);
}
#[test]
fn decode_utf16le_odd_bytes_is_error() {
let result = decode_utf16le(&[0x41, 0x00, 0x42]);
assert!(result.is_err());
}
#[test]
fn file_notify_action_display() {
assert_eq!(format!("{}", FileNotifyAction::Added), "added");
assert_eq!(format!("{}", FileNotifyAction::Removed), "removed");
assert_eq!(format!("{}", FileNotifyAction::Modified), "modified");
assert_eq!(
format!("{}", FileNotifyAction::RenamedOldName),
"renamed (old name)"
);
assert_eq!(
format!("{}", FileNotifyAction::RenamedNewName),
"renamed (new name)"
);
}
#[test]
fn file_notify_action_from_u32_unknown_is_error() {
let result = FileNotifyAction::from_u32(0x9999);
assert!(result.is_err());
}
}
/// Loss-window tests using a strict-server simulator.
///
/// These probe the architectural property the watcher contract should
/// guarantee: every event the server observes is eventually delivered
/// to the consumer, even when the server drops events that arrive
/// while no `CHANGE_NOTIFY` request is outstanding (the naspi / older
/// Samba behavior that triggered cmdr's field reproduction).
///
/// **TDD-red on `main`**: `LossySim` drops events when no request is
/// outstanding; current `next_events()` issues one CHANGE_NOTIFY per
/// call, so there's always a gap between response delivery and the
/// next request. Events pushed during that gap are dropped, and the
/// test fails. The pipelined-watcher fix (always keep one CHANGE_NOTIFY
/// pre-issued on the wire) closes the gap, the simulator never drops,
/// and the test passes.
#[cfg(test)]
mod loss_window_tests {
use super::*;
use crate::client::connection::{pack_message, Connection, NegotiatedParams};
use crate::client::tree::Tree;
use crate::msg::change_notify::ChangeNotifyResponse;
use crate::msg::header::Header;
use crate::pack::Guid;
use crate::transport::{TransportReceive, TransportSend};
use crate::types::flags::Capabilities;
use crate::types::{Command, Dialect, MessageId, SessionId, TreeId};
use async_trait::async_trait;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::Notify;
/// Simulates a CHANGE_NOTIFY server that DROPS events that arrive
/// while no request is outstanding. Models naspi / older Samba
/// firmware (the server side of cmdr's 9-files → 4-events field
/// reproduction). Forgiving servers like Docker Samba buffer
/// generously and won't trigger this; the simulator's job is to
/// surface the architectural bug regardless of how forgiving any
/// real server happens to be.
struct LossySim {
/// Outstanding CHANGE_NOTIFY request msg_ids (FIFO).
outstanding: Mutex<VecDeque<u64>>,
/// Events the server has observed but not yet delivered.
pending_events: Mutex<Vec<(String, u32)>>,
/// Response queue read by `receive()`.
responses: Mutex<VecDeque<Vec<u8>>>,
/// Count of events the server saw with no request outstanding.
dropped: Mutex<usize>,
send_notify: Notify,
recv_notify: Notify,
closed: AtomicBool,
}
impl LossySim {
fn new() -> Self {
Self {
outstanding: Mutex::new(VecDeque::new()),
pending_events: Mutex::new(Vec::new()),
responses: Mutex::new(VecDeque::new()),
dropped: Mutex::new(0),
send_notify: Notify::new(),
recv_notify: Notify::new(),
closed: AtomicBool::new(false),
}
}
/// Block until at least one CHANGE_NOTIFY request is outstanding.
async fn wait_outstanding(&self) {
loop {
if !self.outstanding.lock().unwrap().is_empty() {
return;
}
if self.closed.load(Ordering::Acquire) {
return;
}
self.send_notify.notified().await;
}
}
/// Push an event. If a CHANGE_NOTIFY request is outstanding, buffer
/// the event for the next `deliver_pending()`. Else, drop silently
/// and bump the dropped counter.
fn push_event(&self, name: &str) {
let outstanding = !self.outstanding.lock().unwrap().is_empty();
if outstanding {
self.pending_events
.lock()
.unwrap()
.push((name.to_string(), 1 /* FILE_ACTION_ADDED */));
} else {
*self.dropped.lock().unwrap() += 1;
}
}
/// Wrap all buffered events into a single CHANGE_NOTIFY response,
/// consuming one outstanding msg_id.
fn deliver_pending(&self) {
let msg_id = self.outstanding.lock().unwrap().pop_front();
let events = std::mem::take(&mut *self.pending_events.lock().unwrap());
if let Some(id) = msg_id {
let resp = build_response(id, &events);
self.responses.lock().unwrap().push_back(resp);
self.recv_notify.notify_one();
}
}
fn dropped_count(&self) -> usize {
*self.dropped.lock().unwrap()
}
fn close(&self) {
self.closed.store(true, Ordering::Release);
self.recv_notify.notify_waiters();
self.send_notify.notify_waiters();
}
}
#[async_trait]
impl TransportSend for LossySim {
async fn send(&self, data: &[u8]) -> crate::error::Result<()> {
if let Some(msg_id) = extract_change_notify_msg_id(data) {
self.outstanding.lock().unwrap().push_back(msg_id);
self.send_notify.notify_waiters();
}
Ok(())
}
}
#[async_trait]
impl TransportReceive for LossySim {
async fn receive(&self) -> crate::error::Result<Vec<u8>> {
loop {
if let Some(data) = self.responses.lock().unwrap().pop_front() {
return Ok(data);
}
if self.closed.load(Ordering::Acquire) {
return Err(crate::Error::Disconnected);
}
self.recv_notify.notified().await;
}
}
}
/// Pull `MessageId` out of a request frame, but only for CHANGE_NOTIFY.
/// Non-CHANGE_NOTIFY sends are ignored by the simulator (the test
/// pre-configures the connection so no other requests should hit this
/// transport — but if any do, we won't track them).
fn extract_change_notify_msg_id(data: &[u8]) -> Option<u64> {
const HEADER_MIN: usize = 64;
if data.len() < HEADER_MIN || &data[0..4] != b"\xFESMB" {
return None;
}
let cmd = u16::from_le_bytes([data[12], data[13]]);
if cmd != Command::ChangeNotify as u16 {
return None;
}
Some(u64::from_le_bytes(data[24..32].try_into().unwrap()))
}
/// Pack a CHANGE_NOTIFY response carrying the given (name, action) pairs.
fn build_response(msg_id: u64, events: &[(String, u32)]) -> Vec<u8> {
let mut output_data = Vec::new();
for (i, (name, action)) in events.iter().enumerate() {
let is_last = i == events.len() - 1;
let utf16: Vec<u16> = name.encode_utf16().collect();
let filename_bytes: Vec<u8> = utf16.iter().flat_map(|c| c.to_le_bytes()).collect();
let filename_len = filename_bytes.len() as u32;
let entry_size = 12 + filename_bytes.len();
let aligned_size = (entry_size + 3) & !3;
let next_offset = if is_last { 0u32 } else { aligned_size as u32 };
let start = output_data.len();
output_data.extend_from_slice(&next_offset.to_le_bytes());
output_data.extend_from_slice(&action.to_le_bytes());
output_data.extend_from_slice(&filename_len.to_le_bytes());
output_data.extend_from_slice(&filename_bytes);
while output_data.len() - start < aligned_size {
output_data.push(0);
}
}
let mut h = Header::new_request(Command::ChangeNotify);
h.flags.set_response();
h.message_id = MessageId(msg_id);
h.credits = 32;
let body = ChangeNotifyResponse { output_data };
pack_message(&h, &body)
}
fn setup_connection(sim: &Arc<LossySim>) -> Connection {
let mut conn =
Connection::from_transport(Box::new(sim.clone()), Box::new(sim.clone()), "test-server");
conn.set_test_params(NegotiatedParams {
dialect: Dialect::Smb2_0_2,
max_read_size: 65536,
max_write_size: 65536,
max_transact_size: 65536,
server_guid: Guid::ZERO,
signing_required: false,
capabilities: Capabilities::default(),
gmac_negotiated: false,
cipher: None,
compression_supported: false,
});
conn.set_session_id(SessionId(0x1234));
conn
}
fn test_tree() -> Tree {
Tree {
tree_id: TreeId(1),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
}
}
/// Cycle, repeated N times:
/// 1. wait for outstanding (watcher armed)
/// 2. push event A → buffered
/// 3. deliver_pending → response queued, msg_id consumed
/// 4. push GAP event → on `main`, no outstanding → DROPPED;
/// on the pipelined-watcher fix, the next request is already
/// issued → buffered.
///
/// Final flush: one more wait_outstanding + push + deliver to make
/// sure any buffered gap events on the fix path get out.
///
/// On `main`: `dropped_count() > 0`, `delivered.len() < expected`.
/// On the fix: `dropped_count() == 0`, all events delivered.
#[tokio::test]
async fn watcher_does_not_lose_events_between_consecutive_requests() {
let _ = env_logger::try_init();
const N_CYCLES: usize = 5;
let sim = Arc::new(LossySim::new());
let conn = setup_connection(&sim);
let tree = test_tree();
let scenario_sim = sim.clone();
let scenario = tokio::spawn(async move {
let sim = scenario_sim;
for round in 0..N_CYCLES {
sim.wait_outstanding().await;
sim.push_event(&format!("a_{round:02}"));
sim.deliver_pending();
// Inline push (no .await) — outstanding queue was just
// emptied by deliver_pending. On `main`, no request has
// been re-issued yet, so this lands in the "drop" branch.
// On the fix, a pre-issued request is still outstanding,
// so it lands in the "buffer" branch.
sim.push_event(&format!("gap_{round:02}"));
// Models "time passes between server-side events". Real
// workloads have at least a syscall worth of latency
// between events, which is enough for the watcher task
// to wake up, process the previous response, and
// re-dispatch. The pipelining fix only guarantees one
// outstanding through the response-processing window,
// not through arbitrary back-to-back synchronous
// delivers within a single scheduler quantum.
tokio::task::yield_now().await;
}
// Flush: drive one more cycle to push any buffered gap events
// out the door for the fix path.
sim.wait_outstanding().await;
sim.push_event("flush_marker");
sim.deliver_pending();
// Brief grace period for the watcher to drain the response,
// then close so its next next_events() returns Disconnected
// and the consumer loop exits.
tokio::time::sleep(Duration::from_millis(50)).await;
sim.close();
});
let mut watcher = Watcher::new(
tree,
conn,
crate::types::FileId {
persistent: 0x1111,
volatile: 0x2222,
},
true,
);
let mut delivered: Vec<String> = Vec::new();
while let Ok(events) = watcher.next_events().await {
for e in &events {
delivered.push(e.filename.clone());
}
}
scenario.await.unwrap();
let dropped = sim.dropped_count();
// `a_*` events always land in the outstanding window. `flush_marker`
// ditto. `gap_*` events expose the bug: dropped today, delivered
// after the fix.
let expected_min = N_CYCLES /* a_* */ + 1 /* flush_marker */;
let expected_max = expected_min + N_CYCLES /* gap_* */;
assert!(
delivered.len() >= expected_min,
"watcher dropped 'a_*' or 'flush_marker' events: got {:?}",
delivered
);
assert_eq!(
dropped, 0,
"{} server-side event(s) arrived with no outstanding CHANGE_NOTIFY \
request and were dropped. The pipelined-watcher fix should keep \
one CHANGE_NOTIFY request continuously outstanding so no event \
ever lands in the drop branch. Delivered to consumer: {:?}",
dropped, delivered
);
assert_eq!(
delivered.len(),
expected_max,
"expected every 'a_*', 'gap_*', and 'flush_marker' event delivered; \
got {:?}",
delivered
);
}
}