diff --git a/AGENTS.md b/AGENTS.md index 277f88f..12be09c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -2925,3 +2925,26 @@ cargo test -p markbase-core --lib # ✅ 157 passed, 0 failed **最后更新**:2026-06-20 14:45 **版本**:1.33(Web Frontend Phase 3 完成) + +--- + +## SMB Server Phase 2 Build Fix(2026-06-20)⭐⭐⭐⭐⭐ + +**完成时间**:约 30 分钟 + +### 修复内容 ⭐⭐⭐⭐⭐ + +1. **`VfsFile` trait 添加 `Send` supertrait**:`vfs/mod.rs` — 所有实现已经是 `Send`,显式约束无需 unsafe cast +2. **`SmbServerCommand` 改为 enum**:`smb_server.rs` — 使用 `#[derive(Subcommand)]` 枚举(`Start` 变体)以兼容 `#[command(flatten)]` +3. **`smb_server_backend.rs` 测试修复**:用 `matches!(result, Err(SmbError::NotFound))` 替代 `result.unwrap_err()` 避免 `Debug` 约束 +4. **移除未使用 `VfsFile` import**:`webdav.rs` + `scp_handler.rs` + +### 验证 ✅ + +```bash +cargo build -p markbase-core --features smb-server # ✅ 0 error +cargo test -p markbase-core --lib --features smb-server # ✅ 169 passed, 0 failed +cargo build -p markbase-core # ✅ 0 error (no features) +``` + +**版本**:1.34(SMB Server Phase 2 Build Fix) diff --git a/Cargo.lock b/Cargo.lock index 77bcae8..14434d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -113,6 +113,12 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "ambient-authority" +version = "0.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9d4ee0d472d1cd2e28c97dfa124b3d8d992e10eb0a035f33f5d12e3a177ba3b" + [[package]] name = "android_system_properties" version = "0.1.5" @@ -211,6 +217,12 @@ dependencies = [ "password-hash 0.6.1", ] +[[package]] +name = "array-init" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d62b7694a562cdf5a74227903507c56ab2cc8bdd1f781ed5cb4cf9c9f810bfc" + [[package]] name = "async-trait" version = "0.1.89" @@ -433,13 +445,46 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "binrw" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d53195f985e88ab94d1cc87e80049dd2929fd39e4a772c5ae96a7e5c4aad3642" +dependencies = [ + "array-init", + "binrw_derive", + "bytemuck", +] + +[[package]] +name = "binrw_derive" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5910da05ee556b789032c8ff5a61fb99239580aa3fd0bfaa8f4d094b2aee00ad" +dependencies = [ + "either", + "owo-colors", + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "bit-set" version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0481a0e032742109b1133a095184ee93d88f3dc9e0d28a5d033dc77a073f44f" dependencies = [ - "bit-vec", + "bit-vec 0.7.0", +] + +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec 0.8.0", ] [[package]] @@ -448,6 +493,12 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2c54ff287cfc0a34f38a6b832ea1bd8e448a330b3e40a50859e6488bee07f22" +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + [[package]] name = "bitflags" version = "1.3.2" @@ -553,6 +604,12 @@ version = "3.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" + [[package]] name = "byteorder" version = "1.5.0" @@ -594,6 +651,36 @@ dependencies = [ "crossbeam-queue", ] +[[package]] +name = "cap-primitives" +version = "3.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6cf3aea8a5081171859ef57bc1606b1df6999df4f1110f8eef68b30098d1d3a" +dependencies = [ + "ambient-authority", + "fs-set-times", + "io-extras", + "io-lifetimes", + "ipnet", + "maybe-owned", + "rustix", + "rustix-linux-procfs", + "windows-sys 0.52.0", + "winx", +] + +[[package]] +name = "cap-std" +version = "3.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6dc3090992a735d23219de5c204927163d922f42f575a0189b005c62d37549a" +dependencies = [ + "cap-primitives", + "io-extras", + "io-lifetimes", + "rustix", +] + [[package]] name = "caps" version = "0.5.6" @@ -633,6 +720,18 @@ dependencies = [ "shlex", ] +[[package]] +name = "ccm" +version = "0.6.0-rc.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4edea5ea70a1285565ac264767613d6c88351a9a0557e7af793a0942590baaed" +dependencies = [ + "aead 0.6.0-rc.10", + "cipher 0.5.2", + "ctr 0.10.1", + "subtle", +] + [[package]] name = "cexpr" version = "0.6.0" @@ -779,6 +878,28 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" +[[package]] +name = "cmac" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8543454e3c3f5126effff9cd44d562af4e31fb8ce1cc0d3dcd8f084515dbc1aa" +dependencies = [ + "cipher 0.4.4", + "dbl 0.3.2", + "digest 0.10.7", +] + +[[package]] +name = "cmac" +version = "0.8.0-rc.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f7f5c25253a49afbdd6a256a21a554c509cf0e6400f59d6dd85e0f15b5f15f6" +dependencies = [ + "cipher 0.5.2", + "dbl 0.5.0", + "digest 0.11.3", +] + [[package]] name = "cmake" version = "0.1.58" @@ -1126,6 +1247,24 @@ dependencies = [ "xmltree", ] +[[package]] +name = "dbl" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd2735a791158376708f9347fe8faba9667589d82427ef3aed6794a8981de3d9" +dependencies = [ + "generic-array 0.14.7", +] + +[[package]] +name = "dbl" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d7a944e61df464668c5f51f56cc667396a8821434273112948ea0b66e405d7" +dependencies = [ + "hybrid-array", +] + [[package]] name = "dbs-snapshot" version = "1.5.2" @@ -1680,6 +1819,17 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs-set-times" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94e7099f6313ecacbe1256e8ff9d617b75d1bcb16a6fddef94866d225a01a14a" +dependencies = [ + "io-lifetimes", + "rustix", + "windows-sys 0.52.0", +] + [[package]] name = "fs2" version = "0.4.3" @@ -2417,6 +2567,22 @@ dependencies = [ "rand_core 0.10.1", ] +[[package]] +name = "io-extras" +version = "0.18.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2285ddfe3054097ef4b2fe909ef8c3bcd1ea52a8f0d274416caebeef39f04a65" +dependencies = [ + "io-lifetimes", + "windows-sys 0.52.0", +] + +[[package]] +name = "io-lifetimes" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06432fb54d3be7964ecd3649233cddf80db2832f47fec34c01f65b3d9d774983" + [[package]] name = "io-uring" version = "0.5.13" @@ -2427,6 +2593,12 @@ dependencies = [ "libc", ] +[[package]] +name = "ipnet" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" + [[package]] name = "is_terminal_polyfill" version = "1.70.2" @@ -2658,6 +2830,15 @@ dependencies = [ "libc", ] +[[package]] +name = "lz4_flex" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ef0d4ed8669f8f8826eb00dc878084aa8f253506c4fd5e8f58f5bce72ddb97e" +dependencies = [ + "twox-hash", +] + [[package]] name = "lzma-rust" version = "0.1.7" @@ -2686,6 +2867,7 @@ dependencies = [ "aes 0.8.4", "aes-gcm 0.10.3", "anyhow", + "async-trait", "axum", "axum-extra", "base64", @@ -2726,6 +2908,8 @@ dependencies = [ "sevenz-rust", "sha2 0.10.9", "sled", + "smb-server", + "smb2", "ssh-key", "ssh2", "tar", @@ -2734,6 +2918,7 @@ dependencies = [ "tokio-postgres", "tokio-util", "toml", + "tracing-subscriber", "unrar", "ureq", "url", @@ -2865,12 +3050,37 @@ dependencies = [ "xmltree", ] +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + [[package]] name = "matchit" version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "maybe-owned" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4facc753ae494aeb6e3c22f839b158aebd4f9270f55cd3c79906c45476c47ab4" + +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest 0.10.7", +] + [[package]] name = "md-5" version = "0.11.0" @@ -2881,6 +3091,24 @@ dependencies = [ "digest 0.11.3", ] +[[package]] +name = "md4" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da5ac363534dce5fabf69949225e174fbf111a498bf0ff794c8ea1fba9f3dda" +dependencies = [ + "digest 0.10.7", +] + +[[package]] +name = "md4" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd76fb0fd6b2e4be62a73f8e0858ca97f81babcb1af322dcaca196f735f17f80" +dependencies = [ + "digest 0.11.3", +] + [[package]] name = "md5" version = "0.7.0" @@ -3086,6 +3314,15 @@ dependencies = [ "time", ] +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "num" version = "0.1.43" @@ -3219,6 +3456,28 @@ dependencies = [ "libm", ] +[[package]] +name = "num_enum" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0bca838442ec211fa11de3a8b0e0e8f3a4522575b5c4c06ed722e005036f26" +dependencies = [ + "num_enum_derive", + "rustversion", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "680998035259dcfcafe653688bf2aa6d3e2dc05e98be6ab46afb089dc84f1df8" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "objc2" version = "0.6.4" @@ -3316,6 +3575,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "owo-colors" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d211803b9b6b570f68772237e415a029d5a50c65d382910b879fb19d3271f94d" + [[package]] name = "p256" version = "0.13.2" @@ -3749,7 +4014,7 @@ dependencies = [ "bytes", "fallible-iterator 0.2.0", "hmac 0.13.0", - "md-5", + "md-5 0.11.0", "memchr", "rand 0.10.1", "sha2 0.11.0", @@ -3833,6 +4098,15 @@ dependencies = [ "elliptic-curve 0.14.0-rc.33", ] +[[package]] +name = "proc-macro-crate" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" +dependencies = [ + "toml_edit 0.25.12+spec-1.1.0", +] + [[package]] name = "proc-macro2" version = "1.0.106" @@ -3842,6 +4116,25 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proptest" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b45fcc2344c680f5025fe57779faef368840d0bd1f42f216291f0dc4ace4744" +dependencies = [ + "bit-set 0.8.0", + "bit-vec 0.8.0", + "bitflags 2.11.1", + "num-traits 0.2.19", + "rand 0.9.4", + "rand_chacha 0.9.0", + "rand_xorshift", + "regex-syntax", + "rusty-fork", + "tempfile", + "unarray", +] + [[package]] name = "pulldown-cmark" version = "0.12.2" @@ -3861,6 +4154,12 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "007d8adb5ddab6f8e3f491ac63566a7d5002cc7ed73901f72057943fa71ae1ae" +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quote" version = "1.0.45" @@ -3912,10 +4211,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" dependencies = [ "libc", - "rand_chacha", + "rand_chacha 0.3.1", "rand_core 0.6.4", ] +[[package]] +name = "rand" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.5", +] + [[package]] name = "rand" version = "0.10.1" @@ -3937,6 +4246,16 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.5", +] + [[package]] name = "rand_core" version = "0.3.1" @@ -3961,12 +4280,30 @@ dependencies = [ "getrandom 0.2.17", ] +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + [[package]] name = "rand_core" version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" +[[package]] +name = "rand_xorshift" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" +dependencies = [ + "rand_core 0.9.5", +] + [[package]] name = "rayon" version = "1.12.0" @@ -3987,6 +4324,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "rc4" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "840038b674daa9f7a7957440d937951d15c0143c056e631e529141fd780e0c92" +dependencies = [ + "cipher 0.5.2", +] + [[package]] name = "rdrand" version = "0.4.0" @@ -4408,6 +4754,16 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustix-linux-procfs" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fc84bf7e9aa16c4f2c758f27412dc9841341e16aa682d9c7ac308fe3ee12056" +dependencies = [ + "once_cell", + "rustix", +] + [[package]] name = "rustls" version = "0.23.40" @@ -4449,6 +4805,18 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "rusty-fork" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc6bf79ff24e648f6da1f8d1f011e9cac26491b619e6b9280f2b47f1774e6ee2" +dependencies = [ + "fnv", + "quick-error", + "tempfile", + "wait-timeout", +] + [[package]] name = "rusty-s3" version = "0.10.0" @@ -4459,7 +4827,7 @@ dependencies = [ "hmac 0.13.0", "instant-xml", "jiff", - "md-5", + "md-5 0.11.0", "percent-encoding", "serde", "serde_json", @@ -4663,7 +5031,7 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26482cf1ecce4540dc782fc70019eba89ffc4d87b3717eb5ec524b5db6fdefef" dependencies = [ - "bit-set", + "bit-set 0.6.0", "byteorder", "crc", "filetime_creation", @@ -4728,6 +5096,15 @@ dependencies = [ "keccak", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shared-local-state" version = "0.1.4" @@ -4824,6 +5201,60 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "smb-server" +version = "0.4.1" +dependencies = [ + "aes 0.8.4", + "async-trait", + "binrw", + "bytes", + "cap-std", + "cmac 0.7.2", + "getrandom 0.4.2", + "hex", + "hmac 0.12.1", + "md-5 0.10.6", + "md4 0.10.2", + "rc4", + "sha2 0.10.9", + "tempfile", + "thiserror 1.0.69", + "tokio", + "tracing", + "tracing-subscriber", + "uuid", +] + +[[package]] +name = "smb2" +version = "0.11.3" +dependencies = [ + "aes 0.9.1", + "aes-gcm 0.11.0-rc.4", + "async-trait", + "ccm", + "cmac 0.8.0-rc.5", + "digest 0.11.3", + "env_logger", + "futures-util", + "getrandom 0.4.2", + "hmac 0.13.0", + "log", + "lz4_flex", + "md-5 0.11.0", + "md4 0.11.0", + "num_enum", + "pbkdf2 0.13.0", + "proptest", + "serde", + "serde_json", + "sha1 0.11.0", + "sha2 0.11.0", + "thiserror 2.0.18", + "tokio", +] + [[package]] name = "socket2" version = "0.4.10" @@ -5123,6 +5554,15 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + [[package]] name = "time" version = "0.3.47" @@ -5291,8 +5731,8 @@ checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" dependencies = [ "serde", "serde_spanned", - "toml_datetime", - "toml_edit", + "toml_datetime 0.6.11", + "toml_edit 0.22.27", ] [[package]] @@ -5304,6 +5744,15 @@ dependencies = [ "serde", ] +[[package]] +name = "toml_datetime" +version = "1.1.1+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7" +dependencies = [ + "serde_core", +] + [[package]] name = "toml_edit" version = "0.22.27" @@ -5313,9 +5762,30 @@ dependencies = [ "indexmap", "serde", "serde_spanned", - "toml_datetime", + "toml_datetime 0.6.11", "toml_write", - "winnow", + "winnow 0.7.15", +] + +[[package]] +name = "toml_edit" +version = "0.25.12+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2153edc6955a6c354fad8f5efd38b6a8769bdccf9fe50f8e1329f81b0baa5d7" +dependencies = [ + "indexmap", + "toml_datetime 1.1.1+spec-1.1.0", + "toml_parser", + "winnow 1.0.3", +] + +[[package]] +name = "toml_parser" +version = "1.1.2+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" +dependencies = [ + "winnow 1.0.3", ] [[package]] @@ -5382,14 +5852,56 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", + "valuable", ] +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "twox-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" + [[package]] name = "typenum" version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + [[package]] name = "unicase" version = "2.9.0" @@ -5544,6 +6056,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "vcpkg" version = "0.2.15" @@ -5656,6 +6174,15 @@ dependencies = [ "libc", ] +[[package]] +name = "wait-timeout" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11" +dependencies = [ + "libc", +] + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" @@ -6193,12 +6720,31 @@ dependencies = [ "memchr", ] +[[package]] +name = "winnow" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0592e1c9d151f854e6fd382574c3a0855250e1d9b2f99d9281c6e6391af352f1" +dependencies = [ + "memchr", +] + [[package]] name = "winsafe" version = "0.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904" +[[package]] +name = "winx" +version = "0.36.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f3fd376f71958b862e7afb20cfe5a22830e1963462f3a17f49d82a6c1d1f42d" +dependencies = [ + "bitflags 2.11.1", + "windows-sys 0.52.0", +] + [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/Cargo.toml b/Cargo.toml index 4dd945b..f5e00f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,3 +15,5 @@ members = [ "markbase-iscsi", "markbase-sync", "rust-iscsi-initiator", ] + + diff --git a/markbase-core/Cargo.toml b/markbase-core/Cargo.toml index ff5c0c5..f1a1d69 100644 --- a/markbase-core/Cargo.toml +++ b/markbase-core/Cargo.toml @@ -68,9 +68,18 @@ ureq = "2.12" # 輕量同步 HTTP 客戶端 rayon = "1.10" # Phase 4: 并行加密 url = "2" # URL 解析(rusty-s3 依賴) +# === SMB/CIFS Client (Phase 1) === +smb2 = { path = "../vendor/smb2" } # Pure-Rust SMB2/3 client library with pipelined I/O + +# === SMB/CIFS Server (Phase 2) — optional (vendored) === +smb-server = { path = "../vendor/smb-server", optional = true, default-features = false } +async-trait = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + [features] default = [] # 默认不启用可选格式 optional-formats = ["unrar", "xz2", "sevenz-rust"] # 争议格式可选启用 +smb-server = ["dep:smb-server"] # SMB server feature flag [dev-dependencies] # tempfile moved to dependencies (needed for archive extraction) diff --git a/markbase-core/src/cli/tools/mod.rs b/markbase-core/src/cli/tools/mod.rs index b7c69c6..4c458a9 100644 --- a/markbase-core/src/cli/tools/mod.rs +++ b/markbase-core/src/cli/tools/mod.rs @@ -1,4 +1,5 @@ pub mod render; +pub mod smb_server; pub mod test; use clap::Subcommand; @@ -9,12 +10,15 @@ pub enum ToolsCommands { Render(render::RenderCommand), #[command(flatten)] Test(test::TestCommand), + #[command(flatten)] + SmbServer(smb_server::SmbServerCommand), } pub async fn handle_tools_command(cmd: ToolsCommands) -> anyhow::Result<()> { match cmd { ToolsCommands::Render(c) => render::handle_render_command(c)?, ToolsCommands::Test(c) => test::handle_test_command(c)?, + ToolsCommands::SmbServer(c) => smb_server::handle_smb_server_command(c).await?, } Ok(()) } diff --git a/markbase-core/src/cli/tools/smb_server.rs b/markbase-core/src/cli/tools/smb_server.rs new file mode 100644 index 0000000..ce0cd6e --- /dev/null +++ b/markbase-core/src/cli/tools/smb_server.rs @@ -0,0 +1,71 @@ +use clap::Subcommand; + +#[derive(Subcommand)] +pub enum SmbServerCommand { + #[command(name = "smb-start")] + Start { + #[arg(short, long, default_value = "4445")] + port: u16, + + #[arg(short, long, default_value = "/Users/accusys/momentry/var/sftpgo/data/demo")] + root: String, + + #[arg(short, long, default_value = "markbase")] + share_name: String, + + #[arg(long)] + read_only: bool, + }, +} + +pub async fn handle_smb_server_command(cmd: SmbServerCommand) -> anyhow::Result<()> { + #[cfg(feature = "smb-server")] + { + match cmd { + SmbServerCommand::Start { port, root, share_name, read_only } => { + use std::path::PathBuf; + + use smb_server::{Access, Share, SmbServer}; + use tracing_subscriber::EnvFilter; + + let _ = tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new("info")), + ) + .try_init(); + + let addr: std::net::SocketAddr = + format!("0.0.0.0:{}", port).parse()?; + let root_path = PathBuf::from(&root); + + let vfs = Box::new(crate::vfs::local_fs::LocalFs::new()); + let backend = crate::vfs::smb_server_backend::VfsShareBackend::new(vfs, root_path) + .read_only(read_only); + + let share = Share::new(&share_name, backend) + .user("demo", Access::ReadWrite); + + let server = SmbServer::builder() + .listen(addr) + .user("demo", "demo123") + .share(share) + .build()?; + + log::info!("SMB server listening on {}", addr); + log::info!("Share '{}' at root: {}", share_name, root); + log::info!("User: demo / demo123"); + + server.serve().await?; + } + } + } + + #[cfg(not(feature = "smb-server"))] + { + let _ = cmd; + anyhow::bail!("SMB server support not enabled. Build with --features smb-server"); + } + + Ok(()) +} diff --git a/markbase-core/src/ssh_server/scp_handler.rs b/markbase-core/src/ssh_server/scp_handler.rs index 6da08d9..4375f8a 100644 --- a/markbase-core/src/ssh_server/scp_handler.rs +++ b/markbase-core/src/ssh_server/scp_handler.rs @@ -2,7 +2,7 @@ // 参考OpenSSH scp.c源码 use crate::vfs::open_flags::OpenFlags; -use crate::vfs::{VfsBackend, VfsFile, VfsStat}; +use crate::vfs::{VfsBackend, VfsStat}; use anyhow::{anyhow, Result}; use log::{debug, info, warn}; use std::io::{BufRead, Read, Write}; diff --git a/markbase-core/src/vfs/mod.rs b/markbase-core/src/vfs/mod.rs index b3d251f..88bf537 100644 --- a/markbase-core/src/vfs/mod.rs +++ b/markbase-core/src/vfs/mod.rs @@ -1,6 +1,9 @@ pub mod local_fs; pub mod open_flags; pub mod s3_fs; +pub mod smb_fs; +#[cfg(feature = "smb-server")] +pub mod smb_server_backend; pub mod util; use std::path::{Path, PathBuf}; @@ -81,7 +84,7 @@ pub struct VfsDirEntry { } /// 打开文件的抽象 -pub trait VfsFile { +pub trait VfsFile: Send { fn read(&mut self, buf: &mut [u8]) -> Result; fn write(&mut self, buf: &[u8]) -> Result; fn seek(&mut self, pos: std::io::SeekFrom) -> Result; diff --git a/markbase-core/src/vfs/smb_fs.rs b/markbase-core/src/vfs/smb_fs.rs new file mode 100644 index 0000000..ef91484 --- /dev/null +++ b/markbase-core/src/vfs/smb_fs.rs @@ -0,0 +1,539 @@ +use super::open_flags::OpenFlags; +use super::{VfsBackend, VfsDirEntry, VfsError, VfsFile, VfsStat}; +use smb2::ClientConfig; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +const SMB_TIMEOUT_SECS: u64 = 30; +const FILETIME_TO_UNIX_SECS: u64 = 11_644_473_600; + +fn filetime_to_systemtime(raw: u64) -> SystemTime { + let secs = raw / 10_000_000; + if secs > FILETIME_TO_UNIX_SECS { + UNIX_EPOCH + Duration::from_secs(secs - FILETIME_TO_UNIX_SECS) + } else { + UNIX_EPOCH + } +} + +fn map_smb_error(e: smb2::Error) -> VfsError { + match e.kind() { + smb2::ErrorKind::NotFound => VfsError::NotFound(e.to_string()), + smb2::ErrorKind::AlreadyExists => VfsError::AlreadyExists(e.to_string()), + smb2::ErrorKind::AccessDenied => VfsError::PermissionDenied(e.to_string()), + smb2::ErrorKind::IsADirectory => VfsError::IsADirectory(e.to_string()), + smb2::ErrorKind::NotADirectory => VfsError::NotADirectory(e.to_string()), + smb2::ErrorKind::ConnectionLost | smb2::ErrorKind::TimedOut | smb2::ErrorKind::SessionExpired => { + VfsError::Io(format!("SMB connection error: {}", e)) + } + _ => VfsError::Io(format!("SMB error: {}", e)), + } +} + +/// SMB 客户端 VFS 后端 (SMB 2/3) +pub struct SmbVfs { + runtime: Arc, + client: Arc>, + tree: Mutex, +} + +impl SmbVfs { + pub fn new( + addr: &str, + share: &str, + username: &str, + password: &str, + ) -> Result { + let runtime = Arc::new( + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .map_err(|e| VfsError::Io(format!("Failed to create tokio runtime: {}", e)))?, + ); + + let config = ClientConfig { + addr: addr.to_string(), + timeout: Duration::from_secs(SMB_TIMEOUT_SECS), + username: username.to_string(), + password: password.to_string(), + domain: String::new(), + auto_reconnect: false, + compression: true, + dfs_enabled: false, + dfs_target_overrides: std::collections::HashMap::new(), + }; + + let (client, tree) = runtime.block_on(async { + let mut c = smb2::SmbClient::connect(config) + .await + .map_err(|e| VfsError::Io(format!("SMB connect failed: {}", e)))?; + let t = c + .connect_share(share) + .await + .map_err(|e| VfsError::Io(format!("SMB connect_share failed: {}", e)))?; + Ok::<_, VfsError>((c, t)) + })?; + + Ok(Self { + runtime, + client: Arc::new(Mutex::new(client)), + tree: Mutex::new(tree), + }) + } + + fn path_to_str(path: &Path) -> String { + let s = path.to_string_lossy().to_string(); + s.trim_start_matches('/').to_string() + } +} + +impl Clone for SmbVfs { + fn clone(&self) -> Self { + Self { + runtime: self.runtime.clone(), + client: self.client.clone(), + tree: Mutex::new(self.tree.lock().unwrap().clone()), + } + } +} + +impl VfsBackend for SmbVfs { + fn clone_boxed(&self) -> Box { + Box::new(self.clone()) + } + + fn read_dir(&self, path: &Path) -> Result, VfsError> { + let smb_path = Self::path_to_str(path); + let mut client = self.client.lock().map_err(|e| VfsError::Io(e.to_string()))?; + let mut tree = self.tree.lock().map_err(|e| VfsError::Io(e.to_string()))?; + let entries = self + .runtime + .block_on(client.list_directory(&mut *tree, &smb_path)) + .map_err(map_smb_error)?; + + Ok(entries + .into_iter() + .filter(|e| e.name != "." && e.name != "..") + .map(|e| VfsDirEntry { + name: e.name, + long_name: String::new(), + stat: VfsStat { + size: e.size, + mode: if e.is_directory { 0o755 } else { 0o644 }, + uid: 0, + gid: 0, + atime: filetime_to_systemtime(0), + mtime: filetime_to_systemtime(e.modified.0), + is_dir: e.is_directory, + is_symlink: false, + }, + }) + .collect()) + } + + fn open_file( + &self, + path: &Path, + flags: &OpenFlags, + ) -> Result, VfsError> { + let smb_path = Self::path_to_str(path); + let mut client = self.client.lock().map_err(|e| VfsError::Io(e.to_string()))?; + let mut tree = self.tree.lock().map_err(|e| VfsError::Io(e.to_string()))?; + + if flags.write || flags.create || flags.truncate { + Ok(Box::new(SmbVfsFile { + runtime: self.runtime.clone(), + client: self.client.clone(), + tree: tree.clone(), + path: smb_path, + mode: FileMode::Write, + position: 0, + write_buf: Vec::new(), + data: Vec::new(), + size: 0, + })) + } else { + let data = self + .runtime + .block_on(client.read_file(&mut *tree, &smb_path)) + .map_err(map_smb_error)?; + let size = data.len() as u64; + Ok(Box::new(SmbVfsFile { + runtime: self.runtime.clone(), + client: self.client.clone(), + tree: tree.clone(), + path: smb_path, + mode: FileMode::Read, + position: 0, + write_buf: Vec::new(), + data, + size, + })) + } + } + + fn stat(&self, path: &Path) -> Result { + let smb_path = Self::path_to_str(path); + let mut client = self.client.lock().map_err(|e| VfsError::Io(e.to_string()))?; + let mut tree = self.tree.lock().map_err(|e| VfsError::Io(e.to_string()))?; + let info = self + .runtime + .block_on(client.stat(&mut *tree, &smb_path)) + .map_err(map_smb_error)?; + + Ok(VfsStat { + size: info.size, + mode: if info.is_directory { 0o755 } else { 0o644 }, + uid: 0, + gid: 0, + atime: filetime_to_systemtime(info.accessed.0), + mtime: filetime_to_systemtime(info.modified.0), + is_dir: info.is_directory, + is_symlink: false, + }) + } + + fn lstat(&self, path: &Path) -> Result { + self.stat(path) + } + + fn create_dir(&self, path: &Path, _mode: u32) -> Result<(), VfsError> { + let smb_path = Self::path_to_str(path); + let mut client = self.client.lock().map_err(|e| VfsError::Io(e.to_string()))?; + let mut tree = self.tree.lock().map_err(|e| VfsError::Io(e.to_string()))?; + self.runtime + .block_on(client.create_directory(&mut *tree, &smb_path)) + .map_err(map_smb_error) + } + + fn create_dir_all(&self, path: &Path, mode: u32) -> Result<(), VfsError> { + let mut current = path.to_path_buf(); + let mut stack = Vec::new(); + while let Some(parent) = current.parent() { + if parent.as_os_str().is_empty() || parent == Path::new("/") { + break; + } + stack.push(parent.to_path_buf()); + current = parent.to_path_buf(); + } + for dir in stack.into_iter().rev() { + if self.stat(&dir).is_err() { + self.create_dir(&dir, mode)?; + } + } + if self.stat(path).is_err() { + self.create_dir(path, mode)?; + } + Ok(()) + } + + fn remove_dir(&self, path: &Path) -> Result<(), VfsError> { + let smb_path = Self::path_to_str(path); + let mut client = self.client.lock().map_err(|e| VfsError::Io(e.to_string()))?; + let mut tree = self.tree.lock().map_err(|e| VfsError::Io(e.to_string()))?; + self.runtime + .block_on(client.delete_directory(&mut *tree, &smb_path)) + .map_err(map_smb_error) + } + + fn remove_file(&self, path: &Path) -> Result<(), VfsError> { + let smb_path = Self::path_to_str(path); + let mut client = self.client.lock().map_err(|e| VfsError::Io(e.to_string()))?; + let mut tree = self.tree.lock().map_err(|e| VfsError::Io(e.to_string()))?; + self.runtime + .block_on(client.delete_file(&mut *tree, &smb_path)) + .map_err(map_smb_error) + } + + fn rename(&self, from: &Path, to: &Path) -> Result<(), VfsError> { + let smb_from = Self::path_to_str(from); + let smb_to = Self::path_to_str(to); + let mut client = self.client.lock().map_err(|e| VfsError::Io(e.to_string()))?; + let mut tree = self.tree.lock().map_err(|e| VfsError::Io(e.to_string()))?; + self.runtime + .block_on(client.rename(&mut *tree, &smb_from, &smb_to)) + .map_err(map_smb_error) + } + + fn set_stat(&self, _path: &Path, _stat: &VfsStat) -> Result<(), VfsError> { + Err(VfsError::Unsupported("SMB set_stat".to_string())) + } + + fn read_link(&self, _path: &Path) -> Result { + Err(VfsError::Unsupported("SMB read_link".to_string())) + } + + fn create_symlink(&self, _target: &Path, _link: &Path) -> Result<(), VfsError> { + Err(VfsError::Unsupported("SMB create_symlink".to_string())) + } + + fn real_path(&self, path: &Path) -> Result { + let smb_path = Self::path_to_str(path); + let mut client = self.client.lock().map_err(|e| VfsError::Io(e.to_string()))?; + let mut tree = self.tree.lock().map_err(|e| VfsError::Io(e.to_string()))?; + let _info = self + .runtime + .block_on(client.stat(&mut *tree, &smb_path)) + .map_err(map_smb_error)?; + Ok(path.to_path_buf()) + } + + fn exists(&self, path: &Path) -> bool { + let smb_path = Self::path_to_str(path); + let mut client = match self.client.lock() { + Ok(c) => c, + Err(_) => return false, + }; + let mut tree = match self.tree.lock() { + Ok(t) => t, + Err(_) => return false, + }; + self.runtime + .block_on(client.stat(&mut *tree, &smb_path)) + .is_ok() + } + + fn hard_link(&self, _original: &Path, _link: &Path) -> Result<(), VfsError> { + Err(VfsError::Unsupported("SMB hard_link".to_string())) + } +} + +enum FileMode { + Read, + Write, +} + +struct SmbVfsFile { + runtime: Arc, + client: Arc>, + tree: smb2::Tree, + path: String, + mode: FileMode, + position: u64, + write_buf: Vec, + data: Vec, + size: u64, +} + +impl SmbVfsFile { + fn ensure_data_loaded(&mut self) -> Result<(), VfsError> { + if self.data.is_empty() && self.size > 0 { + let mut client = self.client.lock().map_err(|e| VfsError::Io(e.to_string()))?; + let data = self + .runtime + .block_on(client.read_file(&mut self.tree, &self.path)) + .map_err(map_smb_error)?; + self.size = data.len() as u64; + self.data = data; + } + Ok(()) + } +} + +impl VfsFile for SmbVfsFile { + fn read(&mut self, buf: &mut [u8]) -> Result { + self.ensure_data_loaded()?; + if self.position >= self.size { + return Ok(0); + } + let start = self.position as usize; + let available = self.size as usize - start; + let to_copy = std::cmp::min(buf.len(), available); + buf[..to_copy].copy_from_slice(&self.data[start..start + to_copy]); + self.position += to_copy as u64; + Ok(to_copy) + } + + fn write(&mut self, buf: &[u8]) -> Result { + self.write_buf.extend_from_slice(buf); + self.position += buf.len() as u64; + Ok(buf.len()) + } + + fn seek(&mut self, pos: std::io::SeekFrom) -> Result { + match pos { + std::io::SeekFrom::Start(offset) => { + self.position = offset; + Ok(offset) + } + std::io::SeekFrom::End(offset) => { + let new_pos = if offset >= 0 { + self.size + offset as u64 + } else { + self.size.saturating_sub((-offset) as u64) + }; + self.position = new_pos; + Ok(new_pos) + } + std::io::SeekFrom::Current(offset) => { + let new_pos = if offset >= 0 { + self.position + offset as u64 + } else { + self.position.saturating_sub((-offset) as u64) + }; + self.position = new_pos; + Ok(new_pos) + } + } + } + + fn flush(&mut self) -> Result<(), VfsError> { + if let FileMode::Write = self.mode { + if !self.write_buf.is_empty() { + let data = std::mem::take(&mut self.write_buf); + let mut client = self.client.lock().map_err(|e| VfsError::Io(e.to_string()))?; + self.runtime + .block_on(client.write_file(&mut self.tree, &self.path, &data)) + .map_err(map_smb_error)?; + self.size = data.len() as u64; + } + } + Ok(()) + } + + fn stat(&mut self) -> Result { + let mut client = self.client.lock().map_err(|e| VfsError::Io(e.to_string()))?; + let info = self + .runtime + .block_on(client.stat(&mut self.tree, &self.path)) + .map_err(map_smb_error)?; + Ok(VfsStat { + size: info.size, + mode: if info.is_directory { 0o755 } else { 0o644 }, + uid: 0, + gid: 0, + atime: filetime_to_systemtime(info.accessed.0), + mtime: filetime_to_systemtime(info.modified.0), + is_dir: info.is_directory, + is_symlink: false, + }) + } + + fn set_len(&mut self, _size: u64) -> Result<(), VfsError> { + Err(VfsError::Unsupported("SMB set_len".to_string())) + } +} + +impl Drop for SmbVfsFile { + fn drop(&mut self) { + if let FileMode::Write = self.mode { + if !self.write_buf.is_empty() { + let data = std::mem::take(&mut self.write_buf); + if let Ok(mut client) = self.client.lock() { + let _ = self + .runtime + .block_on(client.write_file(&mut self.tree, &self.path, &data)); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_filetime_conversion() { + let raw: u64 = 133604700000000000; + let st = filetime_to_systemtime(raw); + assert!(st > UNIX_EPOCH); + } + + #[test] + fn test_path_to_str() { + assert_eq!(SmbVfs::path_to_str(Path::new("foo/bar.txt")), "foo/bar.txt"); + assert_eq!(SmbVfs::path_to_str(Path::new("/foo/bar.txt")), "foo/bar.txt"); + assert_eq!(SmbVfs::path_to_str(Path::new("")), ""); + } + + #[test] + fn test_error_mapping_invalid_data() { + let err = smb2::Error::invalid_data("test"); + let mapped = map_smb_error(err); + match mapped { + VfsError::Io(_) => {} + _ => panic!("Expected Io, got {:?}", mapped), + } + } + + /// Integration test: requires Docker Samba container on port 10445. + /// Run with: docker compose -f vendor/smb2/tests/docker/internal/docker-compose.yml up -d smb-guest + #[test] + #[ignore] + fn test_smb_vfs_list_root() { + let vfs = SmbVfs::new("127.0.0.1:10445", "public", "", "").unwrap(); + let entries = vfs.read_dir(Path::new("/")).unwrap(); + assert!(!entries.is_empty(), "Expected at least . and .."); + } + + #[test] + #[ignore] + fn test_smb_vfs_write_read_file() { + let vfs = SmbVfs::new("127.0.0.1:10445", "public", "", "").unwrap(); + + let content = b"Hello SMB VFS!"; + let path = Path::new("/smb_vfs_test.txt"); + + // Write + { + let flags = OpenFlags::new().write().create().truncate(); + let mut file = vfs.open_file(path, &flags).unwrap(); + file.write(content).unwrap(); + file.flush().unwrap(); + } + + // Read back + let flags = OpenFlags::new().read(); + let mut file = vfs.open_file(path, &flags).unwrap(); + let mut buf = vec![0u8; 1024]; + let n = file.read(&mut buf).unwrap(); + assert_eq!(&buf[..n], content); + + // Stat + let stat = vfs.stat(path).unwrap(); + assert_eq!(stat.size, content.len() as u64); + + // Cleanup + vfs.remove_file(path).unwrap(); + assert!(!vfs.exists(path)); + } + + #[test] + #[ignore] + fn test_smb_vfs_create_remove_dir() { + let vfs = SmbVfs::new("127.0.0.1:10445", "public", "", "").unwrap(); + let dir_path = Path::new("/smb_vfs_test_dir"); + + vfs.create_dir(dir_path, 0o755).unwrap(); + assert!(vfs.exists(dir_path)); + + vfs.remove_dir(dir_path).unwrap(); + assert!(!vfs.exists(dir_path)); + } + + #[test] + #[ignore] + fn test_smb_vfs_rename_file() { + let vfs = SmbVfs::new("127.0.0.1:10445", "public", "", "").unwrap(); + let src = Path::new("/rename_src.txt"); + let dst = Path::new("/rename_dst.txt"); + + // Create source file + let flags = OpenFlags::new().write().create().truncate(); + { + let mut file = vfs.open_file(src, &flags).unwrap(); + file.write(b"rename test").unwrap(); + file.flush().unwrap(); + } + + // Rename + vfs.rename(src, dst).unwrap(); + assert!(!vfs.exists(src)); + assert!(vfs.exists(dst)); + + // Cleanup + vfs.remove_file(dst).unwrap(); + } +} diff --git a/markbase-core/src/vfs/smb_server_backend.rs b/markbase-core/src/vfs/smb_server_backend.rs new file mode 100644 index 0000000..011aed6 --- /dev/null +++ b/markbase-core/src/vfs/smb_server_backend.rs @@ -0,0 +1,437 @@ +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::sync::Mutex; +use std::time::SystemTime; + +use async_trait::async_trait; +use bytes::Bytes; +use smb_server::{ + BackendCapabilities, DirEntry, FileInfo, FileTimes, Handle, OpenIntent, OpenOptions, ShareBackend, + SmbError, SmbPath, +}; + +use super::open_flags::OpenFlags; +use super::{VfsBackend, VfsError, VfsStat}; + +const FILETIME_OFFSET: u64 = 116_444_736_000_000_000; + +pub struct VfsShareBackend { + vfs: Arc, + root: PathBuf, + read_only: bool, +} + +impl VfsShareBackend { + pub fn new(vfs: Box, root: PathBuf) -> Self { + Self { + vfs: Arc::from(vfs), + root, + read_only: false, + } + } + + pub fn read_only(mut self, yes: bool) -> Self { + self.read_only = yes; + self + } +} + +fn resolve_path(root: &Path, smb_path: &SmbPath) -> PathBuf { + if smb_path.is_root() { + return root.to_path_buf(); + } + let mut result = root.to_path_buf(); + for component in smb_path.components() { + result.push(component); + } + result +} + +fn map_error(e: VfsError) -> SmbError { + match e { + VfsError::NotFound(_) => SmbError::NotFound, + VfsError::PermissionDenied(_) => SmbError::AccessDenied, + VfsError::AlreadyExists(_) => SmbError::Exists, + VfsError::NotEmpty(_) => SmbError::NotEmpty, + VfsError::NotADirectory(_) => SmbError::NotADirectory, + VfsError::IsADirectory(_) => SmbError::IsDirectory, + VfsError::Unsupported(_) => SmbError::NotSupported, + VfsError::Io(msg) => SmbError::Io(std::io::Error::other(msg)), + VfsError::UnexpectedEof => SmbError::Io(std::io::Error::other("unexpected eof")), + } +} + +fn system_time_to_filetime(t: SystemTime) -> u64 { + match t.duration_since(SystemTime::UNIX_EPOCH) { + Ok(d) => { + FILETIME_OFFSET + + (d.as_secs() * 10_000_000) + + (d.subsec_nanos() as u64 / 100) + } + Err(_) => 0, + } +} + +fn vfs_stat_to_file_info(stat: &VfsStat, name: &str, path: &Path) -> FileInfo { + let name = if name.is_empty() { + path.file_name() + .and_then(|s| s.to_str()) + .unwrap_or("") + .to_string() + } else { + name.to_string() + }; + FileInfo { + name, + end_of_file: stat.size, + allocation_size: stat.size, + creation_time: system_time_to_filetime(stat.mtime), + last_access_time: system_time_to_filetime(stat.atime), + last_write_time: system_time_to_filetime(stat.mtime), + change_time: system_time_to_filetime(stat.mtime), + is_directory: stat.is_dir, + file_index: 0, + } +} + +fn vfs_error_to_io(e: VfsError) -> std::io::Error { + std::io::Error::other(e.to_string()) +} + +#[async_trait] +impl ShareBackend for VfsShareBackend { + async fn open(&self, path: &SmbPath, opts: OpenOptions) -> Result, SmbError> { + let full_path = resolve_path(&self.root, path); + + if opts.directory { + match opts.intent { + OpenIntent::Create => { + if self.vfs.exists(&full_path) { + return Err(SmbError::Exists); + } + self.vfs.create_dir(&full_path, 0o755).map_err(map_error)?; + } + OpenIntent::OpenOrCreate | OpenIntent::OverwriteOrCreate => { + if !self.vfs.exists(&full_path) { + self.vfs.create_dir(&full_path, 0o755).map_err(map_error)?; + } + } + _ => { + if !self.vfs.exists(&full_path) { + return Err(SmbError::NotFound); + } + } + } + let stat = self.vfs.stat(&full_path).map_err(map_error)?; + if !stat.is_dir { + return Err(SmbError::NotADirectory); + } + return Ok(Box::new(VfsHandle::Directory { + vfs: self.vfs.clone(), + path: full_path, + })); + } + + let mut flags = OpenFlags::new(); + if opts.read { + flags = flags.read(); + } + if opts.write { + flags = flags.write(); + } + match opts.intent { + OpenIntent::Open => {} + OpenIntent::Create => { + flags = flags.create().exclusive(); + } + OpenIntent::OpenOrCreate => { + flags = flags.create(); + } + OpenIntent::OverwriteOrCreate => { + flags = flags.create().truncate(); + } + OpenIntent::Truncate => { + flags = flags.truncate(); + } + } + + if opts.non_directory && self.vfs.exists(&full_path) { + let stat = self.vfs.stat(&full_path).map_err(map_error)?; + if stat.is_dir { + return Err(SmbError::IsDirectory); + } + } + + let file = self + .vfs + .open_file(&full_path, &flags) + .map_err(map_error)?; + Ok(Box::new(VfsHandle::File { + file: Mutex::new(file), + path: full_path, + vfs: self.vfs.clone(), + })) + } + + async fn unlink(&self, path: &SmbPath) -> Result<(), SmbError> { + let full_path = resolve_path(&self.root, path); + if self.vfs.exists(&full_path) { + let stat = self.vfs.stat(&full_path).map_err(map_error)?; + if stat.is_dir { + return self.vfs.remove_dir(&full_path).map_err(map_error); + } + } + self.vfs.remove_file(&full_path).map_err(map_error) + } + + async fn rename(&self, from: &SmbPath, to: &SmbPath) -> Result<(), SmbError> { + let from_path = resolve_path(&self.root, from); + let to_path = resolve_path(&self.root, to); + if self.vfs.exists(&to_path) { + return Err(SmbError::Exists); + } + self.vfs.rename(&from_path, &to_path).map_err(map_error) + } + + fn capabilities(&self) -> BackendCapabilities { + BackendCapabilities { + is_read_only: self.read_only, + case_sensitive: true, + } + } +} + +enum VfsHandle { + File { + file: Mutex>, + path: PathBuf, + vfs: Arc, + }, + Directory { + vfs: Arc, + path: PathBuf, + }, +} + +#[async_trait] +impl Handle for VfsHandle { + async fn read(&self, offset: u64, len: u32) -> Result { + match self { + Self::File { file, .. } => { + let mut file = file.lock().unwrap(); + file.seek(std::io::SeekFrom::Start(offset)) + .map_err(vfs_error_to_io)?; + let mut buf = vec![0u8; len as usize]; + let n = file.read(&mut buf).map_err(map_error)?; + buf.truncate(n); + Ok(Bytes::from(buf)) + } + Self::Directory { .. } => Err(SmbError::NotSupported), + } + } + + async fn write(&self, offset: u64, data: &[u8]) -> Result { + match self { + Self::File { file, .. } => { + let mut file = file.lock().unwrap(); + file.seek(std::io::SeekFrom::Start(offset)) + .map_err(vfs_error_to_io)?; + let n = file.write(data).map_err(map_error)?; + Ok(n as u32) + } + Self::Directory { .. } => Err(SmbError::NotSupported), + } + } + + async fn flush(&self) -> Result<(), SmbError> { + match self { + Self::File { file, .. } => { + let mut file = file.lock().unwrap(); + file.flush().map_err(map_error) + } + Self::Directory { .. } => Ok(()), + } + } + + async fn stat(&self) -> Result { + match self { + Self::File { file, path, .. } => { + let mut f = file.lock().unwrap(); + let vfs_stat = f.stat().map_err(map_error)?; + Ok(vfs_stat_to_file_info(&vfs_stat, "", path)) + } + Self::Directory { vfs, path } => { + let vfs_stat = vfs.stat(path).map_err(map_error)?; + Ok(vfs_stat_to_file_info(&vfs_stat, "", path)) + } + } + } + + async fn set_times(&self, times: FileTimes) -> Result<(), SmbError> { + let (vfs, path) = match self { + Self::File { path, vfs, .. } => (vfs, path), + Self::Directory { vfs, path } => (vfs, path), + }; + let mut stat = VfsStat::new(); + if let Some(t) = times.last_write_time { + stat.mtime = filetime_to_systemtime(t); + } + if let Some(t) = times.last_access_time { + stat.atime = filetime_to_systemtime(t); + } + vfs.set_stat(path, &stat).map_err(map_error) + } + + async fn truncate(&self, len: u64) -> Result<(), SmbError> { + match self { + Self::File { file, .. } => { + let mut file = file.lock().unwrap(); + file.set_len(len).map_err(map_error) + } + Self::Directory { .. } => Err(SmbError::NotSupported), + } + } + + async fn list_dir(&self, _pattern: Option<&str>) -> Result, SmbError> { + match self { + Self::File { .. } => Err(SmbError::NotADirectory), + Self::Directory { vfs, path } => { + let entries = vfs.read_dir(path).map_err(map_error)?; + let result = entries + .into_iter() + .map(|entry| { + let info = vfs_stat_to_file_info(&entry.stat, &entry.name, path); + DirEntry { info } + }) + .collect(); + Ok(result) + } + } + } + + async fn close(self: Box) -> Result<(), SmbError> { + Ok(()) + } +} + +fn filetime_to_systemtime(ft: u64) -> SystemTime { + if ft < FILETIME_OFFSET { + return SystemTime::UNIX_EPOCH; + } + let delta_secs = (ft - FILETIME_OFFSET) / 10_000_000; + let delta_ns = ((ft - FILETIME_OFFSET) % 10_000_000) as u32 * 100; + SystemTime::UNIX_EPOCH + + std::time::Duration::new(delta_secs, delta_ns) +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use smb_server::{Share, SmbServer, Access}; + + use crate::vfs::local_fs::LocalFs; + + use super::*; + + #[test] + fn test_resolve_path_root() { + let root = PathBuf::from("/srv/share"); + let smb = SmbPath::root(); + assert_eq!(resolve_path(&root, &smb), root); + } + + #[test] + fn test_resolve_path_components() { + let root = PathBuf::from("/srv/share"); + let smb: SmbPath = "dir\\sub\\file.txt".parse().unwrap(); + let expected = PathBuf::from("/srv/share/dir/sub/file.txt"); + assert_eq!(resolve_path(&root, &smb), expected); + } + + #[test] + fn test_system_time_to_filetime() { + let epoch = SystemTime::UNIX_EPOCH; + let ft = system_time_to_filetime(epoch); + assert_eq!(ft, FILETIME_OFFSET); + } + + #[test] + fn test_filetime_roundtrip() { + let now = SystemTime::now(); + let ft = system_time_to_filetime(now); + let back = filetime_to_systemtime(ft); + let diff = if now > back { + now.duration_since(back).unwrap() + } else { + back.duration_since(now).unwrap() + }; + assert!(diff.as_millis() < 100); + } + + #[test] + fn test_map_errors() { + assert!(matches!( + map_error(VfsError::NotFound("x".into())), + SmbError::NotFound + )); + assert!(matches!( + map_error(VfsError::AlreadyExists("x".into())), + SmbError::Exists + )); + assert!(matches!( + map_error(VfsError::PermissionDenied("x".into())), + SmbError::AccessDenied + )); + assert!(matches!( + map_error(VfsError::NotEmpty("x".into())), + SmbError::NotEmpty + )); + assert!(matches!( + map_error(VfsError::NotADirectory("x".into())), + SmbError::NotADirectory + )); + assert!(matches!( + map_error(VfsError::IsADirectory("x".into())), + SmbError::IsDirectory + )); + } + + #[test] + fn test_vfs_share_backend_creation() { + let vfs = Box::new(LocalFs::new()); + let root = PathBuf::from("/tmp"); + let backend = VfsShareBackend::new(vfs, root); + assert!(!backend.capabilities().is_read_only); + } + + #[tokio::test] + async fn test_open_nonexistent_file() { + let vfs = Box::new(LocalFs::new()); + let root = PathBuf::from("/nonexistent"); + let backend = VfsShareBackend::new(vfs, root); + let smb_path: SmbPath = "missing.txt".parse().unwrap(); + let opts = OpenOptions { + read: true, + write: false, + intent: OpenIntent::Open, + directory: false, + non_directory: false, + delete_on_close: false, + }; + let result = backend.open(&smb_path, opts).await; + assert!(matches!(result, Err(SmbError::NotFound))); + } + + #[test] + fn test_rejects_dotdot() { + assert!("a\\..\\b".parse::().is_err()); + } + + #[test] + fn test_rejects_forbidden_chars() { + for bad in ["ab", "a:b", "a\"b", "a|b", "a?b", "a*b"] { + assert!(bad.parse::().is_err()); + } + } +} diff --git a/markbase-core/src/webdav.rs b/markbase-core/src/webdav.rs index 66dfbc0..a4a1523 100644 --- a/markbase-core/src/webdav.rs +++ b/markbase-core/src/webdav.rs @@ -1,5 +1,5 @@ use crate::vfs::open_flags::OpenFlags; -use crate::vfs::{VfsBackend, VfsDirEntry, VfsStat, VfsFile}; +use crate::vfs::{VfsBackend, VfsDirEntry, VfsStat}; use crate::ssh_server::upload_hook::UploadHook; use bytes::{Buf, Bytes}; use dav_server::davpath::DavPath; diff --git a/vendor/smb-server/Cargo.toml b/vendor/smb-server/Cargo.toml new file mode 100644 index 0000000..a9f9965 --- /dev/null +++ b/vendor/smb-server/Cargo.toml @@ -0,0 +1,57 @@ +[package] +name = "smb-server" +version = "0.4.1" +edition = "2024" +rust-version = "1.95" +license = "MIT" +repository = "https://github.com/paltaio/rust-smb-server" +description = "SMB2/3 file-sharing server library with pluggable storage backends." + +[dependencies] +tokio = { version = "1.40", features = ["full"] } +bytes = "1.7" +async-trait = "0.1" +tracing = "0.1" +thiserror = "1" +uuid = { version = "1.10", features = ["v4"] } +binrw = "0.15" +getrandom = "0.4" +cap-std = { version = "3", optional = true } +hmac = "0.12" +sha2 = "0.10" +md-5 = "0.10" +md4 = "0.10" +aes = "0.8" +cmac = "0.7" +rc4 = "0.2" + +[features] +default = ["localfs"] +localfs = ["dep:cap-std"] + +[dev-dependencies] +tracing-subscriber = "0.3" +tempfile = "3" +hex = "0.4" + +[[test]] +name = "integration_localfs" +path = "tests/integration_localfs.rs" +required-features = ["localfs"] + +[[test]] +name = "integration_localfs_write" +path = "tests/integration_localfs_write.rs" +required-features = ["localfs"] + +[[test]] +name = "integration_negotiate" +path = "tests/integration_negotiate.rs" +required-features = ["localfs"] + +[profile.release] +opt-level = 3 +lto = true +codegen-units = 1 +panic = "abort" +strip = true diff --git a/vendor/smb-server/src/backend.rs b/vendor/smb-server/src/backend.rs new file mode 100644 index 0000000..7c2bdcf --- /dev/null +++ b/vendor/smb-server/src/backend.rs @@ -0,0 +1,238 @@ +//! `ShareBackend` and `Handle` traits — the storage abstraction. +//! +//! Implementors of these traits plug into `Share::new(name, backend)`. The +//! protocol layer never exposes raw FS types to backends; everything goes +//! through validated `SmbPath`s and the small structs below. + +use async_trait::async_trait; +use std::time::SystemTime; + +use crate::error::{SmbError, SmbResult}; +use crate::path::SmbPath; + +// --------------------------------------------------------------------------- +// OpenOptions +// --------------------------------------------------------------------------- + +/// Translated SMB CREATE intent — the small set of cases v1 cares about. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum OpenIntent { + /// `FILE_OPEN` — open existing only; fail if missing. + Open, + /// `FILE_CREATE` — create new only; fail if exists. + Create, + /// `FILE_OPEN_IF` — open existing or create new. + OpenOrCreate, + /// `FILE_OVERWRITE_IF` — open existing (truncating) or create new. + OverwriteOrCreate, + /// `FILE_OVERWRITE` — open existing and truncate; fail if missing. + Truncate, +} + +/// Options passed to `ShareBackend::open`. v1 keeps this tight on purpose; +/// extra knobs become methods later if a backend genuinely needs them. +#[derive(Debug, Clone, Copy)] +pub struct OpenOptions { + /// Read access requested. + pub read: bool, + /// Write access requested. + pub write: bool, + /// CREATE disposition. + pub intent: OpenIntent, + /// `FILE_DIRECTORY_FILE` was set on CREATE — open or create a directory. + pub directory: bool, + /// `FILE_NON_DIRECTORY_FILE` was set on CREATE — fail if the target is a directory. + pub non_directory: bool, + /// `FILE_DELETE_ON_CLOSE` was set on CREATE. + pub delete_on_close: bool, +} + +impl Default for OpenOptions { + fn default() -> Self { + Self { + read: true, + write: false, + intent: OpenIntent::Open, + directory: false, + non_directory: false, + delete_on_close: false, + } + } +} + +// --------------------------------------------------------------------------- +// FileInfo / DirEntry / FileTimes +// --------------------------------------------------------------------------- + +/// Filesystem-style metadata for a single file or directory. +#[derive(Debug, Clone)] +pub struct FileInfo { + /// Display name (last component). For QUERY_INFO at the share root this + /// is the share name. + pub name: String, + /// File size in bytes. + pub end_of_file: u64, + /// Allocation size — typically `end_of_file` rounded up to a cluster size. + /// v1 backends may safely return the same value as `end_of_file`. + pub allocation_size: u64, + /// FILETIME (100ns ticks since 1601). + pub creation_time: u64, + pub last_access_time: u64, + pub last_write_time: u64, + pub change_time: u64, + /// True if this is a directory. + pub is_directory: bool, + /// Optional 64-bit unique file id (for `FileInternalInformation`). v1 may + /// return `0` if unavailable; the dispatcher will substitute the FileId. + pub file_index: u64, +} + +impl FileInfo { + /// SMB2 file attributes (MS-FSCC §2.6) for this file. v1 returns + /// `FILE_ATTRIBUTE_DIRECTORY` for dirs, `FILE_ATTRIBUTE_NORMAL` (0x80) for + /// regular files. (`FILE_ATTRIBUTE_NORMAL` MUST be the only attribute set + /// when used.) + pub fn attributes(&self) -> u32 { + const FILE_ATTRIBUTE_DIRECTORY: u32 = 0x0000_0010; + const FILE_ATTRIBUTE_NORMAL: u32 = 0x0000_0080; + if self.is_directory { + FILE_ATTRIBUTE_DIRECTORY + } else { + FILE_ATTRIBUTE_NORMAL + } + } +} + +/// One entry of a directory listing. +#[derive(Debug, Clone)] +pub struct DirEntry { + pub info: FileInfo, +} + +/// Optional FILETIME values for `set_times`. `None` means "leave unchanged". +#[derive(Debug, Clone, Copy, Default)] +pub struct FileTimes { + pub creation_time: Option, + pub last_access_time: Option, + pub last_write_time: Option, + pub change_time: Option, +} + +impl FileTimes { + /// Convenience: convert `SystemTime` into a `FileTimes` setting all four + /// fields to the same instant. + pub fn all(t: SystemTime) -> Self { + let ft = crate::utils::system_time_to_filetime(t); + Self { + creation_time: Some(ft), + last_access_time: Some(ft), + last_write_time: Some(ft), + change_time: Some(ft), + } + } +} + +// --------------------------------------------------------------------------- +// BackendCapabilities +// --------------------------------------------------------------------------- + +/// Static, advertised capabilities of a backend. +/// +/// Kept small intentionally — extending requires discussing with the maintainer. +#[derive(Debug, Clone, Copy, Default)] +pub struct BackendCapabilities { + /// If true, all write-class operations are denied at the protocol layer + /// before reaching the backend (matches `LocalFsBackend::read_only()`). + pub is_read_only: bool, + /// True iff the backend treats names case-sensitively. + pub case_sensitive: bool, +} + +// --------------------------------------------------------------------------- +// Traits +// --------------------------------------------------------------------------- + +/// Pluggable storage backend mounted as a share. +/// +/// Implementors must be `Send + Sync + 'static` so the server can spawn +/// per-request handlers freely. +#[async_trait] +pub trait ShareBackend: Send + Sync + 'static { + /// Open or create a file or directory. Returns a fresh handle. + async fn open(&self, path: &SmbPath, opts: OpenOptions) -> SmbResult>; + + /// Unlink (delete) a file. Directories: must be empty. v1 does not + /// recursively delete. + async fn unlink(&self, path: &SmbPath) -> SmbResult<()>; + + /// Rename `from` to `to`. The backend must reject if `to` already exists. + async fn rename(&self, from: &SmbPath, to: &SmbPath) -> SmbResult<()>; + + /// Static capabilities. The dispatcher consults these at TREE_CONNECT and + /// uses `is_read_only` to clamp authz. + fn capabilities(&self) -> BackendCapabilities; +} + +/// A live open file or directory handle. +/// +/// One handle per `CREATE`. The handle is dropped when CLOSE arrives or the +/// session goes away. +#[async_trait] +pub trait Handle: Send + Sync { + /// Read up to `len` bytes at `offset`. May return fewer. + async fn read(&self, offset: u64, len: u32) -> SmbResult; + + /// Write `data` at `offset`. Returns bytes written. + async fn write(&self, offset: u64, data: &[u8]) -> SmbResult; + + /// Write owned `data` at `offset`. Backends that need ownership across a + /// blocking boundary can override this to avoid an extra copy. + async fn write_owned(&self, offset: u64, data: Vec) -> SmbResult { + self.write(offset, &data).await + } + + /// Flush buffered writes. May be a no-op on backends that always flush. + async fn flush(&self) -> SmbResult<()>; + + /// Stat: current file info. + async fn stat(&self) -> SmbResult; + + /// Set timestamps. `None` fields leave the corresponding field alone. + async fn set_times(&self, times: FileTimes) -> SmbResult<()>; + + /// Truncate (or extend) to `len` bytes. For directories: the protocol + /// layer rejects this before reaching the backend. + async fn truncate(&self, len: u64) -> SmbResult<()>; + + /// List directory entries matching the optional pattern. v1 ignores + /// `pattern` if the backend doesn't implement matching — the dispatcher + /// post-filters as needed for QUERY_DIRECTORY. + async fn list_dir(&self, pattern: Option<&str>) -> SmbResult>; + + /// Close the handle. Boxed self lets implementors consume internal state. + async fn close(self: Box) -> SmbResult<()>; +} + +/// No-op backend used for the synthetic IPC$ share. Every method returns +/// [`SmbError::NotSupported`]. Exists so we can hand a `ShareBackend` +/// implementor to the IPC$ tree without any real storage attached. +pub(crate) struct NotSupportedBackend; + +#[async_trait] +impl ShareBackend for NotSupportedBackend { + async fn open(&self, _path: &SmbPath, _opts: OpenOptions) -> SmbResult> { + Err(SmbError::NotSupported) + } + async fn unlink(&self, _path: &SmbPath) -> SmbResult<()> { + Err(SmbError::NotSupported) + } + async fn rename(&self, _from: &SmbPath, _to: &SmbPath) -> SmbResult<()> { + Err(SmbError::NotSupported) + } + fn capabilities(&self) -> BackendCapabilities { + BackendCapabilities { + is_read_only: true, + case_sensitive: false, + } + } +} diff --git a/vendor/smb-server/src/builder.rs b/vendor/smb-server/src/builder.rs new file mode 100644 index 0000000..5bc45f4 --- /dev/null +++ b/vendor/smb-server/src/builder.rs @@ -0,0 +1,259 @@ +//! Public builder API for `SmbServer` and `Share`. + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; + +use thiserror::Error; +use uuid::Uuid; + +use crate::backend::ShareBackend; +use crate::server::{ServerConfig, ServerState, ServerUsers, ShareBindings, ShareMode, SmbServer}; + +// --------------------------------------------------------------------------- +// Access +// --------------------------------------------------------------------------- + +/// Access level granted to a user on a share, or to anonymous on a public +/// share. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Access { + Read, + ReadWrite, +} + +impl Access { + pub fn allows_write(self) -> bool { + matches!(self, Access::ReadWrite) + } + + pub fn clamp_to(self, cap: Access) -> Access { + match (self, cap) { + (Access::ReadWrite, Access::ReadWrite) => Access::ReadWrite, + _ => Access::Read, + } + } +} + +// --------------------------------------------------------------------------- +// Share +// --------------------------------------------------------------------------- + +/// One share definition, attached to a single backend. +pub struct Share { + pub(crate) name: String, + pub(crate) backend: Arc, + pub(crate) mode: ShareMode, + pub(crate) users: HashMap, +} + +impl Share { + /// Build a new share with the given name and backend. + pub fn new(name: impl Into, backend: impl ShareBackend) -> Self { + Self { + name: name.into(), + backend: Arc::new(backend), + mode: ShareMode::AuthenticatedOnly, + users: HashMap::new(), + } + } + + /// Anonymous + authenticated read+write. + pub fn public(mut self) -> Self { + self.mode = ShareMode::Public; + self + } + + /// Anonymous + authenticated read-only. + pub fn public_read_only(mut self) -> Self { + self.mode = ShareMode::PublicReadOnly; + self + } + + /// Grant `access` to the given (already-registered) user. Multiple calls + /// accumulate. + pub fn user(mut self, name: impl Into, access: Access) -> Self { + self.users.insert(name.into(), access); + self + } +} + +// --------------------------------------------------------------------------- +// BuildError +// --------------------------------------------------------------------------- + +/// Errors raised by `SmbServerBuilder::build`. +#[derive(Debug, Error)] +pub enum BuildError { + #[error("listen address must be set")] + MissingListenAddr, + #[error("share `{0}` is declared more than once")] + DuplicateShare(String), + #[error("share `{0}` mixes .public()/.public_read_only() with explicit .user(...) entries")] + PublicMixedWithUsers(String), + #[error("share `{0}` calls `.public*()` more than once")] + DoublePublic(String), + #[error("share `{share}` references unknown user `{user}`")] + UnknownUser { share: String, user: String }, + #[error("user `{0}` is registered twice")] + DuplicateUser(String), + #[error("user name `{0}` is reserved (use .public()/.public_read_only() for anonymous)")] + ReservedUserName(String), + #[error("user name must be non-empty")] + EmptyUserName, +} + +// --------------------------------------------------------------------------- +// SmbServerBuilder +// --------------------------------------------------------------------------- + +/// Builder for `SmbServer`. See `SmbServer::builder`. +pub struct SmbServerBuilder { + listen_addr: Option, + users: HashMap, // name -> password + user_order: Vec, + shares: Vec, + netbios_name: Option, + max_read_size: u32, + max_write_size: u32, + server_guid: Option, +} + +impl Default for SmbServerBuilder { + fn default() -> Self { + Self::new() + } +} + +impl SmbServerBuilder { + pub(crate) fn new() -> Self { + Self { + listen_addr: None, + users: HashMap::new(), + user_order: Vec::new(), + shares: Vec::new(), + netbios_name: None, + max_read_size: 1024 * 1024, + max_write_size: 1024 * 1024, + server_guid: None, + } + } + + pub fn listen(mut self, addr: SocketAddr) -> Self { + self.listen_addr = Some(addr); + self + } + + pub fn user(mut self, name: impl Into, password: impl Into) -> Self { + let n = name.into(); + if !self.users.contains_key(&n) { + self.user_order.push(n.clone()); + } + self.users.insert(n, password.into()); + self + } + + pub fn share(mut self, share: Share) -> Self { + self.shares.push(share); + self + } + + pub fn netbios_name(mut self, name: impl Into) -> Self { + self.netbios_name = Some(name.into()); + self + } + + pub fn max_read_size(mut self, bytes: u32) -> Self { + self.max_read_size = bytes; + self + } + + pub fn max_write_size(mut self, bytes: u32) -> Self { + self.max_write_size = bytes; + self + } + + /// Override the random per-process server GUID. Mostly useful in tests. + pub fn server_guid(mut self, guid: Uuid) -> Self { + self.server_guid = Some(guid); + self + } + + pub fn build(self) -> Result { + // 1. Validate users. + for name in &self.user_order { + if name.is_empty() { + return Err(BuildError::EmptyUserName); + } + if name.eq_ignore_ascii_case("anonymous") { + return Err(BuildError::ReservedUserName(name.clone())); + } + } + + // 2. Validate shares. + let mut seen_names = std::collections::HashSet::new(); + for share in &self.shares { + if !seen_names.insert(share.name.to_ascii_lowercase()) { + return Err(BuildError::DuplicateShare(share.name.clone())); + } + // Public-vs-users mutual exclusivity. + let is_public = matches!(share.mode, ShareMode::Public | ShareMode::PublicReadOnly); + if is_public && !share.users.is_empty() { + return Err(BuildError::PublicMixedWithUsers(share.name.clone())); + } + // Each per-share user must exist in the global user table. + for u in share.users.keys() { + if !self.users.contains_key(u) { + return Err(BuildError::UnknownUser { + share: share.name.clone(), + user: u.clone(), + }); + } + } + } + + // 3. Listen address required. + let listen = self.listen_addr.ok_or(BuildError::MissingListenAddr)?; + + // 4. Decide NetBIOS name. + let netbios = self.netbios_name.unwrap_or_else(|| { + // Hostname or "SMBSERVER". + std::env::var("HOSTNAME") + .ok() + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| "SMBSERVER".to_string()) + }); + + // 5. Build ShareBindings — keep mode + users + backend together. + let mut share_bindings: Vec> = Vec::with_capacity(self.shares.len()); + for s in self.shares { + share_bindings.push(ShareBindings::new( + s.name, s.backend, s.mode, s.users, false, + )); + } + + // 6. Materialize the user table (precompute NT hashes to avoid retaining plaintext). + let mut user_table = HashMap::new(); + for name in &self.user_order { + let pw = &self.users[name]; + let creds = crate::proto::auth::ntlm::UserCreds::from_password(pw); + user_table.insert(name.clone(), creds); + } + + let server_guid = self.server_guid.unwrap_or_else(Uuid::new_v4); + + let cfg = ServerConfig { + listen_addr: listen, + netbios_name: netbios, + max_read_size: self.max_read_size, + max_write_size: self.max_write_size, + server_guid, + }; + let users = ServerUsers { + table: tokio::sync::RwLock::new(user_table), + }; + + let state = ServerState::new(cfg, users, share_bindings); + Ok(SmbServer::from_state(state)) + } +} diff --git a/vendor/smb-server/src/conn/mod.rs b/vendor/smb-server/src/conn/mod.rs new file mode 100644 index 0000000..b4550aa --- /dev/null +++ b/vendor/smb-server/src/conn/mod.rs @@ -0,0 +1,39 @@ +//! Per-connection task layout. + +pub mod reader; +pub mod state; +pub mod writer; + +use std::io; +use std::sync::Arc; + +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tracing::{debug, info}; + +use crate::server::ServerState; +use state::Connection; + +/// Runs the reader and writer tasks for a single accepted connection until +/// either side hangs up. Returns once both halves are done. +pub async fn connection_loop(stream: TcpStream, server: Arc) -> io::Result<()> { + let (read_half, write_half) = tokio::io::split(stream); + let conn = Arc::new(Connection::new( + server.config.server_guid, + server.config.max_read_size, + server.config.max_write_size, + )); + let conn_id = server.active_connections.register(&conn).await; + let (tx, rx) = mpsc::channel::(writer::WRITER_CHANNEL); + + let writer_handle = tokio::spawn(writer::writer_task(write_half, rx)); + + info!("connection accepted"); + let reader_result = reader::reader_task(read_half, server.clone(), conn.clone(), tx).await; + debug!(?reader_result, "reader exited"); + // Wait for writer to drain. + let _ = writer_handle.await; + server.active_connections.unregister(conn_id).await; + info!("connection closed"); + reader_result +} diff --git a/vendor/smb-server/src/conn/reader.rs b/vendor/smb-server/src/conn/reader.rs new file mode 100644 index 0000000..253be3f --- /dev/null +++ b/vendor/smb-server/src/conn/reader.rs @@ -0,0 +1,80 @@ +//! Per-connection frame reader: pulls bytes off the socket, frames them, +//! hands each frame to the dispatcher. + +use std::io; +use std::sync::Arc; + +use crate::proto::framing::{FRAME_HEADER_LEN, decode_frame_header}; +use tokio::io::{AsyncReadExt, ReadHalf}; +use tokio::net::TcpStream; +use tracing::{debug, error}; + +use crate::conn::state::Connection; +use crate::server::ServerState; + +/// Read one frame's payload (without the 4-byte length prefix). +/// +/// Returns `Ok(None)` on a clean EOF, `Ok(Some(bytes))` on a complete frame, +/// `Err` on partial/garbled data. +pub async fn read_one_frame(reader: &mut ReadHalf) -> io::Result>> { + let mut hdr = [0u8; FRAME_HEADER_LEN]; + match reader.read_exact(&mut hdr).await { + Ok(_) => {} + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(e), + } + let len = match decode_frame_header(&hdr) { + Ok(n) => n, + Err(e) => { + return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())); + } + }; + let mut payload = vec![0u8; len as usize]; + reader.read_exact(&mut payload).await?; + Ok(Some(payload)) +} + +/// Continuously read frames; for each, await `dispatch_one`'s response and +/// route it to the writer. +/// +/// Sequential dispatch keeps v1 simple and matches the spec's "single writer +/// task / per-frame dispatch" pattern. We process one frame at a time per +/// connection in v1 — a follow-up can spawn dispatch tasks if a workload +/// proves to need credit-window concurrency. +pub async fn reader_task( + mut reader: ReadHalf, + server: Arc, + conn: Arc, + tx: tokio::sync::mpsc::Sender, +) -> io::Result<()> { + loop { + let frame = match read_one_frame(&mut reader).await { + Ok(Some(b)) => b, + Ok(None) => { + debug!("client closed connection"); + return Ok(()); + } + Err(e) => { + error!(error = %e, "frame read error"); + return Err(e); + } + }; + // Check shutdown after every frame. + if server + .shutting_down + .load(std::sync::atomic::Ordering::Acquire) + { + debug!("server shutting down; dropping connection"); + return Ok(()); + } + // The dispatcher is async but we await it inline — order-preserving and + // good enough for v1. + let response = crate::dispatch::dispatch_frame(&server, &conn, &frame).await; + if let Some(bytes) = response + && tx.send(bytes).await.is_err() + { + debug!("writer channel closed; reader exiting"); + return Ok(()); + } + } +} diff --git a/vendor/smb-server/src/conn/state.rs b/vendor/smb-server/src/conn/state.rs new file mode 100644 index 0000000..6919571 --- /dev/null +++ b/vendor/smb-server/src/conn/state.rs @@ -0,0 +1,328 @@ +//! Connection / session / tree / open state held during a single TCP +//! connection's lifetime. + +use std::collections::HashMap; +use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; + +use crate::proto::auth::ntlm::{Identity, NtlmServer}; +use crate::proto::crypto::{PreauthIntegrity, SigningAlgo}; +use crate::proto::messages::{Dialect, FileId}; +use tokio::sync::RwLock; +use uuid::Uuid; + +use crate::backend::Handle; +use crate::builder::Access; +use crate::path::SmbPath; +use crate::server::ShareBindings; + +/// In-flight NTLM acceptor + a `is_raw_ntlmssp` flag (true = raw, false = +/// SPNEGO-wrapped). The handler hands the second-round response back in the +/// same form the client opened with. +pub type PendingAuth = Arc>; + +// --------------------------------------------------------------------------- +// Connection +// --------------------------------------------------------------------------- + +/// One connection's negotiated state and its session/tree/open tables. +pub struct Connection { + pub server_guid: Uuid, + pub client_guid: tokio::sync::RwLock, + pub dialect: tokio::sync::RwLock>, + pub signing_algo: tokio::sync::RwLock, + /// Connection.PreauthIntegrityHashValue after NEGOTIATE. SMB 3.1.1 + /// SESSION_SETUP exchanges fork this into `session_preauth`. + pub preauth: Mutex, + /// Granted at NEGOTIATE: large MTU support flag etc. + pub max_read_size: tokio::sync::RwLock, + pub max_write_size: tokio::sync::RwLock, + + /// Sessions keyed by SessionId. + pub sessions: RwLock>>>, + + /// In-flight NTLM acceptors keyed by SessionId. We keep them out of + /// `Session` because a session is created only after a successful first + /// SESSION_SETUP round — between rounds the entry lives here. The + /// `bool` records whether the client sent raw NTLMSSP (true) or + /// SPNEGO-wrapped (false) so the second-round response matches form. + pub pending_auths: RwLock>, + + /// In-flight SMB 3.1.1 preauth state keyed by SessionId during + /// multi-leg SESSION_SETUP. + pub session_preauth: RwLock>, + + /// Monotonic SessionId allocator. + next_session_id: AtomicU64, +} + +impl Connection { + pub fn new(server_guid: Uuid, max_read_size: u32, max_write_size: u32) -> Self { + Self { + server_guid, + client_guid: tokio::sync::RwLock::new(Uuid::nil()), + dialect: tokio::sync::RwLock::new(None), + signing_algo: tokio::sync::RwLock::new(SigningAlgo::HmacSha256), + preauth: Mutex::new(PreauthIntegrity::new()), + max_read_size: tokio::sync::RwLock::new(max_read_size), + max_write_size: tokio::sync::RwLock::new(max_write_size), + sessions: RwLock::new(HashMap::new()), + pending_auths: RwLock::new(HashMap::new()), + session_preauth: RwLock::new(HashMap::new()), + next_session_id: AtomicU64::new(1), + } + } + + pub fn alloc_session_id(&self) -> u64 { + self.next_session_id.fetch_add(1, Ordering::Relaxed) + } + + pub async fn close_session(&self, session_id: u64) -> bool { + let removed = { + let mut sessions = self.sessions.write().await; + sessions.remove(&session_id) + }; + if let Some(sess_arc) = removed { + close_session_state(&sess_arc).await; + true + } else { + false + } + } + + pub async fn close_tree(&self, session_id: u64, tree_id: u32) -> bool { + let sess_arc = { + let sessions = self.sessions.read().await; + sessions.get(&session_id).cloned() + }; + let Some(sess_arc) = sess_arc else { + return false; + }; + remove_tree_from_session(&sess_arc, tree_id).await + } + + pub async fn close_sessions_for_user(&self, user: &str) -> usize { + let to_remove = { + let sessions = self.sessions.read().await; + let mut ids = Vec::new(); + for (session_id, sess_arc) in sessions.iter() { + let sess = sess_arc.read().await; + if matches!(&sess.identity, Identity::User { user: session_user, .. } if session_user == user) + { + ids.push(*session_id); + } + } + ids + }; + + let mut removed = 0; + for session_id in to_remove { + if self.close_session(session_id).await { + removed += 1; + } + } + removed + } + + pub async fn close_trees_for_share(&self, share_name: &str) -> usize { + self.close_matching_trees(|_, tree| tree.share.name.eq_ignore_ascii_case(share_name)) + .await + } + + pub async fn close_trees_for_user_share(&self, user: &str, share_name: &str) -> usize { + self.close_matching_trees(|sess, tree| { + matches!(&sess.identity, Identity::User { user: session_user, .. } if session_user == user) + && tree.share.name.eq_ignore_ascii_case(share_name) + }) + .await + } + + async fn close_matching_trees( + &self, + matches_tree: impl Fn(&Session, &TreeConnect) -> bool, + ) -> usize { + let sessions: Vec<_> = { + let sessions = self.sessions.read().await; + sessions.values().cloned().collect() + }; + + let mut removed = 0; + for sess_arc in sessions { + let tree_ids = { + let sess = sess_arc.read().await; + let trees = sess.trees.read().await; + let mut ids = Vec::new(); + for (tree_id, tree_arc) in trees.iter() { + let tree = tree_arc.read().await; + if matches_tree(&sess, &tree) { + ids.push(*tree_id); + } + } + ids + }; + + for tree_id in tree_ids { + if remove_tree_from_session(&sess_arc, tree_id).await { + removed += 1; + } + } + } + removed + } +} + +async fn close_session_state(sess_arc: &Arc>) { + let sess = sess_arc.write().await; + let trees: Vec<_> = sess.trees.write().await.drain().collect(); + for (_tree_id, tree_arc) in trees { + close_tree_state(&tree_arc).await; + } +} + +async fn remove_tree_from_session(sess_arc: &Arc>, tree_id: u32) -> bool { + let removed = { + let sess = sess_arc.read().await; + let mut trees = sess.trees.write().await; + trees.remove(&tree_id) + }; + if let Some(tree_arc) = removed { + close_tree_state(&tree_arc).await; + true + } else { + false + } +} + +async fn close_tree_state(tree_arc: &Arc>) { + let tree = tree_arc.write().await; + let opens: Vec<_> = tree.opens.write().await.drain().collect(); + for (_fid, open_arc) in opens { + let mut open = open_arc.write().await; + if let Some(handle) = open.handle.take() { + let _ = handle.close().await; + } + } +} + +// --------------------------------------------------------------------------- +// Session +// --------------------------------------------------------------------------- + +pub struct Session { + pub id: u64, + pub identity: Identity, + pub session_base_key: [u8; 16], + pub signing_key: [u8; 16], + /// Whether signing is required for this session's traffic. + pub signing_required: bool, + pub trees: RwLock>>>, + /// 3.1.1: snapshot taken at SESSION_SETUP completion (after the request + /// hash but before the response is hashed). Used as KDF context. + pub preauth_snapshot: Option<[u8; 64]>, + + next_tree_id: AtomicU32, +} + +impl Session { + pub fn new( + id: u64, + identity: Identity, + session_base_key: [u8; 16], + signing_key: [u8; 16], + signing_required: bool, + preauth_snapshot: Option<[u8; 64]>, + ) -> Self { + Self { + id, + identity, + session_base_key, + signing_key, + signing_required, + trees: RwLock::new(HashMap::new()), + preauth_snapshot, + next_tree_id: AtomicU32::new(1), + } + } + + pub fn alloc_tree_id(&self) -> u32 { + self.next_tree_id.fetch_add(1, Ordering::Relaxed) + } + + pub fn is_anonymous(&self) -> bool { + matches!(self.identity, Identity::Anonymous) + } +} + +// --------------------------------------------------------------------------- +// TreeConnect +// --------------------------------------------------------------------------- + +pub struct TreeConnect { + pub id: u32, + pub share: Arc, + pub granted_access: Access, + pub opens: RwLock>>>, + next_volatile: AtomicU64, +} + +impl TreeConnect { + pub fn new(id: u32, share: Arc, granted_access: Access) -> Self { + Self { + id, + share, + granted_access, + opens: RwLock::new(HashMap::new()), + next_volatile: AtomicU64::new(1), + } + } + + pub fn alloc_file_id(&self) -> FileId { + let v = self.next_volatile.fetch_add(1, Ordering::Relaxed); + FileId::new(v, v) + } +} + +// --------------------------------------------------------------------------- +// Open / DirCursor +// --------------------------------------------------------------------------- + +pub struct Open { + pub file_id: FileId, + pub handle: Option>, + pub granted_access: Access, + pub last_path: SmbPath, + pub is_directory: bool, + pub delete_on_close: bool, + pub search_state: Option, +} + +impl Open { + pub fn new( + file_id: FileId, + handle: Box, + granted_access: Access, + last_path: SmbPath, + is_directory: bool, + delete_on_close: bool, + ) -> Self { + Self { + file_id, + handle: Some(handle), + granted_access, + last_path, + is_directory, + delete_on_close, + search_state: None, + } + } +} + +/// Iterator state for a directory listing across multiple QUERY_DIRECTORY +/// calls. We snapshot the entries once and consume them in order; subsequent +/// calls advance `next` until exhaustion. +pub struct DirCursor { + pub entries: Vec, + pub next: usize, + /// The pattern fixed on the first scan; `RESTART_SCANS` resets `next`. + pub pattern: Option, +} diff --git a/vendor/smb-server/src/conn/writer.rs b/vendor/smb-server/src/conn/writer.rs new file mode 100644 index 0000000..7eae534 --- /dev/null +++ b/vendor/smb-server/src/conn/writer.rs @@ -0,0 +1,32 @@ +//! Per-connection writer task: serializes responses, applies signing, and +//! frames the bytes onto the wire. + +use crate::proto::framing::encode_frame; +use tokio::io::{AsyncWriteExt, WriteHalf}; +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tracing::{debug, error}; + +/// One packet of bytes to send. Already includes the final SMB2 header + +/// body, *with signing already applied if required*. +pub type FramePayload = Vec; + +/// Writer-task channel size: large enough that a slow remote rarely backs up +/// the dispatcher. +pub const WRITER_CHANNEL: usize = 64; + +pub async fn writer_task(mut writer: WriteHalf, mut rx: mpsc::Receiver) { + while let Some(payload) = rx.recv().await { + let mut out = Vec::with_capacity(payload.len() + 4); + encode_frame(&payload, &mut out); + if let Err(e) = writer.write_all(&out).await { + error!(error = %e, "writer task: socket write failed"); + return; + } + debug!(len = out.len(), "wrote frame"); + } + // Channel closed — flush and bail. + if let Err(e) = writer.shutdown().await { + debug!(error = %e, "writer shutdown error (best-effort)"); + } +} diff --git a/vendor/smb-server/src/dispatch.rs b/vendor/smb-server/src/dispatch.rs new file mode 100644 index 0000000..5ff1995 --- /dev/null +++ b/vendor/smb-server/src/dispatch.rs @@ -0,0 +1,656 @@ +//! Per-frame dispatch: parse header, route to handler, sign response, encode. + +use std::sync::Arc; + +use crate::proto::auth::ntlm::Identity; +use crate::proto::crypto::{PreauthIntegrity, sign}; +use crate::proto::header::{ + Command, HeaderTail, SMB2_FLAGS_ASYNC_COMMAND, SMB2_FLAGS_RELATED_OPERATIONS, + SMB2_FLAGS_SERVER_TO_REDIR, SMB2_FLAGS_SIGNED, SMB2_HEADER_LEN, Smb2Header, +}; +use crate::proto::messages::ErrorResponse; +use tracing::{Instrument, debug, debug_span, error, warn}; + +use crate::conn::state::Connection; +use crate::handlers; +use crate::ntstatus; +use crate::server::ServerState; + +/// Result of a handler: a complete (unsigned) response payload + the NTSTATUS +/// to set in the header. The dispatcher patches the header, applies signing +/// (if required), and ships the bytes. +pub struct HandlerResponse { + /// Bytes after the SMB2 header — the body. The handler owns body + /// construction. + pub body: Vec, + /// NTSTATUS for the response header. + pub status: u32, + /// Optional override for `tree_id` on the response header (e.g. + /// TREE_CONNECT returns the freshly minted tree id). + pub override_tree_id: Option, + /// Optional override for `session_id` on the response header (e.g. + /// SESSION_SETUP returns the freshly minted session id). + pub override_session_id: Option, + /// If true, the dispatcher will not sign the response. Used for + /// pre-session-setup messages where no key exists yet. + pub skip_signing: bool, + /// If set, take the per-session 3.1.1 preauth snapshot after hashing the + /// SESSION_SETUP request but before hashing the response. Set by + /// SESSION_SETUP on the round that produces STATUS_SUCCESS, so the + /// session's KDF context can use the snapshot. + pub take_preauth_snapshot_for_session: Option, +} + +impl HandlerResponse { + pub fn ok(body: Vec) -> Self { + Self { + body, + status: ntstatus::STATUS_SUCCESS, + override_tree_id: None, + override_session_id: None, + skip_signing: false, + take_preauth_snapshot_for_session: None, + } + } + + pub fn err(status: u32) -> Self { + let er = ErrorResponse::status(status); + let mut buf = Vec::new(); + er.write_to(&mut buf).expect("error response encodes"); + Self { + body: buf, + status, + override_tree_id: None, + override_session_id: None, + skip_signing: false, + take_preauth_snapshot_for_session: None, + } + } +} + +/// Top-level frame dispatch. Returns the bytes to push into the writer +/// channel, or `None` if the request elicits no response (CANCEL). +pub async fn dispatch_frame( + server: &Arc, + conn: &Arc, + frame: &[u8], +) -> Option> { + // SMB1 multi-protocol bootstrap (MS-SMB2 §3.3.5.3.1). The only SMB1 we + // accept: a NEGOTIATE_REQUEST listing "SMB 2.???" or "SMB 2.002". + // Reply with an SMB2 NEGOTIATE response and the client follows up with + // a real SMB2 NEGOTIATE. + if let Some(bytes) = handle_smb1_multi_protocol(server, conn, frame).await { + return Some(bytes); + } + if frame.len() < SMB2_HEADER_LEN { + warn!(len = frame.len(), "frame too short for SMB2 header"); + return None; + } + + let mut sub_offset = 0; + let mut responses = Vec::new(); + let mut prev_session_id = 0; + let mut prev_tree_id = 0; + let mut prev_create_file_id = None; + + while sub_offset < frame.len() { + let available = &frame[sub_offset..]; + if available.len() < SMB2_HEADER_LEN { + warn!(remaining = available.len(), "compound tail too short"); + return None; + } + + let (mut req_hdr, _) = match Smb2Header::parse(available) { + Ok(p) => p, + Err(e) => { + warn!(error = %e, "failed to parse compound sub-header"); + return None; + } + }; + + let next = req_hdr.next_command as usize; + let sub_len = if next == 0 { + available.len() + } else if next < SMB2_HEADER_LEN || next > available.len() { + warn!( + next, + remaining = available.len(), + "invalid compound NextCommand" + ); + return None; + } else { + next + }; + + let mut sub_frame = available[..sub_len].to_vec(); + if req_hdr.flags & SMB2_FLAGS_RELATED_OPERATIONS != 0 { + inherit_related_context( + &mut sub_frame, + &mut req_hdr, + prev_session_id, + prev_tree_id, + prev_create_file_id, + ); + } + + prev_session_id = req_hdr.session_id; + prev_tree_id = req_hdr.tree_id().unwrap_or(0); + + if let Some(response) = dispatch_one(server, conn, &sub_frame).await { + if req_hdr.command == Command::Create { + prev_create_file_id = capture_create_file_id(&response); + } + responses.push(response); + } + + if next == 0 { + break; + } + sub_offset += next; + } + + if responses.is_empty() { + return None; + } + + Some(stitch_responses(conn, responses).await) +} + +fn inherit_related_context( + sub_frame: &mut [u8], + req_hdr: &mut Smb2Header, + prev_session_id: u64, + prev_tree_id: u32, + prev_create_file_id: Option<[u8; 16]>, +) { + if read_u64(sub_frame, 0x28) == u64::MAX { + sub_frame[0x28..0x30].copy_from_slice(&prev_session_id.to_le_bytes()); + req_hdr.session_id = prev_session_id; + } + + if read_u32(sub_frame, 0x24) == u32::MAX { + sub_frame[0x24..0x28].copy_from_slice(&prev_tree_id.to_le_bytes()); + if let HeaderTail::Sync { reserved, .. } = req_hdr.tail { + req_hdr.tail = HeaderTail::Sync { + reserved, + tree_id: prev_tree_id, + }; + } + } + + let Some(file_id) = prev_create_file_id else { + return; + }; + let Some(body_offset) = file_id_body_offset(req_hdr.command) else { + return; + }; + let offset = SMB2_HEADER_LEN + body_offset; + if offset + 16 <= sub_frame.len() + && read_u64(sub_frame, offset) == u64::MAX + && read_u64(sub_frame, offset + 8) == u64::MAX + { + sub_frame[offset..offset + 16].copy_from_slice(&file_id); + } +} + +fn file_id_body_offset(command: Command) -> Option { + match command { + Command::Close + | Command::Flush + | Command::Lock + | Command::Ioctl + | Command::QueryDirectory + | Command::ChangeNotify + | Command::OplockBreak => Some(8), + Command::Read | Command::Write => Some(16), + Command::QueryInfo => Some(24), + Command::SetInfo => Some(16), + _ => None, + } +} + +fn capture_create_file_id(response: &[u8]) -> Option<[u8; 16]> { + if response.len() < SMB2_HEADER_LEN + 80 || read_u32(response, 0x08) != ntstatus::STATUS_SUCCESS + { + return None; + } + + let mut file_id = [0u8; 16]; + let offset = SMB2_HEADER_LEN + 64; + file_id.copy_from_slice(&response[offset..offset + 16]); + Some(file_id) +} + +async fn stitch_responses(conn: &Arc, responses: Vec>) -> Vec { + let mut out = Vec::new(); + let mut ranges = Vec::with_capacity(responses.len()); + let response_count = responses.len(); + + for (index, mut response) in responses.into_iter().enumerate() { + let start = out.len(); + let actual_len = response.len(); + if index + 1 < response_count { + let next = align_8(actual_len); + response[0x14..0x18].copy_from_slice(&(next as u32).to_le_bytes()); + } + out.extend_from_slice(&response); + ranges.push((start, actual_len)); + + if index + 1 < response_count { + out.resize(start + align_8(actual_len), 0); + } + } + + let algo = *conn.signing_algo.read().await; + for (start, len) in ranges { + let flags = read_u32(&out, start + 0x10); + if flags & SMB2_FLAGS_SIGNED == 0 { + continue; + } + + let session_id = read_u64(&out, start + 0x28); + let key = { + let sessions = conn.sessions.read().await; + sessions.get(&session_id).cloned() + }; + let Some(session) = key else { + continue; + }; + let session = session.read().await; + if matches!(session.identity, Identity::Anonymous) { + continue; + } + let signing_key = session.signing_key; + drop(session); + + if let Err(e) = sign(&mut out[start..start + len], &signing_key, algo) { + error!(error = %e, "failed to sign compound response"); + } + } + + out +} + +const fn align_8(n: usize) -> usize { + (n + 7) & !7 +} + +fn read_u32(buf: &[u8], offset: usize) -> u32 { + let mut bytes = [0u8; 4]; + bytes.copy_from_slice(&buf[offset..offset + 4]); + u32::from_le_bytes(bytes) +} + +fn read_u64(buf: &[u8], offset: usize) -> u64 { + let mut bytes = [0u8; 8]; + bytes.copy_from_slice(&buf[offset..offset + 8]); + u64::from_le_bytes(bytes) +} + +async fn dispatch_one( + server: &Arc, + conn: &Arc, + frame: &[u8], +) -> Option> { + let (req_hdr, body_bytes) = match Smb2Header::parse(frame) { + Ok(p) => p, + Err(e) => { + warn!(error = %e, "failed to parse header"); + return None; + } + }; + + let cmd = req_hdr.command; + let mid = req_hdr.message_id; + let sid = req_hdr.session_id; + let tid = req_hdr.tree_id().unwrap_or(0); + + let span = debug_span!("dispatch", cmd = ?cmd, mid, sid, tid); + async move { + debug!("dispatch start"); + + // Verify signature on incoming request (when applicable). + if let Err(status) = verify_request_signature(server, conn, &req_hdr, frame).await { + return Some(build_response_bytes(conn, &req_hdr, HandlerResponse::err(status)).await); + } + + // CANCEL is fire-and-forget — no response. + if cmd == Command::Cancel { + debug!("CANCEL received; no response"); + return None; + } + + let dialect = *conn.dialect.read().await; + let mut session_preauth = None; + + // 3.1.1 preauth is connection-scoped for NEGOTIATE, then per + // SESSION_SETUP authentication exchange. + if cmd == Command::Negotiate { + let mut p = conn + .preauth + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + p.update(frame); + } else if cmd == Command::SessionSetup + && dialect == Some(crate::proto::messages::Dialect::Smb311) + { + let mut p = take_session_preauth(conn, req_hdr.session_id).await; + p.update(frame); + session_preauth = Some(p); + } + + let resp = handlers::dispatch_command(server, conn, &req_hdr, body_bytes).await; + + // If the handler asked for a preauth snapshot (3.1.1), take it now. + if let Some(sid) = resp.take_preauth_snapshot_for_session { + let snap = session_preauth + .as_ref() + .expect("SMB 3.1.1 SessionSetup snapshot requires per-session preauth") + .snapshot(); + // Stash on the session — the handler already created it. + let sessions = conn.sessions.read().await; + if let Some(sess_arc) = sessions.get(&sid) { + let mut sess = sess_arc.write().await; + sess.preauth_snapshot = Some(snap); + // For 3.1.1, recompute signing key now that we have the snapshot. + let dialect = *conn.dialect.read().await; + if dialect == Some(crate::proto::messages::Dialect::Smb311) { + sess.signing_key = + crate::proto::crypto::signing_key_311(&sess.session_base_key, &snap); + } + } + } + + let bytes = build_response_bytes(conn, &req_hdr, resp).await; + + if cmd == Command::Negotiate { + let mut p = conn + .preauth + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + p.update(&bytes); + } else if cmd == Command::SessionSetup + && dialect == Some(crate::proto::messages::Dialect::Smb311) + { + if read_u32(&bytes, 0x08) == ntstatus::STATUS_MORE_PROCESSING_REQUIRED { + if let Some(mut p) = session_preauth { + p.update(&bytes); + let sid = read_u64(&bytes, 0x28); + conn.session_preauth.write().await.insert(sid, p); + } + } else { + conn.session_preauth + .write() + .await + .remove(&req_hdr.session_id); + } + } + + Some(bytes) + } + .instrument(span) + .await +} + +async fn take_session_preauth(conn: &Arc, session_id: u64) -> PreauthIntegrity { + if session_id != 0 + && let Some(preauth) = conn.session_preauth.write().await.remove(&session_id) + { + return preauth; + } + + conn.preauth + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .clone() +} + +async fn verify_request_signature( + _server: &Arc, + conn: &Arc, + hdr: &Smb2Header, + frame: &[u8], +) -> Result<(), u32> { + if hdr.command == Command::Negotiate { + return Ok(()); + } + if hdr.session_id == 0 { + return Ok(()); + } + let sessions = conn.sessions.read().await; + let sess_arc = match sessions.get(&hdr.session_id) { + Some(s) => s.clone(), + None => { + // Unknown session. + if hdr.flags & SMB2_FLAGS_SIGNED == 0 { + return Ok(()); + } + return Err(ntstatus::STATUS_USER_SESSION_DELETED); + } + }; + drop(sessions); + + if hdr.flags & SMB2_FLAGS_SIGNED != 0 { + let sess = sess_arc.read().await; + if matches!(sess.identity, Identity::Anonymous) { + return Ok(()); + } + let key = sess.signing_key; + drop(sess); + let algo = *conn.signing_algo.read().await; + if let Err(e) = crate::proto::crypto::verify(frame, &key, algo) { + warn!(error = %e, "request signature verification failed"); + return Err(ntstatus::STATUS_ACCESS_DENIED); + } + } else if hdr.command != Command::SessionSetup { + let sess = sess_arc.read().await; + let need = sess.signing_required && !matches!(sess.identity, Identity::Anonymous); + drop(sess); + if need { + warn!(?hdr.command, "missing required signature on request"); + return Err(ntstatus::STATUS_ACCESS_DENIED); + } + } + Ok(()) +} + +/// Build the final on-the-wire bytes: header + body, with signing applied +/// when the session has a key. +async fn build_response_bytes( + conn: &Arc, + req_hdr: &Smb2Header, + handler_resp: HandlerResponse, +) -> Vec { + let mut hdr = *req_hdr; + hdr.flags |= SMB2_FLAGS_SERVER_TO_REDIR; + hdr.flags &= !SMB2_FLAGS_ASYNC_COMMAND; + hdr.next_command = 0; + hdr.channel_sequence_status = handler_resp.status; + hdr.tail = HeaderTail::sync( + handler_resp + .override_tree_id + .unwrap_or_else(|| req_hdr.tree_id().unwrap_or(0)), + ); + if let Some(sid) = handler_resp.override_session_id { + hdr.session_id = sid; + } + hdr.signature = [0u8; 16]; + + let request_was_signed = req_hdr.flags & SMB2_FLAGS_SIGNED != 0; + // MS-SMB2 §3.3.5.5.3 step 12: SessionSetup SUCCESS must be signed for + // non-anon/non-guest sessions even though the request cannot be signed yet. + let is_session_setup_success = + req_hdr.command == Command::SessionSetup && handler_resp.status == ntstatus::STATUS_SUCCESS; + let mut should_sign = false; + let mut key = [0u8; 16]; + let algo = *conn.signing_algo.read().await; + if !handler_resp.skip_signing + && hdr.session_id != 0 + && (request_was_signed || is_session_setup_success) + { + let sessions = conn.sessions.read().await; + if let Some(sess_arc) = sessions.get(&hdr.session_id) { + let sess = sess_arc.read().await; + let is_anon = matches!(sess.identity, Identity::Anonymous); + let is_guest_response = is_session_setup_success + && handler_resp.body.len() >= 4 + && (handler_resp.body[2] & 0x01) != 0; + if !is_anon && !is_guest_response && sess.signing_key != [0u8; 16] { + key = sess.signing_key; + should_sign = true; + } + } + } + if should_sign { + hdr.flags |= SMB2_FLAGS_SIGNED; + } else { + hdr.flags &= !SMB2_FLAGS_SIGNED; + } + let mut out = Vec::with_capacity(SMB2_HEADER_LEN + handler_resp.body.len()); + if let Err(e) = hdr.write(&mut out) { + error!(error = %e, "failed to encode response header"); + return Vec::new(); + } + out.extend_from_slice(&handler_resp.body); + + if should_sign && let Err(e) = sign(&mut out, &key, algo) { + error!(error = %e, "failed to sign response"); + } + out +} + +/// Detect and answer an SMB1 multi-protocol NEGOTIATE_REQUEST. +/// +/// SMB1 frame layout for the request we accept: +/// * `[0..4]` — magic `0xFF 'S' 'M' 'B'` +/// * `[4]` — command (0x72 = SMB_COM_NEGOTIATE) +/// * `[5..32]` — rest of SMB1 header (status, flags, pid, tid, mid …) +/// * `[32]` — `WordCount` (0 for NEGOTIATE) +/// * `[33..35]`— `ByteCount` (u16 LE) +/// * `[35..]` — dialect strings, each `0x02 0x00`. +/// +/// Returns `Some(reply_bytes)` only for a SMB1 NEGOTIATE that lists at least +/// one SMB2 dialect we recognise; otherwise `None` so the caller can fall +/// through to the normal SMB2 path. +async fn handle_smb1_multi_protocol( + server: &Arc, + conn: &Arc, + frame: &[u8], +) -> Option> { + if frame.len() < 35 || frame[0..4] != [0xFF, b'S', b'M', b'B'] || frame[4] != 0x72 { + return None; + } + let body_start = 33; // 32-byte header + 1-byte WordCount(=0) + let byte_count = u16::from_le_bytes([frame[body_start], frame[body_start + 1]]) as usize; + let blob_start = body_start + 2; + let blob_end = (blob_start + byte_count).min(frame.len()); + let blob = &frame[blob_start..blob_end]; + + let mut wants_wildcard = false; + let mut wants_smb202 = false; + let mut i = 0; + while i < blob.len() { + if blob[i] != 0x02 { + break; + } + i += 1; + let nul = match blob[i..].iter().position(|&b| b == 0) { + Some(p) => p, + None => break, + }; + let s = std::str::from_utf8(&blob[i..i + nul]).unwrap_or(""); + match s { + "SMB 2.???" => wants_wildcard = true, + "SMB 2.002" => wants_smb202 = true, + _ => {} + } + i += nul + 1; + } + + let chosen = if wants_wildcard { + crate::proto::messages::Dialect::Smb2Wildcard.as_u16() + } else if wants_smb202 { + crate::proto::messages::Dialect::Smb202.as_u16() + } else { + return None; + }; + + debug!( + chosen = %format_args!("0x{chosen:04X}"), + "SMB1 multi-protocol negotiate" + ); + + // Synthesize a request header so build_response_bytes can mint the + // SERVER_TO_REDIR response. Per MS-SMB2 §3.3.5.3.1 the response uses + // message_id=0, tree_id=0xFFFF, session_id=0. + let req_hdr = Smb2Header { + command: Command::Negotiate, + message_id: 0, + session_id: 0, + tail: HeaderTail::Sync { + reserved: 0, + tree_id: 0xFFFF, + }, + ..Default::default() + }; + let resp = handlers::negotiate::multi_protocol_response(server, conn, chosen).await; + Some(build_response_bytes(conn, &req_hdr, resp).await) +} + +#[cfg(test)] +mod tests { + use super::*; + use uuid::Uuid; + + fn test_conn() -> Arc { + Arc::new(Connection::new(Uuid::nil(), 1024 * 1024, 1024 * 1024)) + } + + fn negotiated_preauth() -> PreauthIntegrity { + let mut preauth = PreauthIntegrity::new(); + preauth.update(b"negotiate request"); + preauth.update(b"negotiate response"); + preauth + } + + #[tokio::test] + async fn new_session_setup_preauth_starts_from_negotiate_base() { + let conn = test_conn(); + let base = negotiated_preauth(); + *conn.preauth.lock().expect("preauth lock") = base.clone(); + + let mut first_session = take_session_preauth(&conn, 0).await; + first_session.update(b"session one request"); + first_session.update(b"session one response"); + conn.session_preauth.write().await.insert(1, first_session); + + let mut second_session = take_session_preauth(&conn, 0).await; + second_session.update(b"session two request"); + + let mut expected = base.clone(); + expected.update(b"session two request"); + + let mut polluted = base; + polluted.update(b"session one request"); + polluted.update(b"session one response"); + polluted.update(b"session two request"); + + assert_eq!(second_session.snapshot(), expected.snapshot()); + assert_ne!(second_session.snapshot(), polluted.snapshot()); + } + + #[tokio::test] + async fn followup_session_setup_consumes_stored_session_preauth() { + let conn = test_conn(); + let mut stored = negotiated_preauth(); + stored.update(b"session setup request"); + stored.update(b"session setup more-processing response"); + let expected = stored.snapshot(); + conn.session_preauth.write().await.insert(7, stored); + + let got = take_session_preauth(&conn, 7).await; + + assert_eq!(got.snapshot(), expected); + assert!(!conn.session_preauth.read().await.contains_key(&7)); + } +} diff --git a/vendor/smb-server/src/error.rs b/vendor/smb-server/src/error.rs new file mode 100644 index 0000000..1e2e7a1 --- /dev/null +++ b/vendor/smb-server/src/error.rs @@ -0,0 +1,86 @@ +//! Public error type for the server, plus the NTSTATUS mapping per spec §8. + +use thiserror::Error; + +use crate::ntstatus; + +pub type SmbResult = Result; + +/// Errors returned by `ShareBackend` and surfaced through the SMB protocol. +/// +/// `to_nt_status` maps each variant onto a single NTSTATUS code per the spec +/// §8 table. Internal protocol-layer failures (malformed frames, signing +/// errors) never become `SmbError`; the connection loop logs them and aborts. +#[derive(Debug, Error)] +pub enum SmbError { + #[error("not found")] + NotFound, + #[error("path not found")] + PathNotFound, + #[error("access denied")] + AccessDenied, + #[error("exists")] + Exists, + #[error("not empty")] + NotEmpty, + #[error("is a directory")] + IsDirectory, + #[error("not a directory")] + NotADirectory, + #[error("name too long / invalid")] + NameInvalid, + #[error("sharing violation")] + Sharing, + #[error("not supported")] + NotSupported, + #[error("io: {0}")] + Io(#[from] std::io::Error), +} + +impl SmbError { + /// Map this error onto an NTSTATUS code per the v1 spec §8 table. + pub fn to_nt_status(&self) -> u32 { + match self { + SmbError::NotFound => ntstatus::STATUS_OBJECT_NAME_NOT_FOUND, + SmbError::PathNotFound => ntstatus::STATUS_OBJECT_PATH_NOT_FOUND, + SmbError::AccessDenied => ntstatus::STATUS_ACCESS_DENIED, + SmbError::Exists => ntstatus::STATUS_OBJECT_NAME_COLLISION, + SmbError::NotEmpty => ntstatus::STATUS_DIRECTORY_NOT_EMPTY, + SmbError::IsDirectory => ntstatus::STATUS_FILE_IS_A_DIRECTORY, + SmbError::NotADirectory => ntstatus::STATUS_NOT_A_DIRECTORY, + SmbError::NameInvalid => ntstatus::STATUS_OBJECT_NAME_INVALID, + SmbError::Sharing => ntstatus::STATUS_SHARING_VIOLATION, + SmbError::NotSupported => ntstatus::STATUS_NOT_SUPPORTED, + SmbError::Io(_) => ntstatus::STATUS_UNEXPECTED_IO_ERROR, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn nt_status_table_matches_spec() { + assert_eq!(SmbError::NotFound.to_nt_status(), 0xC000_000F); + assert_eq!(SmbError::PathNotFound.to_nt_status(), 0xC000_003A); + assert_eq!(SmbError::AccessDenied.to_nt_status(), 0xC000_0022); + assert_eq!(SmbError::Exists.to_nt_status(), 0xC000_0035); + assert_eq!(SmbError::NotEmpty.to_nt_status(), 0xC000_0101); + assert_eq!(SmbError::IsDirectory.to_nt_status(), 0xC000_00BA); + assert_eq!(SmbError::NotADirectory.to_nt_status(), 0xC000_0103); + assert_eq!(SmbError::NameInvalid.to_nt_status(), 0xC000_0033); + assert_eq!(SmbError::Sharing.to_nt_status(), 0xC000_0043); + assert_eq!(SmbError::NotSupported.to_nt_status(), 0xC000_00BB); + + let io_err = SmbError::Io(std::io::Error::other("boom")); + assert_eq!(io_err.to_nt_status(), 0xC000_009C); + } + + #[test] + fn io_err_from_blanket_works() { + let io: std::io::Error = std::io::Error::other("x"); + let smb: SmbError = io.into(); + assert_eq!(smb.to_nt_status(), 0xC000_009C); + } +} diff --git a/vendor/smb-server/src/fs/local.rs b/vendor/smb-server/src/fs/local.rs new file mode 100644 index 0000000..3e4eff5 --- /dev/null +++ b/vendor/smb-server/src/fs/local.rs @@ -0,0 +1,921 @@ +//! `LocalFsBackend` — a `ShareBackend` backed by a real on-disk directory. +//! +//! The share root is opened once via `cap_std::fs::Dir::open_ambient_dir` and +//! kept as the sole authority handle. All subsequent path operations are +//! resolved relative to that handle, so a malicious symlink or `..` smuggled +//! through `SmbPath` cannot escape the sandbox — `cap-std` enforces this at +//! every step. +//! +//! Per the v1 design (spec §3.4) this backend is intentionally minimal: +//! +//! - Sync FS calls are wrapped in `tokio::task::spawn_blocking` so the async +//! `ShareBackend`/`Handle` methods integrate cleanly with the dispatcher. +//! - `read_only()` flips a flag that makes write-class opens reject early +//! with `SmbError::AccessDenied`. +//! - DOS-style glob matching for `list_dir` is handled here (case-insensitive, +//! `?` and `*`), since cap-std only provides raw `entries()`. + +use std::io; +use std::os::unix::fs::FileExt as _; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use async_trait::async_trait; +use bytes::Bytes; +use cap_std::ambient_authority; +use cap_std::fs::{Dir, OpenOptions as CapOpenOptions}; +use tokio::task::spawn_blocking; + +use crate::backend::{ + BackendCapabilities, DirEntry as SmbDirEntry, FileInfo, FileTimes, Handle, OpenIntent, + OpenOptions, ShareBackend, +}; +use crate::error::{SmbError, SmbResult}; +use crate::path::SmbPath; + +// --------------------------------------------------------------------------- +// Backend +// --------------------------------------------------------------------------- + +/// Local-filesystem backend, sandboxed at a single root directory. +/// +/// Cheap to clone: internally an `Arc` plus a flag. +pub struct LocalFsBackend { + root: Arc, + read_only: bool, +} + +impl LocalFsBackend { + /// Open `path` as the share root. Errors if the path does not exist or is + /// not a directory. + pub fn new(path: impl AsRef) -> io::Result { + let dir = Dir::open_ambient_dir(path, ambient_authority())?; + Ok(Self { + root: Arc::new(dir), + read_only: false, + }) + } + + /// Mark the backend as read-only. All write-class opens and writes return + /// an access-denied SMB error. + #[must_use] + pub fn read_only(mut self) -> Self { + self.read_only = true; + self + } +} + +// --------------------------------------------------------------------------- +// Path translation +// --------------------------------------------------------------------------- + +/// Convert a validated `SmbPath` into a relative `PathBuf` suitable for +/// `cap_std::fs::Dir` lookups. +/// +/// `SmbPath` is already validated (no `..`, no forbidden chars, no doubled +/// separators), so this is purely a join. The empty `SmbPath` (root) yields +/// `PathBuf::from(".")` — cap-std accepts this for `metadata` etc. +fn to_rel_path(path: &SmbPath) -> PathBuf { + if path.is_root() { + return PathBuf::from("."); + } + let mut out = PathBuf::new(); + for c in path.components() { + out.push(c); + } + out +} + +// --------------------------------------------------------------------------- +// Error mapping +// --------------------------------------------------------------------------- + +fn io_to_smb(err: io::Error) -> SmbError { + use io::ErrorKind::*; + match err.kind() { + NotFound => SmbError::NotFound, + PermissionDenied => SmbError::AccessDenied, + AlreadyExists => SmbError::Exists, + DirectoryNotEmpty => SmbError::NotEmpty, + IsADirectory => SmbError::IsDirectory, + NotADirectory => SmbError::NotADirectory, + InvalidInput | InvalidFilename => SmbError::NameInvalid, + _ => SmbError::Io(err), + } +} + +/// Convert a panic from `spawn_blocking` into an `io::Error`. Panics in the +/// blocking pool are exotic; we surface them as a generic `Other` rather than +/// re-panicking on the async side. +fn join_to_io(_e: tokio::task::JoinError) -> io::Error { + io::Error::other("blocking task panicked or was cancelled") +} + +// --------------------------------------------------------------------------- +// FILETIME conversion +// --------------------------------------------------------------------------- + +/// Number of 100-nanosecond intervals between 1601-01-01 (Windows FILETIME +/// epoch) and 1970-01-01 (UNIX epoch). +const FILETIME_OFFSET: u64 = 116_444_736_000_000_000; + +fn system_time_to_filetime(t: SystemTime) -> u64 { + match t.duration_since(UNIX_EPOCH) { + Ok(d) => FILETIME_OFFSET + (d.as_secs() * 10_000_000) + u64::from(d.subsec_nanos() / 100), + Err(_) => 0, + } +} + +fn filetime_to_system_time(ft: u64) -> Option { + if ft < FILETIME_OFFSET { + return None; + } + let unix_100ns = ft - FILETIME_OFFSET; + let secs = unix_100ns / 10_000_000; + let nanos = ((unix_100ns % 10_000_000) * 100) as u32; + UNIX_EPOCH.checked_add(Duration::new(secs, nanos)) +} + +// --------------------------------------------------------------------------- +// FileInfo construction +// --------------------------------------------------------------------------- + +fn file_info_from_metadata(name: String, md: &cap_std::fs::Metadata) -> FileInfo { + let len = md.len(); + let modified = md.modified().ok().map(|t| t.into_std()); + let accessed = md.accessed().ok().map(|t| t.into_std()); + let created = md.created().ok().map(|t| t.into_std()); + + // Fall back: if a particular timestamp isn't available on the platform, + // use whichever timestamp is available, then `now()` as last resort. SMB + // clients tolerate equal timestamps fine. + let modified = modified + .or(created) + .or(accessed) + .unwrap_or(SystemTime::UNIX_EPOCH); + let accessed = accessed.unwrap_or(modified); + let created = created.unwrap_or(modified); + + FileInfo { + name, + end_of_file: len, + allocation_size: len, + creation_time: system_time_to_filetime(created), + last_access_time: system_time_to_filetime(accessed), + last_write_time: system_time_to_filetime(modified), + change_time: system_time_to_filetime(modified), + is_directory: md.is_dir(), + // `cap-std` does not expose a stable inode-style identifier in its + // public API; the dispatcher substitutes the FileId where needed. + file_index: 0, + } +} + +// --------------------------------------------------------------------------- +// DOS glob matching +// --------------------------------------------------------------------------- + +/// Match `name` against a DOS-style pattern. `?` matches any single char, +/// `*` matches any sequence (possibly empty). Comparison is case-insensitive +/// (ASCII fold) — sufficient for the v1 use-case where names are validated to +/// be free of weird Unicode tricks. +fn glob_match(pattern: &str, name: &str) -> bool { + // Walk both strings as char vectors so `?` matches a char rather than a + // byte, without going through grapheme territory. + let p: Vec = pattern.chars().collect(); + let n: Vec = name.chars().collect(); + glob_match_inner(&p, &n) +} + +fn glob_match_inner(p: &[char], n: &[char]) -> bool { + let mut pi = 0usize; + let mut ni = 0usize; + let mut star: Option<(usize, usize)> = None; // (pi after '*', ni at the time) + + while ni < n.len() { + if pi < p.len() && (p[pi] == '?' || ascii_eq_ci(p[pi], n[ni])) { + pi += 1; + ni += 1; + } else if pi < p.len() && p[pi] == '*' { + star = Some((pi + 1, ni)); + pi += 1; + } else if let Some((sp, sn)) = star { + pi = sp; + ni = sn + 1; + star = Some((sp, sn + 1)); + } else { + return false; + } + } + while pi < p.len() && p[pi] == '*' { + pi += 1; + } + pi == p.len() +} + +fn ascii_eq_ci(a: char, b: char) -> bool { + a.eq_ignore_ascii_case(&b) +} + +// --------------------------------------------------------------------------- +// ShareBackend impl +// --------------------------------------------------------------------------- + +#[async_trait] +impl ShareBackend for LocalFsBackend { + async fn open(&self, path: &SmbPath, opts: OpenOptions) -> SmbResult> { + // 1. Read-only check: any open that requests creation, write access, + // truncation, or overwrite is rejected up front. Pure read opens + // pass through. + let writes = opts.write + || matches!( + opts.intent, + OpenIntent::Create + | OpenIntent::OpenOrCreate + | OpenIntent::OverwriteOrCreate + | OpenIntent::Truncate + ); + if self.read_only && writes { + return Err(SmbError::AccessDenied); + } + + let rel = to_rel_path(path); + let root = Arc::clone(&self.root); + let read_only = self.read_only; + let directory = opts.directory; + let non_directory = opts.non_directory; + + // For directories, cap-std exposes `open_dir` separately; we don't + // need an OpenOptions translation in that case. + if directory { + // Directory CREATE intents: Create / OpenOrCreate / OverwriteOrCreate + // imply mkdir; Open / Truncate require existing. + let intent = opts.intent; + let dir_handle = spawn_blocking(move || -> io::Result { + match intent { + OpenIntent::Open => root.open_dir(&rel), + OpenIntent::Create => { + root.create_dir(&rel)?; + root.open_dir(&rel) + } + OpenIntent::OpenOrCreate => { + if !root.exists(&rel) { + root.create_dir(&rel)?; + } + root.open_dir(&rel) + } + OpenIntent::Truncate | OpenIntent::OverwriteOrCreate => { + // Truncating a directory has no meaning; reject. + Err(io::Error::from(io::ErrorKind::InvalidInput)) + } + } + }) + .await + .map_err(join_to_io) + .map_err(io_to_smb)? + .map_err(io_to_smb)?; + + return Ok(Box::new(LocalHandle::Dir { + name: file_name_for(path), + dir_handle: Arc::new(dir_handle), + })); + } + + let existing_is_dir = { + let root = Arc::clone(&self.root); + let rel = rel.clone(); + spawn_blocking(move || -> io::Result { + match root.metadata(&rel) { + Ok(md) => Ok(md.is_dir()), + Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(false), + Err(e) => Err(e), + } + }) + .await + .map_err(join_to_io) + .map_err(io_to_smb)? + .map_err(io_to_smb)? + }; + if existing_is_dir { + if non_directory { + return Err(SmbError::IsDirectory); + } + match opts.intent { + OpenIntent::Open | OpenIntent::OpenOrCreate => { + let root = Arc::clone(&self.root); + let rel = rel.clone(); + let dir_handle = spawn_blocking(move || root.open_dir(&rel)) + .await + .map_err(join_to_io) + .map_err(io_to_smb)? + .map_err(io_to_smb)?; + return Ok(Box::new(LocalHandle::Dir { + name: file_name_for(path), + dir_handle: Arc::new(dir_handle), + })); + } + OpenIntent::Create => return Err(SmbError::Exists), + OpenIntent::Truncate | OpenIntent::OverwriteOrCreate => { + return Err(SmbError::IsDirectory); + } + } + } + + // 2. Translate OpenIntent → cap-std OpenOptions. + let mut cap_opts = CapOpenOptions::new(); + match opts.intent { + OpenIntent::Open => { + cap_opts.read(true).write(opts.write); + } + OpenIntent::Create => { + cap_opts.read(opts.read).write(true).create_new(true); + } + OpenIntent::Truncate => { + cap_opts.read(opts.read).write(true).truncate(true); + } + OpenIntent::OpenOrCreate => { + cap_opts.read(opts.read).write(true).create(true); + } + OpenIntent::OverwriteOrCreate => { + cap_opts + .read(opts.read) + .write(true) + .create(true) + .truncate(true); + } + } + + let cap_file = spawn_blocking(move || root.open_with(&rel, &cap_opts)) + .await + .map_err(join_to_io) + .map_err(io_to_smb)? + .map_err(io_to_smb)?; + + // Convert to a `std::fs::File`. We only need cap-std for the safe + // *open*; once we hold a verified file handle, std's API gives us + // `set_times`, `set_len`, `sync_data`, and `FileExt::{read,write}_at` + // without pulling in extra crates. + let std_file: std::fs::File = cap_file.into_std(); + + Ok(Box::new(LocalHandle::File { + name: file_name_for(path), + file: Arc::new(std_file), + read_only, + })) + } + + async fn unlink(&self, path: &SmbPath) -> SmbResult<()> { + if self.read_only { + return Err(SmbError::AccessDenied); + } + if path.is_root() { + // Refusing to delete the share root itself. + return Err(SmbError::AccessDenied); + } + let rel = to_rel_path(path); + let root = Arc::clone(&self.root); + + spawn_blocking(move || -> io::Result<()> { + match root.remove_file(&rel) { + Ok(()) => Ok(()), + Err(e) if e.kind() == io::ErrorKind::IsADirectory => { + // Caller's intent was "delete this name"; if it turned + // out to be a directory, fall back to remove_dir which + // refuses non-empty dirs (mapped to NotEmpty above). + root.remove_dir(&rel) + } + Err(e) => Err(e), + } + }) + .await + .map_err(join_to_io) + .map_err(io_to_smb)? + .map_err(io_to_smb) + } + + async fn rename(&self, from: &SmbPath, to: &SmbPath) -> SmbResult<()> { + if self.read_only { + return Err(SmbError::AccessDenied); + } + if from.is_root() || to.is_root() { + return Err(SmbError::NameInvalid); + } + let from = to_rel_path(from); + let to_path = to_rel_path(to); + let root = Arc::clone(&self.root); + let root2 = Arc::clone(&self.root); + + spawn_blocking(move || -> io::Result<()> { + // Reject overwrite — SMB rename semantics require explicit + // replace-if-exists which we do not implement in v1. + if root2.exists(&to_path) { + return Err(io::Error::from(io::ErrorKind::AlreadyExists)); + } + root.rename(&from, &root2, &to_path) + }) + .await + .map_err(join_to_io) + .map_err(io_to_smb)? + .map_err(io_to_smb) + } + + fn capabilities(&self) -> BackendCapabilities { + BackendCapabilities { + is_read_only: self.read_only, + // POSIX filesystems are typically case-sensitive. We don't try to + // emulate case-insensitive lookup in v1 (see spec §3.4). + case_sensitive: cfg!(any(target_os = "linux", target_os = "freebsd")), + } + } +} + +// --------------------------------------------------------------------------- +// Handle +// --------------------------------------------------------------------------- + +/// Internal handle variant. `File` carries a `std::fs::File` (after cap-std +/// has done the safe open); `Dir` keeps the `cap_std::fs::Dir` so we can +/// re-list entries. +enum LocalHandle { + File { + name: String, + file: Arc, + read_only: bool, + }, + Dir { + name: String, + dir_handle: Arc, + }, +} + +fn file_name_for(path: &SmbPath) -> String { + path.file_name().unwrap_or("").to_string() +} + +#[async_trait] +impl Handle for LocalHandle { + async fn read(&self, offset: u64, len: u32) -> SmbResult { + match self { + LocalHandle::File { file, .. } => { + let file = Arc::clone(file); + let n = len as usize; + let bytes = spawn_blocking(move || -> io::Result { + let mut buf = vec![0u8; n]; + let read = file.read_at(&mut buf, offset)?; + buf.truncate(read); + Ok(Bytes::from(buf)) + }) + .await + .map_err(join_to_io) + .map_err(io_to_smb)? + .map_err(io_to_smb)?; + Ok(bytes) + } + LocalHandle::Dir { .. } => Err(SmbError::IsDirectory), + } + } + + async fn write(&self, offset: u64, data: &[u8]) -> SmbResult { + self.write_owned(offset, data.to_vec()).await + } + + async fn write_owned(&self, offset: u64, data: Vec) -> SmbResult { + match self { + LocalHandle::File { + file, read_only, .. + } => { + if *read_only { + return Err(SmbError::AccessDenied); + } + let file = Arc::clone(file); + let written = spawn_blocking(move || file.write_at(&data, offset)) + .await + .map_err(join_to_io) + .map_err(io_to_smb)? + .map_err(io_to_smb)?; + Ok(u32::try_from(written).unwrap_or(u32::MAX)) + } + LocalHandle::Dir { .. } => Err(SmbError::IsDirectory), + } + } + + async fn flush(&self) -> SmbResult<()> { + match self { + LocalHandle::File { file, .. } => { + let file = Arc::clone(file); + spawn_blocking(move || file.sync_data()) + .await + .map_err(join_to_io) + .map_err(io_to_smb)? + .map_err(io_to_smb) + } + // Flushing a directory is a no-op in SMB semantics. + LocalHandle::Dir { .. } => Ok(()), + } + } + + async fn stat(&self) -> SmbResult { + match self { + LocalHandle::File { file, name, .. } => { + let file = Arc::clone(file); + let name = name.clone(); + spawn_blocking(move || -> io::Result { + let std_md = file.metadata()?; + // Synthesize a cap-std Metadata from the std one so we + // can reuse `file_info_from_metadata`. cap-primitives + // exposes `Metadata::from_just_metadata` for this. + let md = cap_std::fs::Metadata::from_just_metadata(std_md); + Ok(file_info_from_metadata(name, &md)) + }) + .await + .map_err(join_to_io) + .map_err(io_to_smb)? + .map_err(io_to_smb) + } + LocalHandle::Dir { + dir_handle, name, .. + } => { + let dir_handle = Arc::clone(dir_handle); + let name = name.clone(); + spawn_blocking(move || -> io::Result { + let md = dir_handle.dir_metadata()?; + Ok(file_info_from_metadata(name, &md)) + }) + .await + .map_err(join_to_io) + .map_err(io_to_smb)? + .map_err(io_to_smb) + } + } + } + + async fn set_times(&self, times: FileTimes) -> SmbResult<()> { + match self { + LocalHandle::File { + file, read_only, .. + } => { + if *read_only { + return Err(SmbError::AccessDenied); + } + let file = Arc::clone(file); + spawn_blocking(move || -> io::Result<()> { + let mut std_times = std::fs::FileTimes::new(); + if let Some(ft) = times.last_write_time + && let Some(t) = filetime_to_system_time(ft) + { + std_times = std_times.set_modified(t); + } + if let Some(ft) = times.last_access_time + && let Some(t) = filetime_to_system_time(ft) + { + std_times = std_times.set_accessed(t); + } + // creation_time / change_time: stable std::fs::FileTimes + // does not expose setters for these; silently ignored. + file.set_times(std_times) + }) + .await + .map_err(join_to_io) + .map_err(io_to_smb)? + .map_err(io_to_smb) + } + // cap-std's directory handle does not expose set_times in its + // stable API; mark as unsupported on directories. + LocalHandle::Dir { .. } => Err(SmbError::NotSupported), + } + } + + async fn truncate(&self, len: u64) -> SmbResult<()> { + match self { + LocalHandle::File { + file, read_only, .. + } => { + if *read_only { + return Err(SmbError::AccessDenied); + } + let file = Arc::clone(file); + spawn_blocking(move || file.set_len(len)) + .await + .map_err(join_to_io) + .map_err(io_to_smb)? + .map_err(io_to_smb) + } + // Protocol layer rejects truncate on dir handles before this; if + // it ever reaches us, surface as NotSupported. + LocalHandle::Dir { .. } => Err(SmbError::NotSupported), + } + } + + async fn list_dir(&self, pattern: Option<&str>) -> SmbResult> { + match self { + LocalHandle::File { .. } => Err(SmbError::NotADirectory), + LocalHandle::Dir { dir_handle, .. } => { + let dir_handle = Arc::clone(dir_handle); + let pat = pattern.map(|s| s.to_owned()); + spawn_blocking(move || -> io::Result> { + let mut out = Vec::new(); + for entry in dir_handle.entries()? { + let entry = entry?; + let os_name = entry.file_name(); + let Some(name) = os_name.to_str().map(str::to_owned) else { + // Skip non-UTF-8 names; SMB wire format is UTF-16 + // and we never want to emit invalid Unicode here. + continue; + }; + if let Some(p) = pat.as_deref() { + // Empty / "*" / "*.*" all mean "match everything" + // in DOS-speak. + if !(p.is_empty() || p == "*" || p == "*.*" || glob_match(p, &name)) { + continue; + } + } + let md = entry.metadata()?; + let info = file_info_from_metadata(name, &md); + out.push(SmbDirEntry { info }); + } + Ok(out) + }) + .await + .map_err(join_to_io) + .map_err(io_to_smb)? + .map_err(io_to_smb) + } + } + } + + async fn close(self: Box) -> SmbResult<()> { + // Drop is sufficient — closing the underlying handle is what the OS + // does when the last `Arc` ref goes away. No flush here: SMB CLOSE + // does not imply fsync. + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::backend::{OpenIntent, OpenOptions}; + use crate::path::SmbPath; + use tempfile::tempdir; + + fn p(s: &str) -> SmbPath { + s.parse::().unwrap() + } + + fn opts_create() -> OpenOptions { + OpenOptions { + read: true, + write: true, + intent: OpenIntent::Create, + directory: false, + non_directory: false, + delete_on_close: false, + } + } + + fn opts_open_rw() -> OpenOptions { + OpenOptions { + read: true, + write: true, + intent: OpenIntent::Open, + directory: false, + non_directory: false, + delete_on_close: false, + } + } + + fn opts_open_ro() -> OpenOptions { + OpenOptions { + read: true, + write: false, + intent: OpenIntent::Open, + directory: false, + non_directory: false, + delete_on_close: false, + } + } + + fn opts_open_dir() -> OpenOptions { + OpenOptions { + read: true, + write: false, + intent: OpenIntent::Open, + directory: true, + non_directory: false, + delete_on_close: false, + } + } + + #[tokio::test] + async fn create_write_read_stat_close() { + let td = tempdir().unwrap(); + let backend = LocalFsBackend::new(td.path()).unwrap(); + + // Create + let h = backend.open(&p("hello.txt"), opts_create()).await.unwrap(); + let n = h.write(0, b"hello world").await.unwrap(); + assert_eq!(n, 11); + h.flush().await.unwrap(); + + // Stat + let info = h.stat().await.unwrap(); + assert_eq!(info.name, "hello.txt"); + assert_eq!(info.end_of_file, 11); + assert!(!info.is_directory); + assert!(info.last_write_time > 0); + h.close().await.unwrap(); + + // Reopen for read + let h2 = backend.open(&p("hello.txt"), opts_open_ro()).await.unwrap(); + let bytes = h2.read(0, 1024).await.unwrap(); + assert_eq!(&bytes[..], b"hello world"); + + // Short-read past EOF returns truncated + let bytes = h2.read(6, 1024).await.unwrap(); + assert_eq!(&bytes[..], b"world"); + + // Read past EOF returns empty + let bytes = h2.read(100, 1024).await.unwrap(); + assert!(bytes.is_empty()); + h2.close().await.unwrap(); + } + + #[tokio::test] + async fn list_dir_finds_created_file() { + let td = tempdir().unwrap(); + let backend = LocalFsBackend::new(td.path()).unwrap(); + let h = backend.open(&p("a.txt"), opts_create()).await.unwrap(); + h.close().await.unwrap(); + + let dir_h = backend + .open(&SmbPath::root(), opts_open_dir()) + .await + .unwrap(); + let entries = dir_h.list_dir(None).await.unwrap(); + assert!(entries.iter().any(|e| e.info.name == "a.txt")); + dir_h.close().await.unwrap(); + } + + #[tokio::test] + async fn read_only_rejects_writes() { + let td = tempdir().unwrap(); + // Pre-create a file via a writable backend so we have something to + // attempt to open RW. + { + let writable = LocalFsBackend::new(td.path()).unwrap(); + let h = writable.open(&p("x.txt"), opts_create()).await.unwrap(); + h.close().await.unwrap(); + } + + let backend = LocalFsBackend::new(td.path()).unwrap().read_only(); + assert!(backend.capabilities().is_read_only); + + // RW open should be rejected. + let err = backend + .open(&p("x.txt"), opts_open_rw()) + .await + .err() + .unwrap(); + assert!(matches!(err, SmbError::AccessDenied)); + + // Create should be rejected. + let err = backend + .open(&p("y.txt"), opts_create()) + .await + .err() + .unwrap(); + assert!(matches!(err, SmbError::AccessDenied)); + + // Pure read open is fine. + let h = backend.open(&p("x.txt"), opts_open_ro()).await.unwrap(); + // Writing through a handle obtained from a read-only backend would + // already be impossible — but if a backend ever yields one, the + // check still bites. + h.close().await.unwrap(); + + // unlink rejected. + let err = backend.unlink(&p("x.txt")).await.err().unwrap(); + assert!(matches!(err, SmbError::AccessDenied)); + } + + #[tokio::test] + async fn unlink_file_then_nonempty_dir_errors() { + let td = tempdir().unwrap(); + let backend = LocalFsBackend::new(td.path()).unwrap(); + + // Create & remove a file. + let h = backend.open(&p("doomed.txt"), opts_create()).await.unwrap(); + h.close().await.unwrap(); + backend.unlink(&p("doomed.txt")).await.unwrap(); + assert!(matches!( + backend.unlink(&p("doomed.txt")).await.err().unwrap(), + SmbError::NotFound + )); + + // Create a non-empty directory; unlink should fail with NotEmpty. + std::fs::create_dir(td.path().join("dir1")).unwrap(); + std::fs::write(td.path().join("dir1").join("inside"), b"x").unwrap(); + + let err = backend.unlink(&p("dir1")).await.err().unwrap(); + assert!( + matches!(err, SmbError::NotEmpty), + "expected NotEmpty, got {err:?}" + ); + + // Empty it and retry. + std::fs::remove_file(td.path().join("dir1").join("inside")).unwrap(); + backend.unlink(&p("dir1")).await.unwrap(); + } + + #[tokio::test] + async fn rename_within_root() { + let td = tempdir().unwrap(); + let backend = LocalFsBackend::new(td.path()).unwrap(); + + let h = backend.open(&p("old.txt"), opts_create()).await.unwrap(); + h.write(0, b"data").await.unwrap(); + h.close().await.unwrap(); + + backend.rename(&p("old.txt"), &p("new.txt")).await.unwrap(); + assert!(td.path().join("new.txt").exists()); + assert!(!td.path().join("old.txt").exists()); + + // Renaming over an existing target should fail. + let h = backend.open(&p("other.txt"), opts_create()).await.unwrap(); + h.close().await.unwrap(); + let err = backend + .rename(&p("other.txt"), &p("new.txt")) + .await + .err() + .unwrap(); + assert!(matches!(err, SmbError::Exists), "got {err:?}"); + } + + #[tokio::test] + async fn list_dir_pattern_matching() { + let td = tempdir().unwrap(); + let backend = LocalFsBackend::new(td.path()).unwrap(); + + for name in ["a.txt", "b.txt", "c.log", "README"] { + let h = backend.open(&p(name), opts_create()).await.unwrap(); + h.close().await.unwrap(); + } + + let dir_h = backend + .open(&SmbPath::root(), opts_open_dir()) + .await + .unwrap(); + + let txts = dir_h.list_dir(Some("*.txt")).await.unwrap(); + let names: Vec<_> = txts.iter().map(|e| e.info.name.as_str()).collect(); + assert_eq!(names.len(), 2, "expected 2 .txt files, got {names:?}"); + assert!(names.contains(&"a.txt")); + assert!(names.contains(&"b.txt")); + + // Single-char wildcard. + let one = dir_h.list_dir(Some("?.log")).await.unwrap(); + let names: Vec<_> = one.iter().map(|e| e.info.name.as_str()).collect(); + assert_eq!(names, vec!["c.log"]); + + // Case-insensitive. + let any_txt = dir_h.list_dir(Some("*.TXT")).await.unwrap(); + assert_eq!(any_txt.len(), 2); + + // "*" matches everything. + let all = dir_h.list_dir(Some("*")).await.unwrap(); + assert_eq!(all.len(), 4); + + dir_h.close().await.unwrap(); + } + + #[test] + fn glob_match_basics() { + assert!(glob_match("*", "anything")); + assert!(glob_match("*.txt", "foo.txt")); + assert!(!glob_match("*.txt", "foo.log")); + assert!(glob_match("a?c", "abc")); + assert!(!glob_match("a?c", "ac")); + assert!(glob_match("a*b*c", "axxxbxxxc")); + assert!(glob_match("FOO", "foo")); + assert!(glob_match("", "")); + assert!(!glob_match("", "a")); + } + + #[test] + fn filetime_round_trip() { + let now = SystemTime::now(); + let ft = system_time_to_filetime(now); + let back = filetime_to_system_time(ft).unwrap(); + let delta = now + .duration_since(back) + .or_else(|e| Ok::<_, std::time::SystemTimeError>(e.duration())) + .unwrap(); + // 100ns granularity — round-trip should be sub-microsecond. + assert!(delta < Duration::from_micros(1), "delta = {delta:?}"); + } +} diff --git a/vendor/smb-server/src/fs/mod.rs b/vendor/smb-server/src/fs/mod.rs new file mode 100644 index 0000000..615e414 --- /dev/null +++ b/vendor/smb-server/src/fs/mod.rs @@ -0,0 +1,5 @@ +//! Local-filesystem [`ShareBackend`] for `smb-server`, sandboxed via `cap-std`. + +mod local; + +pub use local::LocalFsBackend; diff --git a/vendor/smb-server/src/handlers/change_notify.rs b/vendor/smb-server/src/handlers/change_notify.rs new file mode 100644 index 0000000..9edd8cc --- /dev/null +++ b/vendor/smb-server/src/handlers/change_notify.rs @@ -0,0 +1,19 @@ +//! CHANGE_NOTIFY handler — v1 always returns NOT_SUPPORTED. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; + +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::ntstatus; +use crate::server::ServerState; + +pub async fn handle( + _server: &Arc, + _conn: &Arc, + _hdr: &Smb2Header, + _body: &[u8], +) -> HandlerResponse { + HandlerResponse::err(ntstatus::STATUS_NOT_SUPPORTED) +} diff --git a/vendor/smb-server/src/handlers/close.rs b/vendor/smb-server/src/handlers/close.rs new file mode 100644 index 0000000..27432ce --- /dev/null +++ b/vendor/smb-server/src/handlers/close.rs @@ -0,0 +1,107 @@ +//! CLOSE handler. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::{CloseRequest, CloseResponse}; +use tracing::debug; + +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::handlers::shared::lookup_session_tree; +use crate::ntstatus; +use crate::server::ServerState; + +const FLAG_POSTQUERY_ATTRIB: u16 = 0x0001; + +pub async fn handle( + _server: &Arc, + conn: &Arc, + hdr: &Smb2Header, + body: &[u8], +) -> HandlerResponse { + let req = match CloseRequest::parse(body) { + Ok(r) => r, + Err(_) => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + let tree_arc = match lookup_session_tree(conn, hdr).await { + Ok(t) => t, + Err(s) => return HandlerResponse::err(s), + }; + let removed = { + let tree = tree_arc.write().await; + let mut opens = tree.opens.write().await; + opens.remove(&req.file_id) + }; + let open_arc = match removed { + Some(o) => o, + None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), + }; + + // Pull state out, close the handle, then optionally unlink. + let mut open = open_arc.write().await; + let handle = open.handle.take(); + let path = open.last_path.clone(); + let delete_on_close = open.delete_on_close; + let want_attrs = req.flags & FLAG_POSTQUERY_ATTRIB != 0; + drop(open); + + // Stat before closing if needed. + let info_before_close = if want_attrs { + if let Some(h) = handle.as_ref() { + h.stat().await.ok() + } else { + None + } + } else { + None + }; + if let Some(h) = handle { + let _ = h.close().await; + } + if delete_on_close { + let tree = tree_arc.read().await; + let backend = tree.share.backend.clone(); + drop(tree); + if let Err(e) = backend.unlink(&path).await { + debug!(error = %e, "delete-on-close unlink failed"); + } + } + + let resp = CloseResponse { + structure_size: 60, + flags: req.flags & FLAG_POSTQUERY_ATTRIB, + reserved: 0, + creation_time: info_before_close + .as_ref() + .map(|i| i.creation_time) + .unwrap_or(0), + last_access_time: info_before_close + .as_ref() + .map(|i| i.last_access_time) + .unwrap_or(0), + last_write_time: info_before_close + .as_ref() + .map(|i| i.last_write_time) + .unwrap_or(0), + change_time: info_before_close + .as_ref() + .map(|i| i.change_time) + .unwrap_or(0), + allocation_size: info_before_close + .as_ref() + .map(|i| i.allocation_size) + .unwrap_or(0), + end_of_file: info_before_close + .as_ref() + .map(|i| i.end_of_file) + .unwrap_or(0), + file_attributes: info_before_close + .as_ref() + .map(|i| i.attributes()) + .unwrap_or(0), + }; + let mut buf = Vec::new(); + resp.write_to(&mut buf).expect("encode"); + HandlerResponse::ok(buf) +} diff --git a/vendor/smb-server/src/handlers/create.rs b/vendor/smb-server/src/handlers/create.rs new file mode 100644 index 0000000..789eaab --- /dev/null +++ b/vendor/smb-server/src/handlers/create.rs @@ -0,0 +1,194 @@ +//! CREATE handler — open or create a file/directory and allocate a FileId. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::{CreateRequest, CreateResponse}; +use tracing::{debug, warn}; + +use crate::backend::{OpenIntent, OpenOptions}; +use crate::builder::Access; +use crate::conn::state::{Connection, Open}; +use crate::dispatch::HandlerResponse; +use crate::handlers::shared::lookup_session_tree; +use crate::ntstatus; +use crate::path::SmbPath; +use crate::server::ServerState; +use crate::utils::utf16le_to_units; + +// MS-SMB2 §2.2.13 access mask flags +const FILE_READ_DATA: u32 = 0x0000_0001; +const FILE_WRITE_DATA: u32 = 0x0000_0002; +const FILE_APPEND_DATA: u32 = 0x0000_0004; +const FILE_READ_ATTRIBUTES: u32 = 0x0000_0080; +const FILE_WRITE_ATTRIBUTES: u32 = 0x0000_0100; +const DELETE: u32 = 0x0001_0000; +const GENERIC_READ: u32 = 0x8000_0000; +const GENERIC_WRITE: u32 = 0x4000_0000; +const GENERIC_ALL: u32 = 0x1000_0000; +const MAX_ALLOWED: u32 = 0x0200_0000; + +// CreateOptions +const FILE_DIRECTORY_FILE: u32 = 0x0000_0001; +const FILE_NON_DIRECTORY_FILE: u32 = 0x0000_0040; +const FILE_DELETE_ON_CLOSE: u32 = 0x0000_1000; + +// CreateDisposition +const FILE_SUPERSEDE: u32 = 0x0000_0000; +const FILE_OPEN: u32 = 0x0000_0001; +const FILE_CREATE: u32 = 0x0000_0002; +const FILE_OPEN_IF: u32 = 0x0000_0003; +const FILE_OVERWRITE: u32 = 0x0000_0004; +const FILE_OVERWRITE_IF: u32 = 0x0000_0005; + +// CreateAction in response (MS-SMB2 §2.2.14) +const FILE_OPENED: u32 = 0x0000_0001; +const FILE_CREATED: u32 = 0x0000_0002; + +pub async fn handle( + _server: &Arc, + conn: &Arc, + hdr: &Smb2Header, + body: &[u8], +) -> HandlerResponse { + let req = match CreateRequest::parse(body) { + Ok(r) => r, + Err(_) => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + + let tree_arc = match lookup_session_tree(conn, hdr).await { + Ok(t) => t, + Err(s) => return HandlerResponse::err(s), + }; + let tree = tree_arc.read().await; + let granted = tree.granted_access; + let backend = tree.share.backend.clone(); + drop(tree); + + // Decode path. + let units = match utf16le_to_units(&req.name) { + Some(u) => u, + None => return HandlerResponse::err(ntstatus::STATUS_OBJECT_NAME_INVALID), + }; + let path = match SmbPath::from_utf16(&units) { + Ok(p) => p, + Err(_) => return HandlerResponse::err(ntstatus::STATUS_OBJECT_NAME_INVALID), + }; + + // Translate disposition. + let intent = match req.create_disposition { + FILE_SUPERSEDE | FILE_OVERWRITE_IF => OpenIntent::OverwriteOrCreate, + FILE_OPEN => OpenIntent::Open, + FILE_CREATE => OpenIntent::Create, + FILE_OPEN_IF => OpenIntent::OpenOrCreate, + FILE_OVERWRITE => OpenIntent::Truncate, + _ => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + + // Translate desired access into read/write hints. + let want_read = req.desired_access + & (FILE_READ_DATA | FILE_READ_ATTRIBUTES | GENERIC_READ | GENERIC_ALL | MAX_ALLOWED) + != 0; + let want_write = req.desired_access + & (FILE_WRITE_DATA + | FILE_APPEND_DATA + | FILE_WRITE_ATTRIBUTES + | DELETE + | GENERIC_WRITE + | GENERIC_ALL + | MAX_ALLOWED) + != 0; + + // Reject writes on a read-only tree. + if want_write && !granted.allows_write() { + warn!(path = %path, "write open on read-only tree"); + return HandlerResponse::err(ntstatus::STATUS_ACCESS_DENIED); + } + // Disposition that creates: requires write permission. + if !granted.allows_write() + && matches!( + intent, + OpenIntent::Create + | OpenIntent::OpenOrCreate + | OpenIntent::OverwriteOrCreate + | OpenIntent::Truncate + ) + { + return HandlerResponse::err(ntstatus::STATUS_ACCESS_DENIED); + } + + let directory = req.create_options & FILE_DIRECTORY_FILE != 0; + let non_directory = req.create_options & FILE_NON_DIRECTORY_FILE != 0; + if directory && non_directory { + return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER); + } + let delete_on_close = req.create_options & FILE_DELETE_ON_CLOSE != 0; + + let opts = OpenOptions { + read: want_read || !want_write, + write: want_write, + intent, + directory, + non_directory, + delete_on_close, + }; + + let handle = match backend.open(&path, opts).await { + Ok(h) => h, + Err(e) => { + debug!(error = %e, path = %path, "backend open failed"); + return HandlerResponse::err(e.to_nt_status()); + } + }; + + // Stat for the response. + let info = match handle.stat().await { + Ok(i) => i, + Err(e) => { + let _ = handle.close().await; + return HandlerResponse::err(e.to_nt_status()); + } + }; + + // Allocate FileId, register Open. + let tree = tree_arc.write().await; + let file_id = tree.alloc_file_id(); + let open = Open::new( + file_id, + handle, + if want_write { granted } else { Access::Read }, + path, + info.is_directory, + delete_on_close, + ); + let open_arc = Arc::new(tokio::sync::RwLock::new(open)); + tree.opens.write().await.insert(file_id, open_arc); + drop(tree); + + let create_action = match intent { + OpenIntent::Create => FILE_CREATED, + OpenIntent::OpenOrCreate | OpenIntent::OverwriteOrCreate => FILE_OPENED, + OpenIntent::Open | OpenIntent::Truncate => FILE_OPENED, + }; + let resp = CreateResponse { + structure_size: 89, + oplock_level: 0, + flags: 0, + create_action, + creation_time: info.creation_time, + last_access_time: info.last_access_time, + last_write_time: info.last_write_time, + change_time: info.change_time, + allocation_size: info.allocation_size, + end_of_file: info.end_of_file, + file_attributes: info.attributes(), + reserved2: 0, + file_id, + create_contexts_offset: 0, + create_contexts_length: 0, + create_contexts: vec![], + }; + let mut buf = Vec::new(); + resp.write_to(&mut buf).expect("encode"); + HandlerResponse::ok(buf) +} diff --git a/vendor/smb-server/src/handlers/echo.rs b/vendor/smb-server/src/handlers/echo.rs new file mode 100644 index 0000000..e50b10f --- /dev/null +++ b/vendor/smb-server/src/handlers/echo.rs @@ -0,0 +1,21 @@ +//! ECHO handler. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::EchoResponse; + +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::server::ServerState; + +pub async fn handle( + _server: &Arc, + _conn: &Arc, + _hdr: &Smb2Header, + _body: &[u8], +) -> HandlerResponse { + let mut buf = Vec::new(); + EchoResponse::default().write_to(&mut buf).expect("encode"); + HandlerResponse::ok(buf) +} diff --git a/vendor/smb-server/src/handlers/flush.rs b/vendor/smb-server/src/handlers/flush.rs new file mode 100644 index 0000000..10efd51 --- /dev/null +++ b/vendor/smb-server/src/handlers/flush.rs @@ -0,0 +1,46 @@ +//! FLUSH handler. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::{FileId, FlushRequest, FlushResponse}; + +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::handlers::shared::{lookup_open, lookup_session_tree}; +use crate::ntstatus; +use crate::server::ServerState; + +pub async fn handle( + _server: &Arc, + conn: &Arc, + hdr: &Smb2Header, + body: &[u8], +) -> HandlerResponse { + let req = match FlushRequest::parse(body) { + Ok(r) => r, + Err(_) => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + let fid = FileId::new(req.file_id_persistent, req.file_id_volatile); + let tree_arc = match lookup_session_tree(conn, hdr).await { + Ok(t) => t, + Err(s) => return HandlerResponse::err(s), + }; + let open_arc = match lookup_open(&tree_arc, fid).await { + Some(o) => o, + None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), + }; + let res = { + let open = open_arc.read().await; + match open.handle.as_ref() { + Some(h) => h.flush().await, + None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), + } + }; + if let Err(e) = res { + return HandlerResponse::err(e.to_nt_status()); + } + let mut buf = Vec::new(); + FlushResponse::default().write_to(&mut buf).expect("encode"); + HandlerResponse::ok(buf) +} diff --git a/vendor/smb-server/src/handlers/ioctl.rs b/vendor/smb-server/src/handlers/ioctl.rs new file mode 100644 index 0000000..2bf59ec --- /dev/null +++ b/vendor/smb-server/src/handlers/ioctl.rs @@ -0,0 +1,59 @@ +//! IOCTL handler — handles FSCTL_VALIDATE_NEGOTIATE_INFO; everything else +//! returns NOT_SUPPORTED. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::{Fsctl, IoctlRequest, IoctlResponse}; + +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::handlers::negotiate::{NEGOTIATE_CAPABILITIES, NEGOTIATE_SECURITY_MODE}; +use crate::ntstatus; +use crate::server::ServerState; + +pub async fn handle( + server: &Arc, + conn: &Arc, + _hdr: &Smb2Header, + body: &[u8], +) -> HandlerResponse { + let req = match IoctlRequest::parse(body) { + Ok(r) => r, + Err(_) => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + + match req.fsctl() { + Fsctl::ValidateNegotiateInfo => { + // Build VALIDATE_NEGOTIATE_INFO_RESPONSE per MS-SMB2 §2.2.32.6: + // Capabilities (4) | Guid (16) | SecurityMode (2) | Dialect (2) = 24 bytes. + let dialect = conn.dialect.read().await.map(|d| d.as_u16()).unwrap_or(0); + let mut out = Vec::with_capacity(24); + out.extend_from_slice(&NEGOTIATE_CAPABILITIES.to_le_bytes()); + out.extend_from_slice(server.config.server_guid.as_bytes()); + out.extend_from_slice(&NEGOTIATE_SECURITY_MODE.to_le_bytes()); + out.extend_from_slice(&dialect.to_le_bytes()); + + let resp = IoctlResponse { + structure_size: 49, + reserved: 0, + ctl_code: req.ctl_code, + file_id: req.file_id, + input_offset: 0, + input_count: 0, + output_offset: 0x70, + output_count: out.len() as u32, + flags: 0, + reserved2: 0, + output: out, + }; + let mut buf = Vec::new(); + resp.write_to(&mut buf).expect("IOCTL response encodes"); + HandlerResponse::ok(buf) + } + Fsctl::DfsGetReferrals | Fsctl::DfsGetReferralsEx => { + HandlerResponse::err(ntstatus::STATUS_FS_DRIVER_REQUIRED) + } + _ => HandlerResponse::err(ntstatus::STATUS_NOT_SUPPORTED), + } +} diff --git a/vendor/smb-server/src/handlers/lock.rs b/vendor/smb-server/src/handlers/lock.rs new file mode 100644 index 0000000..d7e449c --- /dev/null +++ b/vendor/smb-server/src/handlers/lock.rs @@ -0,0 +1,21 @@ +//! LOCK handler — v1 returns success without enforcing locks. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::LockResponse; + +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::server::ServerState; + +pub async fn handle( + _server: &Arc, + _conn: &Arc, + _hdr: &Smb2Header, + _body: &[u8], +) -> HandlerResponse { + let mut buf = Vec::new(); + LockResponse::default().write_to(&mut buf).expect("encode"); + HandlerResponse::ok(buf) +} diff --git a/vendor/smb-server/src/handlers/logoff.rs b/vendor/smb-server/src/handlers/logoff.rs new file mode 100644 index 0000000..1f0d58b --- /dev/null +++ b/vendor/smb-server/src/handlers/logoff.rs @@ -0,0 +1,28 @@ +//! LOGOFF handler. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::LogoffResponse; + +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::ntstatus; +use crate::server::ServerState; + +pub async fn handle( + _server: &Arc, + conn: &Arc, + hdr: &Smb2Header, + _body: &[u8], +) -> HandlerResponse { + if hdr.session_id == 0 { + return HandlerResponse::err(ntstatus::STATUS_USER_SESSION_DELETED); + } + conn.close_session(hdr.session_id).await; + let mut buf = Vec::new(); + LogoffResponse::default() + .write_to(&mut buf) + .expect("encode"); + HandlerResponse::ok(buf) +} diff --git a/vendor/smb-server/src/handlers/mod.rs b/vendor/smb-server/src/handlers/mod.rs new file mode 100644 index 0000000..62b21e4 --- /dev/null +++ b/vendor/smb-server/src/handlers/mod.rs @@ -0,0 +1,64 @@ +//! Per-command handlers. +//! +//! Each function here builds a `HandlerResponse` for a specific SMB2 command. +//! Handlers receive the parsed request header and a slice of the body bytes; +//! they return either a successful body or `HandlerResponse::err(ntstatus)`. + +use std::sync::Arc; + +use crate::proto::header::{Command, Smb2Header}; + +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::ntstatus; +use crate::server::ServerState; + +mod change_notify; +mod close; +mod create; +mod echo; +mod flush; +mod ioctl; +mod lock; +mod logoff; +pub(crate) mod negotiate; +mod oplock_break; +mod query_directory; +mod query_info; +mod read; +mod session_setup; +mod set_info; +pub(crate) mod shared; +mod tree_connect; +mod tree_disconnect; +mod write; + +/// Top-level command router. +pub async fn dispatch_command( + server: &Arc, + conn: &Arc, + hdr: &Smb2Header, + body: &[u8], +) -> HandlerResponse { + match hdr.command { + Command::Negotiate => negotiate::handle(server, conn, hdr, body).await, + Command::SessionSetup => session_setup::handle(server, conn, hdr, body).await, + Command::Logoff => logoff::handle(server, conn, hdr, body).await, + Command::TreeConnect => tree_connect::handle(server, conn, hdr, body).await, + Command::TreeDisconnect => tree_disconnect::handle(server, conn, hdr, body).await, + Command::Create => create::handle(server, conn, hdr, body).await, + Command::Close => close::handle(server, conn, hdr, body).await, + Command::Flush => flush::handle(server, conn, hdr, body).await, + Command::Read => read::handle(server, conn, hdr, body).await, + Command::Write => write::handle(server, conn, hdr, body).await, + Command::Lock => lock::handle(server, conn, hdr, body).await, + Command::Ioctl => ioctl::handle(server, conn, hdr, body).await, + Command::Echo => echo::handle(server, conn, hdr, body).await, + Command::QueryDirectory => query_directory::handle(server, conn, hdr, body).await, + Command::ChangeNotify => change_notify::handle(server, conn, hdr, body).await, + Command::QueryInfo => query_info::handle(server, conn, hdr, body).await, + Command::SetInfo => set_info::handle(server, conn, hdr, body).await, + Command::OplockBreak => oplock_break::handle(server, conn, hdr, body).await, + Command::Cancel => HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + } +} diff --git a/vendor/smb-server/src/handlers/negotiate.rs b/vendor/smb-server/src/handlers/negotiate.rs new file mode 100644 index 0000000..8087cb1 --- /dev/null +++ b/vendor/smb-server/src/handlers/negotiate.rs @@ -0,0 +1,223 @@ +//! NEGOTIATE handler. + +use std::sync::Arc; + +use crate::proto::auth::spnego::encode_init_response; +use crate::proto::crypto::SigningAlgo; +use crate::proto::header::Smb2Header; +use crate::proto::messages::{ + Dialect, NegotiateContext, NegotiateRequest, NegotiateResponse, PreauthIntegrityCapabilities, + SigningCapabilities, +}; +use tracing::info; +use uuid::Uuid; + +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::ntstatus; +use crate::server::ServerState; +use crate::utils::{fill_random, now_filetime}; + +// MS-SMB2 §2.2.4 SecurityMode bits. Keep SIGNING_REQUIRED clear: anonymous +// Linux cifs mounts do not send enough NTLM material for the server to derive +// matching SMB3 signing keys. +pub(crate) const NEGOTIATE_SECURITY_MODE: u16 = 0x0001; + +const CAP_DFS: u32 = 0x0000_0001; +const CAP_LEASING: u32 = 0x0000_0002; +const CAP_LARGE_MTU: u32 = 0x0000_0004; +pub(crate) const NEGOTIATE_CAPABILITIES: u32 = CAP_DFS | CAP_LEASING | CAP_LARGE_MTU; + +pub async fn handle( + server: &Arc, + conn: &Arc, + _hdr: &Smb2Header, + body: &[u8], +) -> HandlerResponse { + let req = match NegotiateRequest::parse(body) { + Ok(r) => r, + Err(_) => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + + // Pick the highest dialect we support that the client offered. + const SUPPORTED: &[u16] = &[0x0202, 0x0210, 0x0300, 0x0302, 0x0311]; + let mut chosen: Option = None; + for &d in &req.dialects { + if SUPPORTED.contains(&d) { + chosen = match chosen { + None => Some(d), + Some(prev) if d > prev => Some(d), + Some(prev) => Some(prev), + }; + } + } + let chosen = match chosen { + Some(d) => d, + None => return HandlerResponse::err(ntstatus::STATUS_NOT_SUPPORTED), + }; + let dialect = match Dialect::from_u16(chosen) { + Some(dialect) => dialect, + None => return HandlerResponse::err(ntstatus::STATUS_NOT_SUPPORTED), + }; + *conn.dialect.write().await = Some(dialect); + *conn.client_guid.write().await = Uuid::from_bytes(req.client_guid); + *conn.signing_algo.write().await = match dialect { + Dialect::Smb202 | Dialect::Smb210 => SigningAlgo::HmacSha256, + _ => SigningAlgo::AesCmac, + }; + + // Build SPNEGO security blob (mech-list-only, advertising NTLMSSP). + let security_blob = encode_init_response(); + let security_buffer_offset: u16 = 64 + 64; // SMB2 header + fixed NEG response (64 bytes) + let security_buffer_length: u16 = security_blob.len() as u16; + + // For 3.1.1 build negotiate contexts. + let mut contexts_bytes: Vec = Vec::new(); + let mut context_count: u16 = 0; + let mut negotiate_context_offset: u32 = 0; + + if dialect == Dialect::Smb311 { + // PREAUTH_INTEGRITY_CAPABILITIES + let mut salt = [0u8; 32]; + fill_random(&mut salt); + let preauth_caps = PreauthIntegrityCapabilities { + hash_algorithm_count: 1, + salt_length: 32, + hash_algorithms: vec![PreauthIntegrityCapabilities::HASH_SHA512], + salt: salt.to_vec(), + }; + let preauth_data = { + use binrw::BinWrite; + let mut c = std::io::Cursor::new(Vec::new()); + BinWrite::write(&preauth_caps, &mut c).expect("preauth negotiate context encodes"); + c.into_inner() + }; + let preauth_ctx = NegotiateContext { + context_type: NegotiateContext::TYPE_PREAUTH_INTEGRITY, + data_length: preauth_data.len() as u16, + reserved: 0, + data: preauth_data, + }; + + // SIGNING_CAPABILITIES — advertise AES-CMAC. + let signing_caps = SigningCapabilities { + signing_algorithm_count: 1, + signing_algorithms: vec![SigningCapabilities::ALGORITHM_AES_CMAC], + }; + let signing_data = { + use binrw::BinWrite; + let mut c = std::io::Cursor::new(Vec::new()); + BinWrite::write(&signing_caps, &mut c).expect("signing negotiate context encodes"); + c.into_inner() + }; + let signing_ctx = NegotiateContext { + context_type: NegotiateContext::TYPE_SIGNING, + data_length: signing_data.len() as u16, + reserved: 0, + data: signing_data, + }; + + let ctxs = vec![preauth_ctx, signing_ctx]; + if let Err(e) = NegotiateContext::encode_list(&ctxs, &mut contexts_bytes) { + tracing::error!(error = %e, "encode_list failed"); + return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER); + } + context_count = ctxs.len() as u16; + + // The contexts go after security buffer, 8-byte aligned. + let post_security = security_buffer_offset as u32 + security_buffer_length as u32; + // Round up to next multiple of 8 from the start of the SMB2 header. + negotiate_context_offset = (post_security + 7) & !7; + } + + let max_read_size = *conn.max_read_size.read().await; + let max_write_size = *conn.max_write_size.read().await; + let max_transact_size = max_read_size; // common practice + + let resp = NegotiateResponse { + structure_size: 65, + security_mode: NEGOTIATE_SECURITY_MODE, + dialect_revision: chosen, + negotiate_context_count_or_reserved: context_count, + server_guid: *server.config.server_guid.as_bytes(), + capabilities: NEGOTIATE_CAPABILITIES, + max_transact_size, + max_read_size, + max_write_size, + system_time: now_filetime(), + server_start_time: server.server_start_filetime, + security_buffer_offset, + security_buffer_length, + negotiate_context_offset_or_reserved2: negotiate_context_offset, + security_buffer: security_blob, + }; + + let mut body_out = Vec::new(); + if let Err(e) = resp.write_to(&mut body_out) { + tracing::error!(error = %e, "encode NEGOTIATE response"); + return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER); + } + // Append padding to align contexts at `negotiate_context_offset`. + if dialect == Dialect::Smb311 && context_count > 0 { + let cur = 64 + body_out.len() as u32; // header + body so far + if cur < negotiate_context_offset { + let pad = (negotiate_context_offset - cur) as usize; + body_out.extend(std::iter::repeat_n(0u8, pad)); + } + body_out.extend_from_slice(&contexts_bytes); + } + info!(?dialect, "NEGOTIATE complete"); + let mut hr = HandlerResponse::ok(body_out); + hr.skip_signing = true; + hr +} + +/// Build the SMB2 NEGOTIATE response sent in reply to an SMB1 multi-protocol +/// NEGOTIATE_REQUEST that listed an SMB2 dialect (MS-SMB2 §3.3.5.3.1). +/// +/// We do NOT commit the connection dialect here — the client will follow up +/// with a real SMB2 NEGOTIATE which goes through [`handle`]. This response +/// only tells the client "yes, I speak SMB2; send me an SMB2 NEGOTIATE next". +pub async fn multi_protocol_response( + server: &Arc, + conn: &Arc, + chosen: u16, +) -> HandlerResponse { + let security_blob = encode_init_response(); + let security_buffer_offset: u16 = 64 + 64; + let security_buffer_length: u16 = security_blob.len() as u16; + let max_read_size = *conn.max_read_size.read().await; + let max_write_size = *conn.max_write_size.read().await; + let max_transact_size = max_read_size; + + let resp = NegotiateResponse { + structure_size: 65, + security_mode: NEGOTIATE_SECURITY_MODE, + dialect_revision: chosen, + negotiate_context_count_or_reserved: 0, + server_guid: *server.config.server_guid.as_bytes(), + capabilities: 0, + max_transact_size, + max_read_size, + max_write_size, + system_time: now_filetime(), + server_start_time: server.server_start_filetime, + security_buffer_offset, + security_buffer_length, + negotiate_context_offset_or_reserved2: 0, + security_buffer: security_blob, + }; + + let mut body_out = Vec::new(); + if let Err(e) = resp.write_to(&mut body_out) { + tracing::error!(error = %e, "encode multi-protocol NEGOTIATE response"); + return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER); + } + info!( + chosen = %format_args!("0x{chosen:04X}"), + "SMB1 multi-protocol -> SMB2" + ); + let mut hr = HandlerResponse::ok(body_out); + hr.skip_signing = true; + hr +} diff --git a/vendor/smb-server/src/handlers/oplock_break.rs b/vendor/smb-server/src/handlers/oplock_break.rs new file mode 100644 index 0000000..3b75a47 --- /dev/null +++ b/vendor/smb-server/src/handlers/oplock_break.rs @@ -0,0 +1,27 @@ +//! OPLOCK_BREAK handler — acknowledge breaks without granting oplocks. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::FileId; + +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::server::ServerState; + +pub async fn handle( + _server: &Arc, + _conn: &Arc, + _hdr: &Smb2Header, + _body: &[u8], +) -> HandlerResponse { + // Echo back the same shape as the notification — structure_size=24, level=0. + let mut buf = Vec::new(); + buf.extend_from_slice(&24u16.to_le_bytes()); // structure_size + buf.push(0); // OplockLevel + buf.push(0); // Reserved + buf.extend_from_slice(&0u32.to_le_bytes()); // Reserved2 + buf.extend_from_slice(&FileId::any().persistent.to_le_bytes()); + buf.extend_from_slice(&FileId::any().volatile.to_le_bytes()); + HandlerResponse::ok(buf) +} diff --git a/vendor/smb-server/src/handlers/query_directory.rs b/vendor/smb-server/src/handlers/query_directory.rs new file mode 100644 index 0000000..48e98b1 --- /dev/null +++ b/vendor/smb-server/src/handlers/query_directory.rs @@ -0,0 +1,136 @@ +//! QUERY_DIRECTORY handler. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::{FileInfoClass, QueryDirectoryRequest, QueryDirectoryResponse}; + +use crate::conn::state::{Connection, DirCursor}; +use crate::dispatch::HandlerResponse; +use crate::handlers::shared::{lookup_open, lookup_session_tree}; +use crate::info_class::{align8, encode_dir_entry}; +use crate::ntstatus; +use crate::server::ServerState; +use crate::utils::utf16le_to_string; + +pub async fn handle( + _server: &Arc, + conn: &Arc, + hdr: &Smb2Header, + body: &[u8], +) -> HandlerResponse { + let req = match QueryDirectoryRequest::parse(body) { + Ok(r) => r, + Err(_) => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + if FileInfoClass::from_u8(req.file_information_class).is_none() { + return HandlerResponse::err(ntstatus::STATUS_INVALID_INFO_CLASS); + } + let class_byte = req.file_information_class; + + let tree_arc = match lookup_session_tree(conn, hdr).await { + Ok(t) => t, + Err(s) => return HandlerResponse::err(s), + }; + let open_arc = match lookup_open(&tree_arc, req.file_id).await { + Some(o) => o, + None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), + }; + + let pattern_str = utf16le_to_string(&req.file_name); + let pattern: Option = if pattern_str.is_empty() || pattern_str == "*" { + None + } else { + Some(pattern_str) + }; + + let restart = req.flags & QueryDirectoryRequest::FLAG_RESTART_SCANS != 0 + || req.flags & QueryDirectoryRequest::FLAG_REOPEN != 0; + let single_entry = req.flags & QueryDirectoryRequest::FLAG_RETURN_SINGLE_ENTRY != 0; + + // Populate or refresh the cursor. + { + let mut open = open_arc.write().await; + if !open.is_directory { + return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER); + } + if open.search_state.is_none() || restart { + let entries = match open.handle.as_ref() { + Some(h) => h.list_dir(pattern.as_deref()).await, + None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), + }; + let entries = match entries { + Ok(e) => e, + Err(e) => return HandlerResponse::err(e.to_nt_status()), + }; + open.search_state = Some(DirCursor { + entries, + next: 0, + pattern: pattern.clone(), + }); + } + } + + // Encode entries into the output buffer. + let mut buf: Vec = Vec::new(); + let mut last_offset_pos: Option = None; + let cap = req.output_buffer_length as usize; + + { + let mut open = open_arc.write().await; + let cursor = open.search_state.as_mut().expect("populated above"); + loop { + if cursor.next >= cursor.entries.len() { + break; + } + let entry = &cursor.entries[cursor.next]; + let file_index = entry.info.file_index; + let mut bytes = encode_dir_entry(class_byte, entry, file_index); + if bytes.is_empty() { + cursor.next += 1; + continue; + } + + // Determine total size with padding for chaining. + let entry_aligned = align8(bytes.len()); + // If this is *not* the first entry, we already padded the previous + // entry up to entry_aligned. We commit only if total fits. + let prev_len = buf.len(); + let total_after = prev_len + entry_aligned; + if total_after > cap && !buf.is_empty() { + // No room for this entry; stop. + break; + } + // Patch previous NextEntryOffset. + if let Some(prev_off) = last_offset_pos { + let delta = (prev_len - prev_off) as u32; + buf[prev_off..prev_off + 4].copy_from_slice(&delta.to_le_bytes()); + } + // Track NextEntryOffset position for the entry we are appending. + last_offset_pos = Some(prev_len); + // Append the entry, then pad to 8. + let target_len = prev_len + entry_aligned; + buf.append(&mut bytes); + while buf.len() < target_len { + buf.push(0); + } + cursor.next += 1; + if single_entry { + break; + } + } + } + if buf.is_empty() { + return HandlerResponse::err(ntstatus::STATUS_NO_MORE_FILES); + } + + let resp = QueryDirectoryResponse { + structure_size: 9, + output_buffer_offset: 64 + 8, + output_buffer_length: buf.len() as u32, + buffer: buf, + }; + let mut out = Vec::new(); + resp.write_to(&mut out).expect("encode"); + HandlerResponse::ok(out) +} diff --git a/vendor/smb-server/src/handlers/query_info.rs b/vendor/smb-server/src/handlers/query_info.rs new file mode 100644 index 0000000..8685602 --- /dev/null +++ b/vendor/smb-server/src/handlers/query_info.rs @@ -0,0 +1,144 @@ +//! QUERY_INFO handler. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::{InfoType, QueryInfoRequest, QueryInfoResponse}; + +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::handlers::shared::{lookup_open, lookup_session_tree}; +use crate::info_class as ic; +use crate::ntstatus; +use crate::server::ServerState; + +const FILE_DEVICE_DISK: u32 = 0x0000_0007; +const FILE_REMOTE_DEVICE: u32 = 0x0000_0010; + +// FS attribute flags (MS-FSCC §2.5.1) +const FILE_CASE_SENSITIVE_SEARCH: u32 = 0x0000_0001; +const FILE_CASE_PRESERVED_NAMES: u32 = 0x0000_0002; +const FILE_UNICODE_ON_DISK: u32 = 0x0000_0004; +const FILE_PERSISTENT_ACLS: u32 = 0x0000_0008; +const FILE_FILE_COMPRESSION: u32 = 0x0000_0010; +const FILE_SUPPORTS_HARD_LINKS: u32 = 0x0040_0000; +const FILE_SUPPORTS_EXTENDED_ATTRIBUTES: u32 = 0x0080_0000; + +pub async fn handle( + _server: &Arc, + conn: &Arc, + hdr: &Smb2Header, + body: &[u8], +) -> HandlerResponse { + let req = match QueryInfoRequest::parse(body) { + Ok(r) => r, + Err(_) => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + let info_type = match req.info_type_enum() { + Some(t) => t, + None => return HandlerResponse::err(ntstatus::STATUS_INVALID_INFO_CLASS), + }; + + let tree_arc = match lookup_session_tree(conn, hdr).await { + Ok(t) => t, + Err(s) => return HandlerResponse::err(s), + }; + let open_arc = match lookup_open(&tree_arc, req.file_id).await { + Some(o) => o, + None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), + }; + + // Pull the file index (we use FileId.volatile as the unique handle id). + let (file_index, info_res) = { + let open = open_arc.read().await; + let fid = open.file_id; + match open.handle.as_ref() { + Some(h) => (fid.volatile, h.stat().await), + None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), + } + }; + + let buf: Vec = match info_type { + InfoType::File => { + let info = match info_res { + Ok(i) => i, + Err(e) => return HandlerResponse::err(e.to_nt_status()), + }; + match req.file_information_class { + ic::FILE_BASIC_INFORMATION => ic::encode_file_basic_information(&info), + ic::FILE_STANDARD_INFORMATION => ic::encode_file_standard_information(&info), + ic::FILE_INTERNAL_INFORMATION => ic::encode_file_internal_information(file_index), + ic::FILE_EA_INFORMATION => ic::encode_file_ea_information(), + ic::FILE_FULL_EA_INFORMATION => { + return HandlerResponse::err(ntstatus::STATUS_NO_EAS_ON_FILE); + } + ic::FILE_ACCESS_INFORMATION => ic::encode_file_access_information(0x001F_01FF), + ic::FILE_POSITION_INFORMATION => ic::encode_file_position_information(), + ic::FILE_MODE_INFORMATION => ic::encode_file_mode_information(0), + ic::FILE_ALIGNMENT_INFORMATION => ic::encode_file_alignment_information(), + ic::FILE_NAME_INFORMATION => ic::encode_file_name_information(&info.name), + ic::FILE_ALL_INFORMATION => { + ic::encode_file_all_information(&info, file_index, 0x001F_01FF) + } + ic::FILE_NETWORK_OPEN_INFORMATION => { + ic::encode_file_network_open_information(&info) + } + ic::FILE_STREAM_INFORMATION => ic::encode_file_stream_information(&info), + _ => return HandlerResponse::err(ntstatus::STATUS_INVALID_INFO_CLASS), + } + } + InfoType::FileSystem => { + // For FS info we use the open's tree's backend for context. + let creation_time = info_res.as_ref().map(|i| i.creation_time).unwrap_or(0); + match req.file_information_class { + ic::FS_VOLUME_INFORMATION => { + ic::encode_fs_volume_information(creation_time, 0xCAFE_BABE, "smb-server") + } + ic::FS_SIZE_INFORMATION => { + // 1 PiB free pseudo-volume, 4 KiB cluster. + ic::encode_fs_size_information( + 1u64 << 40, // total + 1u64 << 39, // free + 1, // sectors per cluster + 4096, // bytes per sector + ) + } + ic::FS_DEVICE_INFORMATION => { + ic::encode_fs_device_information(FILE_DEVICE_DISK, FILE_REMOTE_DEVICE) + } + ic::FS_ATTRIBUTE_INFORMATION => ic::encode_fs_attribute_information( + FILE_CASE_SENSITIVE_SEARCH + | FILE_CASE_PRESERVED_NAMES + | FILE_UNICODE_ON_DISK + | FILE_PERSISTENT_ACLS + | FILE_FILE_COMPRESSION + | FILE_SUPPORTS_HARD_LINKS + | FILE_SUPPORTS_EXTENDED_ATTRIBUTES, + 255, + "NTFS", + ), + ic::FS_FULL_SIZE_INFORMATION => { + ic::encode_fs_full_size_information(1u64 << 40, 1u64 << 39, 1u64 << 39, 1, 4096) + } + _ => return HandlerResponse::err(ntstatus::STATUS_INVALID_INFO_CLASS), + } + } + InfoType::Security => ic::encode_minimal_security_descriptor(), + InfoType::Quota => return HandlerResponse::err(ntstatus::STATUS_NOT_SUPPORTED), + }; + + if buf.len() as u32 > req.output_buffer_length { + return HandlerResponse::err(ntstatus::STATUS_INFO_LENGTH_MISMATCH); + } + + let resp = QueryInfoResponse { + structure_size: 9, + output_buffer_offset: 64 + 8, + output_buffer_length: buf.len() as u32, + buffer: buf, + }; + let mut out = Vec::new(); + resp.write_to(&mut out) + .expect("QUERY_INFO response encodes"); + HandlerResponse::ok(out) +} diff --git a/vendor/smb-server/src/handlers/read.rs b/vendor/smb-server/src/handlers/read.rs new file mode 100644 index 0000000..72b001a --- /dev/null +++ b/vendor/smb-server/src/handlers/read.rs @@ -0,0 +1,62 @@ +//! READ handler. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::{ReadRequest, ReadResponse}; + +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::handlers::shared::{lookup_open, lookup_session_tree}; +use crate::ntstatus; +use crate::server::ServerState; + +pub async fn handle( + _server: &Arc, + conn: &Arc, + hdr: &Smb2Header, + body: &[u8], +) -> HandlerResponse { + let req = match ReadRequest::parse(body) { + Ok(r) => r, + Err(_) => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + let max_read = *conn.max_read_size.read().await; + if req.length > max_read { + return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER); + } + let tree_arc = match lookup_session_tree(conn, hdr).await { + Ok(t) => t, + Err(s) => return HandlerResponse::err(s), + }; + let open_arc = match lookup_open(&tree_arc, req.file_id).await { + Some(o) => o, + None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), + }; + let result = { + let open = open_arc.read().await; + match open.handle.as_ref() { + Some(h) => h.read(req.offset, req.length).await, + None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), + } + }; + let bytes = match result { + Ok(b) => b, + Err(e) => return HandlerResponse::err(e.to_nt_status()), + }; + if bytes.is_empty() && req.length > 0 { + return HandlerResponse::err(ntstatus::STATUS_END_OF_FILE); + } + let resp = ReadResponse { + structure_size: 17, + data_offset: ReadResponse::STANDARD_DATA_OFFSET, + reserved: 0, + data_length: bytes.len() as u32, + data_remaining: 0, + flags: 0, + data: bytes.to_vec(), + }; + let mut buf = Vec::new(); + resp.write_to(&mut buf).expect("encode"); + HandlerResponse::ok(buf) +} diff --git a/vendor/smb-server/src/handlers/session_setup.rs b/vendor/smb-server/src/handlers/session_setup.rs new file mode 100644 index 0000000..888744b --- /dev/null +++ b/vendor/smb-server/src/handlers/session_setup.rs @@ -0,0 +1,262 @@ +//! SESSION_SETUP handler — drives the SPNEGO + NTLMv2 state machine. + +use std::sync::Arc; + +use crate::proto::auth::ntlm::{Identity, NtlmServer, NtlmTargetInfo, UserCreds}; +use crate::proto::auth::spnego::{ + NegState, OID_NTLMSSP, decode_init_token, decode_resp_token, encode_resp_token, +}; +use crate::proto::crypto::signing_key_30; +use crate::proto::header::Smb2Header; +use crate::proto::messages::{Dialect, SessionSetupRequest, SessionSetupResponse}; +use tracing::{debug, info, warn}; + +use crate::conn::state::{Connection, Session}; +use crate::dispatch::HandlerResponse; +use crate::ntstatus; +use crate::server::ServerState; +use crate::utils::{fill_random, now_filetime}; + +pub async fn handle( + server: &Arc, + conn: &Arc, + hdr: &Smb2Header, + body: &[u8], +) -> HandlerResponse { + let req = match SessionSetupRequest::parse(body) { + Ok(r) => r, + Err(_) => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + + let blob = req.security_buffer; + if blob.is_empty() { + return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER); + } + if tracing::enabled!(tracing::Level::DEBUG) { + let mut first8 = String::with_capacity(16); + for b in blob.iter().take(8) { + use std::fmt::Write as _; + write!(&mut first8, "{b:02x}").expect("writing to String cannot fail"); + } + tracing::debug!( + first8 = %first8, + len = blob.len(), + sid = hdr.session_id, + "session setup blob" + ); + } + + // Decide which form the security blob takes: + // * GSS-API NegTokenInit — starts with 0x60. + // * SPNEGO NegTokenResp — starts with 0xa1 ([1] context tag). + // * Raw NTLMSSP message — starts with "NTLMSSP\0" (RFC 4178 + // §4.2.1 lets the client skip SPNEGO once the mech is settled; both + // Win11 reauth and Linux cifs.ko use this form). + const NTLMSSP_MAGIC: &[u8] = b"NTLMSSP\0"; + let inner_token: Vec; + let is_first_round: bool; + let is_raw_ntlmssp: bool; + if blob.starts_with(NTLMSSP_MAGIC) { + // Raw NTLMSSP. Decide round by message-type at offset 8. + let msg_type = if blob.len() >= 12 { + u32::from_le_bytes([blob[8], blob[9], blob[10], blob[11]]) + } else { + 0 + }; + // 1 = NEGOTIATE (first), 3 = AUTHENTICATE (second). 2 is server-only. + is_first_round = msg_type == 1; + is_raw_ntlmssp = true; + inner_token = blob.to_vec(); + } else if blob[0] == 0x60 { + // GSS-API outer wrapper — NegTokenInit. + let init = match decode_init_token(&blob) { + Ok(t) => t, + Err(e) => { + warn!(error = %e, "SPNEGO init decode failed"); + return HandlerResponse::err(ntstatus::STATUS_LOGON_FAILURE); + } + }; + if !init.mech_types.iter().any(|m| m == OID_NTLMSSP) { + return HandlerResponse::err(ntstatus::STATUS_NOT_SUPPORTED); + } + inner_token = init.mech_token.unwrap_or_default(); + is_first_round = true; + is_raw_ntlmssp = false; + } else { + // NegTokenResp follow-up. + let resp = match decode_resp_token(&blob) { + Ok(r) => r, + Err(e) => { + warn!(error = %e, "SPNEGO resp decode failed"); + return HandlerResponse::err(ntstatus::STATUS_LOGON_FAILURE); + } + }; + inner_token = resp.response_token.unwrap_or_default(); + is_first_round = false; + is_raw_ntlmssp = false; + } + + if is_first_round { + // Allocate a fresh session id and start the NTLM state machine. + let new_sid = conn.alloc_session_id(); + let mut server_challenge = [0u8; 8]; + fill_random(&mut server_challenge); + let netbios = server.config.netbios_name.clone(); + let mut acceptor = NtlmServer::new( + server_challenge, + NtlmTargetInfo::new(netbios.clone(), netbios.clone(), netbios, "", ""), + now_filetime(), + ); + + // Step 1: parse client NEGOTIATE. + if let Err(e) = acceptor.step1_negotiate(&inner_token) { + warn!(error = %e, "NTLM step1 failed"); + return HandlerResponse::err(ntstatus::STATUS_LOGON_FAILURE); + } + let challenge_blob = acceptor.challenge(); + // Reply form mirrors the request: raw NTLMSSP if the client skipped + // SPNEGO, else SPNEGO-wrapped. + let outbound = if is_raw_ntlmssp { + challenge_blob + } else { + encode_resp_token( + NegState::AcceptIncomplete, + Some(OID_NTLMSSP), + Some(&challenge_blob), + None, + ) + }; + + // Stash the acceptor for the next round; remember the form so the + // success response can match. + { + let mut pa = conn.pending_auths.write().await; + pa.insert( + new_sid, + Arc::new(std::sync::Mutex::new((acceptor, is_raw_ntlmssp))), + ); + } + + let body_out = + build_session_setup_response(ntstatus::STATUS_MORE_PROCESSING_REQUIRED, &outbound, 0); + return HandlerResponse { + body: body_out, + status: ntstatus::STATUS_MORE_PROCESSING_REQUIRED, + override_tree_id: None, + override_session_id: Some(new_sid), + skip_signing: true, // no key yet + take_preauth_snapshot_for_session: None, + }; + } + + // Follow-up round: look up pending acceptor by session id from header. + let sid = hdr.session_id; + if sid == 0 { + return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER); + } + let acceptor_arc = { + let mut pa = conn.pending_auths.write().await; + pa.remove(&sid) + }; + let acceptor_arc = match acceptor_arc { + Some(a) => a, + None => return HandlerResponse::err(ntstatus::STATUS_USER_SESSION_DELETED), + }; + let users = server.users.table.read().await.clone(); + let (outcome, raw_form) = { + let pair = acceptor_arc + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + let (acceptor, raw_form) = (&pair.0, pair.1); + let lookup = |u: &str, _d: &str| -> Option { users.get(u).cloned() }; + let outcome = match acceptor.authenticate(&inner_token, lookup) { + Ok(o) => o, + Err(e) => { + info!(error = %e, "NTLM authenticate failed"); + return HandlerResponse::err(ntstatus::STATUS_LOGON_FAILURE); + } + }; + (outcome, raw_form) + }; + + // Anonymous gating. + if matches!(outcome.identity, Identity::Anonymous) && !server.anonymous_allowed().await { + return HandlerResponse::err(ntstatus::STATUS_LOGON_FAILURE); + } + + let session_base_key = outcome.session_key; + let dialect = *conn.dialect.read().await; + let signing_key = match dialect { + Some(Dialect::Smb311) => [0u8; 16], + Some(_) => signing_key_30(&session_base_key), + None => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + + let session_flags = if matches!(outcome.identity, Identity::Anonymous) { + SessionSetupResponse::FLAG_IS_GUEST + } else { + 0 + }; + let signing_required = false; + + let session = Session::new( + sid, + outcome.identity.clone(), + session_base_key, + signing_key, + signing_required, + None, + ); + let session_arc = Arc::new(tokio::sync::RwLock::new(session)); + { + let mut sessions = conn.sessions.write().await; + sessions.insert(sid, session_arc); + } + + // Empty buffer for raw NTLMSSP path; SPNEGO accept-completed for SPNEGO. + let success_buf: Vec = if raw_form { + Vec::new() + } else { + empty_completed() + }; + let body_out = + build_session_setup_response(ntstatus::STATUS_SUCCESS, &success_buf, session_flags); + + let take_snapshot = if dialect == Some(Dialect::Smb311) { + Some(sid) + } else { + None + }; + + info!(?outcome.identity, "session established"); + + HandlerResponse { + body: body_out, + status: ntstatus::STATUS_SUCCESS, + override_tree_id: None, + override_session_id: Some(sid), + // Anonymous responses are not signed (no key). Signed responses for + // authenticated sessions get signed by the dispatcher's normal path. + skip_signing: matches!(outcome.identity, Identity::Anonymous), + take_preauth_snapshot_for_session: take_snapshot, + } +} + +fn build_session_setup_response(_status: u32, spnego_blob: &[u8], session_flags: u16) -> Vec { + let resp = SessionSetupResponse { + structure_size: 9, + session_flags, + security_buffer_offset: 64 + 8, // SMB2 header + fixed prefix + security_buffer_length: spnego_blob.len() as u16, + security_buffer: spnego_blob.to_vec(), + }; + let mut buf = Vec::new(); + resp.write_to(&mut buf) + .expect("SESSION_SETUP response encodes"); + debug!(len = buf.len(), "SESSION_SETUP response built"); + buf +} + +fn empty_completed() -> Vec { + encode_resp_token(NegState::AcceptCompleted, None, None, None) +} diff --git a/vendor/smb-server/src/handlers/set_info.rs b/vendor/smb-server/src/handlers/set_info.rs new file mode 100644 index 0000000..f8d4945 --- /dev/null +++ b/vendor/smb-server/src/handlers/set_info.rs @@ -0,0 +1,143 @@ +//! SET_INFO handler. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::{InfoType, SetInfoRequest, SetInfoResponse}; + +use crate::backend::FileTimes; +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::handlers::shared::{lookup_open, lookup_session_tree}; +use crate::info_class as ic; +use crate::ntstatus; +use crate::path::SmbPath; +use crate::server::ServerState; +use crate::utils::utf16le_to_units; + +pub async fn handle( + _server: &Arc, + conn: &Arc, + hdr: &Smb2Header, + body: &[u8], +) -> HandlerResponse { + let req = match SetInfoRequest::parse(body) { + Ok(r) => r, + Err(_) => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + let info_type = match InfoType::from_u8(req.info_type) { + Some(t) => t, + None => return HandlerResponse::err(ntstatus::STATUS_INVALID_INFO_CLASS), + }; + if !matches!(info_type, InfoType::File) { + return HandlerResponse::err(ntstatus::STATUS_NOT_SUPPORTED); + } + + let tree_arc = match lookup_session_tree(conn, hdr).await { + Ok(t) => t, + Err(s) => return HandlerResponse::err(s), + }; + let open_arc = match lookup_open(&tree_arc, req.file_id).await { + Some(o) => o, + None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), + }; + + let class = req.file_information_class; + let buffer = req.buffer; + let backend = { + let tree = tree_arc.read().await; + tree.share.backend.clone() + }; + + let result = match class { + ic::FILE_BASIC_INFORMATION => { + if buffer.len() < 36 { + return HandlerResponse::err(ntstatus::STATUS_INFO_LENGTH_MISMATCH); + } + let creation = u64::from_le_bytes(buffer[0..8].try_into().unwrap()); + let access = u64::from_le_bytes(buffer[8..16].try_into().unwrap()); + let write = u64::from_le_bytes(buffer[16..24].try_into().unwrap()); + let change = u64::from_le_bytes(buffer[24..32].try_into().unwrap()); + // 0 means "do not change", -1 (u64::MAX) means "do not change" too per spec. + let to_some = |v: u64| { + if v == 0 || v == u64::MAX { + None + } else { + Some(v) + } + }; + let times = FileTimes { + creation_time: to_some(creation), + last_access_time: to_some(access), + last_write_time: to_some(write), + change_time: to_some(change), + }; + let open = open_arc.read().await; + match open.handle.as_ref() { + Some(h) => h.set_times(times).await, + None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), + } + } + ic::FILE_END_OF_FILE_INFORMATION => { + if buffer.len() < 8 { + return HandlerResponse::err(ntstatus::STATUS_INFO_LENGTH_MISMATCH); + } + let new_len = u64::from_le_bytes(buffer[0..8].try_into().unwrap()); + let open = open_arc.read().await; + match open.handle.as_ref() { + Some(h) => h.truncate(new_len).await, + None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), + } + } + ic::FILE_DISPOSITION_INFORMATION => { + if buffer.is_empty() { + return HandlerResponse::err(ntstatus::STATUS_INFO_LENGTH_MISMATCH); + } + let mut open = open_arc.write().await; + open.delete_on_close = buffer[0] != 0; + Ok(()) + } + ic::FILE_RENAME_INFORMATION => { + // FILE_RENAME_INFORMATION layout (MS-FSCC §2.4.37): + // ReplaceIfExists (1) | Reserved (7) | RootDirectory (8) | FileNameLength (4) | FileName... + if buffer.len() < 20 { + return HandlerResponse::err(ntstatus::STATUS_INFO_LENGTH_MISMATCH); + } + let name_len = u32::from_le_bytes(buffer[16..20].try_into().unwrap()) as usize; + if buffer.len() < 20 + name_len { + return HandlerResponse::err(ntstatus::STATUS_INFO_LENGTH_MISMATCH); + } + let name_bytes = &buffer[20..20 + name_len]; + let units = match utf16le_to_units(name_bytes) { + Some(u) => u, + None => return HandlerResponse::err(ntstatus::STATUS_OBJECT_NAME_INVALID), + }; + let new_path = match SmbPath::from_utf16(&units) { + Ok(p) => p, + Err(_) => return HandlerResponse::err(ntstatus::STATUS_OBJECT_NAME_INVALID), + }; + let from = open_arc.read().await.last_path.clone(); + match backend.rename(&from, &new_path).await { + Ok(()) => { + open_arc.write().await.last_path = new_path; + Ok(()) + } + Err(e) => Err(e), + } + } + ic::FILE_ALLOCATION_INFORMATION => { + // We don't preallocate; respond OK. + Ok(()) + } + _ => return HandlerResponse::err(ntstatus::STATUS_NOT_SUPPORTED), + }; + + if let Err(e) = result { + return HandlerResponse::err(e.to_nt_status()); + } + let mut buf = Vec::new(); + SetInfoResponse::default() + .write_to(&mut buf) + .expect("encode"); + HandlerResponse::ok(buf) +} diff --git a/vendor/smb-server/src/handlers/shared.rs b/vendor/smb-server/src/handlers/shared.rs new file mode 100644 index 0000000..826c83d --- /dev/null +++ b/vendor/smb-server/src/handlers/shared.rs @@ -0,0 +1,46 @@ +//! Internal helpers shared across handlers — tree/open lookup, etc. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::FileId; +use tokio::sync::RwLock; + +use crate::conn::state::{Connection, Open, Session, TreeConnect}; +use crate::ntstatus; + +/// Look up the session and tree referenced by `hdr`, returning the tree +/// inside the session. Returns the appropriate NTSTATUS on miss. +pub async fn lookup_session_tree( + conn: &Arc, + hdr: &Smb2Header, +) -> Result>, u32> { + let tid = hdr.tree_id().ok_or(ntstatus::STATUS_INVALID_PARAMETER)?; + let sess_arc = lookup_session(conn, hdr.session_id).await?; + let sess = sess_arc.read().await; + let trees = sess.trees.read().await; + trees + .get(&tid) + .cloned() + .ok_or(ntstatus::STATUS_NETWORK_NAME_DELETED) +} + +pub async fn lookup_session(conn: &Arc, sid: u64) -> Result>, u32> { + if sid == 0 { + return Err(ntstatus::STATUS_USER_SESSION_DELETED); + } + let sessions = conn.sessions.read().await; + sessions + .get(&sid) + .cloned() + .ok_or(ntstatus::STATUS_USER_SESSION_DELETED) +} + +pub async fn lookup_open( + tree: &Arc>, + file_id: FileId, +) -> Option>> { + let tree = tree.read().await; + let opens = tree.opens.read().await; + opens.get(&file_id).cloned() +} diff --git a/vendor/smb-server/src/handlers/tree_connect.rs b/vendor/smb-server/src/handlers/tree_connect.rs new file mode 100644 index 0000000..89a09de --- /dev/null +++ b/vendor/smb-server/src/handlers/tree_connect.rs @@ -0,0 +1,140 @@ +//! TREE_CONNECT handler — share lookup + authorization. + +use std::sync::Arc; + +use crate::proto::auth::ntlm::Identity; +use crate::proto::header::Smb2Header; +use crate::proto::messages::{TreeConnectRequest, TreeConnectResponse}; +use tracing::{info, warn}; + +use crate::builder::Access; +use crate::conn::state::{Connection, TreeConnect}; +use crate::dispatch::HandlerResponse; +use crate::handlers::shared::lookup_session; +use crate::ntstatus; +use crate::server::{ServerState, ShareMode}; + +const SHARE_TYPE_DISK: u8 = 0x01; +const SHARE_TYPE_PIPE: u8 = 0x02; + +const FILE_GENERIC_READ: u32 = 0x0012_0089; +const FILE_GENERIC_EXECUTE: u32 = 0x0012_00A0; +const FILE_ALL_ACCESS: u32 = 0x001F_01FF; + +pub async fn handle( + server: &Arc, + conn: &Arc, + hdr: &Smb2Header, + body: &[u8], +) -> HandlerResponse { + let req = match TreeConnectRequest::parse(body) { + Ok(r) => r, + Err(_) => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + let path = req.path_str().unwrap_or_default(); + tracing::debug!(%path, "tree connect path"); + let share_name = match extract_share_name(&path) { + Some(s) => s, + None => { + tracing::warn!(%path, "tree connect: empty share name"); + return HandlerResponse::err(ntstatus::STATUS_BAD_NETWORK_NAME); + } + }; + tracing::debug!(%share_name, "tree connect lookup"); + let sess_arc = match lookup_session(conn, hdr.session_id).await { + Ok(s) => s, + Err(s) => return HandlerResponse::err(s), + }; + let sess = sess_arc.read().await; + let identity = sess.identity.clone(); + drop(sess); + + // IPC$: synthetic share. Accept at TREE_CONNECT (Windows always probes + // it before mounting an actual share); downstream CREATE/IOCTL on it + // return NotSupported via the no-op backend. + let share = if share_name.eq_ignore_ascii_case("IPC$") { + crate::server::ShareBindings::ipc() + } else { + match server.find_share(&share_name).await { + Some(s) => s, + None => return HandlerResponse::err(ntstatus::STATUS_BAD_NETWORK_NAME), + } + }; + + // Authorize. + let acl = share.acl.read().await; + let granted = match authorize(&acl.mode, &acl.users, &identity) { + Some(a) => a, + None => { + warn!(?identity, share = %share.name, "TREE_CONNECT denied"); + return HandlerResponse::err(ntstatus::STATUS_ACCESS_DENIED); + } + }; + drop(acl); + // Backend cap. + let granted = if share.backend.capabilities().is_read_only { + granted.clamp_to(Access::Read) + } else { + granted + }; + + let tree_id = sess_arc.read().await.alloc_tree_id(); + let tc = Arc::new(tokio::sync::RwLock::new(TreeConnect::new( + tree_id, + share.clone(), + granted, + ))); + { + let sess = sess_arc.read().await; + let mut trees = sess.trees.write().await; + trees.insert(tree_id, tc); + } + + let maximal_access = match granted { + Access::Read => FILE_GENERIC_READ | FILE_GENERIC_EXECUTE, + Access::ReadWrite => FILE_ALL_ACCESS, + }; + let resp = TreeConnectResponse { + structure_size: 16, + share_type: if share.is_ipc { + SHARE_TYPE_PIPE + } else { + SHARE_TYPE_DISK + }, + reserved: 0, + share_flags: 0, + capabilities: 0, + maximal_access, + }; + let mut buf = Vec::new(); + resp.write_to(&mut buf).expect("encode"); + info!(tree_id, share = %share.name, ?granted, "tree connect"); + let mut hr = HandlerResponse::ok(buf); + hr.override_tree_id = Some(tree_id); + hr +} + +fn extract_share_name(unc: &str) -> Option { + // \\server\share or \\server\share\ + let trimmed = unc.trim_end_matches(['\\', '/']); + let parts: Vec<&str> = trimmed + .split(['\\', '/']) + .filter(|s| !s.is_empty()) + .collect(); + parts.last().map(|s| s.to_string()) +} + +fn authorize( + mode: &ShareMode, + users: &std::collections::HashMap, + identity: &Identity, +) -> Option { + match mode { + ShareMode::Public => Some(Access::ReadWrite), + ShareMode::PublicReadOnly => Some(Access::Read), + ShareMode::AuthenticatedOnly => match identity { + Identity::Anonymous => None, + Identity::User { user, .. } => users.get(user).copied(), + }, + } +} diff --git a/vendor/smb-server/src/handlers/tree_disconnect.rs b/vendor/smb-server/src/handlers/tree_disconnect.rs new file mode 100644 index 0000000..1f47093 --- /dev/null +++ b/vendor/smb-server/src/handlers/tree_disconnect.rs @@ -0,0 +1,36 @@ +//! TREE_DISCONNECT handler. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::TreeDisconnectResponse; + +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::handlers::shared::lookup_session; +use crate::ntstatus; +use crate::server::ServerState; + +pub async fn handle( + _server: &Arc, + conn: &Arc, + hdr: &Smb2Header, + _body: &[u8], +) -> HandlerResponse { + let tid = match hdr.tree_id() { + Some(t) => t, + None => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + + if lookup_session(conn, hdr.session_id).await.is_err() { + return HandlerResponse::err(ntstatus::STATUS_USER_SESSION_DELETED); + } + if !conn.close_tree(hdr.session_id, tid).await { + return HandlerResponse::err(ntstatus::STATUS_NETWORK_NAME_DELETED); + } + let mut buf = Vec::new(); + TreeDisconnectResponse::default() + .write_to(&mut buf) + .expect("encode"); + HandlerResponse::ok(buf) +} diff --git a/vendor/smb-server/src/handlers/write.rs b/vendor/smb-server/src/handlers/write.rs new file mode 100644 index 0000000..16735b1 --- /dev/null +++ b/vendor/smb-server/src/handlers/write.rs @@ -0,0 +1,60 @@ +//! WRITE handler. + +use std::sync::Arc; + +use crate::proto::header::Smb2Header; +use crate::proto::messages::{WriteRequest, WriteResponse}; + +use crate::builder::Access; +use crate::conn::state::Connection; +use crate::dispatch::HandlerResponse; +use crate::handlers::shared::{lookup_open, lookup_session_tree}; +use crate::ntstatus; +use crate::server::ServerState; + +pub async fn handle( + _server: &Arc, + conn: &Arc, + hdr: &Smb2Header, + body: &[u8], +) -> HandlerResponse { + let req = match WriteRequest::parse(body) { + Ok(r) => r, + Err(_) => return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER), + }; + let max_write = *conn.max_write_size.read().await; + if req.length > max_write { + return HandlerResponse::err(ntstatus::STATUS_INVALID_PARAMETER); + } + let tree_arc = match lookup_session_tree(conn, hdr).await { + Ok(t) => t, + Err(s) => return HandlerResponse::err(s), + }; + let granted = { + let tree = tree_arc.read().await; + tree.granted_access + }; + if !matches!(granted, Access::ReadWrite) { + return HandlerResponse::err(ntstatus::STATUS_ACCESS_DENIED); + } + let open_arc = match lookup_open(&tree_arc, req.file_id).await { + Some(o) => o, + None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), + }; + let result = { + let open = open_arc.read().await; + match open.handle.as_ref() { + Some(h) => h.write_owned(req.offset, req.data).await, + None => return HandlerResponse::err(ntstatus::STATUS_FILE_CLOSED), + } + }; + let count = match result { + Ok(n) => n, + Err(e) => return HandlerResponse::err(e.to_nt_status()), + }; + let mut buf = Vec::new(); + WriteResponse::new(count) + .write_to(&mut buf) + .expect("encode"); + HandlerResponse::ok(buf) +} diff --git a/vendor/smb-server/src/info_class.rs b/vendor/smb-server/src/info_class.rs new file mode 100644 index 0000000..42a3e9d --- /dev/null +++ b/vendor/smb-server/src/info_class.rs @@ -0,0 +1,470 @@ +//! File / FileSystem / Security info-class encoders used by QUERY_INFO, +//! SET_INFO, and QUERY_DIRECTORY. +//! +//! These are byte-for-byte wire encodings per MS-FSCC §2.4 (file info) / +//! §2.5 (filesystem info) / MS-DTYP §2.4 (security descriptor). + +use crate::backend::{DirEntry, FileInfo}; +use crate::utils::utf16le; + +// --------------------------------------------------------------------------- +// File info classes (MS-FSCC §2.4) +// --------------------------------------------------------------------------- + +pub const FILE_DIRECTORY_INFORMATION: u8 = 0x01; +pub const FILE_FULL_DIRECTORY_INFORMATION: u8 = 0x02; +pub const FILE_BOTH_DIRECTORY_INFORMATION: u8 = 0x03; +pub const FILE_BASIC_INFORMATION: u8 = 0x04; +pub const FILE_STANDARD_INFORMATION: u8 = 0x05; +pub const FILE_INTERNAL_INFORMATION: u8 = 0x06; +pub const FILE_EA_INFORMATION: u8 = 0x07; +pub const FILE_ACCESS_INFORMATION: u8 = 0x08; +pub const FILE_NAME_INFORMATION: u8 = 0x09; +pub const FILE_NAMES_INFORMATION: u8 = 0x0C; +pub const FILE_POSITION_INFORMATION: u8 = 0x0E; +pub const FILE_FULL_EA_INFORMATION: u8 = 0x0F; +pub const FILE_MODE_INFORMATION: u8 = 0x10; +pub const FILE_ALIGNMENT_INFORMATION: u8 = 0x11; +pub const FILE_ALL_INFORMATION: u8 = 0x12; +pub const FILE_ALLOCATION_INFORMATION: u8 = 0x13; +pub const FILE_END_OF_FILE_INFORMATION: u8 = 0x14; +pub const FILE_STREAM_INFORMATION: u8 = 0x16; +pub const FILE_DISPOSITION_INFORMATION: u8 = 0x0D; +pub const FILE_RENAME_INFORMATION: u8 = 0x0A; +pub const FILE_NETWORK_OPEN_INFORMATION: u8 = 0x22; +pub const FILE_ID_BOTH_DIRECTORY_INFORMATION: u8 = 0x25; +pub const FILE_ID_FULL_DIRECTORY_INFORMATION: u8 = 0x26; + +// --------------------------------------------------------------------------- +// FileBasicInformation (MS-FSCC §2.4.7) — 40 bytes +// --------------------------------------------------------------------------- + +pub fn encode_file_basic_information(info: &FileInfo) -> Vec { + let mut out = Vec::with_capacity(40); + out.extend_from_slice(&info.creation_time.to_le_bytes()); + out.extend_from_slice(&info.last_access_time.to_le_bytes()); + out.extend_from_slice(&info.last_write_time.to_le_bytes()); + out.extend_from_slice(&info.change_time.to_le_bytes()); + out.extend_from_slice(&info.attributes().to_le_bytes()); + out.extend_from_slice(&0u32.to_le_bytes()); // Reserved + out +} + +// --------------------------------------------------------------------------- +// FileStandardInformation (MS-FSCC §2.4.41) — 24 bytes +// --------------------------------------------------------------------------- + +pub fn encode_file_standard_information(info: &FileInfo) -> Vec { + let mut out = Vec::with_capacity(24); + out.extend_from_slice(&info.allocation_size.to_le_bytes()); + out.extend_from_slice(&info.end_of_file.to_le_bytes()); + out.extend_from_slice(&1u32.to_le_bytes()); // NumberOfLinks = 1 + out.push(0); // DeletePending + out.push(if info.is_directory { 1 } else { 0 }); // Directory + out.extend_from_slice(&0u16.to_le_bytes()); // Reserved + out +} + +// --------------------------------------------------------------------------- +// FileInternalInformation (MS-FSCC §2.4.20) — 8 bytes +// --------------------------------------------------------------------------- + +pub fn encode_file_internal_information(file_index: u64) -> Vec { + file_index.to_le_bytes().to_vec() +} + +// --------------------------------------------------------------------------- +// FileEaInformation (MS-FSCC §2.4.12) — 4 bytes +// --------------------------------------------------------------------------- + +pub fn encode_file_ea_information() -> Vec { + 0u32.to_le_bytes().to_vec() +} + +// --------------------------------------------------------------------------- +// FileAccessInformation (MS-FSCC §2.4.1) — 4 bytes +// --------------------------------------------------------------------------- + +pub fn encode_file_access_information(access_mask: u32) -> Vec { + access_mask.to_le_bytes().to_vec() +} + +// --------------------------------------------------------------------------- +// FilePositionInformation (MS-FSCC §2.4.32) — 8 bytes +// --------------------------------------------------------------------------- + +pub fn encode_file_position_information() -> Vec { + 0u64.to_le_bytes().to_vec() +} + +// --------------------------------------------------------------------------- +// FileModeInformation (MS-FSCC §2.4.24) — 4 bytes +// --------------------------------------------------------------------------- + +pub fn encode_file_mode_information(mode: u32) -> Vec { + mode.to_le_bytes().to_vec() +} + +// --------------------------------------------------------------------------- +// FileAlignmentInformation (MS-FSCC §2.4.3) — 4 bytes +// --------------------------------------------------------------------------- + +pub fn encode_file_alignment_information() -> Vec { + // FILE_BYTE_ALIGNMENT (0) — no alignment requirement. + 0u32.to_le_bytes().to_vec() +} + +// --------------------------------------------------------------------------- +// FileNameInformation (MS-FSCC §2.4.27) — 4 bytes + UTF-16LE name +// --------------------------------------------------------------------------- + +pub fn encode_file_name_information(name: &str) -> Vec { + let n = utf16le(name); + let mut out = Vec::with_capacity(4 + n.len()); + out.extend_from_slice(&(n.len() as u32).to_le_bytes()); + out.extend_from_slice(&n); + out +} + +// --------------------------------------------------------------------------- +// FileAllInformation (MS-FSCC §2.4.2) — concatenation of basic, standard, +// internal, EA, access, position, mode, alignment, name. +// --------------------------------------------------------------------------- + +pub fn encode_file_all_information(info: &FileInfo, file_index: u64, access_mask: u32) -> Vec { + let mut out = Vec::new(); + out.extend_from_slice(&encode_file_basic_information(info)); + out.extend_from_slice(&encode_file_standard_information(info)); + out.extend_from_slice(&encode_file_internal_information(file_index)); + out.extend_from_slice(&encode_file_ea_information()); + out.extend_from_slice(&encode_file_access_information(access_mask)); + out.extend_from_slice(&encode_file_position_information()); + out.extend_from_slice(&encode_file_mode_information(0)); + out.extend_from_slice(&encode_file_alignment_information()); + out.extend_from_slice(&encode_file_name_information(&info.name)); + // Linux cifs checks FileAllInformation against its struct with + // FileName[1], so the empty-name root case must still be at least 101 + // bytes. + if out.len() < 101 { + out.push(0); + } + out +} + +// --------------------------------------------------------------------------- +// FileNetworkOpenInformation (MS-FSCC §2.4.30) — 56 bytes +// --------------------------------------------------------------------------- + +pub fn encode_file_network_open_information(info: &FileInfo) -> Vec { + let mut out = Vec::with_capacity(56); + out.extend_from_slice(&info.creation_time.to_le_bytes()); + out.extend_from_slice(&info.last_access_time.to_le_bytes()); + out.extend_from_slice(&info.last_write_time.to_le_bytes()); + out.extend_from_slice(&info.change_time.to_le_bytes()); + out.extend_from_slice(&info.allocation_size.to_le_bytes()); + out.extend_from_slice(&info.end_of_file.to_le_bytes()); + out.extend_from_slice(&info.attributes().to_le_bytes()); + out.extend_from_slice(&0u32.to_le_bytes()); // Reserved + out +} + +// --------------------------------------------------------------------------- +// FileStreamInformation (MS-FSCC §2.4.43) — for non-directories, one default +// stream entry (`::$DATA`); for directories, empty buffer. +// --------------------------------------------------------------------------- + +pub fn encode_file_stream_information(info: &FileInfo) -> Vec { + if info.is_directory { + return Vec::new(); + } + let stream_name = utf16le("::$DATA"); + let stream_name_len = stream_name.len() as u32; + let mut out = Vec::new(); + out.extend_from_slice(&0u32.to_le_bytes()); // NextEntryOffset = 0 + out.extend_from_slice(&stream_name_len.to_le_bytes()); // StreamNameLength + out.extend_from_slice(&info.end_of_file.to_le_bytes()); // StreamSize + out.extend_from_slice(&info.allocation_size.to_le_bytes()); // StreamAllocationSize + out.extend_from_slice(&stream_name); + out +} + +// --------------------------------------------------------------------------- +// FS info classes (MS-FSCC §2.5) +// --------------------------------------------------------------------------- + +pub const FS_VOLUME_INFORMATION: u8 = 0x01; +pub const FS_SIZE_INFORMATION: u8 = 0x03; +pub const FS_DEVICE_INFORMATION: u8 = 0x04; +pub const FS_ATTRIBUTE_INFORMATION: u8 = 0x05; +pub const FS_FULL_SIZE_INFORMATION: u8 = 0x07; + +/// FileFsVolumeInformation (MS-FSCC §2.5.9). Volume creation time, serial, +/// label. +pub fn encode_fs_volume_information(creation_time: u64, serial: u32, label: &str) -> Vec { + let label_u16 = utf16le(label); + let mut out = Vec::new(); + out.extend_from_slice(&creation_time.to_le_bytes()); + out.extend_from_slice(&serial.to_le_bytes()); + out.extend_from_slice(&(label_u16.len() as u32).to_le_bytes()); + out.push(0); // SupportsObjects + out.push(0); // Reserved + out.extend_from_slice(&label_u16); + out +} + +/// FileFsSizeInformation (MS-FSCC §2.5.7) — 24 bytes. +pub fn encode_fs_size_information( + total_alloc_units: u64, + avail_alloc_units: u64, + sectors_per_unit: u32, + bytes_per_sector: u32, +) -> Vec { + let mut out = Vec::with_capacity(24); + out.extend_from_slice(&total_alloc_units.to_le_bytes()); + out.extend_from_slice(&avail_alloc_units.to_le_bytes()); + out.extend_from_slice(§ors_per_unit.to_le_bytes()); + out.extend_from_slice(&bytes_per_sector.to_le_bytes()); + out +} + +/// FileFsDeviceInformation (MS-FSCC §2.5.10) — 8 bytes. +pub fn encode_fs_device_information(device_type: u32, characteristics: u32) -> Vec { + let mut out = Vec::with_capacity(8); + out.extend_from_slice(&device_type.to_le_bytes()); + out.extend_from_slice(&characteristics.to_le_bytes()); + out +} + +/// FileFsAttributeInformation (MS-FSCC §2.5.1) — variable. +pub fn encode_fs_attribute_information( + attributes: u32, + max_component_len: u32, + fs_name: &str, +) -> Vec { + let name_u16 = utf16le(fs_name); + let mut out = Vec::new(); + out.extend_from_slice(&attributes.to_le_bytes()); + out.extend_from_slice(&max_component_len.to_le_bytes()); + out.extend_from_slice(&(name_u16.len() as u32).to_le_bytes()); + out.extend_from_slice(&name_u16); + out +} + +/// FileFsFullSizeInformation (MS-FSCC §2.5.4) — 32 bytes. +pub fn encode_fs_full_size_information( + total_alloc_units: u64, + caller_avail_alloc_units: u64, + actual_avail_alloc_units: u64, + sectors_per_unit: u32, + bytes_per_sector: u32, +) -> Vec { + let mut out = Vec::with_capacity(32); + out.extend_from_slice(&total_alloc_units.to_le_bytes()); + out.extend_from_slice(&caller_avail_alloc_units.to_le_bytes()); + out.extend_from_slice(&actual_avail_alloc_units.to_le_bytes()); + out.extend_from_slice(§ors_per_unit.to_le_bytes()); + out.extend_from_slice(&bytes_per_sector.to_le_bytes()); + out +} + +// --------------------------------------------------------------------------- +// Minimal SECURITY_DESCRIPTOR with owner=Everyone, DACL=Everyone allowed. +// --------------------------------------------------------------------------- + +/// Build a minimal absolute-form SECURITY_DESCRIPTOR per MS-DTYP §2.4.6. +/// +/// Owner = Everyone (S-1-1-0). No group. DACL = single Allow ACE granting +/// `0x001F_01FF` (FILE_ALL_ACCESS) to Everyone. Self-relative format so it +/// embeds cleanly in the QUERY_INFO buffer. +pub fn encode_minimal_security_descriptor() -> Vec { + // SID Everyone (S-1-1-0): 1, 1, [0,0,0,0,0,1], [0,0,0,0] + // Total length: 1 (Revision) + 1 (SubAuthorityCount=1) + 6 (Identifier) + 4 (subauth) = 12 + let everyone: Vec = vec![ + 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + ]; + + // Build ACE: AccessAllowedAce + // Header: 4 bytes (Type=0, Flags=0, Size) + // Mask: 4 bytes + // Sid: variable + let mut ace = Vec::new(); + ace.push(0x00); // ACCESS_ALLOWED_ACE_TYPE + ace.push(0x00); // AceFlags + let ace_size: u16 = (4 + 4 + everyone.len()) as u16; + ace.extend_from_slice(&ace_size.to_le_bytes()); + ace.extend_from_slice(&0x001F_01FFu32.to_le_bytes()); // FILE_ALL_ACCESS + ace.extend_from_slice(&everyone); + + // ACL: Revision (1), Sbz1 (1), AclSize (2), AceCount (2), Sbz2 (2), then ACEs. + let acl_size: u16 = (8 + ace.len()) as u16; + let mut dacl = Vec::new(); + dacl.push(0x02); // Revision = ACL_REVISION + dacl.push(0x00); // Sbz1 + dacl.extend_from_slice(&acl_size.to_le_bytes()); + dacl.extend_from_slice(&1u16.to_le_bytes()); // AceCount + dacl.extend_from_slice(&0u16.to_le_bytes()); // Sbz2 + dacl.extend_from_slice(&ace); + + // SECURITY_DESCRIPTOR (self-relative): + // Revision (1), Sbz1 (1), Control (2), + // OwnerOffset (4), GroupOffset (4), SaclOffset (4), DaclOffset (4) + // Then concatenated entities. + const SE_DACL_PRESENT: u16 = 0x0004; + const SE_SELF_RELATIVE: u16 = 0x8000; + let mut sd = Vec::new(); + sd.push(0x01); // Revision = SECURITY_DESCRIPTOR_REVISION + sd.push(0x00); // Sbz1 + sd.extend_from_slice(&(SE_DACL_PRESENT | SE_SELF_RELATIVE).to_le_bytes()); + let header_len: u32 = 20; + let owner_off = header_len; + let group_off = 0u32; + let sacl_off = 0u32; + let dacl_off = owner_off + everyone.len() as u32; + sd.extend_from_slice(&owner_off.to_le_bytes()); + sd.extend_from_slice(&group_off.to_le_bytes()); + sd.extend_from_slice(&sacl_off.to_le_bytes()); + sd.extend_from_slice(&dacl_off.to_le_bytes()); + sd.extend_from_slice(&everyone); + sd.extend_from_slice(&dacl); + sd +} + +// --------------------------------------------------------------------------- +// Directory information classes (MS-FSCC §2.4.{8,14,17,30,31}) +// --------------------------------------------------------------------------- + +/// Encode a single FileBothDirectoryInformation entry. Returns the encoded +/// bytes. The caller patches `NextEntryOffset` for chained entries. +pub fn encode_dir_entry(class: u8, entry: &DirEntry, file_index: u64) -> Vec { + let info = &entry.info; + let name_u16 = utf16le(&info.name); + match class { + FILE_DIRECTORY_INFORMATION => { + // 64 bytes fixed + name + let mut out = Vec::new(); + write_dir_entry_prefix(&mut out, info, file_index, name_u16.len()); + out.extend_from_slice(&name_u16); + out + } + FILE_FULL_DIRECTORY_INFORMATION => { + let mut out = Vec::new(); + write_dir_entry_prefix(&mut out, info, file_index, name_u16.len()); + out.extend_from_slice(&0u32.to_le_bytes()); // EaSize + out.extend_from_slice(&name_u16); + out + } + FILE_BOTH_DIRECTORY_INFORMATION => { + let mut out = Vec::new(); + write_dir_entry_prefix(&mut out, info, file_index, name_u16.len()); + out.extend_from_slice(&0u32.to_le_bytes()); // EaSize + out.push(0); // ShortNameLength + out.push(0); // Reserved1 + // ShortName: 24 bytes (12 UTF-16 chars). + out.extend_from_slice(&[0u8; 24]); + out.extend_from_slice(&name_u16); + out + } + FILE_ID_BOTH_DIRECTORY_INFORMATION => { + let mut out = Vec::new(); + write_dir_entry_prefix(&mut out, info, file_index, name_u16.len()); + out.extend_from_slice(&0u32.to_le_bytes()); // EaSize + out.push(0); // ShortNameLength + out.push(0); // Reserved1 + out.extend_from_slice(&[0u8; 24]); // ShortName + out.extend_from_slice(&0u16.to_le_bytes()); // Reserved2 + out.extend_from_slice(&file_index.to_le_bytes()); // FileId + out.extend_from_slice(&name_u16); + out + } + FILE_ID_FULL_DIRECTORY_INFORMATION => { + let mut out = Vec::new(); + write_dir_entry_prefix(&mut out, info, file_index, name_u16.len()); + out.extend_from_slice(&0u32.to_le_bytes()); // EaSize + out.extend_from_slice(&0u32.to_le_bytes()); // Reserved + out.extend_from_slice(&file_index.to_le_bytes()); // FileId + out.extend_from_slice(&name_u16); + out + } + FILE_NAMES_INFORMATION => { + let mut out = Vec::new(); + out.extend_from_slice(&0u32.to_le_bytes()); + out.extend_from_slice(&(file_index as u32).to_le_bytes()); + out.extend_from_slice(&(name_u16.len() as u32).to_le_bytes()); + out.extend_from_slice(&name_u16); + out + } + _ => Vec::new(), + } +} + +fn write_dir_entry_prefix(out: &mut Vec, info: &FileInfo, file_index: u64, name_len: usize) { + out.extend_from_slice(&0u32.to_le_bytes()); // NextEntryOffset (patched later) + out.extend_from_slice(&(file_index as u32).to_le_bytes()); // FileIndex + out.extend_from_slice(&info.creation_time.to_le_bytes()); + out.extend_from_slice(&info.last_access_time.to_le_bytes()); + out.extend_from_slice(&info.last_write_time.to_le_bytes()); + out.extend_from_slice(&info.change_time.to_le_bytes()); + out.extend_from_slice(&info.end_of_file.to_le_bytes()); + out.extend_from_slice(&info.allocation_size.to_le_bytes()); + out.extend_from_slice(&info.attributes().to_le_bytes()); + out.extend_from_slice(&(name_len as u32).to_le_bytes()); +} + +/// Round up `n` to the next multiple of 8. +pub fn align8(n: usize) -> usize { + (n + 7) & !7 +} + +#[cfg(test)] +mod tests { + use super::*; + + fn fake_info() -> FileInfo { + FileInfo { + name: "file.txt".to_string(), + end_of_file: 100, + allocation_size: 100, + creation_time: 0x01D9_0000_0000_0000, + last_access_time: 0x01D9_0000_0000_0000, + last_write_time: 0x01D9_0000_0000_0000, + change_time: 0x01D9_0000_0000_0000, + is_directory: false, + file_index: 1, + } + } + + #[test] + fn basic_information_is_40_bytes() { + let bytes = encode_file_basic_information(&fake_info()); + assert_eq!(bytes.len(), 40); + } + + #[test] + fn standard_information_is_24_bytes() { + let bytes = encode_file_standard_information(&fake_info()); + assert_eq!(bytes.len(), 24); + } + + #[test] + fn network_open_information_is_56_bytes() { + let bytes = encode_file_network_open_information(&fake_info()); + assert_eq!(bytes.len(), 56); + } + + #[test] + fn file_all_information_empty_name_keeps_linux_minimum_size() { + let mut info = fake_info(); + info.name.clear(); + let bytes = encode_file_all_information(&info, 1, 0x001F_01FF); + assert_eq!(bytes.len(), 101); + } + + #[test] + fn security_descriptor_is_self_relative() { + let sd = encode_minimal_security_descriptor(); + // Revision=1, then Control bits 8000 set => self-relative. + assert_eq!(sd[0], 0x01); + let control = u16::from_le_bytes([sd[2], sd[3]]); + assert!(control & 0x8000 != 0); + } +} diff --git a/vendor/smb-server/src/lib.rs b/vendor/smb-server/src/lib.rs new file mode 100644 index 0000000..01fb8ae --- /dev/null +++ b/vendor/smb-server/src/lib.rs @@ -0,0 +1,52 @@ +//! SMB2/3 file-sharing server with pluggable storage backends. +//! +//! See `docs/superpowers/specs/2026-04-27-rust-smb-server-design.md` for the +//! v1 design. The public API is small on purpose: +//! +//! ```no_run +//! use smb_server::{SmbServer, Share, Access, ShareBackend}; +//! # async fn run(backend: B) -> Result<(), Box> { +//! SmbServer::builder() +//! .listen("0.0.0.0:4445".parse()?) +//! .user("alice", "password") +//! .share(Share::new("home", backend).user("alice", Access::ReadWrite)) +//! .build()? +//! .serve() +//! .await?; +//! # Ok(()) } +//! ``` + +mod backend; +mod builder; +pub(crate) mod conn; +mod dispatch; +mod error; +#[cfg(feature = "localfs")] +mod fs; +mod handlers; +pub(crate) mod info_class; +pub mod ntstatus; +mod path; +mod proto; +mod server; +mod utils; + +pub use backend::{BackendCapabilities, DirEntry, FileInfo, FileTimes, Handle, OpenIntent, OpenOptions, ShareBackend}; +pub use error::SmbError; +pub use path::SmbPath; +pub use builder::{Access, Share}; +#[cfg(feature = "localfs")] +pub use fs::LocalFsBackend; +pub use proto::auth::ntlm::Identity; +pub use server::{ConfigHandle, ShareMode, ShutdownHandle, SmbServer}; + +pub mod wire { + pub use crate::proto::header; + pub use crate::proto::messages; +} + +#[cfg(test)] +mod tests { + mod dynamic_config; + mod memfs; +} diff --git a/vendor/smb-server/src/ntstatus.rs b/vendor/smb-server/src/ntstatus.rs new file mode 100644 index 0000000..581c7e3 --- /dev/null +++ b/vendor/smb-server/src/ntstatus.rs @@ -0,0 +1,41 @@ +//! NTSTATUS constants used by SMB2 handlers. +//! +//! Cross-referenced with MS-ERREF §2.3.1 (NTSTATUS Values). Only the codes the +//! v1 server actually emits or recognizes live here — kept tight on purpose. + +pub const STATUS_SUCCESS: u32 = 0x0000_0000; +pub const STATUS_PENDING: u32 = 0x0000_0103; +pub const STATUS_NOTIFY_CLEANUP: u32 = 0x0000_010B; +pub const STATUS_NOTIFY_ENUM_DIR: u32 = 0x0000_010C; +pub const STATUS_BUFFER_OVERFLOW: u32 = 0x8000_0005; +pub const STATUS_NO_MORE_FILES: u32 = 0x8000_0006; + +pub const STATUS_INVALID_HANDLE: u32 = 0xC000_0008; +pub const STATUS_INVALID_PARAMETER: u32 = 0xC000_000D; +pub const STATUS_NO_SUCH_FILE: u32 = 0xC000_000F; +pub const STATUS_OBJECT_NAME_NOT_FOUND: u32 = 0xC000_000F; +pub const STATUS_INVALID_DEVICE_REQUEST: u32 = 0xC000_0010; +pub const STATUS_END_OF_FILE: u32 = 0xC000_0011; +pub const STATUS_MORE_PROCESSING_REQUIRED: u32 = 0xC000_0016; +pub const STATUS_ACCESS_DENIED: u32 = 0xC000_0022; +pub const STATUS_BUFFER_TOO_SMALL: u32 = 0xC000_0023; +pub const STATUS_OBJECT_NAME_INVALID: u32 = 0xC000_0033; +pub const STATUS_OBJECT_NAME_COLLISION: u32 = 0xC000_0035; +pub const STATUS_OBJECT_PATH_NOT_FOUND: u32 = 0xC000_003A; +pub const STATUS_OBJECT_PATH_SYNTAX_BAD: u32 = 0xC000_003B; +pub const STATUS_SHARING_VIOLATION: u32 = 0xC000_0043; +pub const STATUS_DELETE_PENDING: u32 = 0xC000_0056; +pub const STATUS_LOGON_FAILURE: u32 = 0xC000_006D; +pub const STATUS_FS_DRIVER_REQUIRED: u32 = 0xC000_019C; +pub const STATUS_NOT_SUPPORTED: u32 = 0xC000_00BB; +pub const STATUS_FILE_IS_A_DIRECTORY: u32 = 0xC000_00BA; +pub const STATUS_NETWORK_NAME_DELETED: u32 = 0xC000_00C9; +pub const STATUS_BAD_NETWORK_NAME: u32 = 0xC000_00CC; +pub const STATUS_UNEXPECTED_IO_ERROR: u32 = 0xC000_009C; +pub const STATUS_DIRECTORY_NOT_EMPTY: u32 = 0xC000_0101; +pub const STATUS_NOT_A_DIRECTORY: u32 = 0xC000_0103; +pub const STATUS_USER_SESSION_DELETED: u32 = 0xC000_015C; +pub const STATUS_INFO_LENGTH_MISMATCH: u32 = 0xC000_0004; +pub const STATUS_FILE_CLOSED: u32 = 0xC000_0128; +pub const STATUS_INVALID_INFO_CLASS: u32 = 0xC000_0003; +pub const STATUS_NO_EAS_ON_FILE: u32 = 0xC000_0052; diff --git a/vendor/smb-server/src/path.rs b/vendor/smb-server/src/path.rs new file mode 100644 index 0000000..c1955a3 --- /dev/null +++ b/vendor/smb-server/src/path.rs @@ -0,0 +1,280 @@ +//! `SmbPath` — validated, normalized SMB path used between dispatcher and +//! backend. +//! +//! Construction is exclusively from a `&[u16]` (UTF-16LE-decoded) buffer, per +//! spec §7. The protocol layer turns wire bytes into `&[u16]`; this module +//! turns `&[u16]` into a path that backends can blindly trust. + +use std::str::FromStr; + +use crate::error::{SmbError, SmbResult}; + +/// A validated, component-list path. No `..`, no Windows-forbidden chars, no +/// alternate streams. Always relative to the share root — the empty path is +/// the root. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct SmbPath { + components: Vec, +} + +impl SmbPath { + /// The share root. + pub fn root() -> Self { + Self::default() + } + + /// Construct from a UTF-16 code-unit slice (already decoded from UTF-16LE + /// wire bytes). + pub fn from_utf16(units: &[u16]) -> SmbResult { + // 1. Convert to UTF-8 lossily — but reject if conversion produced any + // replacement characters that didn't exist in the input. We test + // the round-trip: invalid surrogates are rejected. + let s = decode_utf16_strict(units)?; + s.parse() + } + + fn parse_components(s: &str) -> SmbResult { + // Strip a leading separator (clients sometimes prefix `\` or `/`). + let trimmed = s + .strip_prefix('\\') + .or_else(|| s.strip_prefix('/')) + .unwrap_or(s); + if trimmed.is_empty() { + return Ok(Self::root()); + } + + // 2. Reject forbidden characters anywhere in the path. + for ch in trimmed.chars() { + if ch == '\0' || ('\u{0001}'..='\u{001F}').contains(&ch) { + return Err(SmbError::NameInvalid); + } + // Allow `\` and `/` as separators, reject the rest of the + // Windows-forbidden set anywhere. + match ch { + '<' | '>' | ':' | '"' | '|' | '?' | '*' => return Err(SmbError::NameInvalid), + _ => {} + } + } + + // 3. Split on `\` or `/`; reject `..` and empty components; skip `.`. + let mut components = Vec::new(); + for raw in trimmed.split(['\\', '/']) { + if raw.is_empty() { + // Doubled separator like `foo\\bar` — reject. + return Err(SmbError::NameInvalid); + } + if raw == "." { + continue; + } + if raw == ".." { + return Err(SmbError::NameInvalid); + } + // 4. Reject reserved DOS device names. + if is_reserved_dos_name(raw) { + return Err(SmbError::NameInvalid); + } + components.push(raw.to_string()); + } + Ok(Self { components }) + } + + /// Path components in order. Empty for the root. + pub fn components(&self) -> &[String] { + &self.components + } + + /// Is this the share root? + pub fn is_root(&self) -> bool { + self.components.is_empty() + } + + /// Return the parent path, or `None` if this is the root. + pub fn parent(&self) -> Option { + if self.is_root() { + return None; + } + let mut parent = self.components.clone(); + parent.pop(); + Some(SmbPath { components: parent }) + } + + /// Return the last component, if any. + pub fn file_name(&self) -> Option<&str> { + self.components.last().map(|s| s.as_str()) + } + + /// Append a single, already-validated last component to this path. + pub fn join(&self, last: &str) -> SmbResult { + // Run `last` through the same validator (treating it as a single- + // component path). + let extra = last.parse::()?; + let mut out = self.clone(); + out.components.extend(extra.components); + Ok(out) + } + + /// Render as a backslash-separated string. Empty for root. + pub fn display_backslash(&self) -> String { + self.components.join("\\") + } +} + +impl FromStr for SmbPath { + type Err = SmbError; + + fn from_str(s: &str) -> Result { + Self::parse_components(s) + } +} + +impl std::fmt::Display for SmbPath { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.is_root() { + f.write_str("\\") + } else { + f.write_str(&self.display_backslash()) + } + } +} + +fn is_reserved_dos_name(s: &str) -> bool { + // Strip extension before checking, e.g. "CON.txt" is also reserved. + let stem = match s.rsplit_once('.') { + Some((stem, _)) => stem, + None => s, + }; + let upper = stem.to_ascii_uppercase(); + matches!(upper.as_str(), "CON" | "PRN" | "AUX" | "NUL") || matches_com_or_lpt(&upper) +} + +fn matches_com_or_lpt(s: &str) -> bool { + if s.len() != 4 { + return false; + } + let bytes = s.as_bytes(); + let prefix = &bytes[..3]; + let last = bytes[3] as char; + if !matches!(last, '1'..='9') { + return false; + } + prefix == b"COM" || prefix == b"LPT" +} + +fn decode_utf16_strict(units: &[u16]) -> SmbResult { + // Reject unpaired surrogates explicitly. `String::from_utf16` does this + // already; we surface its error as NameInvalid. + String::from_utf16(units).map_err(|_| SmbError::NameInvalid) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn utf16(s: &str) -> Vec { + s.encode_utf16().collect() + } + + #[test] + fn root_paths() { + assert!("".parse::().unwrap().is_root()); + assert!("\\".parse::().unwrap().is_root()); + assert!("/".parse::().unwrap().is_root()); + assert!(SmbPath::from_utf16(&utf16("")).unwrap().is_root()); + } + + #[test] + fn simple_paths_split() { + let p = "dir\\sub\\file.txt".parse::().unwrap(); + assert_eq!(p.components(), &["dir", "sub", "file.txt"]); + assert_eq!(p.display_backslash(), "dir\\sub\\file.txt"); + assert!(!p.is_root()); + assert_eq!(p.file_name(), Some("file.txt")); + } + + #[test] + fn forward_slash_accepted() { + let p = "a/b/c".parse::().unwrap(); + assert_eq!(p.components(), &["a", "b", "c"]); + } + + #[test] + fn dot_components_skipped() { + let p = "a\\.\\b".parse::().unwrap(); + assert_eq!(p.components(), &["a", "b"]); + } + + #[test] + fn parent_returns_one_component_less() { + let p = "a\\b\\c".parse::().unwrap(); + let parent = p.parent().unwrap(); + assert_eq!(parent.components(), &["a", "b"]); + let grand = parent.parent().unwrap(); + assert_eq!(grand.components(), &["a"]); + let root = grand.parent().unwrap(); + assert!(root.is_root()); + assert!(root.parent().is_none()); + } + + #[test] + fn join_appends_component() { + let p = "a".parse::().unwrap(); + let q = p.join("b").unwrap(); + assert_eq!(q.components(), &["a", "b"]); + } + + #[test] + fn rejects_double_dot() { + assert!("a\\..\\b".parse::().is_err()); + assert!("..".parse::().is_err()); + } + + #[test] + fn rejects_double_separator() { + assert!("a\\\\b".parse::().is_err()); + } + + #[test] + fn rejects_forbidden_chars() { + for bad in ["ab", "a:b", "a\"b", "a|b", "a?b", "a*b"] { + assert!(bad.parse::().is_err(), "{bad}"); + } + } + + #[test] + fn rejects_control_chars() { + let s = format!("a{}b", '\u{0001}'); + assert!(s.parse::().is_err()); + let s = format!("a{}b", '\u{0000}'); + assert!(s.parse::().is_err()); + } + + #[test] + fn rejects_reserved_dos_names() { + for bad in [ + "CON", "con", "PRN", "AUX", "NUL", "COM1", "LPT9", "Con.txt", "NUL.dat", + ] { + assert!(bad.parse::().is_err(), "{bad}"); + } + } + + #[test] + fn allows_lookalike_names() { + // Not reserved. + assert!("CON1".parse::().is_ok()); + assert!("LPT".parse::().is_ok()); + assert!("LPT0".parse::().is_ok()); // 0 is not in the 1-9 range + assert!("NUL_FILE.txt".parse::().is_ok()); + } + + #[test] + fn rejects_unpaired_surrogate() { + let units: [u16; 2] = [0xD800, 0x0061]; // unpaired high surrogate + assert!(SmbPath::from_utf16(&units).is_err()); + } + + #[test] + fn round_trip_via_utf16() { + let p = SmbPath::from_utf16(&utf16("a\\b")).unwrap(); + assert_eq!(p.components(), &["a", "b"]); + } +} diff --git a/vendor/smb-server/src/proto/auth.rs b/vendor/smb-server/src/proto/auth.rs new file mode 100644 index 0000000..304f42c --- /dev/null +++ b/vendor/smb-server/src/proto/auth.rs @@ -0,0 +1,11 @@ +//! NTLMv2 server-side authentication and minimal SPNEGO outer envelope. +//! +//! See: +//! * MS-NLMP — NT LAN Manager (NTLM) Authentication Protocol +//! * MS-SPNG — Simple and Protected GSS-API Negotiation Mechanism +//! +//! v1 implements **only** the NTLM (NTLMSSP) mechanism inside SPNEGO. +//! Kerberos is out of scope (revisit in v0.2). + +pub mod ntlm; +pub mod spnego; diff --git a/vendor/smb-server/src/proto/auth/ntlm.rs b/vendor/smb-server/src/proto/auth/ntlm.rs new file mode 100644 index 0000000..eb8a57b --- /dev/null +++ b/vendor/smb-server/src/proto/auth/ntlm.rs @@ -0,0 +1,1053 @@ +//! NTLMv2 server-side authentication. +//! +//! Spec references (all from MS-NLMP): +//! * §2.2.1 NTLM messages (NEGOTIATE, CHALLENGE, AUTHENTICATE) +//! * §2.2.2 Common structures (AV_PAIR, NTLMv2_RESPONSE, NTLMv2_CLIENT_CHALLENGE) +//! * §3.3.2 NTLM v2 Authentication algorithm +//! * §3.4 Key derivation (`NTOWFv2`, `LMOWFv2`) +//! * §3.4.4 Message Integrity Code (MIC) +//! * §4.2.4 Known-answer test vectors +//! +//! This module implements the **server side only**. We parse incoming +//! `NEGOTIATE_MESSAGE` (Type 1) and `AUTHENTICATE_MESSAGE` (Type 3) blobs, +//! produce the `CHALLENGE_MESSAGE` (Type 2) reply, and validate the client's +//! NT response to derive a session key. + +use hmac::{Hmac, Mac}; +use md4::{Digest, Md4}; +use md5::Md5; +use rc4::Rc4; +use rc4::cipher::{KeyInit, StreamCipher}; + +use crate::proto::error::{ProtoError, ProtoResult}; + +type HmacMd5 = Hmac; + +// --- NTLMSSP signature & message types -------------------------------------- + +/// 8-byte signature `"NTLMSSP\0"` prefixing every NTLMSSP message. +pub const NTLMSSP_SIGNATURE: &[u8; 8] = b"NTLMSSP\0"; + +pub const MSG_NEGOTIATE: u32 = 0x0000_0001; +pub const MSG_CHALLENGE: u32 = 0x0000_0002; +pub const MSG_AUTHENTICATE: u32 = 0x0000_0003; + +// --- NTLMSSP negotiate flags (MS-NLMP §2.2.2.5) ----------------------------- + +pub mod flags { + pub const NTLMSSP_NEGOTIATE_UNICODE: u32 = 0x0000_0001; + pub const NTLMSSP_REQUEST_TARGET: u32 = 0x0000_0004; + pub const NTLMSSP_NEGOTIATE_SIGN: u32 = 0x0000_0010; + pub const NTLMSSP_NEGOTIATE_NTLM: u32 = 0x0000_0200; + #[cfg(test)] + pub const NTLMSSP_NEGOTIATE_ANONYMOUS: u32 = 0x0000_0800; + pub const NTLMSSP_NEGOTIATE_ALWAYS_SIGN: u32 = 0x0000_8000; + pub const NTLMSSP_TARGET_TYPE_SERVER: u32 = 0x0002_0000; + pub const NTLMSSP_NEGOTIATE_EXTENDED_SESSIONSECURITY: u32 = 0x0008_0000; + pub const NTLMSSP_NEGOTIATE_TARGET_INFO: u32 = 0x0080_0000; + pub const NTLMSSP_NEGOTIATE_VERSION: u32 = 0x0200_0000; + pub const NTLMSSP_NEGOTIATE_128: u32 = 0x2000_0000; + pub const NTLMSSP_NEGOTIATE_KEY_EXCH: u32 = 0x4000_0000; + pub const NTLMSSP_NEGOTIATE_56: u32 = 0x8000_0000; +} + +// --- AV_PAIR types (MS-NLMP §2.2.2.1) --------------------------------------- + +#[allow(dead_code)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u16)] +pub enum AvId { + Eol = 0x0000, + NbComputerName = 0x0001, + NbDomainName = 0x0002, + DnsComputerName = 0x0003, + DnsDomainName = 0x0004, + DnsTreeName = 0x0005, + Flags = 0x0006, + Timestamp = 0x0007, + SingleHost = 0x0008, + TargetName = 0x0009, + ChannelBindings = 0x000A, +} + +/// One AV_PAIR (attribute–value pair) from a target-info / authenticate-target-info list. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AvPair { + pub id: u16, + pub value: Vec, +} + +impl AvPair { + pub fn new(id: AvId, value: Vec) -> Self { + Self { + id: id as u16, + value, + } + } +} + +/// Encode a list of AV_PAIRs in wire format (each: 2-byte LE id, 2-byte LE +/// length, value bytes), terminated by an `MsvAvEOL` (id=0, len=0) entry. +pub fn encode_av_pairs(pairs: &[AvPair]) -> Vec { + let mut out = Vec::new(); + for p in pairs { + out.extend_from_slice(&p.id.to_le_bytes()); + out.extend_from_slice(&(p.value.len() as u16).to_le_bytes()); + out.extend_from_slice(&p.value); + } + // Terminator + out.extend_from_slice(&(AvId::Eol as u16).to_le_bytes()); + out.extend_from_slice(&0u16.to_le_bytes()); + out +} + +/// Decode AV_PAIRs from a byte slice; stops at (and consumes) the EOL entry. +#[cfg(test)] +pub fn decode_av_pairs(buf: &[u8]) -> ProtoResult> { + let mut out = Vec::new(); + let mut i = 0usize; + loop { + if buf.len() < i + 4 { + return Err(ProtoError::Auth("av_pair list truncated")); + } + let id = u16::from_le_bytes([buf[i], buf[i + 1]]); + let len = u16::from_le_bytes([buf[i + 2], buf[i + 3]]) as usize; + i += 4; + if id == AvId::Eol as u16 { + // EOL must have len=0; tolerate stray bytes. + break; + } + if buf.len() < i + len { + return Err(ProtoError::Auth("av_pair value truncated")); + } + out.push(AvPair { + id, + value: buf[i..i + len].to_vec(), + }); + i += len; + } + Ok(out) +} + +// --- Helpers --------------------------------------------------------------- + +/// UTF-8 → UTF-16LE bytes (no BOM). +pub fn utf16le(s: &str) -> Vec { + let mut out = Vec::with_capacity(s.len() * 2); + for unit in s.encode_utf16() { + out.extend_from_slice(&unit.to_le_bytes()); + } + out +} + +/// Decode UTF-16LE bytes into a `String` (lossy on bad surrogates). +fn utf16le_to_string(bytes: &[u8]) -> String { + let units: Vec = bytes + .chunks_exact(2) + .map(|c| u16::from_le_bytes([c[0], c[1]])) + .collect(); + String::from_utf16_lossy(&units) +} + +/// MD4 of UTF-16LE password — the "NT hash". +pub fn nt_hash(password: &str) -> [u8; 16] { + let mut h = Md4::new(); + h.update(utf16le(password)); + let out = h.finalize(); + let mut o = [0u8; 16]; + o.copy_from_slice(&out); + o +} + +/// `NTOWFv2(password, user, domain) = HMAC_MD5(NT_hash(password), UTF-16LE(UPPER(user) || domain))`. +/// +/// The user name is uppercased; the domain is **not** (per MS-NLMP §3.4 NTOWFv2). +pub fn ntowf_v2(nt_hash_bytes: &[u8; 16], user: &str, domain: &str) -> [u8; 16] { + let mut mac = HmacMd5::new_from_slice(nt_hash_bytes).expect("HMAC accepts any key length"); + mac.update(&utf16le(&user.to_uppercase())); + mac.update(&utf16le(domain)); + let res = mac.finalize().into_bytes(); + let mut out = [0u8; 16]; + out.copy_from_slice(&res); + out +} + +/// Constant-time 16-byte comparison. +fn ct_eq_16(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + return false; + } + let mut diff = 0u8; + for (x, y) in a.iter().zip(b.iter()) { + diff |= x ^ y; + } + diff == 0 +} + +// --- Read helpers for NTLMSSP fields ----------------------------------------- + +/// A `(len, max_len, offset)` field descriptor as used throughout NTLMSSP messages. +/// The parser keeps the fields used to slice payloads; `max_len` is ignored. +#[derive(Debug, Clone, Copy)] +struct Field { + len: u16, + offset: u32, +} + +fn read_field(buf: &[u8], at: usize) -> ProtoResult { + if buf.len() < at + 8 { + return Err(ProtoError::Auth("field descriptor truncated")); + } + let _max_len = u16::from_le_bytes([buf[at + 2], buf[at + 3]]); + Ok(Field { + len: u16::from_le_bytes([buf[at], buf[at + 1]]), + offset: u32::from_le_bytes([buf[at + 4], buf[at + 5], buf[at + 6], buf[at + 7]]), + }) +} + +fn slice_field(buf: &[u8], f: Field) -> ProtoResult<&[u8]> { + let start = f.offset as usize; + let end = start.saturating_add(f.len as usize); + if end > buf.len() { + return Err(ProtoError::Auth("field slice out of range")); + } + Ok(&buf[start..end]) +} + +fn check_signature(buf: &[u8], expected_msg: u32) -> ProtoResult<()> { + if buf.len() < 12 { + return Err(ProtoError::Auth("ntlmssp message too short")); + } + if &buf[..8] != NTLMSSP_SIGNATURE { + return Err(ProtoError::Auth("ntlmssp signature mismatch")); + } + let msg = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]); + if msg != expected_msg { + return Err(ProtoError::Auth("unexpected ntlmssp message type")); + } + Ok(()) +} + +// --- NEGOTIATE_MESSAGE (Type 1) --------------------------------------------- + +#[derive(Debug, Clone, Default)] +pub struct NtlmNegotiate { + pub flags: u32, + pub domain: Vec, + pub workstation: Vec, + /// Raw bytes of the original message — needed for MIC computation later. + pub raw: Vec, +} + +impl NtlmNegotiate { + /// Parse a Type 1 NEGOTIATE_MESSAGE. + /// + /// Layout (MS-NLMP §2.2.1.1): + /// ```text + /// 0 : "NTLMSSP\0" + /// 8 : MessageType = 0x01 (u32 LE) + /// 12 : NegotiateFlags (u32 LE) + /// 16 : DomainNameFields (8 bytes: len, maxlen, offset) + /// 24 : WorkstationFields (8 bytes) + /// 32 : Version (8 bytes, optional — only if NTLMSSP_NEGOTIATE_VERSION set) + /// ``` + pub fn parse(buf: &[u8]) -> ProtoResult { + check_signature(buf, MSG_NEGOTIATE)?; + if buf.len() < 32 { + return Err(ProtoError::Auth("NEGOTIATE_MESSAGE too short")); + } + let flags = u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]); + let domain_field = read_field(buf, 16)?; + let ws_field = read_field(buf, 24)?; + + // Fields may be empty (offset/len = 0) when supplied flags don't set them. + let domain = if domain_field.len == 0 { + Vec::new() + } else { + slice_field(buf, domain_field)?.to_vec() + }; + let workstation = if ws_field.len == 0 { + Vec::new() + } else { + slice_field(buf, ws_field)?.to_vec() + }; + + Ok(Self { + flags, + domain, + workstation, + raw: buf.to_vec(), + }) + } +} + +// --- CHALLENGE_MESSAGE (Type 2) --------------------------------------------- + +/// Server-side construction parameters for the CHALLENGE_MESSAGE. +#[derive(Debug, Clone)] +pub struct ChallengeParams<'a> { + pub server_challenge: [u8; 8], + pub target_name: &'a str, + pub nb_domain_name: &'a str, + pub nb_computer_name: &'a str, + pub dns_domain_name: &'a str, + pub dns_computer_name: &'a str, + /// Windows FILETIME (100-ns intervals since 1601-01-01) — caller-provided + /// so this module stays clock-free. + pub timestamp: u64, + /// Negotiated flags (already AND-ed with server policy). + pub flags: u32, +} + +/// Build a Type 2 CHALLENGE_MESSAGE. +/// +/// Layout (MS-NLMP §2.2.1.2): +/// ```text +/// 0 : "NTLMSSP\0" +/// 8 : MessageType = 0x02 +/// 12 : TargetNameFields (8 bytes) +/// 20 : NegotiateFlags (4 bytes) +/// 24 : ServerChallenge (8 bytes) +/// 32 : Reserved (8 bytes, zeroed) +/// 40 : TargetInfoFields (8 bytes) +/// 48 : Version (8 bytes) +/// 56 : Payload... +/// ``` +pub fn build_challenge(p: &ChallengeParams<'_>) -> Vec { + let target_name_utf16 = utf16le(p.target_name); + let av_pairs = vec![ + AvPair::new(AvId::NbDomainName, utf16le(p.nb_domain_name)), + AvPair::new(AvId::NbComputerName, utf16le(p.nb_computer_name)), + AvPair::new(AvId::DnsDomainName, utf16le(p.dns_domain_name)), + AvPair::new(AvId::DnsComputerName, utf16le(p.dns_computer_name)), + AvPair::new(AvId::Timestamp, p.timestamp.to_le_bytes().to_vec()), + ]; + let target_info = encode_av_pairs(&av_pairs); + + let header_len: u32 = 56; + let target_name_offset = header_len; + let target_info_offset = target_name_offset + target_name_utf16.len() as u32; + + let mut out = + Vec::with_capacity(header_len as usize + target_name_utf16.len() + target_info.len()); + // 0..8: signature + out.extend_from_slice(NTLMSSP_SIGNATURE); + // 8..12: message type + out.extend_from_slice(&MSG_CHALLENGE.to_le_bytes()); + // 12..20: TargetNameFields + let tn_len = target_name_utf16.len() as u16; + out.extend_from_slice(&tn_len.to_le_bytes()); + out.extend_from_slice(&tn_len.to_le_bytes()); + out.extend_from_slice(&target_name_offset.to_le_bytes()); + // 20..24: NegotiateFlags + out.extend_from_slice(&p.flags.to_le_bytes()); + // 24..32: ServerChallenge + out.extend_from_slice(&p.server_challenge); + // 32..40: Reserved + out.extend_from_slice(&[0u8; 8]); + // 40..48: TargetInfoFields + let ti_len = target_info.len() as u16; + out.extend_from_slice(&ti_len.to_le_bytes()); + out.extend_from_slice(&ti_len.to_le_bytes()); + out.extend_from_slice(&target_info_offset.to_le_bytes()); + // 48..56: Version (we report 6.1.7600 / NTLMSSP rev 0x0F as a stable choice). + // Per spec, only meaningful if NTLMSSP_NEGOTIATE_VERSION is set; harmless otherwise. + out.extend_from_slice(&[6, 1, 0, 0x1D, 0, 0, 0, 0x0F]); + // payload + out.extend_from_slice(&target_name_utf16); + out.extend_from_slice(&target_info); + out +} + +// --- AUTHENTICATE_MESSAGE (Type 3) ------------------------------------------ + +#[derive(Debug, Clone)] +pub struct NtlmAuthenticate { + pub flags: u32, + #[allow(dead_code)] + pub lm_response: Vec, + pub nt_response: Vec, + pub domain: String, + pub user: String, + #[allow(dead_code)] + pub workstation: String, + pub encrypted_random_session_key: Vec, + /// Optional MIC (16 bytes, zeroed in source bytes during the MIC HMAC). + pub mic: Option<[u8; 16]>, + /// Offset of the MIC field within `raw`, if present (for re-zero during validation). + pub mic_offset: Option, + /// Raw bytes of the original message — needed for MIC computation. + pub raw: Vec, +} + +impl NtlmAuthenticate { + /// Parse a Type 3 AUTHENTICATE_MESSAGE. + /// + /// Layout (MS-NLMP §2.2.1.3): + /// ```text + /// 0 : "NTLMSSP\0" + /// 8 : MessageType = 0x03 + /// 12 : LmChallengeResponseFields + /// 20 : NtChallengeResponseFields + /// 28 : DomainNameFields + /// 36 : UserNameFields + /// 44 : WorkstationFields + /// 52 : EncryptedRandomSessionKeyFields + /// 60 : NegotiateFlags (4 bytes) + /// 64 : Version (8 bytes) + /// 72 : MIC (16 bytes — present only in some versions) + /// 88 : Payload... + /// ``` + /// The MIC is present only when an `MsvAvFlags` AV_PAIR with bit 0x2 was + /// echoed by the client. We detect "MIC present" heuristically by checking + /// whether the smallest field-payload offset ≥ 88; if it is ≥ 88, bytes + /// 72..88 are interpreted as the MIC. Otherwise no MIC. + pub fn parse(buf: &[u8]) -> ProtoResult { + check_signature(buf, MSG_AUTHENTICATE)?; + if buf.len() < 64 { + return Err(ProtoError::Auth("AUTHENTICATE_MESSAGE too short")); + } + let lm_field = read_field(buf, 12)?; + let nt_field = read_field(buf, 20)?; + let domain_field = read_field(buf, 28)?; + let user_field = read_field(buf, 36)?; + let ws_field = read_field(buf, 44)?; + let key_field = read_field(buf, 52)?; + let flags = u32::from_le_bytes([buf[60], buf[61], buf[62], buf[63]]); + + // Determine where the payload starts to know whether the MIC field is present. + // The smallest non-zero offset among the fields tells us. + let mut min_off: u32 = u32::MAX; + for f in [ + lm_field, + nt_field, + domain_field, + user_field, + ws_field, + key_field, + ] { + if f.len > 0 && f.offset > 0 && f.offset < min_off { + min_off = f.offset; + } + } + + let (mic, mic_offset) = if min_off != u32::MAX && min_off as usize >= 88 && buf.len() >= 88 + { + let mut mic = [0u8; 16]; + mic.copy_from_slice(&buf[72..88]); + (Some(mic), Some(72usize)) + } else { + (None, None) + }; + + let lm_response = if lm_field.len == 0 { + Vec::new() + } else { + slice_field(buf, lm_field)?.to_vec() + }; + let nt_response = if nt_field.len == 0 { + Vec::new() + } else { + slice_field(buf, nt_field)?.to_vec() + }; + let domain_bytes = if domain_field.len == 0 { + Vec::new() + } else { + slice_field(buf, domain_field)?.to_vec() + }; + let user_bytes = if user_field.len == 0 { + Vec::new() + } else { + slice_field(buf, user_field)?.to_vec() + }; + let ws_bytes = if ws_field.len == 0 { + Vec::new() + } else { + slice_field(buf, ws_field)?.to_vec() + }; + let encrypted_random_session_key = if key_field.len == 0 { + Vec::new() + } else { + slice_field(buf, key_field)?.to_vec() + }; + + // Per NTLMSSP_NEGOTIATE_UNICODE flag, names are UTF-16LE; otherwise OEM. + // We require Unicode — we only advertise it. Decode UTF-16LE. + let domain = utf16le_to_string(&domain_bytes); + let user = utf16le_to_string(&user_bytes); + let workstation = utf16le_to_string(&ws_bytes); + + Ok(Self { + flags, + lm_response, + nt_response, + domain, + user, + workstation, + encrypted_random_session_key, + mic, + mic_offset, + raw: buf.to_vec(), + }) + } +} + +// --- Public state machine --------------------------------------------------- + +/// Identity recovered from a successful (or anonymous) authentication. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Identity { + Anonymous, + User { user: String, domain: String }, +} + +/// Successful authentication outcome: identity + 16-byte session key. +#[derive(Debug, Clone)] +pub struct AuthOutcome { + pub identity: Identity, + pub session_key: [u8; 16], +} + +/// Caller-supplied user record. We store only the precomputed NT hash — +/// callers should derive it from the password at builder time and discard +/// the plaintext. +#[derive(Debug, Clone)] +pub struct UserCreds { + pub nt_hash: [u8; 16], +} + +impl UserCreds { + /// Derive the NT hash from a plaintext password (UTF-16LE then MD4). + pub fn from_password(password: &str) -> Self { + Self { + nt_hash: nt_hash(password), + } + } + + /// Construct from a precomputed NT hash. + pub fn from_nt_hash(nt_hash: [u8; 16]) -> Self { + Self { nt_hash } + } +} + +#[derive(Debug, Clone)] +pub struct NtlmTargetInfo { + pub target_name: String, + pub nb_domain: String, + pub nb_computer: String, + pub dns_domain: String, + pub dns_computer: String, +} + +impl NtlmTargetInfo { + pub fn new( + target_name: impl Into, + nb_domain: impl Into, + nb_computer: impl Into, + dns_domain: impl Into, + dns_computer: impl Into, + ) -> Self { + Self { + target_name: target_name.into(), + nb_domain: nb_domain.into(), + nb_computer: nb_computer.into(), + dns_domain: dns_domain.into(), + dns_computer: dns_computer.into(), + } + } +} + +/// Server-side state machine driving SESSION_SETUP for a single connection. +/// +/// Lifecycle: +/// 1. `NtlmServer::new(...)` +/// 2. `step1_negotiate(blob)` — record the client's NEGOTIATE bytes (for MIC). +/// 3. `challenge()` — produce CHALLENGE_MESSAGE bytes; record them too. +/// 4. `authenticate(blob, lookup)` — validate AUTHENTICATE; return outcome. +pub struct NtlmServer { + server_challenge: [u8; 8], + target_name: String, + nb_domain: String, + nb_computer: String, + dns_domain: String, + dns_computer: String, + timestamp: u64, + /// Flags we will advertise in the CHALLENGE. + server_flags: u32, + /// Negotiated flags after considering the client's NEGOTIATE. + negotiated_flags: u32, + + /// Bytes of the client NEGOTIATE_MESSAGE (for MIC HMAC over N||C||A). + negotiate_bytes: Vec, + /// Bytes of the server CHALLENGE_MESSAGE (for MIC HMAC). + challenge_bytes: Vec, +} + +impl NtlmServer { + /// Create a new server-side acceptor. + pub fn new(server_challenge: [u8; 8], target: NtlmTargetInfo, timestamp: u64) -> Self { + // Default server flag set — what we are willing to support. + let server_flags = flags::NTLMSSP_NEGOTIATE_UNICODE + | flags::NTLMSSP_REQUEST_TARGET + | flags::NTLMSSP_NEGOTIATE_NTLM + | flags::NTLMSSP_NEGOTIATE_SIGN + | flags::NTLMSSP_NEGOTIATE_ALWAYS_SIGN + | flags::NTLMSSP_NEGOTIATE_EXTENDED_SESSIONSECURITY + | flags::NTLMSSP_TARGET_TYPE_SERVER + | flags::NTLMSSP_NEGOTIATE_TARGET_INFO + | flags::NTLMSSP_NEGOTIATE_VERSION + | flags::NTLMSSP_NEGOTIATE_128 + | flags::NTLMSSP_NEGOTIATE_KEY_EXCH + | flags::NTLMSSP_NEGOTIATE_56; + + Self { + server_challenge, + target_name: target.target_name, + nb_domain: target.nb_domain, + nb_computer: target.nb_computer, + dns_domain: target.dns_domain, + dns_computer: target.dns_computer, + timestamp, + server_flags, + negotiated_flags: server_flags, + negotiate_bytes: Vec::new(), + challenge_bytes: Vec::new(), + } + } + + /// Record the client's NEGOTIATE_MESSAGE bytes and intersect flags. + /// This must be called before `challenge()` if a MIC will be validated. + pub fn step1_negotiate(&mut self, blob: &[u8]) -> ProtoResult { + let n = NtlmNegotiate::parse(blob)?; + // Negotiate down: only keep flags both sides set, then keep our must-have ones. + self.negotiated_flags = (self.server_flags & n.flags) + | flags::NTLMSSP_NEGOTIATE_TARGET_INFO + | flags::NTLMSSP_TARGET_TYPE_SERVER + | flags::NTLMSSP_NEGOTIATE_UNICODE; + self.negotiate_bytes = n.raw.clone(); + Ok(n) + } + + /// Build the CHALLENGE_MESSAGE blob. Stores the bytes for later MIC use. + pub fn challenge(&mut self) -> Vec { + let blob = build_challenge(&ChallengeParams { + server_challenge: self.server_challenge, + target_name: &self.target_name, + nb_domain_name: &self.nb_domain, + nb_computer_name: &self.nb_computer, + dns_domain_name: &self.dns_domain, + dns_computer_name: &self.dns_computer, + timestamp: self.timestamp, + flags: self.negotiated_flags, + }); + self.challenge_bytes = blob.clone(); + blob + } + + /// Validate the AUTHENTICATE_MESSAGE. + /// + /// `lookup` is the application's user-database hook: given the user/domain + /// from the wire, return `Some(UserCreds)` if known, `None` otherwise. + /// + /// Returns `AuthOutcome::session_key` to be plugged into SMB2 KDF. + /// Anonymous logon (empty user + empty NT response) returns a zeroed key + /// and `Identity::Anonymous`. + pub fn authenticate(&self, blob: &[u8], lookup: F) -> ProtoResult + where + F: Fn(&str, &str) -> Option, + { + let auth = NtlmAuthenticate::parse(blob)?; + + // ---- Anonymous fast path. MS-NLMP §3.2.5.1.2: empty user + empty NT + // response (or single-zero-byte LM response) means anonymous logon. + if auth.user.is_empty() && auth.nt_response.is_empty() { + return Ok(AuthOutcome { + identity: Identity::Anonymous, + session_key: [0u8; 16], + }); + } + + // ---- Locate creds. + let creds = lookup(&auth.user, &auth.domain).ok_or(ProtoError::Auth("unknown user"))?; + + // ---- NTOWFv2 = HMAC_MD5(NT_hash, UTF-16LE(UPPER(user) || domain)) + let response_key_nt = ntowf_v2(&creds.nt_hash, &auth.user, &auth.domain); + + // ---- NTLMv2 response layout (MS-NLMP §2.2.2.8): + // 16 bytes NTProofStr || NTLMv2_CLIENT_CHALLENGE blob + if auth.nt_response.len() < 16 { + return Err(ProtoError::Auth("NT response too short")); + } + let (nt_proof_supplied, client_challenge) = auth.nt_response.split_at(16); + + // ---- NTProofStr = HMAC_MD5(response_key_nt, ServerChallenge || ClientChallenge) + let mut mac = HmacMd5::new_from_slice(&response_key_nt).expect("hmac key"); + mac.update(&self.server_challenge); + mac.update(client_challenge); + let nt_proof_computed = mac.finalize().into_bytes(); + + if !ct_eq_16(nt_proof_supplied, &nt_proof_computed) { + return Err(ProtoError::Auth("NT proof mismatch")); + } + + // ---- SessionBaseKey = HMAC_MD5(response_key_nt, NTProofStr) + // (MS-NLMP §3.4 — for NTLMv2, KeyExchangeKey = SessionBaseKey.) + let mut mac = HmacMd5::new_from_slice(&response_key_nt).expect("hmac key"); + mac.update(&nt_proof_computed); + let session_base_key_bytes = mac.finalize().into_bytes(); + let mut key_exchange_key = [0u8; 16]; + key_exchange_key.copy_from_slice(&session_base_key_bytes); + + // ---- Optional RC4-wrapped random session key. + let session_key = if (auth.flags & flags::NTLMSSP_NEGOTIATE_KEY_EXCH) != 0 + && !auth.encrypted_random_session_key.is_empty() + { + if auth.encrypted_random_session_key.len() != 16 { + return Err(ProtoError::Auth("encrypted session key not 16 bytes")); + } + let mut buf = [0u8; 16]; + buf.copy_from_slice(&auth.encrypted_random_session_key); + // RC4(KeyExchangeKey) over the encrypted session key. + let mut rc4 = Rc4::new_from_slice(&key_exchange_key) + .map_err(|_| ProtoError::Auth("rc4 key length"))?; + rc4.apply_keystream(&mut buf); + buf + } else { + key_exchange_key + }; + + // ---- MIC validation: HMAC_MD5(SessionKey, NEGOTIATE || CHALLENGE || AUTHENTICATE-with-MIC-zeroed). + // We only validate if the client supplied a MIC (i.e. presence + // detected during parse) AND we actually have the negotiate/challenge + // bytes. If absent, treat as not supplied. This v1 server does not + // enforce MsvAvFlags bit 0x2 from the challenge target-info. + if let (Some(mic_off), true) = (auth.mic_offset, !self.negotiate_bytes.is_empty()) + && let Some(supplied) = auth.mic + { + let mut auth_zeroed = auth.raw.clone(); + if auth_zeroed.len() < mic_off + 16 { + return Err(ProtoError::Auth("MIC offset out of range")); + } + for b in &mut auth_zeroed[mic_off..mic_off + 16] { + *b = 0; + } + let mut mac = HmacMd5::new_from_slice(&session_key).expect("hmac key"); + mac.update(&self.negotiate_bytes); + mac.update(&self.challenge_bytes); + mac.update(&auth_zeroed); + let computed = mac.finalize().into_bytes(); + if !ct_eq_16(&supplied, &computed) { + return Err(ProtoError::Auth("MIC mismatch")); + } + } + + Ok(AuthOutcome { + identity: Identity::User { + user: auth.user.clone(), + domain: auth.domain.clone(), + }, + session_key, + }) + } +} + +// =========================================================================== +// Tests +// =========================================================================== + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn av_pair_round_trip() { + let pairs = vec![ + AvPair::new(AvId::NbDomainName, utf16le("DOMAIN")), + AvPair::new(AvId::NbComputerName, utf16le("SERVER")), + AvPair::new( + AvId::Timestamp, + 0x1234_5678_9abc_def0u64.to_le_bytes().to_vec(), + ), + ]; + let bytes = encode_av_pairs(&pairs); + let decoded = decode_av_pairs(&bytes).unwrap(); + assert_eq!(decoded, pairs); + } + + #[test] + fn negotiate_round_trip() { + // Build a minimal Type 1 by hand and parse it. + let mut buf = Vec::new(); + buf.extend_from_slice(NTLMSSP_SIGNATURE); + buf.extend_from_slice(&MSG_NEGOTIATE.to_le_bytes()); + let flags = flags::NTLMSSP_NEGOTIATE_UNICODE + | flags::NTLMSSP_NEGOTIATE_NTLM + | flags::NTLMSSP_NEGOTIATE_TARGET_INFO; + buf.extend_from_slice(&flags.to_le_bytes()); + // Domain + workstation fields: empty. + buf.extend_from_slice(&[0u8; 8]); + buf.extend_from_slice(&[0u8; 8]); + // Version (8 bytes). + buf.extend_from_slice(&[0u8; 8]); + + let n = NtlmNegotiate::parse(&buf).unwrap(); + assert_eq!(n.flags, flags); + assert!(n.domain.is_empty()); + assert!(n.workstation.is_empty()); + } + + #[test] + fn challenge_round_trip_structure() { + let blob = build_challenge(&ChallengeParams { + server_challenge: [1, 2, 3, 4, 5, 6, 7, 8], + target_name: "SERVER", + nb_domain_name: "DOMAIN", + nb_computer_name: "SERVER", + dns_domain_name: "domain.local", + dns_computer_name: "server.domain.local", + timestamp: 0, + flags: flags::NTLMSSP_NEGOTIATE_UNICODE + | flags::NTLMSSP_NEGOTIATE_NTLM + | flags::NTLMSSP_NEGOTIATE_TARGET_INFO, + }); + // Signature + message type. + assert_eq!(&blob[..8], NTLMSSP_SIGNATURE); + assert_eq!( + u32::from_le_bytes([blob[8], blob[9], blob[10], blob[11]]), + MSG_CHALLENGE + ); + // Server challenge at offset 24. + assert_eq!(&blob[24..32], &[1, 2, 3, 4, 5, 6, 7, 8]); + // Decode AV_PAIRs from the target-info section. + let ti_off = u32::from_le_bytes([blob[44], blob[45], blob[46], blob[47]]) as usize; + let ti_len = u16::from_le_bytes([blob[40], blob[41]]) as usize; + let av = decode_av_pairs(&blob[ti_off..ti_off + ti_len]).unwrap(); + assert!(av.iter().any(|p| p.id == AvId::NbDomainName as u16)); + assert!(av.iter().any(|p| p.id == AvId::Timestamp as u16)); + } + + /// MS-NLMP §4.2.4 known-answer test for NTLMv2: + /// User="User", Domain="Domain", Password="Password" + /// ServerChallenge = 01 23 45 67 89 ab cd ef + /// ClientChallenge AV-pair blob = 01 01 00 00 00 00 00 00 + /// 00 00 00 00 00 00 00 00 + /// aa aa aa aa aa aa aa aa + /// 00 00 00 00 02 00 0c 00 + /// 44 00 6f 00 6d 00 61 00 + /// 69 00 6e 00 01 00 0c 00 + /// 53 00 65 00 72 00 76 00 + /// 65 00 72 00 00 00 00 00 + /// 00 00 00 00 + /// Expected NTProofStr = 68 cd 0a b8 51 e5 1c 96 aa bc 92 7b eb ef 6a 1c + /// (Note: there are several editions of MS-NLMP with subtly different + /// vectors; this matches the §4.2.4.1.3 vector that includes the trailing + /// 4 zero bytes, common across recent revisions.) + #[test] + fn ntlmv2_known_answer() { + let nt = nt_hash("Password"); + // NT hash of "Password" — MS-NLMP §4.2.2.1.4: a4 f4 9c 40 65 10 bd ca b6 82 4e e7 c3 0f d8 52 + assert_eq!( + nt, + [ + 0xa4, 0xf4, 0x9c, 0x40, 0x65, 0x10, 0xbd, 0xca, 0xb6, 0x82, 0x4e, 0xe7, 0xc3, 0x0f, + 0xd8, 0x52 + ] + ); + + // NTOWFv2("Password","User","Domain") + // MS-NLMP §4.2.4.1.1: 0c 86 8a 40 3b fd 7a 93 a3 00 1e f2 2e f0 2e 3f + let key_nt = ntowf_v2(&nt, "User", "Domain"); + assert_eq!( + key_nt, + [ + 0x0c, 0x86, 0x8a, 0x40, 0x3b, 0xfd, 0x7a, 0x93, 0xa3, 0x00, 0x1e, 0xf2, 0x2e, 0xf0, + 0x2e, 0x3f + ] + ); + + let server_challenge: [u8; 8] = [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef]; + // NTLMv2_CLIENT_CHALLENGE blob from §4.2.4.1.3 + let client_challenge_blob: &[u8] = &[ + 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RespType, HiRespType, Reserved + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // TimeStamp + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, // ChallengeFromClient + 0x00, 0x00, 0x00, 0x00, // Reserved + // AV pairs + 0x02, 0x00, 0x0c, 0x00, // MsvAvNbDomainName, len=12 + 0x44, 0x00, 0x6f, 0x00, 0x6d, 0x00, 0x61, 0x00, 0x69, 0x00, 0x6e, + 0x00, // "Domain" + 0x01, 0x00, 0x0c, 0x00, // MsvAvNbComputerName, len=12 + 0x53, 0x00, 0x65, 0x00, 0x72, 0x00, 0x76, 0x00, 0x65, 0x00, 0x72, + 0x00, // "Server" + 0x00, 0x00, 0x00, 0x00, // EOL + 0x00, 0x00, 0x00, 0x00, // trailing 4 zero bytes (padding seen in spec) + ]; + + let mut mac = HmacMd5::new_from_slice(&key_nt).unwrap(); + mac.update(&server_challenge); + mac.update(client_challenge_blob); + let nt_proof = mac.finalize().into_bytes(); + + // MS-NLMP §4.2.4.2.2: + // NTProofStr = 68 cd 0a b8 51 e5 1c 96 aa bc 92 7b eb ef 6a 1c + assert_eq!( + nt_proof.as_slice(), + [ + 0x68, 0xcd, 0x0a, 0xb8, 0x51, 0xe5, 0x1c, 0x96, 0xaa, 0xbc, 0x92, 0x7b, 0xeb, 0xef, + 0x6a, 0x1c + ] + ); + } + + #[test] + fn server_round_trip_authenticates_user() { + // End-to-end: build a fake AUTHENTICATE_MESSAGE with a known proof and + // make sure NtlmServer accepts it. + let mut srv = NtlmServer::new( + [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef], + NtlmTargetInfo::new( + "SERVER", + "DOMAIN", + "SERVER", + "domain.local", + "server.domain.local", + ), + 0, + ); + // Skip step1_negotiate — MIC will be absent. + let _challenge = srv.challenge(); + + // Compute NTProofStr the same way the client would. + let nt = nt_hash("Password"); + let key_nt = ntowf_v2(&nt, "User", "Domain"); + let client_challenge_blob: Vec = vec![ + 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0x00, 0x00, 0x00, 0x00, + // Empty AV pair list (EOL only) + 0x00, 0x00, 0x00, 0x00, + ]; + let mut mac = HmacMd5::new_from_slice(&key_nt).unwrap(); + mac.update(&[0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef]); + mac.update(&client_challenge_blob); + let nt_proof = mac.finalize().into_bytes(); + + let mut nt_response = Vec::new(); + nt_response.extend_from_slice(&nt_proof); + nt_response.extend_from_slice(&client_challenge_blob); + + // Build AUTHENTICATE_MESSAGE. + let user_u16 = utf16le("User"); + let dom_u16 = utf16le("Domain"); + let ws_u16 = utf16le("CLIENT"); + let lm_response: Vec = vec![0u8; 24]; + + // Layout: header is 72 bytes when no MIC is present + // (signature 8 + msgtype 4 + 6×8-byte fields + flags 4 + version 8 = 72). + // With MIC, it would be 88. + let header_len: u32 = 72; + let mut payload = Vec::new(); + let lm_off = header_len; + payload.extend_from_slice(&lm_response); + let nt_off = header_len + payload.len() as u32; + payload.extend_from_slice(&nt_response); + let dom_off = header_len + payload.len() as u32; + payload.extend_from_slice(&dom_u16); + let user_off = header_len + payload.len() as u32; + payload.extend_from_slice(&user_u16); + let ws_off = header_len + payload.len() as u32; + payload.extend_from_slice(&ws_u16); + let key_off = header_len + payload.len() as u32; + // No encrypted session key. + + let mut buf = Vec::new(); + buf.extend_from_slice(NTLMSSP_SIGNATURE); + buf.extend_from_slice(&MSG_AUTHENTICATE.to_le_bytes()); + // Lm + buf.extend_from_slice(&(lm_response.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(lm_response.len() as u16).to_le_bytes()); + buf.extend_from_slice(&lm_off.to_le_bytes()); + // Nt + buf.extend_from_slice(&(nt_response.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(nt_response.len() as u16).to_le_bytes()); + buf.extend_from_slice(&nt_off.to_le_bytes()); + // Domain + buf.extend_from_slice(&(dom_u16.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(dom_u16.len() as u16).to_le_bytes()); + buf.extend_from_slice(&dom_off.to_le_bytes()); + // User + buf.extend_from_slice(&(user_u16.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(user_u16.len() as u16).to_le_bytes()); + buf.extend_from_slice(&user_off.to_le_bytes()); + // Workstation + buf.extend_from_slice(&(ws_u16.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(ws_u16.len() as u16).to_le_bytes()); + buf.extend_from_slice(&ws_off.to_le_bytes()); + // EncryptedRandomSessionKey + buf.extend_from_slice(&0u16.to_le_bytes()); + buf.extend_from_slice(&0u16.to_le_bytes()); + buf.extend_from_slice(&key_off.to_le_bytes()); + // Flags + buf.extend_from_slice(&flags::NTLMSSP_NEGOTIATE_UNICODE.to_le_bytes()); + // Version (8 bytes) + buf.extend_from_slice(&[0u8; 8]); + // No MIC — header is 64 bytes flat. + assert_eq!(buf.len() as u32, header_len); + buf.extend_from_slice(&payload); + + let creds = UserCreds::from_password("Password"); + let outcome = srv + .authenticate(&buf, |u, d| { + if u == "User" && d == "Domain" { + Some(creds.clone()) + } else { + None + } + }) + .expect("auth should succeed"); + + assert_eq!( + outcome.identity, + Identity::User { + user: "User".to_string(), + domain: "Domain".to_string() + } + ); + + // Wrong password should fail with constant-time mismatch. + let bad = UserCreds::from_password("WrongPassword"); + let err = srv + .authenticate(&buf, |_u, _d| Some(bad.clone())) + .unwrap_err(); + assert!(matches!(err, ProtoError::Auth(_))); + } + + #[test] + fn anonymous_logon() { + let mut srv = NtlmServer::new( + [0u8; 8], + NtlmTargetInfo::new("SERVER", "DOMAIN", "SERVER", "d.local", "s.d.local"), + 0, + ); + let _ = srv.challenge(); + + // Build an AUTHENTICATE_MESSAGE with empty user + empty NT response. + let header_len: u32 = 72; + let mut buf = Vec::new(); + buf.extend_from_slice(NTLMSSP_SIGNATURE); + buf.extend_from_slice(&MSG_AUTHENTICATE.to_le_bytes()); + for _ in 0..6 { + // 6 empty fields (Lm, Nt, Domain, User, Workstation, Key) + buf.extend_from_slice(&0u16.to_le_bytes()); + buf.extend_from_slice(&0u16.to_le_bytes()); + buf.extend_from_slice(&header_len.to_le_bytes()); + } + buf.extend_from_slice(&flags::NTLMSSP_NEGOTIATE_ANONYMOUS.to_le_bytes()); + buf.extend_from_slice(&[0u8; 8]); // version + + let outcome = srv + .authenticate(&buf, |_u, _d| None) + .expect("anonymous should succeed"); + assert_eq!(outcome.identity, Identity::Anonymous); + assert_eq!(outcome.session_key, [0u8; 16]); + } +} diff --git a/vendor/smb-server/src/proto/auth/spnego.rs b/vendor/smb-server/src/proto/auth/spnego.rs new file mode 100644 index 0000000..ecf863a --- /dev/null +++ b/vendor/smb-server/src/proto/auth/spnego.rs @@ -0,0 +1,524 @@ +//! Minimal hand-rolled DER codec for SPNEGO (MS-SPNG / RFC 4178). +//! +//! v1 advertises **only** the NTLMSSP mechanism. We don't pull in a full +//! ASN.1 crate; this is a tiny subset of DER for the few SPNEGO tokens we +//! need to encode/decode during SESSION_SETUP. +//! +//! ASN.1 sketch: +//! +//! ```text +//! GSSAPI-Token (RFC 2743) ::= [APPLICATION 0] IMPLICIT SEQUENCE { +//! thisMech OBJECT IDENTIFIER, -- SPNEGO 1.3.6.1.5.5.2 +//! innerContextToken ANY DEFINED BY thisMech +//! } +//! +//! NegotiationToken ::= CHOICE { +//! negTokenInit [0] NegTokenInit, +//! negTokenResp [1] NegTokenResp +//! } +//! +//! NegTokenInit ::= SEQUENCE { +//! mechTypes [0] MechTypeList, +//! reqFlags [1] ContextFlags OPTIONAL, +//! mechToken [2] OCTET STRING OPTIONAL, +//! mechListMIC [3] OCTET STRING OPTIONAL +//! } +//! +//! NegTokenResp ::= SEQUENCE { +//! negState [0] ENUMERATED OPTIONAL, +//! supportedMech [1] OBJECT IDENTIFIER OPTIONAL, +//! responseToken [2] OCTET STRING OPTIONAL, +//! mechListMIC [3] OCTET STRING OPTIONAL +//! } +//! ``` + +use crate::proto::error::{ProtoError, ProtoResult}; + +// --- Universal & well-known tags -------------------------------------------- + +const TAG_SEQUENCE: u8 = 0x30; // SEQUENCE OF / SEQUENCE (constructed) +const TAG_OBJECT: u8 = 0x06; // OBJECT IDENTIFIER +const TAG_OCTET: u8 = 0x04; // OCTET STRING +const TAG_ENUMERATED: u8 = 0x0a; // ENUMERATED + +const TAG_APP_0: u8 = 0x60; // [APPLICATION 0] IMPLICIT — GSS-API outer +const TAG_CTX_0: u8 = 0xa0; +const TAG_CTX_1: u8 = 0xa1; +const TAG_CTX_2: u8 = 0xa2; +const TAG_CTX_3: u8 = 0xa3; + +// --- OIDs ------------------------------------------------------------------ + +/// SPNEGO `1.3.6.1.5.5.2` encoded as the *content* of an OBJECT IDENTIFIER +/// (i.e. **without** the leading 0x06 tag + length). +pub const OID_SPNEGO: &[u8] = &[0x2b, 0x06, 0x01, 0x05, 0x05, 0x02]; + +/// NTLMSSP `1.3.6.1.4.1.311.2.2.10` encoded as OID *content*. +pub const OID_NTLMSSP: &[u8] = &[0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x02, 0x02, 0x0a]; + +// --- NegState -------------------------------------------------------------- + +/// Values of the `negState` field in NegTokenResp. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum NegState { + AcceptCompleted = 0, + AcceptIncomplete = 1, + Reject = 2, + RequestMic = 3, +} + +impl NegState { + fn from_byte(b: u8) -> ProtoResult { + match b { + 0 => Ok(NegState::AcceptCompleted), + 1 => Ok(NegState::AcceptIncomplete), + 2 => Ok(NegState::Reject), + 3 => Ok(NegState::RequestMic), + _ => Err(ProtoError::Auth("invalid NegState")), + } + } +} + +// --- DER length helpers ---------------------------------------------------- + +/// Encode a DER length (definite-length form, MS-SPNG always uses definite). +fn der_len(n: usize, out: &mut Vec) { + if n < 0x80 { + out.push(n as u8); + return; + } + // Long form. Find minimum number of bytes. + let mut tmp = [0u8; 8]; + let mut nb = 0; + let mut v = n; + while v > 0 { + tmp[nb] = (v & 0xff) as u8; + v >>= 8; + nb += 1; + } + out.push(0x80 | nb as u8); + for i in (0..nb).rev() { + out.push(tmp[i]); + } +} + +/// Read a DER length from `buf` starting at `pos`. Returns `(length, next_pos)`. +fn read_len(buf: &[u8], pos: usize) -> ProtoResult<(usize, usize)> { + if pos >= buf.len() { + return Err(ProtoError::Auth("DER length truncated")); + } + let first = buf[pos]; + if first < 0x80 { + return Ok((first as usize, pos + 1)); + } + let nb = (first & 0x7f) as usize; + if nb == 0 || nb > 4 { + // Indefinite (nb=0) — never used by SPNEGO. + // We cap at 4 bytes (max ~4 GiB), more than enough for tokens. + return Err(ProtoError::Auth("DER length form unsupported")); + } + if pos + 1 + nb > buf.len() { + return Err(ProtoError::Auth("DER length truncated")); + } + let mut v = 0usize; + for i in 0..nb { + v = (v << 8) | buf[pos + 1 + i] as usize; + } + Ok((v, pos + 1 + nb)) +} + +/// Read `(tag, content_slice, next_pos)` at `pos`. Verifies the expected tag. +fn read_tlv(buf: &[u8], pos: usize, expected_tag: u8) -> ProtoResult<(&[u8], usize)> { + if pos >= buf.len() { + return Err(ProtoError::Auth("DER tag truncated")); + } + if buf[pos] != expected_tag { + return Err(ProtoError::Auth("unexpected DER tag")); + } + let (len, after_len) = read_len(buf, pos + 1)?; + let end = after_len + len; + if end > buf.len() { + return Err(ProtoError::Auth("DER content truncated")); + } + Ok((&buf[after_len..end], end)) +} + +/// Read any TLV (returning its tag plus the content slice & end position). +fn read_any_tlv(buf: &[u8], pos: usize) -> ProtoResult<(u8, &[u8], usize)> { + if pos >= buf.len() { + return Err(ProtoError::Auth("DER tag truncated")); + } + let tag = buf[pos]; + let (len, after_len) = read_len(buf, pos + 1)?; + let end = after_len + len; + if end > buf.len() { + return Err(ProtoError::Auth("DER content truncated")); + } + Ok((tag, &buf[after_len..end], end)) +} + +// --- TLV writer helper ----------------------------------------------------- + +fn write_tlv(tag: u8, content: &[u8], out: &mut Vec) { + out.push(tag); + der_len(content.len(), out); + out.extend_from_slice(content); +} + +// --- Public API ------------------------------------------------------------ + +/// Decoded `NegTokenInit` payload — only the bits we care about. +#[derive(Debug, Clone)] +pub struct NegTokenInit { + /// List of mechanism OIDs (each entry is the OID content bytes, no 0x06 tag). + pub mech_types: Vec>, + /// `mechToken [2]` if present — typically the NTLMSSP NEGOTIATE_MESSAGE bytes. + pub mech_token: Option>, +} + +/// Decoded `NegTokenResp` payload. +#[derive(Debug, Clone, Default)] +pub struct NegTokenResp { + pub neg_state: Option, + /// `supportedMech [1]` (OID content bytes). + pub supported_mech: Option>, + /// `responseToken [2]` — typically inner NTLMSSP CHALLENGE/AUTHENTICATE bytes. + pub response_token: Option>, + pub mech_list_mic: Option>, +} + +/// Decode the **initial** SPNEGO blob from the client. This is wrapped in +/// the GSS-API outer `[APPLICATION 0]` tag, contains a `thisMech` OID +/// (SPNEGO), and a `[0] NegTokenInit`. +/// +/// Returns the parsed `NegTokenInit`. +pub fn decode_init_token(buf: &[u8]) -> ProtoResult { + // [APPLICATION 0] IMPLICIT SEQUENCE { thisMech OID, NegotiationToken } + let (gss_inner, _end) = read_tlv(buf, 0, TAG_APP_0)?; + + // thisMech + let (mech, after_mech) = read_tlv(gss_inner, 0, TAG_OBJECT)?; + if mech != OID_SPNEGO { + return Err(ProtoError::Auth("not an SPNEGO token")); + } + + // NegotiationToken — choice tagged [0] for init. + let (init_inner, _) = read_tlv(gss_inner, after_mech, TAG_CTX_0)?; + parse_neg_token_init_body(init_inner) +} + +fn parse_neg_token_init_body(inner: &[u8]) -> ProtoResult { + // Inner is a SEQUENCE. + let (seq_body, _) = read_tlv(inner, 0, TAG_SEQUENCE)?; + let mut pos = 0usize; + let mut mech_types: Vec> = Vec::new(); + let mut mech_token: Option> = None; + + while pos < seq_body.len() { + let (tag, content, next) = read_any_tlv(seq_body, pos)?; + match tag { + TAG_CTX_0 => { + // mechTypes [0] MechTypeList ::= SEQUENCE OF MechType (OID) + let (mt_seq, _) = read_tlv(content, 0, TAG_SEQUENCE)?; + let mut p = 0usize; + while p < mt_seq.len() { + let (oid, e) = read_tlv(mt_seq, p, TAG_OBJECT)?; + mech_types.push(oid.to_vec()); + p = e; + } + } + TAG_CTX_1 => { + // reqFlags — ignored. + } + TAG_CTX_2 => { + // mechToken [2] OCTET STRING + let (oct, _) = read_tlv(content, 0, TAG_OCTET)?; + mech_token = Some(oct.to_vec()); + } + TAG_CTX_3 => { + // mechListMIC — ignored on init. + } + _ => { + // Unknown — skip silently (forward-compat). + } + } + pos = next; + } + + Ok(NegTokenInit { + mech_types, + mech_token, + }) +} + +/// Decode a subsequent `NegTokenResp`. These are sent without the GSS-API +/// outer wrapper — they begin directly with the `[1]` choice tag. +pub fn decode_resp_token(buf: &[u8]) -> ProtoResult { + let (resp_inner, _) = read_tlv(buf, 0, TAG_CTX_1)?; + let (seq_body, _) = read_tlv(resp_inner, 0, TAG_SEQUENCE)?; + let mut pos = 0usize; + let mut out = NegTokenResp::default(); + + while pos < seq_body.len() { + let (tag, content, next) = read_any_tlv(seq_body, pos)?; + match tag { + TAG_CTX_0 => { + let (en, _) = read_tlv(content, 0, TAG_ENUMERATED)?; + if en.len() != 1 { + return Err(ProtoError::Auth("NegState ENUMERATED not 1 byte")); + } + out.neg_state = Some(NegState::from_byte(en[0])?); + } + TAG_CTX_1 => { + let (oid, _) = read_tlv(content, 0, TAG_OBJECT)?; + out.supported_mech = Some(oid.to_vec()); + } + TAG_CTX_2 => { + let (oct, _) = read_tlv(content, 0, TAG_OCTET)?; + out.response_token = Some(oct.to_vec()); + } + TAG_CTX_3 => { + let (oct, _) = read_tlv(content, 0, TAG_OCTET)?; + out.mech_list_mic = Some(oct.to_vec()); + } + _ => {} + } + pos = next; + } + + Ok(out) +} + +/// Encode the **initial** server response to NEGOTIATE — a GSS-API-wrapped +/// `NegTokenInit` advertising NTLMSSP only. Used during SMB2 NEGOTIATE +/// when the server publishes its security blob. +pub fn encode_init_response() -> Vec { + // mechTypes SEQUENCE { OID NTLMSSP } + let mut mech_types_seq = Vec::new(); + write_tlv(TAG_OBJECT, OID_NTLMSSP, &mut mech_types_seq); + let mut mech_types_outer = Vec::new(); + write_tlv(TAG_SEQUENCE, &mech_types_seq, &mut mech_types_outer); + // mechTypes is [0] tagged. + let mut mech_types_ctx0 = Vec::new(); + write_tlv(TAG_CTX_0, &mech_types_outer, &mut mech_types_ctx0); + + // NegTokenInit SEQUENCE { mechTypes [0] } + let mut neg_token_init = Vec::new(); + write_tlv(TAG_SEQUENCE, &mech_types_ctx0, &mut neg_token_init); + + // [0] NegTokenInit (negotiationToken choice) + let mut choice_init = Vec::new(); + write_tlv(TAG_CTX_0, &neg_token_init, &mut choice_init); + + // Inside [APPLICATION 0]: { OID SPNEGO, [0] NegTokenInit } + let mut gss_inner = Vec::new(); + write_tlv(TAG_OBJECT, OID_SPNEGO, &mut gss_inner); + gss_inner.extend_from_slice(&choice_init); + + let mut out = Vec::new(); + write_tlv(TAG_APP_0, &gss_inner, &mut out); + out +} + +/// Encode a `NegTokenResp` wrapping the server's response token (typically +/// the NTLMSSP CHALLENGE_MESSAGE or a final empty-token AcceptCompleted). +/// +/// `supported_mech` is included only with `AcceptIncomplete` (i.e. the very +/// first response to a NegTokenInit) — per RFC 4178 §4.2.2. +pub fn encode_resp_token( + state: NegState, + supported_mech: Option<&[u8]>, + response_token: Option<&[u8]>, + mech_list_mic: Option<&[u8]>, +) -> Vec { + let mut seq = Vec::new(); + + // [0] negState + { + let mut en = Vec::new(); + write_tlv(TAG_ENUMERATED, &[state as u8], &mut en); + write_tlv(TAG_CTX_0, &en, &mut seq); + } + // [1] supportedMech + if let Some(oid) = supported_mech { + let mut o = Vec::new(); + write_tlv(TAG_OBJECT, oid, &mut o); + write_tlv(TAG_CTX_1, &o, &mut seq); + } + // [2] responseToken + if let Some(tok) = response_token { + let mut o = Vec::new(); + write_tlv(TAG_OCTET, tok, &mut o); + write_tlv(TAG_CTX_2, &o, &mut seq); + } + // [3] mechListMIC + if let Some(mic) = mech_list_mic { + let mut o = Vec::new(); + write_tlv(TAG_OCTET, mic, &mut o); + write_tlv(TAG_CTX_3, &o, &mut seq); + } + + let mut inner = Vec::new(); + write_tlv(TAG_SEQUENCE, &seq, &mut inner); + let mut out = Vec::new(); + write_tlv(TAG_CTX_1, &inner, &mut out); + out +} + +// =========================================================================== +// Tests +// =========================================================================== + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn der_len_short() { + let mut v = Vec::new(); + der_len(0x42, &mut v); + assert_eq!(v, [0x42]); + } + + #[test] + fn der_len_long_one_byte() { + let mut v = Vec::new(); + der_len(0xC8, &mut v); + assert_eq!(v, [0x81, 0xC8]); + } + + #[test] + fn der_len_long_two_byte() { + let mut v = Vec::new(); + der_len(0x1234, &mut v); + assert_eq!(v, [0x82, 0x12, 0x34]); + } + + #[test] + fn read_len_round_trip() { + for n in [0usize, 1, 0x7F, 0x80, 0xFF, 0x100, 0xFFFF, 0x10000] { + let mut buf = Vec::new(); + der_len(n, &mut buf); + let (got, next) = read_len(&buf, 0).unwrap(); + assert_eq!(got, n); + assert_eq!(next, buf.len()); + } + } + + #[test] + fn init_response_is_decodable() { + let blob = encode_init_response(); + // Must start with [APPLICATION 0] (0x60) tag. + assert_eq!(blob[0], TAG_APP_0); + // Decode with our own decoder going via decode_init_token. + // We craft a synthetic "init" by appending an empty mechToken? — not + // needed; decode_init_token tolerates absence. Test that the OID and + // the [0] mechTypes are reachable. + let init = decode_init_token(&blob).unwrap(); + assert_eq!(init.mech_types.len(), 1); + assert_eq!(init.mech_types[0], OID_NTLMSSP); + assert!(init.mech_token.is_none()); + } + + #[test] + fn resp_token_round_trip_with_response() { + let payload = b"\x01\x02\x03\x04inner-blob"; + let enc = encode_resp_token( + NegState::AcceptIncomplete, + Some(OID_NTLMSSP), + Some(payload), + None, + ); + let dec = decode_resp_token(&enc).unwrap(); + assert_eq!(dec.neg_state, Some(NegState::AcceptIncomplete)); + assert_eq!(dec.supported_mech.as_deref(), Some(OID_NTLMSSP)); + assert_eq!(dec.response_token.as_deref(), Some(&payload[..])); + assert!(dec.mech_list_mic.is_none()); + } + + #[test] + fn resp_token_round_trip_completed() { + let enc = encode_resp_token(NegState::AcceptCompleted, None, None, None); + let dec = decode_resp_token(&enc).unwrap(); + assert_eq!(dec.neg_state, Some(NegState::AcceptCompleted)); + assert!(dec.supported_mech.is_none()); + assert!(dec.response_token.is_none()); + } + + #[test] + fn resp_token_with_mic() { + let mic = vec![0xAAu8; 16]; + let enc = encode_resp_token(NegState::AcceptCompleted, None, None, Some(&mic)); + let dec = decode_resp_token(&enc).unwrap(); + assert_eq!(dec.mech_list_mic.as_deref(), Some(mic.as_slice())); + } + + /// Build a NegTokenInit by hand (containing a mechToken) and decode it. + #[test] + fn decode_init_with_mech_token() { + let inner_token = b"NTLMSSP\x00fakeNegotiate"; + + // mechTypes + let mut mts = Vec::new(); + write_tlv(TAG_OBJECT, OID_NTLMSSP, &mut mts); + let mut mts_seq = Vec::new(); + write_tlv(TAG_SEQUENCE, &mts, &mut mts_seq); + let mut mts_ctx0 = Vec::new(); + write_tlv(TAG_CTX_0, &mts_seq, &mut mts_ctx0); + + // mechToken [2] OCTET STRING + let mut mt_oct = Vec::new(); + write_tlv(TAG_OCTET, inner_token, &mut mt_oct); + let mut mt_ctx2 = Vec::new(); + write_tlv(TAG_CTX_2, &mt_oct, &mut mt_ctx2); + + // SEQUENCE { [0] mechTypes, [2] mechToken } + let mut seq = Vec::new(); + seq.extend_from_slice(&mts_ctx0); + seq.extend_from_slice(&mt_ctx2); + + let mut neg_token_init = Vec::new(); + write_tlv(TAG_SEQUENCE, &seq, &mut neg_token_init); + + let mut choice = Vec::new(); + write_tlv(TAG_CTX_0, &neg_token_init, &mut choice); + + let mut gss_inner = Vec::new(); + write_tlv(TAG_OBJECT, OID_SPNEGO, &mut gss_inner); + gss_inner.extend_from_slice(&choice); + + let mut blob = Vec::new(); + write_tlv(TAG_APP_0, &gss_inner, &mut blob); + + let dec = decode_init_token(&blob).unwrap(); + assert_eq!(dec.mech_types.len(), 1); + assert_eq!(dec.mech_types[0], OID_NTLMSSP); + assert_eq!(dec.mech_token.as_deref(), Some(&inner_token[..])); + } + + #[test] + fn rejects_non_spnego_oid() { + // Build a GSS token with a different OID inside. + let bad_oid = [0x2bu8, 0x06, 0x01, 0x01, 0x01, 0x01]; + let mut gss_inner = Vec::new(); + write_tlv(TAG_OBJECT, &bad_oid, &mut gss_inner); + // Empty [0] payload. + let mut empty = Vec::new(); + write_tlv(TAG_SEQUENCE, &[], &mut empty); + let mut choice = Vec::new(); + write_tlv(TAG_CTX_0, &empty, &mut choice); + gss_inner.extend_from_slice(&choice); + let mut blob = Vec::new(); + write_tlv(TAG_APP_0, &gss_inner, &mut blob); + + let err = decode_init_token(&blob).unwrap_err(); + assert!(matches!(err, ProtoError::Auth(_))); + } + + #[test] + fn rejects_truncated_blob() { + let err = decode_init_token(&[0x60, 0x05, 0xAA, 0xBB]).unwrap_err(); + assert!(matches!(err, ProtoError::Auth(_))); + } +} diff --git a/vendor/smb-server/src/proto/crypto.rs b/vendor/smb-server/src/proto/crypto.rs new file mode 100644 index 0000000..a67fd19 --- /dev/null +++ b/vendor/smb-server/src/proto/crypto.rs @@ -0,0 +1,20 @@ +//! SMB signing, key derivation, pre-auth integrity. +//! +//! Submodules: +//! * [`kdf`] — SP 800-108 CTR-mode KDF (`SMB2KDF`) and SMB-specific +//! signing/application key helpers (MS-SMB2 §3.1.4.2). +//! * [`sign`] — HMAC-SHA-256 (SMB 2.x) and AES-CMAC (SMB 3.x) signing of +//! SMB2 messages (MS-SMB2 §3.1.4.1). +//! * [`preauth`] — SMB 3.1.1 pre-auth integrity running SHA-512 hash +//! (MS-SMB2 §3.1.4.4.1, §3.3.5.4). +//! +//! Encryption (AES-CCM/AES-GCM) is intentionally out of scope for v1; see the +//! design spec. + +pub mod kdf; +pub mod preauth; +pub mod sign; + +pub use kdf::{signing_key_30, signing_key_311}; +pub use preauth::PreauthIntegrity; +pub use sign::{SigningAlgo, sign, verify}; diff --git a/vendor/smb-server/src/proto/crypto/kdf.rs b/vendor/smb-server/src/proto/crypto/kdf.rs new file mode 100644 index 0000000..aff4d25 --- /dev/null +++ b/vendor/smb-server/src/proto/crypto/kdf.rs @@ -0,0 +1,146 @@ +//! SP 800-108 CTR-mode KDF using HMAC-SHA-256, as required by MS-SMB2 §3.1.4.2. +//! +//! Fixed input fed to the PRF (HMAC-SHA-256) is: +//! +//! ```text +//! i (u32be=1) || Label || 0x00 || Context || L (u32be=128) +//! ``` +//! +//! Convention in this crate: +//! * Callers pass `label` and `context` *already including* a trailing `\x00`. +//! * The KDF then **also** emits a single `0x00` separator between `label` +//! and `context`, so the wire-level input has two consecutive NULs at that +//! boundary. This matches what real Windows clients require — a single NUL +//! produces a different signing key and Windows rejects with +//! `STATUS_ACCESS_DENIED` / event 31013 "signing validation failed". + +use hmac::{Hmac, Mac}; +use sha2::Sha256; + +type HmacSha256 = Hmac; + +/// SP 800-108 CTR-mode KDF using HMAC-SHA-256. +/// +/// * `key` — the input key (session key, typically 16 bytes). +/// * `label` — the label string with trailing NUL (e.g. `b"SMB2AESCMAC\x00"`). +/// * `context` — the context string with trailing NUL (e.g. `b"SmbSign\x00"`). +/// +/// Returns the first 16 bytes of `HMAC-SHA-256(key, fixed_input)` where +/// `fixed_input = [0,0,0,1] || label || 0x00 || context || [0,0,0,0x80]`. +/// The single separator `0x00` between `label` and `context` is required for +/// Windows interop; do not remove. +pub fn smb2_kdf(key: &[u8], label: &[u8], context: &[u8]) -> [u8; 16] { + let mut mac = + ::new_from_slice(key).expect("HMAC-SHA-256 accepts keys of any length"); + + // i = 1 (big-endian u32) + mac.update(&[0x00, 0x00, 0x00, 0x01]); + // Label (including trailing NUL provided by caller) + mac.update(label); + // SP 800-108 separator byte between Label and Context (in addition to any + // trailing NUL the caller already included in `label`). + mac.update(&[0x00]); + // Context (including trailing NUL provided by caller, or for 3.1.1 the + // 64-byte preauth hash) + mac.update(context); + // L = 128 bits (big-endian u32) + mac.update(&[0x00, 0x00, 0x00, 0x80]); + + let full = mac.finalize().into_bytes(); + let mut out = [0u8; 16]; + out.copy_from_slice(&full[..16]); + out +} + +// --- Convenience helpers --------------------------------------------------- + +/// Signing key for SMB 3.0 / 3.0.2. +/// +/// Label = `"SMB2AESCMAC\x00"`, Context = `"SmbSign\x00"` (MS-SMB2 §3.1.4.2). +pub fn signing_key_30(session_key: &[u8]) -> [u8; 16] { + smb2_kdf(session_key, b"SMB2AESCMAC\x00", b"SmbSign\x00") +} + +/// Signing key for SMB 3.1.1. +/// +/// Label = `"SMBSigningKey\x00"`, Context = pre-auth integrity hash +/// (the SHA-512 snapshot taken at SESSION_SETUP completion). +pub fn signing_key_311(session_key: &[u8], preauth_hash: &[u8; 64]) -> [u8; 16] { + smb2_kdf(session_key, b"SMBSigningKey\x00", preauth_hash) +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Determinism / shape sanity: the function always produces 16 bytes and + /// is reproducible for the same inputs. + #[test] + fn smb2_kdf_is_deterministic() { + let key = [0x11u8; 16]; + let a = smb2_kdf(&key, b"SMB2AESCMAC\x00", b"SmbSign\x00"); + let b = smb2_kdf(&key, b"SMB2AESCMAC\x00", b"SmbSign\x00"); + assert_eq!(a, b); + assert_eq!(a.len(), 16); + } + + /// Different label or context → different output. + #[test] + fn smb2_kdf_label_and_context_matter() { + let key = [0x42u8; 16]; + let signing = smb2_kdf(&key, b"SMB2AESCMAC\x00", b"SmbSign\x00"); + let app = smb2_kdf(&key, b"SMB2APP\x00", b"SmbRpc\x00"); + assert_ne!(signing, app); + + let other_ctx = smb2_kdf(&key, b"SMB2AESCMAC\x00", b"OtherCtx\x00"); + assert_ne!(signing, other_ctx); + } + + /// Known-answer test computed directly from the documented fixed-input + /// construction. This pins the exact byte layout we feed to HMAC. + /// + /// Reference computation (Python): + /// ```text + /// import hmac, hashlib + /// key = bytes(16) # all zeros + /// label = b"SMB2AESCMAC\x00" + /// context = b"SmbSign\x00" + /// data = b"\x00\x00\x00\x01" + label + b"\x00" + context + b"\x00\x00\x00\x80" + /// hmac.new(key, data, hashlib.sha256).hexdigest()[:32] + /// # = "9951088b83220f39d99420419d16d393" + /// ``` + #[test] + fn smb2_kdf_known_answer_zero_key_signing_30() { + let key = [0u8; 16]; + let out = signing_key_30(&key); + let expected = hex::decode("9951088b83220f39d99420419d16d393").unwrap(); + assert_eq!(out.as_slice(), expected.as_slice()); + } + + /// 3.1.1 derivation differs from 3.0 (different label, 64-byte context). + #[test] + fn smb2_kdf_311_differs_from_30() { + let key = [0u8; 16]; + let preauth = [0u8; 64]; + let k30 = signing_key_30(&key); + let k311 = signing_key_311(&key, &preauth); + assert_ne!(k30, k311); + } + + /// Known-answer test for 3.1.1 with zero key and zero pre-auth hash. + /// + /// Reference computation (Python): + /// ```text + /// data = b"\x00\x00\x00\x01" + b"SMBSigningKey\x00" + b"\x00" + bytes(64) + b"\x00\x00\x00\x80" + /// hmac.new(bytes(16), data, hashlib.sha256).hexdigest()[:32] + /// # = "a06a153e09bd0f34706a5c671acaa37d" + /// ``` + #[test] + fn smb2_kdf_known_answer_zero_key_signing_311() { + let key = [0u8; 16]; + let preauth = [0u8; 64]; + let out = signing_key_311(&key, &preauth); + let expected = hex::decode("a06a153e09bd0f34706a5c671acaa37d").unwrap(); + assert_eq!(out.as_slice(), expected.as_slice()); + } +} diff --git a/vendor/smb-server/src/proto/crypto/preauth.rs b/vendor/smb-server/src/proto/crypto/preauth.rs new file mode 100644 index 0000000..7ed4452 --- /dev/null +++ b/vendor/smb-server/src/proto/crypto/preauth.rs @@ -0,0 +1,115 @@ +//! SMB 3.1.1 pre-auth integrity (MS-SMB2 §3.1.4.4.1, §3.3.5.4). +//! +//! A running SHA-512 hash, initialized to all zeros, that absorbs SMB 3.1.1 +//! preauth messages (transport prefix excluded). Connection state uses this for +//! NEGOTIATE; each SESSION_SETUP exchange forks its own instance. Per spec: +//! +//! ```text +//! PreauthIntegrityHashValue = +//! SHA-512(PreauthIntegrityHashValue || RequestOrResponse) +//! ``` + +use sha2::{Digest, Sha512}; + +/// Running SMB 3.1.1 preauth integrity hash. +#[derive(Debug, Clone)] +pub struct PreauthIntegrity { + hash: [u8; 64], +} + +impl Default for PreauthIntegrity { + fn default() -> Self { + Self::new() + } +} + +impl PreauthIntegrity { + /// Create a fresh state, hash initialized to all zeros. + pub fn new() -> Self { + Self { hash: [0u8; 64] } + } + + /// Absorb a frame's bytes (excluding the 4-byte Direct-TCP transport + /// prefix). Updates `hash` in place. + pub fn update(&mut self, frame: &[u8]) { + let mut hasher = Sha512::new(); + hasher.update(self.hash); + hasher.update(frame); + let out = hasher.finalize(); + self.hash.copy_from_slice(&out); + } + + /// Take a copy of the current hash. Used as the KDF context for session + /// keys at SESSION_SETUP completion. + pub fn snapshot(&self) -> [u8; 64] { + self.hash + } +} + +#[cfg(test)] +mod tests { + use super::*; + use sha2::{Digest, Sha512}; + + #[test] + fn new_starts_at_zero() { + let p = PreauthIntegrity::new(); + assert_eq!(p.snapshot(), [0u8; 64]); + } + + #[test] + fn default_starts_at_zero() { + let p = PreauthIntegrity::default(); + assert_eq!(p.snapshot(), [0u8; 64]); + } + + /// Two-step chain matches the literal spec formula. + #[test] + fn chain_two_buffers_matches_precomputed() { + let mut p = PreauthIntegrity::new(); + + let buf1 = b"NEGOTIATE_REQUEST_FIXTURE"; + let buf2 = b"NEGOTIATE_RESPONSE_FIXTURE"; + p.update(buf1); + p.update(buf2); + + // Precomputed using Python: + // h = bytes(64) + // h = sha512(h + buf1).digest() + // h = sha512(h + buf2).digest() + let expected = hex::decode( + "62deb17d9d07d155b7c634dbfec3ac10c32b80981d925333499a6fbd168d0ee3\ + 4d29b093a185529fd927ade8d851c8e8b0d9b55608c7674e4d3e8d438343c95c", + ) + .unwrap(); + assert_eq!(p.snapshot().as_slice(), expected.as_slice()); + } + + /// Chained call equivalence: explicit SHA-512(prev || frame) on the side + /// must match what `update` produces internally. + #[test] + fn update_equals_manual_sha512() { + let buf = b"SOME_FRAME_BYTES_HERE_0123456789"; + + let mut p = PreauthIntegrity::new(); + p.update(buf); + + let mut hasher = Sha512::new(); + hasher.update([0u8; 64]); + hasher.update(buf); + let manual = hasher.finalize(); + + assert_eq!(p.snapshot().as_slice(), manual.as_slice()); + } + + /// Snapshot must not be aliased — modifying state after snapshot must not + /// affect the snapshot already taken. + #[test] + fn snapshot_is_a_copy() { + let mut p = PreauthIntegrity::new(); + p.update(b"first"); + let snap = p.snapshot(); + p.update(b"second"); + assert_ne!(p.snapshot(), snap); + } +} diff --git a/vendor/smb-server/src/proto/crypto/sign.rs b/vendor/smb-server/src/proto/crypto/sign.rs new file mode 100644 index 0000000..b7010d0 --- /dev/null +++ b/vendor/smb-server/src/proto/crypto/sign.rs @@ -0,0 +1,259 @@ +//! SMB2/3 message signing per MS-SMB2 §3.1.4.1. +//! +//! Two algorithms are supported: +//! 1. **HMAC-SHA-256** for SMB 2.0.2 / 2.1 / 3.0 negotiating without 3.x +//! signing. +//! 2. **AES-CMAC** for SMB 3.0+. +//! +//! Both produce a 16-byte signature that lives at bytes 48..64 of the SMB2 +//! header (the `Signature` field, MS-SMB2 §2.2.1.2). +//! +//! Algorithm: +//! 1. Zero out bytes 48..64 of the message. +//! 2. Compute MAC over the **entire** message (header + body). +//! 3. Place the first 16 bytes of MAC at bytes 48..64. + +use aes::Aes128; +use cmac::Cmac; +use hmac::{Hmac, Mac}; +use sha2::Sha256; + +use crate::proto::error::{ProtoError, ProtoResult}; + +type HmacSha256 = Hmac; +type CmacAes128 = Cmac; + +/// SMB2 header is 64 bytes; the 16-byte signature field starts at offset 48. +const SIG_OFF: usize = 48; +const SIG_LEN: usize = 16; +const SMB2_HEADER_LEN: usize = 64; + +/// Which signing algorithm to use for a given session/dialect. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SigningAlgo { + /// HMAC-SHA-256, used by SMB 2.x. + HmacSha256, + /// AES-CMAC over AES-128, used by SMB 3.0+. + AesCmac, +} + +/// Compute the 16-byte MAC over `msg` as if the SMB2 signature field were +/// zeroed, without copying the whole message. +fn compute_mac_zeroed_signature(msg: &[u8], key: &[u8; 16], algo: SigningAlgo) -> [u8; SIG_LEN] { + let mut out = [0u8; SIG_LEN]; + let zero_signature = [0u8; SIG_LEN]; + let prefix = &msg[..SIG_OFF]; + let suffix = &msg[SIG_OFF + SIG_LEN..]; + + match algo { + SigningAlgo::HmacSha256 => { + let mut mac = ::new_from_slice(key) + .expect("HMAC-SHA-256 accepts keys of any length"); + mac.update(prefix); + mac.update(&zero_signature); + mac.update(suffix); + let full = mac.finalize().into_bytes(); + out.copy_from_slice(&full[..SIG_LEN]); + } + SigningAlgo::AesCmac => { + let mut mac = ::new_from_slice(key) + .expect("AES-128-CMAC requires a 16-byte key, which we have"); + mac.update(prefix); + mac.update(&zero_signature); + mac.update(suffix); + let full = mac.finalize().into_bytes(); + out.copy_from_slice(&full[..SIG_LEN]); + } + } + out +} + +/// Compute and embed a signature in `msg`. Mutates `msg` in place. +/// +/// The caller is responsible for setting the SMB2 SIGNED flag (`0x00000008`) +/// on the header *before* calling — it is part of the bytes that get MAC'd. +/// +/// Errors if `msg` is too short to contain an SMB2 header (< 64 bytes). +pub fn sign(msg: &mut [u8], key: &[u8; 16], algo: SigningAlgo) -> ProtoResult<()> { + if msg.len() < SMB2_HEADER_LEN { + return Err(ProtoError::Crypto("message too short to sign")); + } + + // Compute MAC over the whole message with the signature field treated as + // zero, then place the MAC into the signature field. + let mac = compute_mac_zeroed_signature(msg, key, algo); + msg[SIG_OFF..SIG_OFF + SIG_LEN].copy_from_slice(&mac); + + Ok(()) +} + +/// Verify the signature in `msg`. Does **not** modify `msg`. +/// +/// Uses constant-time comparison. Returns `Ok(())` if the embedded signature +/// matches the freshly computed MAC. +pub fn verify(msg: &[u8], key: &[u8; 16], algo: SigningAlgo) -> ProtoResult<()> { + if msg.len() < SMB2_HEADER_LEN { + return Err(ProtoError::Crypto("message too short to verify")); + } + + // Capture the embedded signature. + let mut embedded = [0u8; SIG_LEN]; + embedded.copy_from_slice(&msg[SIG_OFF..SIG_OFF + SIG_LEN]); + + let computed = compute_mac_zeroed_signature(msg, key, algo); + + if constant_time_eq(&embedded, &computed) { + Ok(()) + } else { + Err(ProtoError::Crypto("signature mismatch")) + } +} + +/// Constant-time comparison of two 16-byte arrays. +#[inline] +fn constant_time_eq(a: &[u8; SIG_LEN], b: &[u8; SIG_LEN]) -> bool { + let mut diff: u8 = 0; + for i in 0..SIG_LEN { + diff |= a[i] ^ b[i]; + } + diff == 0 +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Build a 100-byte message: a plausible 64-byte SMB2 header followed by + /// 36 bytes of body. The signature region (bytes 48..64) is left zero; + /// `sign` will overwrite it. + fn fixture_message() -> Vec { + let mut msg = vec![0u8; 100]; + // Magic: 0xFE 'S' 'M' 'B' + msg[0..4].copy_from_slice(&[0xFE, b'S', b'M', b'B']); + // StructureSize = 64 + msg[4..6].copy_from_slice(&64u16.to_le_bytes()); + // Pretend ChannelSequence = 0 + msg[6..8].copy_from_slice(&0u16.to_le_bytes()); + // Command = NEGOTIATE (0) + msg[12..14].copy_from_slice(&0u16.to_le_bytes()); + // Flags: SIGNED (0x00000008) + msg[16..20].copy_from_slice(&0x0000_0008u32.to_le_bytes()); + // Body filler + for (i, b) in msg[64..].iter_mut().enumerate() { + *b = (i as u8).wrapping_mul(7); + } + msg + } + + #[test] + fn sign_and_verify_hmac_sha256() { + let key = [0xAAu8; 16]; + let mut msg = fixture_message(); + sign(&mut msg, &key, SigningAlgo::HmacSha256).expect("sign ok"); + + // Signature should now be non-zero (overwhelmingly likely). + assert_ne!(&msg[SIG_OFF..SIG_OFF + SIG_LEN], &[0u8; 16]); + + verify(&msg, &key, SigningAlgo::HmacSha256).expect("verify ok"); + } + + #[test] + fn sign_and_verify_aes_cmac() { + let key = [0x55u8; 16]; + let mut msg = fixture_message(); + sign(&mut msg, &key, SigningAlgo::AesCmac).expect("sign ok"); + assert_ne!(&msg[SIG_OFF..SIG_OFF + SIG_LEN], &[0u8; 16]); + verify(&msg, &key, SigningAlgo::AesCmac).expect("verify ok"); + } + + #[test] + fn tamper_outside_sig_fails_verify_hmac() { + let key = [0xAAu8; 16]; + let mut msg = fixture_message(); + sign(&mut msg, &key, SigningAlgo::HmacSha256).expect("sign ok"); + + // Flip one body byte. + msg[80] ^= 0x01; + let res = verify(&msg, &key, SigningAlgo::HmacSha256); + assert!(matches!(res, Err(ProtoError::Crypto(_)))); + } + + #[test] + fn tamper_outside_sig_fails_verify_cmac() { + let key = [0x55u8; 16]; + let mut msg = fixture_message(); + sign(&mut msg, &key, SigningAlgo::AesCmac).expect("sign ok"); + + // Flip a header byte (not in the sig region). + msg[10] ^= 0xFF; + let res = verify(&msg, &key, SigningAlgo::AesCmac); + assert!(matches!(res, Err(ProtoError::Crypto(_)))); + } + + #[test] + fn tamper_signature_fails_verify() { + let key = [0xAAu8; 16]; + let mut msg = fixture_message(); + sign(&mut msg, &key, SigningAlgo::HmacSha256).expect("sign ok"); + msg[SIG_OFF] ^= 0x01; + let res = verify(&msg, &key, SigningAlgo::HmacSha256); + assert!(matches!(res, Err(ProtoError::Crypto(_)))); + } + + #[test] + fn wrong_key_fails_verify() { + let key = [0xAAu8; 16]; + let bad_key = [0xBBu8; 16]; + let mut msg = fixture_message(); + sign(&mut msg, &key, SigningAlgo::HmacSha256).expect("sign ok"); + let res = verify(&msg, &bad_key, SigningAlgo::HmacSha256); + assert!(matches!(res, Err(ProtoError::Crypto(_)))); + } + + #[test] + fn too_short_message_errors() { + let mut tiny = [0u8; 10]; + let key = [0u8; 16]; + let res = sign(&mut tiny, &key, SigningAlgo::HmacSha256); + assert!(matches!(res, Err(ProtoError::Crypto(_)))); + let res = verify(&tiny, &key, SigningAlgo::HmacSha256); + assert!(matches!(res, Err(ProtoError::Crypto(_)))); + } + + #[test] + fn verify_does_not_mutate_message_hmac_sha256() { + let key = [0xAAu8; 16]; + let mut msg = fixture_message(); + sign(&mut msg, &key, SigningAlgo::HmacSha256).expect("sign ok"); + let snapshot = msg.clone(); + verify(&msg, &key, SigningAlgo::HmacSha256).expect("verify ok"); + assert_eq!(msg, snapshot); + } + + #[test] + fn verify_does_not_mutate_message_aes_cmac() { + let key = [0x55u8; 16]; + let mut msg = fixture_message(); + sign(&mut msg, &key, SigningAlgo::AesCmac).expect("sign ok"); + let snapshot = msg.clone(); + verify(&msg, &key, SigningAlgo::AesCmac).expect("verify ok"); + assert_eq!(msg, snapshot); + } + + #[test] + fn sign_ignores_existing_signature_bytes() { + let key = [0xAAu8; 16]; + let mut clean = fixture_message(); + let mut dirty = fixture_message(); + dirty[SIG_OFF..SIG_OFF + SIG_LEN].fill(0xCC); + + sign(&mut clean, &key, SigningAlgo::HmacSha256).expect("sign clean"); + sign(&mut dirty, &key, SigningAlgo::HmacSha256).expect("sign dirty"); + + assert_eq!( + &clean[SIG_OFF..SIG_OFF + SIG_LEN], + &dirty[SIG_OFF..SIG_OFF + SIG_LEN] + ); + verify(&dirty, &key, SigningAlgo::HmacSha256).expect("verify dirty"); + } +} diff --git a/vendor/smb-server/src/proto/error.rs b/vendor/smb-server/src/proto/error.rs new file mode 100644 index 0000000..f70fd17 --- /dev/null +++ b/vendor/smb-server/src/proto/error.rs @@ -0,0 +1,26 @@ +//! Crate-wide error type for the internal SMB protocol layer. + +use thiserror::Error; + +pub type ProtoResult = Result; + +#[derive(Debug, Error)] +pub enum ProtoError { + #[error("malformed wire frame: {0}")] + Malformed(&'static str), + + #[error("unsupported dialect: 0x{0:04x}")] + UnsupportedDialect(u16), + + #[error("auth failure: {0}")] + Auth(&'static str), + + #[error("crypto failure: {0}")] + Crypto(&'static str), + + #[error("io error: {0}")] + Io(#[from] std::io::Error), + + #[error("binrw error: {0}")] + Binrw(#[from] binrw::Error), +} diff --git a/vendor/smb-server/src/proto/framing.rs b/vendor/smb-server/src/proto/framing.rs new file mode 100644 index 0000000..4643703 --- /dev/null +++ b/vendor/smb-server/src/proto/framing.rs @@ -0,0 +1,155 @@ +//! Direct-TCP / NetBIOS-over-TCP framing for SMB2/3. +//! +//! MS-SMB2 §2.1 requires a 4-byte big-endian length prefix on every TCP frame: +//! +//! ```text +//! +-------+--------------------------------+ +//! | 0x00 | 24-bit big-endian payload len | +//! +-------+--------------------------------+ +//! | SMB2 packet ... | +//! +----------------------------------------+ +//! ``` +//! +//! The top byte is reserved (must be zero in Direct-TCP transport — it is the +//! NetBIOS session-message-type byte from RFC 1002 §4.3.1). The remaining 24 +//! bits encode the payload length, so the absolute maximum on the wire is +//! `2^24 - 1 = 16_777_215` bytes (16 MiB - 1). We enforce that as the cap. +//! +//! This module is async-runtime-agnostic. Only sync helpers operating on byte +//! slices and `Vec` live here; the server crate wraps these with tokio +//! I/O. + +use crate::proto::error::{ProtoError, ProtoResult}; + +/// Length of the Direct-TCP frame header (4 bytes). +pub const FRAME_HEADER_LEN: usize = 4; + +/// Maximum payload size representable by the 3-byte length field. +/// +/// MS-SMB2 §2.1 — `2^24 - 1 = 16_777_215` bytes. +pub const MAX_FRAME_PAYLOAD: u32 = 0x00FF_FFFF; + +/// Encode a single Direct-TCP frame: 4-byte header + payload. +/// +/// Panics in debug if the payload exceeds [`MAX_FRAME_PAYLOAD`]; release builds +/// silently truncate the high byte. +pub fn encode_frame(payload: &[u8], out: &mut Vec) { + debug_assert!( + payload.len() as u64 <= MAX_FRAME_PAYLOAD as u64, + "frame payload exceeds 16 MiB - 1" + ); + let len = payload.len() as u32; + // Top byte is the NetBIOS session-message type (0x00 for Direct-TCP). + // Lower 3 bytes are payload length, big-endian. + out.reserve(FRAME_HEADER_LEN + payload.len()); + out.push(0x00); + out.push(((len >> 16) & 0xFF) as u8); + out.push(((len >> 8) & 0xFF) as u8); + out.push((len & 0xFF) as u8); + out.extend_from_slice(payload); +} + +/// Decode the 4-byte frame header, returning the payload length. +/// +/// Returns [`ProtoError::Malformed`] if the top byte is non-zero (NetBIOS +/// session-message type other than `SESSION MESSAGE` is not supported in +/// Direct-TCP transport). +pub fn decode_frame_header(bytes: &[u8; FRAME_HEADER_LEN]) -> ProtoResult { + if bytes[0] != 0x00 { + return Err(ProtoError::Malformed( + "NetBIOS session-message type byte must be 0x00 for Direct-TCP", + )); + } + let len = (u32::from(bytes[1]) << 16) | (u32::from(bytes[2]) << 8) | u32::from(bytes[3]); + Ok(len) +} + +/// Convenience: read one full frame from a contiguous byte slice. +/// +/// Returns the payload slice and the remaining bytes after the frame. +#[cfg(test)] +pub fn decode_frame(buf: &[u8]) -> ProtoResult<(&[u8], &[u8])> { + if buf.len() < FRAME_HEADER_LEN { + return Err(ProtoError::Malformed("short frame header")); + } + let mut hdr = [0u8; FRAME_HEADER_LEN]; + hdr.copy_from_slice(&buf[..FRAME_HEADER_LEN]); + let len = decode_frame_header(&hdr)? as usize; + let total = FRAME_HEADER_LEN + len; + if buf.len() < total { + return Err(ProtoError::Malformed("truncated frame body")); + } + Ok((&buf[FRAME_HEADER_LEN..total], &buf[total..])) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn encodes_empty_frame() { + let mut out = Vec::new(); + encode_frame(&[], &mut out); + assert_eq!(out, [0x00, 0x00, 0x00, 0x00]); + } + + #[test] + fn encodes_simple_frame() { + let mut out = Vec::new(); + encode_frame(&[0xAA, 0xBB, 0xCC], &mut out); + assert_eq!(out, [0x00, 0x00, 0x00, 0x03, 0xAA, 0xBB, 0xCC]); + } + + #[test] + fn round_trips_random_payload() { + let payload: Vec = (0u8..=200).collect(); + let mut wire = Vec::new(); + encode_frame(&payload, &mut wire); + + let (decoded, rest) = decode_frame(&wire).unwrap(); + assert_eq!(decoded, payload.as_slice()); + assert!(rest.is_empty()); + } + + #[test] + fn decodes_header_three_byte_length() { + // 0x00_12_34_56 -> length 0x123456 + let len = decode_frame_header(&[0x00, 0x12, 0x34, 0x56]).unwrap(); + assert_eq!(len, 0x0012_3456); + } + + #[test] + fn decodes_header_max_length() { + let len = decode_frame_header(&[0x00, 0xFF, 0xFF, 0xFF]).unwrap(); + assert_eq!(len, MAX_FRAME_PAYLOAD); + } + + #[test] + fn rejects_nonzero_top_byte() { + let err = decode_frame_header(&[0x81, 0x00, 0x00, 0x00]).unwrap_err(); + assert!(matches!(err, ProtoError::Malformed(_))); + } + + #[test] + fn decode_frame_handles_trailing_data() { + let mut wire = Vec::new(); + encode_frame(&[1, 2, 3], &mut wire); + wire.extend_from_slice(&[9, 9, 9]); // simulate a partial second frame + + let (payload, rest) = decode_frame(&wire).unwrap(); + assert_eq!(payload, &[1, 2, 3]); + assert_eq!(rest, &[9, 9, 9]); + } + + #[test] + fn decode_frame_short_header() { + let err = decode_frame(&[0x00, 0x00]).unwrap_err(); + assert!(matches!(err, ProtoError::Malformed(_))); + } + + #[test] + fn decode_frame_truncated_body() { + let err = decode_frame(&[0x00, 0x00, 0x00, 0x05, 0xAA]).unwrap_err(); + assert!(matches!(err, ProtoError::Malformed(_))); + } +} diff --git a/vendor/smb-server/src/proto/header.rs b/vendor/smb-server/src/proto/header.rs new file mode 100644 index 0000000..b6c2656 --- /dev/null +++ b/vendor/smb-server/src/proto/header.rs @@ -0,0 +1,471 @@ +//! SMB2 fixed 64-byte packet header (sync + async forms). +//! +//! References: +//! * MS-SMB2 §2.2.1 — Common header preamble. +//! * MS-SMB2 §2.2.1.1 — Async form (`Flags & SMB2_FLAGS_ASYNC_COMMAND`). +//! * MS-SMB2 §2.2.1.2 — Sync form. +//! +//! ## Encoding choice +//! +//! The two forms differ only in the 12-byte block at offset 0x18..0x24: +//! +//! * **Sync**: `ChannelSequence` (u16) + `Reserved` (u16) + `Reserved2` (u32) + `TreeId` (u32) +//! wait — actually the sync form is: `Reserved` (u32) + `TreeId` (u32) (bytes 0x20..0x28). +//! * **Async**: `AsyncId` (u64) at bytes 0x20..0x28. +//! +//! In *both* forms, bytes 0x10..0x14 are `Status` (or `ChannelSequence + Reserved` on +//! 3.x channel-sequence-aware requests; we treat them as a single u32 named +//! `channel_sequence_status`). Bytes 0x14..0x18 are `Command + CreditReqResp`, +//! 0x18..0x1C are `Flags`, 0x1C..0x20 are `NextCommand`, 0x20..0x28 are `MessageId`. +//! The discriminated 8-byte block lives at 0x28..0x30, followed by the 16-byte +//! `Signature` at 0x30..0x40 — totalling 64 bytes. +//! +//! We model this as a single `Smb2Header` struct with a `tail: HeaderTail` enum +//! that is `Sync { reserved: u32, tree_id: u32 }` or `Async { async_id: u64 }`, +//! discriminated by `Flags & SMB2_FLAGS_ASYNC_COMMAND`. This is the cleanest +//! mapping to the spec — every other field is shared. + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use crate::proto::error::{ProtoError, ProtoResult}; + +/// SMB2 protocol identifier ("\xfeSMB"). +pub const SMB2_MAGIC: [u8; 4] = [0xFE, b'S', b'M', b'B']; + +/// Fixed `StructureSize` of the SMB2 header (MS-SMB2 §2.2.1.1/§2.2.1.2). +pub const SMB2_HEADER_STRUCTURE_SIZE: u16 = 64; + +/// Total wire size of the SMB2 header. +pub const SMB2_HEADER_LEN: usize = 64; + +// --------------------------------------------------------------------------- +// Flags (MS-SMB2 §2.2.1.2 Flags field) +// --------------------------------------------------------------------------- + +/// `SMB2_FLAGS_SERVER_TO_REDIR` — set on responses. +pub const SMB2_FLAGS_SERVER_TO_REDIR: u32 = 0x0000_0001; +/// `SMB2_FLAGS_ASYNC_COMMAND` — selects the async header form. +pub const SMB2_FLAGS_ASYNC_COMMAND: u32 = 0x0000_0002; +/// `SMB2_FLAGS_RELATED_OPERATIONS` — compound chain marker. +pub const SMB2_FLAGS_RELATED_OPERATIONS: u32 = 0x0000_0004; +/// `SMB2_FLAGS_SIGNED` — message is signed. +pub const SMB2_FLAGS_SIGNED: u32 = 0x0000_0008; +/// `SMB2_FLAGS_PRIORITY_MASK` — bits 4..6 hold priority (3.1.1+). +pub const SMB2_FLAGS_PRIORITY_MASK: u32 = 0x0000_0070; +/// `SMB2_FLAGS_DFS_OPERATIONS`. +pub const SMB2_FLAGS_DFS_OPERATIONS: u32 = 0x1000_0000; +/// `SMB2_FLAGS_REPLAY_OPERATION`. +pub const SMB2_FLAGS_REPLAY_OPERATION: u32 = 0x2000_0000; + +// --------------------------------------------------------------------------- +// Command opcodes (MS-SMB2 §2.2.1.2 Command field) +// --------------------------------------------------------------------------- + +/// SMB2 command opcodes (the 19 commands in v1). +#[binrw] +#[brw(little, repr = u16)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Command { + Negotiate = 0x0000, + SessionSetup = 0x0001, + Logoff = 0x0002, + TreeConnect = 0x0003, + TreeDisconnect = 0x0004, + Create = 0x0005, + Close = 0x0006, + Flush = 0x0007, + Read = 0x0008, + Write = 0x0009, + Lock = 0x000A, + Ioctl = 0x000B, + Cancel = 0x000C, + Echo = 0x000D, + QueryDirectory = 0x000E, + ChangeNotify = 0x000F, + QueryInfo = 0x0010, + SetInfo = 0x0011, + OplockBreak = 0x0012, +} + +impl Command { + /// Raw opcode for diagnostics. + pub const fn as_u16(self) -> u16 { + self as u16 + } +} + +// --------------------------------------------------------------------------- +// Header struct +// --------------------------------------------------------------------------- + +/// The 12-byte tail of the header that differs between sync and async forms. +/// +/// The discriminant is `flags & SMB2_FLAGS_ASYNC_COMMAND`. We can't easily use +/// binrw's args+if without making the parent struct generic over the runtime +/// flag value, so the parent reads/writes this manually via `parse` / `write` +/// helpers and we expose a regular Rust enum here. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HeaderTail { + /// Sync form: `Reserved (u32)` + `TreeId (u32)` at bytes 0x24..0x2C. + /// (See note in module docs about offsets.) + Sync { reserved: u32, tree_id: u32 }, + /// Async form: `AsyncId (u64)` at bytes 0x24..0x2C. + Async { async_id: u64 }, +} + +impl HeaderTail { + /// Default sync tail with `TreeId = 0`. + pub const fn sync(tree_id: u32) -> Self { + HeaderTail::Sync { + reserved: 0, + tree_id, + } + } + + /// Default async tail. + pub const fn async_(async_id: u64) -> Self { + HeaderTail::Async { async_id } + } +} + +/// SMB2 fixed 64-byte header. +/// +/// On the wire the layout is (offsets in decimal — total 64 bytes): +/// +/// | Offset | Size | Field | +/// |-------:|-----:|-------| +/// | 0 | 4 | ProtocolId (`0xFE 'S' 'M' 'B'`) | +/// | 4 | 2 | StructureSize (always 64) | +/// | 6 | 2 | CreditCharge | +/// | 8 | 4 | (Channel)Status | +/// | 12 | 2 | Command | +/// | 14 | 2 | CreditRequest/CreditResponse | +/// | 16 | 4 | Flags | +/// | 20 | 4 | NextCommand | +/// | 24 | 8 | MessageId | +/// | 32 | 8 | Reserved/TreeId (sync) **or** AsyncId (async) | +/// | 40 | 8 | SessionId | +/// | 48 | 16 | Signature | +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Smb2Header { + pub credit_charge: u16, + /// Bytes 8..12: in client→server requests on 3.x this can split into + /// `ChannelSequence(u16)` + `Reserved(u16)`; in server→client responses + /// it carries `Status` (NTSTATUS). We expose the raw u32 — handlers/ + /// signing code interpret it. + pub channel_sequence_status: u32, + pub command: Command, + /// On requests this is `CreditRequest`; on responses, `CreditResponse`. + pub credit_request_response: u16, + pub flags: u32, + /// Offset to the next header in a compound chain, or 0 for the last. + pub next_command: u32, + pub message_id: u64, + /// Sync: `(reserved, tree_id)`. Async: `async_id`. Discriminated by + /// `flags & SMB2_FLAGS_ASYNC_COMMAND`. + pub tail: HeaderTail, + pub session_id: u64, + /// 16-byte signature; zeroed on unsigned messages. + pub signature: [u8; 16], +} + +impl Default for Smb2Header { + fn default() -> Self { + Self { + credit_charge: 0, + channel_sequence_status: 0, + command: Command::Negotiate, + credit_request_response: 0, + flags: 0, + next_command: 0, + message_id: 0, + tail: HeaderTail::sync(0), + session_id: 0, + signature: [0u8; 16], + } + } +} + +impl Smb2Header { + /// Convenience: is this an async-form header? + pub fn is_async(&self) -> bool { + self.flags & SMB2_FLAGS_ASYNC_COMMAND != 0 + } + + /// Convenience: is this a server→client response? + pub fn is_response(&self) -> bool { + self.flags & SMB2_FLAGS_SERVER_TO_REDIR != 0 + } + + /// Convenience: tree_id from a sync header (panics if async). + pub fn tree_id(&self) -> Option { + match self.tail { + HeaderTail::Sync { tree_id, .. } => Some(tree_id), + HeaderTail::Async { .. } => None, + } + } + + /// Convenience: async_id from an async header. + pub fn async_id(&self) -> Option { + match self.tail { + HeaderTail::Async { async_id } => Some(async_id), + HeaderTail::Sync { .. } => None, + } + } + + /// Parse from a byte slice. Returns the header and the remaining bytes. + pub fn parse(buf: &[u8]) -> ProtoResult<(Self, &[u8])> { + if buf.len() < SMB2_HEADER_LEN { + return Err(ProtoError::Malformed("short SMB2 header")); + } + let mut cursor = Cursor::new(&buf[..SMB2_HEADER_LEN]); + let raw = RawHeader::read(&mut cursor)?; + if raw.protocol_id != SMB2_MAGIC { + return Err(ProtoError::Malformed("bad SMB2 magic")); + } + if raw.structure_size != SMB2_HEADER_STRUCTURE_SIZE { + return Err(ProtoError::Malformed("SMB2 header structure_size != 64")); + } + let command = match Command::read_le(&mut Cursor::new(raw.command_raw.to_le_bytes())) { + Ok(c) => c, + Err(_) => { + return Err(ProtoError::Malformed("unknown SMB2 command opcode")); + } + }; + let tail = if raw.flags & SMB2_FLAGS_ASYNC_COMMAND != 0 { + HeaderTail::Async { + async_id: u64::from_le_bytes(raw.tail_bytes), + } + } else { + let reserved = u32::from_le_bytes([ + raw.tail_bytes[0], + raw.tail_bytes[1], + raw.tail_bytes[2], + raw.tail_bytes[3], + ]); + let tree_id = u32::from_le_bytes([ + raw.tail_bytes[4], + raw.tail_bytes[5], + raw.tail_bytes[6], + raw.tail_bytes[7], + ]); + HeaderTail::Sync { reserved, tree_id } + }; + Ok(( + Smb2Header { + credit_charge: raw.credit_charge, + channel_sequence_status: raw.channel_sequence_status, + command, + credit_request_response: raw.credit_request_response, + flags: raw.flags, + next_command: raw.next_command, + message_id: raw.message_id, + tail, + session_id: raw.session_id, + signature: raw.signature, + }, + &buf[SMB2_HEADER_LEN..], + )) + } + + /// Serialize the 64-byte header into `out`. + pub fn write(&self, out: &mut Vec) -> ProtoResult<()> { + let tail_bytes = match self.tail { + HeaderTail::Sync { reserved, tree_id } => { + let mut b = [0u8; 8]; + b[..4].copy_from_slice(&reserved.to_le_bytes()); + b[4..].copy_from_slice(&tree_id.to_le_bytes()); + b + } + HeaderTail::Async { async_id } => async_id.to_le_bytes(), + }; + let raw = RawHeader { + protocol_id: SMB2_MAGIC, + structure_size: SMB2_HEADER_STRUCTURE_SIZE, + credit_charge: self.credit_charge, + channel_sequence_status: self.channel_sequence_status, + command_raw: self.command.as_u16(), + credit_request_response: self.credit_request_response, + flags: self.flags, + next_command: self.next_command, + message_id: self.message_id, + tail_bytes, + session_id: self.session_id, + signature: self.signature, + }; + let start = out.len(); + let mut cursor = Cursor::new(Vec::with_capacity(SMB2_HEADER_LEN)); + raw.write(&mut cursor)?; + out.extend_from_slice(&cursor.into_inner()); + debug_assert_eq!(out.len() - start, SMB2_HEADER_LEN); + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Internal raw header for binrw plumbing. +// --------------------------------------------------------------------------- + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, Copy)] +struct RawHeader { + protocol_id: [u8; 4], + structure_size: u16, + credit_charge: u16, + channel_sequence_status: u32, + command_raw: u16, + credit_request_response: u16, + flags: u32, + next_command: u32, + message_id: u64, + tail_bytes: [u8; 8], + session_id: u64, + signature: [u8; 16], +} + +#[cfg(test)] +mod tests { + use super::*; + + fn sample_sync() -> Smb2Header { + Smb2Header { + credit_charge: 1, + channel_sequence_status: 0, + command: Command::Negotiate, + credit_request_response: 1, + flags: 0, + next_command: 0, + message_id: 0, + tail: HeaderTail::Sync { + reserved: 0, + tree_id: 0, + }, + session_id: 0, + signature: [0u8; 16], + } + } + + fn sample_async() -> Smb2Header { + Smb2Header { + credit_charge: 4, + channel_sequence_status: 0, + command: Command::Read, + credit_request_response: 1, + flags: SMB2_FLAGS_ASYNC_COMMAND | SMB2_FLAGS_SERVER_TO_REDIR, + next_command: 0, + message_id: 42, + tail: HeaderTail::Async { + async_id: 0xDEAD_BEEF_CAFE_F00D, + }, + session_id: 0x1122_3344_5566_7788, + signature: [0xAA; 16], + } + } + + #[test] + fn sync_round_trips() { + let hdr = sample_sync(); + let mut buf = Vec::new(); + hdr.write(&mut buf).unwrap(); + assert_eq!(buf.len(), SMB2_HEADER_LEN); + // First 4 bytes must be the magic. + assert_eq!(&buf[..4], &SMB2_MAGIC); + // StructureSize at offset 4 == 64 + assert_eq!(u16::from_le_bytes([buf[4], buf[5]]), 64); + + let (decoded, rest) = Smb2Header::parse(&buf).unwrap(); + assert!(rest.is_empty()); + assert_eq!(decoded, hdr); + } + + #[test] + fn async_round_trips() { + let hdr = sample_async(); + let mut buf = Vec::new(); + hdr.write(&mut buf).unwrap(); + assert_eq!(buf.len(), SMB2_HEADER_LEN); + + let (decoded, _rest) = Smb2Header::parse(&buf).unwrap(); + assert_eq!(decoded, hdr); + assert!(decoded.is_async()); + assert!(decoded.is_response()); + assert_eq!(decoded.async_id(), Some(0xDEAD_BEEF_CAFE_F00D)); + assert_eq!(decoded.tree_id(), None); + } + + #[test] + fn rejects_bad_magic() { + let hdr = sample_sync(); + let mut buf = Vec::new(); + hdr.write(&mut buf).unwrap(); + buf[0] = 0xFF; + let err = Smb2Header::parse(&buf).unwrap_err(); + assert!(matches!(err, ProtoError::Malformed(_))); + } + + #[test] + fn rejects_bad_structure_size() { + let hdr = sample_sync(); + let mut buf = Vec::new(); + hdr.write(&mut buf).unwrap(); + buf[4] = 0; // wreck the structure_size LE bytes + buf[5] = 0; + let err = Smb2Header::parse(&buf).unwrap_err(); + assert!(matches!(err, ProtoError::Malformed(_))); + } + + #[test] + fn rejects_short_buffer() { + let err = Smb2Header::parse(&[0u8; 32]).unwrap_err(); + assert!(matches!(err, ProtoError::Malformed(_))); + } + + #[test] + fn handcrafted_sync_negotiate_request() { + // Hand-built Sync NEGOTIATE request header: magic, size=64, no flags, + // command=0, mid=0, tree_id=0, sid=0, no signature. + let mut buf = vec![0u8; 64]; + buf[..4].copy_from_slice(&SMB2_MAGIC); + buf[4..6].copy_from_slice(&64u16.to_le_bytes()); + // command at offset 12 = 0 (NEGOTIATE), already zero + // everything else zero + let (hdr, _) = Smb2Header::parse(&buf).unwrap(); + assert_eq!(hdr.command, Command::Negotiate); + assert!(!hdr.is_async()); + assert_eq!(hdr.tree_id(), Some(0)); + } + + #[test] + fn command_round_trips_via_binrw() { + for cmd in [ + Command::Negotiate, + Command::SessionSetup, + Command::Logoff, + Command::TreeConnect, + Command::TreeDisconnect, + Command::Create, + Command::Close, + Command::Flush, + Command::Read, + Command::Write, + Command::Lock, + Command::Ioctl, + Command::Cancel, + Command::Echo, + Command::QueryDirectory, + Command::ChangeNotify, + Command::QueryInfo, + Command::SetInfo, + Command::OplockBreak, + ] { + let mut hdr = sample_sync(); + hdr.command = cmd; + let mut buf = Vec::new(); + hdr.write(&mut buf).unwrap(); + let (decoded, _) = Smb2Header::parse(&buf).unwrap(); + assert_eq!(decoded.command, cmd); + } + } +} diff --git a/vendor/smb-server/src/proto/messages/cancel.rs b/vendor/smb-server/src/proto/messages/cancel.rs new file mode 100644 index 0000000..ac32d82 --- /dev/null +++ b/vendor/smb-server/src/proto/messages/cancel.rs @@ -0,0 +1,49 @@ +//! CANCEL Request (MS-SMB2 §2.2.30). No response — server cancels in place. + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use crate::proto::error::ProtoResult; + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CancelRequest { + pub structure_size: u16, + pub reserved: u16, +} + +impl Default for CancelRequest { + fn default() -> Self { + Self { + structure_size: 4, + reserved: 0, + } + } +} + +impl CancelRequest { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn round_trips() { + let r = CancelRequest::default(); + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(buf.len(), 4); + assert_eq!(CancelRequest::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/messages/change_notify.rs b/vendor/smb-server/src/proto/messages/change_notify.rs new file mode 100644 index 0000000..2cc0d5f --- /dev/null +++ b/vendor/smb-server/src/proto/messages/change_notify.rs @@ -0,0 +1,93 @@ +//! CHANGE_NOTIFY Request/Response (MS-SMB2 §2.2.35 / §2.2.36). +//! +//! V1 returns `STATUS_NOT_SUPPORTED`, but we still parse/encode the wire +//! form so the dispatcher can recognize it. + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use super::create::FileId; +use crate::proto::error::ProtoResult; + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ChangeNotifyRequest { + pub structure_size: u16, + pub flags: u16, + pub output_buffer_length: u32, + pub file_id: FileId, + pub completion_filter: u32, + pub reserved: u32, +} + +impl ChangeNotifyRequest { + /// Flag: SMB2_WATCH_TREE. + pub const FLAG_WATCH_TREE: u16 = 0x0001; + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ChangeNotifyResponse { + pub structure_size: u16, + pub output_buffer_offset: u16, + pub output_buffer_length: u32, + #[br(count = output_buffer_length as usize)] + pub buffer: Vec, +} + +impl ChangeNotifyResponse { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn request_round_trips() { + let r = ChangeNotifyRequest { + structure_size: 32, + flags: ChangeNotifyRequest::FLAG_WATCH_TREE, + output_buffer_length: 0x1000, + file_id: FileId::new(1, 2), + completion_filter: 0xFF, + reserved: 0, + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(ChangeNotifyRequest::parse(&buf).unwrap(), r); + } + + #[test] + fn response_round_trips() { + let r = ChangeNotifyResponse { + structure_size: 9, + output_buffer_offset: 0x48, + output_buffer_length: 0, + buffer: vec![], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(ChangeNotifyResponse::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/messages/close.rs b/vendor/smb-server/src/proto/messages/close.rs new file mode 100644 index 0000000..43cea4c --- /dev/null +++ b/vendor/smb-server/src/proto/messages/close.rs @@ -0,0 +1,93 @@ +//! CLOSE Request/Response (MS-SMB2 §2.2.15 / §2.2.16). + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use super::create::FileId; +use crate::proto::error::ProtoResult; + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CloseRequest { + pub structure_size: u16, + pub flags: u16, + pub reserved: u32, + pub file_id: FileId, +} + +impl CloseRequest { + /// Flag: SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB. + pub const FLAG_POSTQUERY_ATTRIB: u16 = 0x0001; + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct CloseResponse { + pub structure_size: u16, + pub flags: u16, + pub reserved: u32, + pub creation_time: u64, + pub last_access_time: u64, + pub last_write_time: u64, + pub change_time: u64, + pub allocation_size: u64, + pub end_of_file: u64, + pub file_attributes: u32, +} + +impl CloseResponse { + pub fn new() -> Self { + Self { + structure_size: 60, + ..Default::default() + } + } + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn round_trips() { + let r = CloseRequest { + structure_size: 24, + flags: CloseRequest::FLAG_POSTQUERY_ATTRIB, + reserved: 0, + file_id: FileId::new(0x1, 0x2), + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(CloseRequest::parse(&buf).unwrap(), r); + + let r = CloseResponse { + structure_size: 60, + ..CloseResponse::new() + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(CloseResponse::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/messages/create.rs b/vendor/smb-server/src/proto/messages/create.rs new file mode 100644 index 0000000..52194ef --- /dev/null +++ b/vendor/smb-server/src/proto/messages/create.rs @@ -0,0 +1,437 @@ +//! CREATE Request/Response (MS-SMB2 §2.2.13 / §2.2.14). +//! +//! `create_contexts` is a chained sequence of `SMB2_CREATE_CONTEXT` records +//! (MS-SMB2 §2.2.13.2). Each record has `Next` (offset to the next entry, +//! relative to the start of *this* entry; 0 marks the last), a name + data +//! pair, and 8-byte alignment. + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use crate::proto::error::{ProtoError, ProtoResult}; + +/// SMB2 FileId — opaque 16 bytes (volatile + persistent). +/// +/// MS-SMB2 §2.2.14.1. We expose both halves; the server uses identical values +/// for both since durable handles are out of scope (spec §2 in the v1 design). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub struct FileId { + pub persistent: u64, + pub volatile: u64, +} + +impl FileId { + pub const fn new(persistent: u64, volatile: u64) -> Self { + Self { + persistent, + volatile, + } + } + + /// MS-SMB2: the "any" FileId is `0xFFFF…FFFF`. + pub const fn any() -> Self { + Self { + persistent: u64::MAX, + volatile: u64::MAX, + } + } +} + +/// MS-SMB2 §2.2.13 CREATE Request — fixed prefix. +/// +/// Variable-length tail: the file `name` (UTF-16LE) and `create_contexts` +/// blob, each at absolute offsets from the start of the SMB2 header. We hold +/// them as length-counted byte buffers immediately following the fixed +/// portion. The server crate parses contexts with [`CreateContext::parse_chain`]. +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CreateRequest { + pub structure_size: u16, + pub security_flags: u8, + pub requested_oplock_level: u8, + pub impersonation_level: u32, + pub smb_create_flags: u64, + pub reserved: u64, + pub desired_access: u32, + pub file_attributes: u32, + pub share_access: u32, + pub create_disposition: u32, + pub create_options: u32, + pub name_offset: u16, + pub name_length: u16, + pub create_contexts_offset: u32, + pub create_contexts_length: u32, + /// UTF-16LE filename. + #[br(count = name_length as usize)] + pub name: Vec, + /// Raw create-contexts chain bytes; parse with + /// [`CreateContext::parse_chain`]. + #[br(count = create_contexts_length as usize)] + pub create_contexts: Vec, +} + +impl CreateRequest { + /// Decode the UTF-16LE filename. + pub fn name_str(&self) -> Option { + if !self.name.len().is_multiple_of(2) { + return None; + } + let units: Vec = self + .name + .chunks_exact(2) + .map(|c| u16::from_le_bytes([c[0], c[1]])) + .collect(); + Some(String::from_utf16_lossy(&units)) + } + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +/// MS-SMB2 §2.2.14 CREATE Response. +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CreateResponse { + pub structure_size: u16, + pub oplock_level: u8, + pub flags: u8, + pub create_action: u32, + pub creation_time: u64, + pub last_access_time: u64, + pub last_write_time: u64, + pub change_time: u64, + pub allocation_size: u64, + pub end_of_file: u64, + pub file_attributes: u32, + pub reserved2: u32, + pub file_id: FileId, + pub create_contexts_offset: u32, + pub create_contexts_length: u32, + #[br(count = create_contexts_length as usize)] + pub create_contexts: Vec, +} + +impl CreateResponse { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Create contexts (MS-SMB2 §2.2.13.2) +// --------------------------------------------------------------------------- + +/// Generic SMB2_CREATE_CONTEXT envelope. +/// +/// Per MS-SMB2 §2.2.13.2 each entry has: +/// * `Next` — offset (bytes) from the start of *this* entry to the start of +/// the next entry in the chain, or 0 for the last entry. +/// * `NameOffset`/`NameLength` — name (typically a 4-byte ASCII tag) at an +/// offset relative to the entry start. +/// * `Reserved` — 2 bytes. +/// * `DataOffset`/`DataLength` — payload at an offset relative to the entry +/// start. +/// +/// We model the entry as `name` + `data` byte vectors plus the raw flags. The +/// chain reader / writer below handles `Next` and 8-byte alignment between +/// entries. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct CreateContext { + pub name: Vec, + pub data: Vec, +} + +impl CreateContext { + // Well-known names (MS-SMB2 §2.2.13.2 table). 4-byte ASCII tags. + pub const NAME_EXTA: &'static [u8; 4] = b"ExtA"; // SMB2_CREATE_EA_BUFFER + pub const NAME_SECD: &'static [u8; 4] = b"SecD"; // SMB2_CREATE_SD_BUFFER + pub const NAME_DHNQ: &'static [u8; 4] = b"DHnQ"; // DURABLE_HANDLE_REQUEST + pub const NAME_DHNC: &'static [u8; 4] = b"DHnC"; // DURABLE_HANDLE_RECONNECT + pub const NAME_ALSI: &'static [u8; 4] = b"AlSi"; // ALLOCATION_SIZE + pub const NAME_MXAC: &'static [u8; 4] = b"MxAc"; // QUERY_MAXIMAL_ACCESS + pub const NAME_TWRP: &'static [u8; 4] = b"TWrp"; // TIMEWARP_TOKEN + pub const NAME_QFID: &'static [u8; 4] = b"QFid"; // QUERY_ON_DISK_ID + pub const NAME_RQLS: &'static [u8; 4] = b"RqLs"; // REQUEST_LEASE + pub const NAME_DH2Q: &'static [u8; 4] = b"DH2Q"; // DURABLE_HANDLE_REQUEST_V2 + pub const NAME_DH2C: &'static [u8; 4] = b"DH2C"; // DURABLE_HANDLE_RECONNECT_V2 + + /// Parse a chain of create-contexts from the raw chain bytes. + /// + /// The chain is empty if `chain.is_empty()`. Otherwise we walk `Next` + /// offsets until we hit a zero terminator, validating bounds at each step. + pub fn parse_chain(chain: &[u8]) -> ProtoResult> { + let mut out = Vec::new(); + if chain.is_empty() { + return Ok(out); + } + let mut cursor_off = 0usize; + loop { + let entry = &chain + .get(cursor_off..) + .ok_or(ProtoError::Malformed("create context out of range"))?; + if entry.len() < 16 { + return Err(ProtoError::Malformed("create context too short")); + } + let next = u32::from_le_bytes([entry[0], entry[1], entry[2], entry[3]]) as usize; + let name_offset = u16::from_le_bytes([entry[4], entry[5]]) as usize; + let name_length = u16::from_le_bytes([entry[6], entry[7]]) as usize; + // entry[8..10] = reserved + let data_offset = u16::from_le_bytes([entry[10], entry[11]]) as usize; + let data_length = + u32::from_le_bytes([entry[12], entry[13], entry[14], entry[15]]) as usize; + + let name = entry + .get(name_offset..name_offset + name_length) + .ok_or(ProtoError::Malformed("create context name out of range"))? + .to_vec(); + let data = if data_length == 0 { + Vec::new() + } else { + entry + .get(data_offset..data_offset + data_length) + .ok_or(ProtoError::Malformed("create context data out of range"))? + .to_vec() + }; + out.push(CreateContext { name, data }); + + if next == 0 { + break; + } + cursor_off = cursor_off + .checked_add(next) + .ok_or(ProtoError::Malformed("create context next overflow"))?; + } + Ok(out) + } + + /// Encode a chain of create-contexts into `out`. Inserts `Next` offsets + /// and 8-byte alignment padding between entries. + pub fn encode_chain(list: &[CreateContext], out: &mut Vec) -> ProtoResult<()> { + if list.is_empty() { + return Ok(()); + } + // We build the chain in a scratch buffer, then copy. Each entry is: + // 16-byte header + name + (pad to 8) + data + (pad to 8 if not last) + // The `Next` of every entry except the last is the size from this + // entry's start to the next entry's start. + let mut scratch: Vec = Vec::new(); + let mut entry_starts: Vec = Vec::with_capacity(list.len()); + + for (i, ctx) in list.iter().enumerate() { + // Pad to 8-byte boundary before each entry (except possibly first + // — but contexts must be 8-byte aligned, and the chain itself is + // anchored at an 8-aligned offset by the server). + while !scratch.len().is_multiple_of(8) { + scratch.push(0); + } + entry_starts.push(scratch.len()); + + // Reserve 16 bytes for the header; will fill in once we know + // the actual offsets. + let header_pos = scratch.len(); + scratch.extend_from_slice(&[0u8; 16]); + + // Name immediately follows the header. + let name_offset_rel = (scratch.len() - header_pos) as u16; + scratch.extend_from_slice(&ctx.name); + // Pad to 8 before data. + while !(scratch.len() - header_pos).is_multiple_of(8) { + scratch.push(0); + } + let data_offset_rel = (scratch.len() - header_pos) as u16; + scratch.extend_from_slice(&ctx.data); + + // Now backfill the header bytes (Next is patched after the loop). + let hdr = &mut scratch[header_pos..header_pos + 16]; + hdr[0..4].copy_from_slice(&0u32.to_le_bytes()); // Next, fixed up below + hdr[4..6].copy_from_slice(&name_offset_rel.to_le_bytes()); + hdr[6..8].copy_from_slice(&(ctx.name.len() as u16).to_le_bytes()); + hdr[8..10].copy_from_slice(&0u16.to_le_bytes()); // Reserved + hdr[10..12].copy_from_slice(&data_offset_rel.to_le_bytes()); + hdr[12..16].copy_from_slice(&(ctx.data.len() as u32).to_le_bytes()); + + // For non-last, pad the trailing data area to 8 so the next + // entry starts aligned. + if i + 1 < list.len() { + while !scratch.len().is_multiple_of(8) { + scratch.push(0); + } + } + } + + // Patch `Next` offsets. + for i in 0..(entry_starts.len() - 1) { + let this = entry_starts[i]; + let next = entry_starts[i + 1]; + let delta = (next - this) as u32; + scratch[this..this + 4].copy_from_slice(&delta.to_le_bytes()); + } + // Last entry's Next stays 0. + + out.extend_from_slice(&scratch); + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Helper enums (oplock level, impersonation level) +// --------------------------------------------------------------------------- + +/// MS-SMB2 §2.2.13 RequestedOplockLevel / §2.2.14 OplockLevel. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum OplockLevel { + None = 0x00, + Ii = 0x01, + Exclusive = 0x08, + Batch = 0x09, + Lease = 0xFF, +} + +impl OplockLevel { + pub fn from_u8(v: u8) -> Option { + Some(match v { + 0x00 => Self::None, + 0x01 => Self::Ii, + 0x08 => Self::Exclusive, + 0x09 => Self::Batch, + 0xFF => Self::Lease, + _ => return None, + }) + } +} + +/// MS-SMB2 §2.2.13 ImpersonationLevel. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u32)] +pub enum ImpersonationLevel { + Anonymous = 0x0000_0000, + Identification = 0x0000_0001, + Impersonation = 0x0000_0002, + Delegate = 0x0000_0003, +} + +#[cfg(test)] +mod tests { + use super::*; + + fn utf16le(s: &str) -> Vec { + s.encode_utf16().flat_map(u16::to_le_bytes).collect() + } + + #[test] + fn request_round_trips() { + let name = utf16le("dir\\file.txt"); + let r = CreateRequest { + structure_size: 57, + security_flags: 0, + requested_oplock_level: 0, + impersonation_level: ImpersonationLevel::Impersonation as u32, + smb_create_flags: 0, + reserved: 0, + desired_access: 0x0012_0089, + file_attributes: 0, + share_access: 0x0000_0007, + create_disposition: 1, + create_options: 0x0000_0040, + name_offset: 0x78, + name_length: name.len() as u16, + create_contexts_offset: 0, + create_contexts_length: 0, + name, + create_contexts: vec![], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + let decoded = CreateRequest::parse(&buf).unwrap(); + assert_eq!(decoded, r); + assert_eq!(decoded.name_str().unwrap(), "dir\\file.txt"); + } + + #[test] + fn response_round_trips() { + let r = CreateResponse { + structure_size: 89, + oplock_level: 0, + flags: 0, + create_action: 1, + creation_time: 0x01D9_0000_0000_0000, + last_access_time: 0x01D9_0000_0000_0000, + last_write_time: 0x01D9_0000_0000_0000, + change_time: 0x01D9_0000_0000_0000, + allocation_size: 0x1000, + end_of_file: 0x800, + file_attributes: 0x0020, + reserved2: 0, + file_id: FileId::new(0x1234, 0x5678), + create_contexts_offset: 0, + create_contexts_length: 0, + create_contexts: vec![], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + let decoded = CreateResponse::parse(&buf).unwrap(); + assert_eq!(decoded, r); + } + + #[test] + fn create_context_chain_round_trips_single() { + let ctxs = vec![CreateContext { + name: b"MxAc".to_vec(), + data: vec![], + }]; + let mut buf = Vec::new(); + CreateContext::encode_chain(&ctxs, &mut buf).unwrap(); + let decoded = CreateContext::parse_chain(&buf).unwrap(); + assert_eq!(decoded, ctxs); + } + + #[test] + fn create_context_chain_round_trips_multi() { + let ctxs = vec![ + CreateContext { + name: b"DHnQ".to_vec(), + data: vec![0u8; 16], + }, + CreateContext { + name: b"MxAc".to_vec(), + data: vec![], + }, + CreateContext { + name: b"QFid".to_vec(), + data: vec![0xAA; 32], + }, + ]; + let mut buf = Vec::new(); + CreateContext::encode_chain(&ctxs, &mut buf).unwrap(); + let decoded = CreateContext::parse_chain(&buf).unwrap(); + assert_eq!(decoded, ctxs); + } + + #[test] + fn empty_chain_round_trips() { + let ctxs: Vec = vec![]; + let mut buf = Vec::new(); + CreateContext::encode_chain(&ctxs, &mut buf).unwrap(); + assert!(buf.is_empty()); + let decoded = CreateContext::parse_chain(&buf).unwrap(); + assert!(decoded.is_empty()); + } +} diff --git a/vendor/smb-server/src/proto/messages/echo.rs b/vendor/smb-server/src/proto/messages/echo.rs new file mode 100644 index 0000000..4d5f5b9 --- /dev/null +++ b/vendor/smb-server/src/proto/messages/echo.rs @@ -0,0 +1,83 @@ +//! ECHO Request/Response (MS-SMB2 §2.2.28). + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use crate::proto::error::ProtoResult; + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EchoRequest { + pub structure_size: u16, + pub reserved: u16, +} + +impl Default for EchoRequest { + fn default() -> Self { + Self { + structure_size: 4, + reserved: 0, + } + } +} + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EchoResponse { + pub structure_size: u16, + pub reserved: u16, +} + +impl Default for EchoResponse { + fn default() -> Self { + Self { + structure_size: 4, + reserved: 0, + } + } +} + +impl EchoRequest { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +impl EchoResponse { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn round_trips() { + let req = EchoRequest::default(); + let mut buf = Vec::new(); + req.write_to(&mut buf).unwrap(); + assert_eq!(buf.len(), 4); + assert_eq!(EchoRequest::parse(&buf).unwrap(), req); + + let resp = EchoResponse::default(); + let mut buf = Vec::new(); + resp.write_to(&mut buf).unwrap(); + assert_eq!(EchoResponse::parse(&buf).unwrap(), resp); + } +} diff --git a/vendor/smb-server/src/proto/messages/error_response.rs b/vendor/smb-server/src/proto/messages/error_response.rs new file mode 100644 index 0000000..b2e070c --- /dev/null +++ b/vendor/smb-server/src/proto/messages/error_response.rs @@ -0,0 +1,84 @@ +//! SMB2 ERROR Response (MS-SMB2 §2.2.2). +//! +//! Sent in place of any normal response when the server returns a non-zero +//! NTSTATUS. The SMB2 header carries the NTSTATUS in `channel_sequence_status`; +//! this body provides extended error context if any. + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use crate::proto::error::ProtoResult; + +/// MS-SMB2 §2.2.2 ERROR Response. +/// +/// `structure_size` is always 9; `byte_count` is the length of `error_data` +/// when there is no structured error context (the common case). When +/// `error_context_count > 0`, `error_data` holds a sequence of +/// [`ErrorContext`] entries (SMB 3.1.1+). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ErrorResponse { + pub structure_size: u16, + pub error_context_count: u8, + pub reserved: u8, + pub byte_count: u32, + #[br(count = if byte_count == 0 { 1 } else { byte_count as usize })] + pub error_data: Vec, +} + +impl ErrorResponse { + /// Build a minimal ERROR response body for the given NTSTATUS. + /// + /// Per MS-SMB2 §2.2.2 a zero-`byte_count` ERROR response still emits a + /// single byte of `error_data` (the field is mandatory, length 1 when + /// there is no payload). + pub fn status(_ntstatus: u32) -> Self { + Self { + structure_size: 9, + error_context_count: 0, + reserved: 0, + byte_count: 0, + error_data: vec![0], + } + } + + pub fn parse(buf: &[u8]) -> ProtoResult { + let mut c = Cursor::new(buf); + Ok(Self::read(&mut c)?) + } + + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +/// MS-SMB2 §2.2.2.1 ERROR Context Response (3.1.1+). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ErrorContext { + pub error_data_length: u32, + pub error_id: u32, + #[br(count = error_data_length as usize)] + pub error_context_data: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn round_trips_status_helper() { + let r = ErrorResponse::status(0xC000_0022 /* STATUS_ACCESS_DENIED */); + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + let decoded = ErrorResponse::parse(&buf).unwrap(); + assert_eq!(decoded, r); + // structure_size, contexts, reserved, bytecount, 1 byte payload = 9 bytes + assert_eq!(buf.len(), 9); + } +} diff --git a/vendor/smb-server/src/proto/messages/flush.rs b/vendor/smb-server/src/proto/messages/flush.rs new file mode 100644 index 0000000..46f7f27 --- /dev/null +++ b/vendor/smb-server/src/proto/messages/flush.rs @@ -0,0 +1,86 @@ +//! FLUSH Request/Response (MS-SMB2 §2.2.17 / §2.2.18). + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use crate::proto::error::ProtoResult; + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FlushRequest { + pub structure_size: u16, + pub reserved1: u16, + pub reserved2: u32, + /// Volatile portion of the FileId. + pub file_id_persistent: u64, + /// Persistent portion of the FileId. + pub file_id_volatile: u64, +} + +impl FlushRequest { + pub fn new(persistent: u64, volatile: u64) -> Self { + Self { + structure_size: 24, + reserved1: 0, + reserved2: 0, + file_id_persistent: persistent, + file_id_volatile: volatile, + } + } +} + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FlushResponse { + pub structure_size: u16, + pub reserved: u16, +} + +impl Default for FlushResponse { + fn default() -> Self { + Self { + structure_size: 4, + reserved: 0, + } + } +} + +macro_rules! impl_codec { + ($t:ty) => { + impl $t { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } + } + }; +} + +impl_codec!(FlushRequest); +impl_codec!(FlushResponse); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn round_trips() { + let r = FlushRequest::new(0x1122_3344_5566_7788, 0xAABB_CCDD_EEFF_0011); + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(buf.len(), 24); + assert_eq!(FlushRequest::parse(&buf).unwrap(), r); + + let r = FlushResponse::default(); + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(FlushResponse::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/messages/ioctl.rs b/vendor/smb-server/src/proto/messages/ioctl.rs new file mode 100644 index 0000000..55219e3 --- /dev/null +++ b/vendor/smb-server/src/proto/messages/ioctl.rs @@ -0,0 +1,206 @@ +//! IOCTL Request/Response (MS-SMB2 §2.2.31 / §2.2.32). + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use super::create::FileId; +use crate::proto::error::ProtoResult; + +/// File-system control codes we recognize at the wire layer. +/// +/// MS-FSCC catalogues the FSCTL codes; we only enumerate the ones referenced +/// in the spec for v1. Unknown codes round-trip via [`Fsctl::Other`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Fsctl { + /// `FSCTL_VALIDATE_NEGOTIATE_INFO` — required handler in v1. + ValidateNegotiateInfo, + /// `FSCTL_DFS_GET_REFERRALS`. + DfsGetReferrals, + /// `FSCTL_DFS_GET_REFERRALS_EX`. + DfsGetReferralsEx, + /// `FSCTL_PIPE_TRANSCEIVE`. + PipeTranscede, + /// `FSCTL_PIPE_PEEK`. + PipePeek, + /// `FSCTL_PIPE_WAIT`. + PipeWait, + /// `FSCTL_LMR_REQUEST_RESILIENCY`. + LmrRequestResiliency, + /// `FSCTL_QUERY_NETWORK_INTERFACE_INFO`. + QueryNetworkInterfaceInfo, + /// Anything else. + Other(u32), +} + +impl Fsctl { + pub const VALIDATE_NEGOTIATE_INFO: u32 = 0x0014_0204; + pub const DFS_GET_REFERRALS: u32 = 0x0006_0194; + pub const DFS_GET_REFERRALS_EX: u32 = 0x0006_0198; + pub const PIPE_TRANSCEIVE: u32 = 0x0011_C017; + pub const PIPE_PEEK: u32 = 0x0011_400C; + pub const PIPE_WAIT: u32 = 0x0011_C018; + pub const LMR_REQUEST_RESILIENCY: u32 = 0x001C_0017; + pub const QUERY_NETWORK_INTERFACE_INFO: u32 = 0x001F_C017; + + pub fn from_u32(code: u32) -> Self { + match code { + Self::VALIDATE_NEGOTIATE_INFO => Self::ValidateNegotiateInfo, + Self::DFS_GET_REFERRALS => Self::DfsGetReferrals, + Self::DFS_GET_REFERRALS_EX => Self::DfsGetReferralsEx, + Self::PIPE_TRANSCEIVE => Self::PipeTranscede, + Self::PIPE_PEEK => Self::PipePeek, + Self::PIPE_WAIT => Self::PipeWait, + Self::LMR_REQUEST_RESILIENCY => Self::LmrRequestResiliency, + Self::QUERY_NETWORK_INTERFACE_INFO => Self::QueryNetworkInterfaceInfo, + other => Self::Other(other), + } + } + + pub fn as_u32(self) -> u32 { + match self { + Self::ValidateNegotiateInfo => Self::VALIDATE_NEGOTIATE_INFO, + Self::DfsGetReferrals => Self::DFS_GET_REFERRALS, + Self::DfsGetReferralsEx => Self::DFS_GET_REFERRALS_EX, + Self::PipeTranscede => Self::PIPE_TRANSCEIVE, + Self::PipePeek => Self::PIPE_PEEK, + Self::PipeWait => Self::PIPE_WAIT, + Self::LmrRequestResiliency => Self::LMR_REQUEST_RESILIENCY, + Self::QueryNetworkInterfaceInfo => Self::QUERY_NETWORK_INTERFACE_INFO, + Self::Other(c) => c, + } + } +} + +/// SMB2_IOCTL_REQUEST (MS-SMB2 §2.2.31). +/// +/// `input_offset` and `output_offset` are absolute (from the start of the +/// SMB2 header). We model the input buffer immediately following the fixed +/// prefix; the output buffer area is unused on requests but kept for round +/// tripping and extension scenarios. +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct IoctlRequest { + pub structure_size: u16, + pub reserved: u16, + pub ctl_code: u32, + pub file_id: FileId, + pub input_offset: u32, + pub input_count: u32, + pub max_input_response: u32, + pub output_offset: u32, + pub output_count: u32, + pub max_output_response: u32, + pub flags: u32, + pub reserved2: u32, + #[br(count = input_count as usize)] + pub input: Vec, +} + +impl IoctlRequest { + /// Flag: SMB2_0_IOCTL_IS_FSCTL. + pub const FLAG_IS_FSCTL: u32 = 0x0000_0001; + + pub fn fsctl(&self) -> Fsctl { + Fsctl::from_u32(self.ctl_code) + } + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +/// SMB2_IOCTL_RESPONSE (MS-SMB2 §2.2.32). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct IoctlResponse { + pub structure_size: u16, + pub reserved: u16, + pub ctl_code: u32, + pub file_id: FileId, + pub input_offset: u32, + pub input_count: u32, + pub output_offset: u32, + pub output_count: u32, + pub flags: u32, + pub reserved2: u32, + /// Output buffer immediately following the fixed prefix. + #[br(count = output_count as usize)] + pub output: Vec, +} + +impl IoctlResponse { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn fsctl_decode_known() { + assert_eq!(Fsctl::from_u32(0x0014_0204), Fsctl::ValidateNegotiateInfo); + assert_eq!(Fsctl::from_u32(0xDEAD_BEEF), Fsctl::Other(0xDEAD_BEEF)); + assert_eq!(Fsctl::ValidateNegotiateInfo.as_u32(), 0x0014_0204); + assert_eq!(Fsctl::Other(0xDEAD_BEEF).as_u32(), 0xDEAD_BEEF); + } + + #[test] + fn request_round_trips() { + let r = IoctlRequest { + structure_size: 57, + reserved: 0, + ctl_code: Fsctl::VALIDATE_NEGOTIATE_INFO, + file_id: FileId::any(), + input_offset: 0x78, + input_count: 4, + max_input_response: 0, + output_offset: 0, + output_count: 0, + max_output_response: 0x1000, + flags: IoctlRequest::FLAG_IS_FSCTL, + reserved2: 0, + input: vec![0xCA, 0xFE, 0xBA, 0xBE], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + let decoded = IoctlRequest::parse(&buf).unwrap(); + assert_eq!(decoded, r); + assert_eq!(decoded.fsctl(), Fsctl::ValidateNegotiateInfo); + } + + #[test] + fn response_round_trips() { + let r = IoctlResponse { + structure_size: 49, + reserved: 0, + ctl_code: Fsctl::VALIDATE_NEGOTIATE_INFO, + file_id: FileId::any(), + input_offset: 0, + input_count: 0, + output_offset: 0x70, + output_count: 4, + flags: 0, + reserved2: 0, + output: vec![1, 2, 3, 4], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(IoctlResponse::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/messages/lock.rs b/vendor/smb-server/src/proto/messages/lock.rs new file mode 100644 index 0000000..2fc70ae --- /dev/null +++ b/vendor/smb-server/src/proto/messages/lock.rs @@ -0,0 +1,118 @@ +//! LOCK Request/Response (MS-SMB2 §2.2.26 / §2.2.27). + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use super::create::FileId; +use crate::proto::error::ProtoResult; + +/// SMB2_LOCK_ELEMENT (MS-SMB2 §2.2.26.1) — exactly 24 bytes. +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LockElement { + pub offset: u64, + pub length: u64, + pub flags: u32, + pub reserved: u32, +} + +impl LockElement { + pub const FLAG_SHARED_LOCK: u32 = 0x0000_0001; + pub const FLAG_EXCLUSIVE_LOCK: u32 = 0x0000_0002; + pub const FLAG_UNLOCK: u32 = 0x0000_0004; + pub const FLAG_FAIL_IMMEDIATELY: u32 = 0x0000_0010; +} + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LockRequest { + pub structure_size: u16, + pub lock_count: u16, + pub lock_sequence: u32, + pub file_id: FileId, + #[br(count = lock_count as usize)] + pub locks: Vec, +} + +impl LockRequest { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LockResponse { + pub structure_size: u16, + pub reserved: u16, +} + +impl Default for LockResponse { + fn default() -> Self { + Self { + structure_size: 4, + reserved: 0, + } + } +} + +impl LockResponse { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn request_round_trips() { + let r = LockRequest { + structure_size: 48, + lock_count: 2, + lock_sequence: 0, + file_id: FileId::new(1, 2), + locks: vec![ + LockElement { + offset: 0, + length: 16, + flags: LockElement::FLAG_EXCLUSIVE_LOCK, + reserved: 0, + }, + LockElement { + offset: 0, + length: 16, + flags: LockElement::FLAG_UNLOCK, + reserved: 0, + }, + ], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(LockRequest::parse(&buf).unwrap(), r); + } + + #[test] + fn response_round_trips() { + let r = LockResponse::default(); + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(LockResponse::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/messages/logoff.rs b/vendor/smb-server/src/proto/messages/logoff.rs new file mode 100644 index 0000000..9ebf7bf --- /dev/null +++ b/vendor/smb-server/src/proto/messages/logoff.rs @@ -0,0 +1,77 @@ +//! LOGOFF Request/Response (MS-SMB2 §2.2.7 / §2.2.8). + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use crate::proto::error::ProtoResult; + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LogoffRequest { + pub structure_size: u16, + pub reserved: u16, +} + +impl Default for LogoffRequest { + fn default() -> Self { + Self { + structure_size: 4, + reserved: 0, + } + } +} + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LogoffResponse { + pub structure_size: u16, + pub reserved: u16, +} + +impl Default for LogoffResponse { + fn default() -> Self { + Self { + structure_size: 4, + reserved: 0, + } + } +} + +macro_rules! impl_codec { + ($t:ty) => { + impl $t { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } + } + }; +} + +impl_codec!(LogoffRequest); +impl_codec!(LogoffResponse); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn round_trips() { + let r = LogoffRequest::default(); + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(LogoffRequest::parse(&buf).unwrap(), r); + + let r = LogoffResponse::default(); + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(LogoffResponse::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/messages/mod.rs b/vendor/smb-server/src/proto/messages/mod.rs new file mode 100644 index 0000000..c9d6231 --- /dev/null +++ b/vendor/smb-server/src/proto/messages/mod.rs @@ -0,0 +1,55 @@ +//! Per-command request/response wire structs. +//! +//! Each SMB2 command (MS-SMB2 §2.2.3 — §2.2.18, §2.2.31, §2.2.37, §2.2.39) +//! gets its own submodule with a `…Request` and `…Response` struct, both +//! `binrw`-driven and round-trip safe. +//! +//! The crate does **not** implement command behavior — it only encodes/decodes +//! the wire bytes. The server crate owns dispatch and state. + +pub mod cancel; +pub mod change_notify; +pub mod close; +pub mod create; +pub mod echo; +pub mod error_response; +pub mod flush; +pub mod ioctl; +pub mod lock; +pub mod logoff; +pub mod negotiate; +pub mod oplock_break; +pub mod query_directory; +pub mod query_info; +pub mod read; +pub mod session_setup; +pub mod set_info; +pub mod tree_connect; +pub mod tree_disconnect; +pub mod write; + +pub use cancel::CancelRequest; +pub use change_notify::{ChangeNotifyRequest, ChangeNotifyResponse}; +pub use close::{CloseRequest, CloseResponse}; +pub use create::{ + CreateContext, CreateRequest, CreateResponse, FileId, ImpersonationLevel, OplockLevel, +}; +pub use echo::{EchoRequest, EchoResponse}; +pub use error_response::{ErrorContext, ErrorResponse}; +pub use flush::{FlushRequest, FlushResponse}; +pub use ioctl::{Fsctl, IoctlRequest, IoctlResponse}; +pub use lock::{LockElement, LockRequest, LockResponse}; +pub use logoff::{LogoffRequest, LogoffResponse}; +pub use negotiate::{ + Dialect, EncryptionCapabilities, NegotiateContext, NegotiateContextData, NegotiateRequest, + NegotiateResponse, PreauthIntegrityCapabilities, SigningCapabilities, +}; +pub use oplock_break::{OplockBreakAck, OplockBreakNotification}; +pub use query_directory::{FileInfoClass, QueryDirectoryRequest, QueryDirectoryResponse}; +pub use query_info::{InfoType, QueryInfoRequest, QueryInfoResponse}; +pub use read::{ReadRequest, ReadResponse}; +pub use session_setup::{SessionSetupRequest, SessionSetupResponse}; +pub use set_info::{SetInfoRequest, SetInfoResponse}; +pub use tree_connect::{TreeConnectRequest, TreeConnectResponse}; +pub use tree_disconnect::{TreeDisconnectRequest, TreeDisconnectResponse}; +pub use write::{WriteRequest, WriteResponse}; diff --git a/vendor/smb-server/src/proto/messages/negotiate.rs b/vendor/smb-server/src/proto/messages/negotiate.rs new file mode 100644 index 0000000..e2027e7 --- /dev/null +++ b/vendor/smb-server/src/proto/messages/negotiate.rs @@ -0,0 +1,384 @@ +//! NEGOTIATE Request/Response (MS-SMB2 §2.2.3 / §2.2.4) including the SMB +//! 3.1.1 negotiate-context machinery from §2.2.3.1.x and §2.2.4.x. + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use crate::proto::error::ProtoResult; + +// --------------------------------------------------------------------------- +// Dialect +// --------------------------------------------------------------------------- + +/// SMB2 dialect revision codes (MS-SMB2 §2.2.3 — DialectRevision). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u16)] +pub enum Dialect { + Smb202 = 0x0202, + Smb210 = 0x0210, + Smb300 = 0x0300, + Smb302 = 0x0302, + Smb311 = 0x0311, + /// Sent by SMB 2.0.2/2.1 clients via SMB1 negotiate; we accept it as a + /// signal to multi-protocol-negotiate. Value 0x02FF. + Smb2Wildcard = 0x02FF, +} + +impl Dialect { + pub fn from_u16(v: u16) -> Option { + Some(match v { + 0x0202 => Self::Smb202, + 0x0210 => Self::Smb210, + 0x0300 => Self::Smb300, + 0x0302 => Self::Smb302, + 0x0311 => Self::Smb311, + 0x02FF => Self::Smb2Wildcard, + _ => return None, + }) + } + + pub const fn as_u16(self) -> u16 { + self as u16 + } +} + +// --------------------------------------------------------------------------- +// Negotiate request +// --------------------------------------------------------------------------- + +/// MS-SMB2 §2.2.3 NEGOTIATE Request. +/// +/// `dialects` is a sequence of u16 little-endian dialect codes; for SMB 3.1.1 +/// the trailing `negotiate_context_list` carries variable-length contexts at +/// `negotiate_context_offset`. +/// +/// Note on parsing: we deliberately don't try to read `negotiate_context_list` +/// here automatically, because its position is given by an absolute offset +/// from the *start of the SMB2 header*, not from the start of this body. +/// The server crate decodes this body, then if `dialects` includes 3.1.1 it +/// resolves `negotiate_context_offset` against the original packet buffer +/// and parses the contexts via [`NegotiateContext::parse_list`]. +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NegotiateRequest { + pub structure_size: u16, + pub dialect_count: u16, + pub security_mode: u16, + pub reserved: u16, + pub capabilities: u32, + pub client_guid: [u8; 16], + /// 3.1.1: NegotiateContextOffset. 2.x/3.0/3.0.2: ClientStartTime. + pub negotiate_context_offset_or_client_start_time: u64, + #[br(count = dialect_count as usize)] + pub dialects: Vec, +} + +impl NegotiateRequest { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Negotiate response +// --------------------------------------------------------------------------- + +/// MS-SMB2 §2.2.4 NEGOTIATE Response. +/// +/// The trailing `security_buffer` and (3.1.1) `negotiate_context_list` are +/// referenced by absolute offsets from the start of the SMB2 header. This +/// struct encodes the *fixed* portion plus a `security_buffer` that we treat +/// as a length-counted blob immediately following the fixed portion (the +/// common server layout). For 3.1.1 contexts, the server crate writes the +/// fixed portion via [`NegotiateResponse::write_to`], then appends 8-byte- +/// aligned negotiate contexts and patches `negotiate_context_offset` to the +/// post-padding offset. +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NegotiateResponse { + pub structure_size: u16, + pub security_mode: u16, + pub dialect_revision: u16, + /// 3.1.1: NegotiateContextCount. 2.x/3.0/3.0.2: Reserved. + pub negotiate_context_count_or_reserved: u16, + pub server_guid: [u8; 16], + pub capabilities: u32, + pub max_transact_size: u32, + pub max_read_size: u32, + pub max_write_size: u32, + /// 100ns ticks since 1601-01-01 UTC. + pub system_time: u64, + pub server_start_time: u64, + pub security_buffer_offset: u16, + pub security_buffer_length: u16, + /// 3.1.1: NegotiateContextOffset. 2.x/3.0/3.0.2: Reserved2. + pub negotiate_context_offset_or_reserved2: u32, + #[br(count = security_buffer_length as usize)] + pub security_buffer: Vec, +} + +impl NegotiateResponse { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Negotiate contexts (SMB 3.1.1) +// --------------------------------------------------------------------------- + +/// MS-SMB2 §2.2.3.1 / §2.2.4.x — NEGOTIATE_CONTEXT generic header. +/// +/// Contexts are 8-byte-aligned in the chain (the trailing padding is between +/// contexts; see §2.2.3.1 "Each NEGOTIATE_CONTEXT MUST be 8-byte aligned"). +/// `parse_list` / `encode_list` handle the alignment. +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NegotiateContext { + pub context_type: u16, + pub data_length: u16, + pub reserved: u32, + #[br(count = data_length as usize)] + pub data: Vec, +} + +impl NegotiateContext { + pub const TYPE_PREAUTH_INTEGRITY: u16 = 0x0001; + pub const TYPE_ENCRYPTION: u16 = 0x0002; + pub const TYPE_COMPRESSION: u16 = 0x0003; + pub const TYPE_NETNAME_NEGOTIATE: u16 = 0x0005; + pub const TYPE_TRANSPORT_CAPS: u16 = 0x0006; + pub const TYPE_RDMA_TRANSFORM: u16 = 0x0007; + pub const TYPE_SIGNING: u16 = 0x0008; + + /// Parse a chain of negotiate contexts from `buf`. The chain is a series + /// of (8-byte-aligned) [`NegotiateContext`] entries. `count` comes from + /// the parent message's `NegotiateContextCount`. + pub fn parse_list(mut buf: &[u8], count: u16) -> ProtoResult> { + let mut out = Vec::with_capacity(count as usize); + let mut consumed_total = 0usize; + for _ in 0..count { + // Pad to 8-byte alignment relative to the start of the list. + let pad = (8 - (consumed_total % 8)) % 8; + if pad > 0 { + if buf.len() < pad { + return Err(crate::proto::error::ProtoError::Malformed( + "negotiate context alignment underflow", + )); + } + buf = &buf[pad..]; + consumed_total += pad; + } + let mut c = Cursor::new(buf); + let ctx = NegotiateContext::read(&mut c)?; + let consumed = c.position() as usize; + buf = &buf[consumed..]; + consumed_total += consumed; + out.push(ctx); + } + Ok(out) + } + + /// Encode a chain of negotiate contexts into `out`, inserting 8-byte + /// padding between entries. + pub fn encode_list(list: &[NegotiateContext], out: &mut Vec) -> ProtoResult<()> { + let start = out.len(); + for (i, ctx) in list.iter().enumerate() { + if i > 0 { + let pad = (8 - ((out.len() - start) % 8)) % 8; + out.extend(std::iter::repeat_n(0u8, pad)); + } + let mut c = Cursor::new(Vec::new()); + BinWrite::write(ctx, &mut c)?; + out.extend_from_slice(&c.into_inner()); + } + Ok(()) + } +} + +/// Parsed payload of a known [`NegotiateContext`] type. Convenience wrapper — +/// the wire form is always [`NegotiateContext`]; this enum is for callers who +/// prefer typed access. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum NegotiateContextData { + PreauthIntegrity(PreauthIntegrityCapabilities), + Encryption(EncryptionCapabilities), + Signing(SigningCapabilities), + /// Unknown / unhandled context — preserve raw bytes for round-tripping. + Other { + context_type: u16, + data: Vec, + }, +} + +/// MS-SMB2 §2.2.3.1.1 / §2.2.4.1 SMB2_PREAUTH_INTEGRITY_CAPABILITIES. +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PreauthIntegrityCapabilities { + pub hash_algorithm_count: u16, + pub salt_length: u16, + #[br(count = hash_algorithm_count as usize)] + pub hash_algorithms: Vec, + #[br(count = salt_length as usize)] + pub salt: Vec, +} + +impl PreauthIntegrityCapabilities { + /// Hash algorithm: SHA-512 (the only one defined in MS-SMB2 §2.2.3.1.1). + pub const HASH_SHA512: u16 = 0x0001; +} + +/// MS-SMB2 §2.2.3.1.2 / §2.2.4.2 SMB2_ENCRYPTION_CAPABILITIES. +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EncryptionCapabilities { + pub cipher_count: u16, + #[br(count = cipher_count as usize)] + pub ciphers: Vec, +} + +impl EncryptionCapabilities { + pub const CIPHER_AES_128_CCM: u16 = 0x0001; + pub const CIPHER_AES_128_GCM: u16 = 0x0002; + pub const CIPHER_AES_256_CCM: u16 = 0x0003; + pub const CIPHER_AES_256_GCM: u16 = 0x0004; +} + +/// MS-SMB2 §2.2.3.1.7 / §2.2.4.7 SMB2_SIGNING_CAPABILITIES. +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SigningCapabilities { + pub signing_algorithm_count: u16, + #[br(count = signing_algorithm_count as usize)] + pub signing_algorithms: Vec, +} + +impl SigningCapabilities { + pub const ALGORITHM_HMAC_SHA256: u16 = 0x0000; + pub const ALGORITHM_AES_CMAC: u16 = 0x0001; + pub const ALGORITHM_AES_GMAC: u16 = 0x0002; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn negotiate_request_round_trips() { + let req = NegotiateRequest { + structure_size: 36, + dialect_count: 5, + security_mode: 0x0001, // signing enabled + reserved: 0, + capabilities: 0x0000_007F, + client_guid: [0xAB; 16], + negotiate_context_offset_or_client_start_time: 0x0000_0070_0000_0000, + dialects: vec![0x0202, 0x0210, 0x0300, 0x0302, 0x0311], + }; + let mut buf = Vec::new(); + req.write_to(&mut buf).unwrap(); + let decoded = NegotiateRequest::parse(&buf).unwrap(); + assert_eq!(decoded, req); + } + + #[test] + fn negotiate_response_round_trips() { + let resp = NegotiateResponse { + structure_size: 65, + security_mode: 0x0003, + dialect_revision: Dialect::Smb311.as_u16(), + negotiate_context_count_or_reserved: 3, + server_guid: [0xCD; 16], + capabilities: 0x0000_007F, + max_transact_size: 0x0010_0000, + max_read_size: 0x0010_0000, + max_write_size: 0x0010_0000, + system_time: 0x01D9_1234_5678_9ABC, + server_start_time: 0, + security_buffer_offset: 0x80, + security_buffer_length: 8, + negotiate_context_offset_or_reserved2: 0x100, + security_buffer: vec![1, 2, 3, 4, 5, 6, 7, 8], + }; + let mut buf = Vec::new(); + resp.write_to(&mut buf).unwrap(); + let decoded = NegotiateResponse::parse(&buf).unwrap(); + assert_eq!(decoded, resp); + } + + #[test] + fn dialect_round_trips() { + for d in [ + Dialect::Smb202, + Dialect::Smb210, + Dialect::Smb300, + Dialect::Smb302, + Dialect::Smb311, + Dialect::Smb2Wildcard, + ] { + assert_eq!(Dialect::from_u16(d.as_u16()), Some(d)); + } + assert_eq!(Dialect::from_u16(0xBEEF), None); + } + + #[test] + fn preauth_caps_round_trips() { + let p = PreauthIntegrityCapabilities { + hash_algorithm_count: 1, + salt_length: 32, + hash_algorithms: vec![PreauthIntegrityCapabilities::HASH_SHA512], + salt: vec![0xAA; 32], + }; + let mut buf = Vec::new(); + let mut c = Cursor::new(&mut buf); + BinWrite::write(&p, &mut c).unwrap(); + let decoded = PreauthIntegrityCapabilities::read(&mut Cursor::new(&buf)).unwrap(); + assert_eq!(decoded, p); + } + + #[test] + fn negotiate_context_list_round_trips() { + let list = vec![ + NegotiateContext { + context_type: NegotiateContext::TYPE_PREAUTH_INTEGRITY, + data_length: 6, + reserved: 0, + data: vec![0x01, 0x00, 0x20, 0x00, 0x01, 0x00], + }, + NegotiateContext { + context_type: NegotiateContext::TYPE_ENCRYPTION, + data_length: 4, + reserved: 0, + data: vec![0x02, 0x00, 0x02, 0x00], + }, + NegotiateContext { + context_type: NegotiateContext::TYPE_SIGNING, + data_length: 4, + reserved: 0, + data: vec![0x01, 0x00, 0x01, 0x00], + }, + ]; + let mut buf = Vec::new(); + NegotiateContext::encode_list(&list, &mut buf).unwrap(); + let parsed = NegotiateContext::parse_list(&buf, 3).unwrap(); + assert_eq!(parsed, list); + } +} diff --git a/vendor/smb-server/src/proto/messages/oplock_break.rs b/vendor/smb-server/src/proto/messages/oplock_break.rs new file mode 100644 index 0000000..5aaa139 --- /dev/null +++ b/vendor/smb-server/src/proto/messages/oplock_break.rs @@ -0,0 +1,59 @@ +//! OPLOCK_BREAK Notification + Acknowledgement (MS-SMB2 §2.2.23 / §2.2.24). +//! +//! V1 never grants oplocks, so we never *send* a notification, but the +//! handler exists for safety. A client may send an OPLOCK_BREAK ACK before +//! the server has cleared its oplock state in the (rare) edge case during +//! teardown. + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use super::create::FileId; +use crate::proto::error::ProtoResult; + +/// SMB2_OPLOCK_BREAK_NOTIFICATION (MS-SMB2 §2.2.23.1). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OplockBreakNotification { + pub structure_size: u16, + pub oplock_level: u8, + pub reserved: u8, + pub reserved2: u32, + pub file_id: FileId, +} + +impl OplockBreakNotification { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +/// SMB2_OPLOCK_BREAK_ACK (MS-SMB2 §2.2.24.1) — same wire shape as the +/// notification. +pub type OplockBreakAck = OplockBreakNotification; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn round_trips() { + let r = OplockBreakNotification { + structure_size: 24, + oplock_level: 0, + reserved: 0, + reserved2: 0, + file_id: FileId::new(1, 2), + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(OplockBreakNotification::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/messages/query_directory.rs b/vendor/smb-server/src/proto/messages/query_directory.rs new file mode 100644 index 0000000..b7a52e7 --- /dev/null +++ b/vendor/smb-server/src/proto/messages/query_directory.rs @@ -0,0 +1,136 @@ +//! QUERY_DIRECTORY Request/Response (MS-SMB2 §2.2.33 / §2.2.34). + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use super::create::FileId; +use crate::proto::error::ProtoResult; + +/// File-info-class identifiers used in QUERY_DIRECTORY (MS-SMB2 §2.2.33 +/// FileInformationClass). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum FileInfoClass { + FileDirectoryInformation = 0x01, + FileFullDirectoryInformation = 0x02, + FileBothDirectoryInformation = 0x03, + FileNamesInformation = 0x0C, + FileIdBothDirectoryInformation = 0x25, + FileIdFullDirectoryInformation = 0x26, +} + +impl FileInfoClass { + pub fn from_u8(v: u8) -> Option { + Some(match v { + 0x01 => Self::FileDirectoryInformation, + 0x02 => Self::FileFullDirectoryInformation, + 0x03 => Self::FileBothDirectoryInformation, + 0x0C => Self::FileNamesInformation, + 0x25 => Self::FileIdBothDirectoryInformation, + 0x26 => Self::FileIdFullDirectoryInformation, + _ => return None, + }) + } +} + +/// SMB2_QUERY_DIRECTORY_REQUEST (MS-SMB2 §2.2.33). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QueryDirectoryRequest { + pub structure_size: u16, + pub file_information_class: u8, + pub flags: u8, + pub file_index: u32, + pub file_id: FileId, + pub file_name_offset: u16, + pub file_name_length: u16, + pub output_buffer_length: u32, + /// UTF-16LE search pattern (e.g. "*"). + #[br(count = file_name_length as usize)] + pub file_name: Vec, +} + +impl QueryDirectoryRequest { + pub const FLAG_RESTART_SCANS: u8 = 0x01; + pub const FLAG_RETURN_SINGLE_ENTRY: u8 = 0x02; + pub const FLAG_INDEX_SPECIFIED: u8 = 0x04; + pub const FLAG_REOPEN: u8 = 0x10; + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +/// SMB2_QUERY_DIRECTORY_RESPONSE (MS-SMB2 §2.2.34). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QueryDirectoryResponse { + pub structure_size: u16, + /// `OutputBufferOffset` is from the start of the SMB2 header. + pub output_buffer_offset: u16, + pub output_buffer_length: u32, + /// Variable-length info-class-specific buffer. + #[br(count = output_buffer_length as usize)] + pub buffer: Vec, +} + +impl QueryDirectoryResponse { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn utf16le(s: &str) -> Vec { + s.encode_utf16().flat_map(u16::to_le_bytes).collect() + } + + #[test] + fn request_round_trips() { + let pat = utf16le("*"); + let r = QueryDirectoryRequest { + structure_size: 33, + file_information_class: FileInfoClass::FileIdBothDirectoryInformation as u8, + flags: QueryDirectoryRequest::FLAG_RESTART_SCANS, + file_index: 0, + file_id: FileId::new(1, 2), + file_name_offset: 0x60, + file_name_length: pat.len() as u16, + output_buffer_length: 0x10000, + file_name: pat, + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(QueryDirectoryRequest::parse(&buf).unwrap(), r); + } + + #[test] + fn response_round_trips() { + let r = QueryDirectoryResponse { + structure_size: 9, + output_buffer_offset: 0x48, + output_buffer_length: 8, + buffer: vec![1, 2, 3, 4, 5, 6, 7, 8], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(QueryDirectoryResponse::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/messages/query_info.rs b/vendor/smb-server/src/proto/messages/query_info.rs new file mode 100644 index 0000000..e90f188 --- /dev/null +++ b/vendor/smb-server/src/proto/messages/query_info.rs @@ -0,0 +1,140 @@ +//! QUERY_INFO Request/Response (MS-SMB2 §2.2.37 / §2.2.38). + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use super::create::FileId; +use crate::proto::error::ProtoResult; + +/// `InfoType` values (MS-SMB2 §2.2.37 InfoType field). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum InfoType { + File = 0x01, + FileSystem = 0x02, + Security = 0x03, + Quota = 0x04, +} + +impl InfoType { + pub fn from_u8(v: u8) -> Option { + Some(match v { + 0x01 => Self::File, + 0x02 => Self::FileSystem, + 0x03 => Self::Security, + 0x04 => Self::Quota, + _ => return None, + }) + } +} + +/// SMB2_QUERY_INFO_REQUEST (MS-SMB2 §2.2.37). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QueryInfoRequest { + pub structure_size: u16, + pub info_type: u8, + pub file_information_class: u8, + pub output_buffer_length: u32, + pub input_buffer_offset: u16, + pub reserved: u16, + pub input_buffer_length: u32, + /// `AdditionalInformation`: which fields of the security descriptor to + /// return when `info_type == Security`. Otherwise an additional info-class + /// selector for FS info. + pub additional_information: u32, + pub flags: u32, + pub file_id: FileId, + /// Optional input buffer (used by FILE/FS info classes that need it, e.g. + /// `FileFullEaInformation` extended-attribute name lists). + #[br(count = input_buffer_length as usize)] + pub input_buffer: Vec, +} + +impl QueryInfoRequest { + /// Flag: SL_RESTART_SCAN. + pub const FLAG_RESTART_SCAN: u32 = 0x0000_0001; + /// Flag: SL_RETURN_SINGLE_ENTRY. + pub const FLAG_RETURN_SINGLE_ENTRY: u32 = 0x0000_0002; + /// Flag: SL_INDEX_SPECIFIED. + pub const FLAG_INDEX_SPECIFIED: u32 = 0x0000_0004; + + pub fn info_type_enum(&self) -> Option { + InfoType::from_u8(self.info_type) + } + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +/// SMB2_QUERY_INFO_RESPONSE (MS-SMB2 §2.2.38). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QueryInfoResponse { + pub structure_size: u16, + pub output_buffer_offset: u16, + pub output_buffer_length: u32, + #[br(count = output_buffer_length as usize)] + pub buffer: Vec, +} + +impl QueryInfoResponse { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn request_round_trips() { + let r = QueryInfoRequest { + structure_size: 41, + info_type: InfoType::File as u8, + file_information_class: 0x05, // FileStandardInformation + output_buffer_length: 0x1000, + input_buffer_offset: 0, + reserved: 0, + input_buffer_length: 0, + additional_information: 0, + flags: 0, + file_id: FileId::new(1, 2), + input_buffer: vec![], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + let decoded = QueryInfoRequest::parse(&buf).unwrap(); + assert_eq!(decoded, r); + assert_eq!(decoded.info_type_enum(), Some(InfoType::File)); + } + + #[test] + fn response_round_trips() { + let r = QueryInfoResponse { + structure_size: 9, + output_buffer_offset: 0x48, + output_buffer_length: 4, + buffer: vec![0xAB, 0xCD, 0xEF, 0x01], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(QueryInfoResponse::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/messages/read.rs b/vendor/smb-server/src/proto/messages/read.rs new file mode 100644 index 0000000..3e0e8db --- /dev/null +++ b/vendor/smb-server/src/proto/messages/read.rs @@ -0,0 +1,141 @@ +//! READ Request/Response (MS-SMB2 §2.2.19 / §2.2.20). +//! +//! ## Data buffer offsets +//! +//! Both the READ request `ReadChannelInfoOffset` and the READ response +//! `DataOffset` are measured from the **start of the SMB2 header**, not from +//! the start of this structure (MS-SMB2 §2.2.20 explicitly: "DataOffset (1 +//! byte): The offset, in bytes, from the beginning of the SMB2 header to the +//! data being read"). When constructing a response, the server crate must +//! compute `DataOffset = SMB2_HEADER_LEN + offset_within_body_of_data`. + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use super::create::FileId; +use crate::proto::error::ProtoResult; + +/// SMB2_READ_REQUEST (MS-SMB2 §2.2.19). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ReadRequest { + pub structure_size: u16, + pub padding: u8, + /// 3.0+ flags (`SMB2_READFLAG_*`); reserved on 2.x. + pub flags: u8, + pub length: u32, + pub offset: u64, + pub file_id: FileId, + pub minimum_count: u32, + pub channel: u32, + pub remaining_bytes: u32, + pub read_channel_info_offset: u16, + pub read_channel_info_length: u16, + /// MS-SMB2: "If ReadChannelInfoOffset and ReadChannelInfoLength are both + /// 0, the client MUST set this field to a single 0 byte." We follow that + /// — at least one byte of buffer is required on the wire. + #[br(count = if read_channel_info_length == 0 { 1 } else { read_channel_info_length as usize })] + pub buffer: Vec, +} + +impl ReadRequest { + /// Flag: SMB2_READFLAG_READ_UNBUFFERED (3.0.2+). + pub const FLAG_READ_UNBUFFERED: u8 = 0x01; + /// Flag: SMB2_READFLAG_REQUEST_COMPRESSED (3.1.1+). + pub const FLAG_REQUEST_COMPRESSED: u8 = 0x02; + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +/// SMB2_READ_RESPONSE (MS-SMB2 §2.2.20). +/// +/// `data_offset` is from the start of the SMB2 header. Use +/// [`ReadResponse::standard_data_offset`] for the canonical "data immediately +/// after the fixed prefix" layout. +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ReadResponse { + pub structure_size: u16, + pub data_offset: u8, + pub reserved: u8, + pub data_length: u32, + pub data_remaining: u32, + /// 3.x: `Flags`. 2.x: reserved. + pub flags: u32, + #[br(count = data_length as usize)] + pub data: Vec, +} + +impl ReadResponse { + /// Canonical `DataOffset` value when the data buffer immediately follows + /// the fixed 16-byte response prefix and the SMB2 header (64 + 16 = 80). + /// + /// Most servers (ksmbd, Samba) emit 0x50 = 80 here. + pub const STANDARD_DATA_OFFSET: u8 = 0x50; + + pub const fn standard_data_offset() -> u8 { + Self::STANDARD_DATA_OFFSET + } + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn request_round_trips() { + let r = ReadRequest { + structure_size: 49, + padding: 0x50, + flags: 0, + length: 0x1000, + offset: 0x2000, + file_id: FileId::new(0xAAAA, 0xBBBB), + minimum_count: 1, + channel: 0, + remaining_bytes: 0, + read_channel_info_offset: 0, + read_channel_info_length: 0, + buffer: vec![0], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(ReadRequest::parse(&buf).unwrap(), r); + } + + #[test] + fn response_round_trips() { + let r = ReadResponse { + structure_size: 17, + data_offset: ReadResponse::STANDARD_DATA_OFFSET, + reserved: 0, + data_length: 5, + data_remaining: 0, + flags: 0, + data: vec![1, 2, 3, 4, 5], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(ReadResponse::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/messages/session_setup.rs b/vendor/smb-server/src/proto/messages/session_setup.rs new file mode 100644 index 0000000..928cc53 --- /dev/null +++ b/vendor/smb-server/src/proto/messages/session_setup.rs @@ -0,0 +1,113 @@ +//! SESSION_SETUP Request/Response (MS-SMB2 §2.2.5 / §2.2.6). + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use crate::proto::error::ProtoResult; + +/// SMB2_SESSION_SETUP_REQUEST (MS-SMB2 §2.2.5). +/// +/// `security_buffer` is opaque GSS-API/SPNEGO data — the auth agent decodes it. +/// The wire offset is from the start of the SMB2 header; we encode/decode it +/// as length-counted data immediately following the fixed prefix, which is +/// the canonical layout. Server crate may patch the offset if it needs an +/// unusual layout. +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionSetupRequest { + pub structure_size: u16, + pub flags: u8, + pub security_mode: u8, + pub capabilities: u32, + pub channel: u32, + pub security_buffer_offset: u16, + pub security_buffer_length: u16, + pub previous_session_id: u64, + #[br(count = security_buffer_length as usize)] + pub security_buffer: Vec, +} + +impl SessionSetupRequest { + /// Flag: SMB2_SESSION_FLAG_BINDING — bind to existing session (3.x). + pub const FLAG_BINDING: u8 = 0x01; + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +/// SMB2_SESSION_SETUP_RESPONSE (MS-SMB2 §2.2.6). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionSetupResponse { + pub structure_size: u16, + pub session_flags: u16, + pub security_buffer_offset: u16, + pub security_buffer_length: u16, + #[br(count = security_buffer_length as usize)] + pub security_buffer: Vec, +} + +impl SessionSetupResponse { + /// Session flag: IS_GUEST. + pub const FLAG_IS_GUEST: u16 = 0x0001; + /// Session flag: IS_NULL (anonymous). + pub const FLAG_IS_NULL: u16 = 0x0002; + /// Session flag: ENCRYPT_DATA. + pub const FLAG_ENCRYPT_DATA: u16 = 0x0004; + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn request_round_trips() { + let r = SessionSetupRequest { + structure_size: 25, + flags: 0, + security_mode: 0x01, + capabilities: 0x01, + channel: 0, + security_buffer_offset: 0x58, + security_buffer_length: 6, + previous_session_id: 0, + security_buffer: vec![0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x02], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(SessionSetupRequest::parse(&buf).unwrap(), r); + } + + #[test] + fn response_round_trips() { + let r = SessionSetupResponse { + structure_size: 9, + session_flags: SessionSetupResponse::FLAG_IS_GUEST, + security_buffer_offset: 0x48, + security_buffer_length: 4, + security_buffer: vec![1, 2, 3, 4], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(SessionSetupResponse::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/messages/set_info.rs b/vendor/smb-server/src/proto/messages/set_info.rs new file mode 100644 index 0000000..0c79d27 --- /dev/null +++ b/vendor/smb-server/src/proto/messages/set_info.rs @@ -0,0 +1,94 @@ +//! SET_INFO Request/Response (MS-SMB2 §2.2.39 / §2.2.40). + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use super::create::FileId; +use crate::proto::error::ProtoResult; + +/// SMB2_SET_INFO_REQUEST (MS-SMB2 §2.2.39). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SetInfoRequest { + pub structure_size: u16, + pub info_type: u8, + pub file_information_class: u8, + pub buffer_length: u32, + pub buffer_offset: u16, + pub reserved: u16, + pub additional_information: u32, + pub file_id: FileId, + #[br(count = buffer_length as usize)] + pub buffer: Vec, +} + +impl SetInfoRequest { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +/// SMB2_SET_INFO_RESPONSE (MS-SMB2 §2.2.40). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SetInfoResponse { + pub structure_size: u16, +} + +impl Default for SetInfoResponse { + fn default() -> Self { + Self { structure_size: 2 } + } +} + +impl SetInfoResponse { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn request_round_trips() { + let r = SetInfoRequest { + structure_size: 33, + info_type: 0x01, // File + file_information_class: 0x14, // FileEndOfFileInformation + buffer_length: 8, + buffer_offset: 0x60, + reserved: 0, + additional_information: 0, + file_id: FileId::new(1, 2), + buffer: vec![0, 0, 0, 0x10, 0, 0, 0, 0], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(SetInfoRequest::parse(&buf).unwrap(), r); + } + + #[test] + fn response_round_trips() { + let r = SetInfoResponse::default(); + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(SetInfoResponse::parse(&buf).unwrap(), r); + assert_eq!(buf.len(), 2); + } +} diff --git a/vendor/smb-server/src/proto/messages/tree_connect.rs b/vendor/smb-server/src/proto/messages/tree_connect.rs new file mode 100644 index 0000000..da132c4 --- /dev/null +++ b/vendor/smb-server/src/proto/messages/tree_connect.rs @@ -0,0 +1,131 @@ +//! TREE_CONNECT Request/Response (MS-SMB2 §2.2.9 / §2.2.10). + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use crate::proto::error::ProtoResult; + +/// SMB2_TREE_CONNECT_REQUEST (MS-SMB2 §2.2.9). +/// +/// `path` is UTF-16LE. The wire format gives `PathOffset` (from the start of +/// the SMB2 header) and `PathLength`; we encode/decode the path immediately +/// following the fixed prefix. The 3.1.1 tree-connect-context machinery +/// (extension `flags`, `path_offset`/`path_length` interpretation) is +/// preserved on the wire and the server crate inspects `flags` if needed. +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TreeConnectRequest { + pub structure_size: u16, + /// 3.1.1: flags. 2.x/3.0/3.0.2: reserved. + pub flags: u16, + pub path_offset: u16, + pub path_length: u16, + /// UTF-16LE share path bytes (e.g. `\\server\share`). + #[br(count = path_length as usize)] + pub path: Vec, +} + +impl TreeConnectRequest { + /// Flag: SMB2_TREE_CONNECT_FLAG_CLUSTER_RECONNECT (3.1.1). + pub const FLAG_CLUSTER_RECONNECT: u16 = 0x0001; + /// Flag: SMB2_TREE_CONNECT_FLAG_REDIRECT_TO_OWNER (3.1.1). + pub const FLAG_REDIRECT_TO_OWNER: u16 = 0x0002; + /// Flag: SMB2_TREE_CONNECT_FLAG_EXTENSION_PRESENT (3.1.1). + pub const FLAG_EXTENSION_PRESENT: u16 = 0x0004; + + /// Decode the UTF-16LE share path into a `String`. Returns `None` if the + /// stored bytes are not an even length (malformed UTF-16LE). + pub fn path_str(&self) -> Option { + if !self.path.len().is_multiple_of(2) { + return None; + } + let units: Vec = self + .path + .chunks_exact(2) + .map(|c| u16::from_le_bytes([c[0], c[1]])) + .collect(); + Some(String::from_utf16_lossy(&units)) + } + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +/// SMB2_TREE_CONNECT_RESPONSE (MS-SMB2 §2.2.10). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TreeConnectResponse { + pub structure_size: u16, + pub share_type: u8, + pub reserved: u8, + pub share_flags: u32, + pub capabilities: u32, + pub maximal_access: u32, +} + +impl TreeConnectResponse { + /// Share type: SMB2_SHARE_TYPE_DISK. + pub const SHARE_TYPE_DISK: u8 = 0x01; + pub const SHARE_TYPE_PIPE: u8 = 0x02; + pub const SHARE_TYPE_PRINT: u8 = 0x03; + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn utf16le(s: &str) -> Vec { + s.encode_utf16().flat_map(u16::to_le_bytes).collect() + } + + #[test] + fn request_round_trips() { + let path = utf16le(r"\\server\share"); + let r = TreeConnectRequest { + structure_size: 9, + flags: 0, + path_offset: 0x48, + path_length: path.len() as u16, + path, + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + let decoded = TreeConnectRequest::parse(&buf).unwrap(); + assert_eq!(decoded, r); + assert_eq!(decoded.path_str().unwrap(), r"\\server\share"); + } + + #[test] + fn response_round_trips() { + let r = TreeConnectResponse { + structure_size: 16, + share_type: TreeConnectResponse::SHARE_TYPE_DISK, + reserved: 0, + share_flags: 0, + capabilities: 0, + maximal_access: 0x001F_01FF, + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(TreeConnectResponse::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/messages/tree_disconnect.rs b/vendor/smb-server/src/proto/messages/tree_disconnect.rs new file mode 100644 index 0000000..00a6ca3 --- /dev/null +++ b/vendor/smb-server/src/proto/messages/tree_disconnect.rs @@ -0,0 +1,77 @@ +//! TREE_DISCONNECT Request/Response (MS-SMB2 §2.2.11 / §2.2.12). + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use crate::proto::error::ProtoResult; + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TreeDisconnectRequest { + pub structure_size: u16, + pub reserved: u16, +} + +impl Default for TreeDisconnectRequest { + fn default() -> Self { + Self { + structure_size: 4, + reserved: 0, + } + } +} + +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TreeDisconnectResponse { + pub structure_size: u16, + pub reserved: u16, +} + +impl Default for TreeDisconnectResponse { + fn default() -> Self { + Self { + structure_size: 4, + reserved: 0, + } + } +} + +macro_rules! impl_codec { + ($t:ty) => { + impl $t { + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } + } + }; +} + +impl_codec!(TreeDisconnectRequest); +impl_codec!(TreeDisconnectResponse); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn round_trips() { + let r = TreeDisconnectRequest::default(); + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(TreeDisconnectRequest::parse(&buf).unwrap(), r); + + let r = TreeDisconnectResponse::default(); + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(TreeDisconnectResponse::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/messages/write.rs b/vendor/smb-server/src/proto/messages/write.rs new file mode 100644 index 0000000..8a5ab8a --- /dev/null +++ b/vendor/smb-server/src/proto/messages/write.rs @@ -0,0 +1,123 @@ +//! WRITE Request/Response (MS-SMB2 §2.2.21 / §2.2.22). +//! +//! ## Data buffer offsets +//! +//! `DataOffset` is from the **start of the SMB2 header**, not from the start +//! of this structure (MS-SMB2 §2.2.21). The canonical layout puts the data +//! immediately after the fixed 48-byte prefix, giving 64 + 48 = 112 = 0x70. + +use binrw::{BinRead, BinWrite, binrw}; +use std::io::Cursor; + +use super::create::FileId; +use crate::proto::error::ProtoResult; + +/// SMB2_WRITE_REQUEST (MS-SMB2 §2.2.21). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct WriteRequest { + pub structure_size: u16, + pub data_offset: u16, + pub length: u32, + pub offset: u64, + pub file_id: FileId, + pub channel: u32, + pub remaining_bytes: u32, + pub write_channel_info_offset: u16, + pub write_channel_info_length: u16, + pub flags: u32, + /// MS-SMB2: at least 1 byte of payload buffer is required on the wire + /// even when length=0. + #[br(count = if length == 0 { 1 } else { length as usize })] + pub data: Vec, +} + +impl WriteRequest { + /// Canonical `DataOffset` placing the data buffer immediately after the + /// fixed 48-byte WRITE prefix: 64 (SMB2 header) + 48 = 112 = 0x70. + pub const STANDARD_DATA_OFFSET: u16 = 0x70; + /// Flag: SMB2_WRITEFLAG_WRITE_THROUGH. + pub const FLAG_WRITE_THROUGH: u32 = 0x0000_0001; + /// Flag: SMB2_WRITEFLAG_WRITE_UNBUFFERED (3.0.2+). + pub const FLAG_WRITE_UNBUFFERED: u32 = 0x0000_0002; + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +/// SMB2_WRITE_RESPONSE (MS-SMB2 §2.2.22). +#[binrw] +#[brw(little)] +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct WriteResponse { + pub structure_size: u16, + pub reserved: u16, + pub count: u32, + pub remaining: u32, + pub write_channel_info_offset: u16, + pub write_channel_info_length: u16, +} + +impl WriteResponse { + pub fn new(count: u32) -> Self { + Self { + structure_size: 17, + reserved: 0, + count, + remaining: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + } + } + + pub fn parse(buf: &[u8]) -> ProtoResult { + Ok(Self::read(&mut Cursor::new(buf))?) + } + pub fn write_to(&self, out: &mut Vec) -> ProtoResult<()> { + let mut c = Cursor::new(Vec::new()); + BinWrite::write(self, &mut c)?; + out.extend_from_slice(&c.into_inner()); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn request_round_trips() { + let r = WriteRequest { + structure_size: 49, + data_offset: WriteRequest::STANDARD_DATA_OFFSET, + length: 4, + offset: 0x100, + file_id: FileId::new(0xAA, 0xBB), + channel: 0, + remaining_bytes: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + flags: 0, + data: vec![1, 2, 3, 4], + }; + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(WriteRequest::parse(&buf).unwrap(), r); + } + + #[test] + fn response_round_trips() { + let r = WriteResponse::new(0x1000); + let mut buf = Vec::new(); + r.write_to(&mut buf).unwrap(); + assert_eq!(WriteResponse::parse(&buf).unwrap(), r); + } +} diff --git a/vendor/smb-server/src/proto/mod.rs b/vendor/smb-server/src/proto/mod.rs new file mode 100644 index 0000000..ac75f2f --- /dev/null +++ b/vendor/smb-server/src/proto/mod.rs @@ -0,0 +1,16 @@ +//! SMB2/3 wire-format types, framing, signing, and authentication primitives. +//! +//! Layered into: +//! * [`framing`] — Direct-TCP/NetBIOS transport framing. +//! * [`header`] — SMB2 64-byte fixed header. +//! * [`messages`] — Per-command request/response structs. +//! * [`auth`] — NTLMv2 server-side authentication and minimal SPNEGO. +//! * [`crypto`] — Signing, key derivation, pre-auth integrity. +//! * [`error`] — Crate-wide error type. + +pub mod auth; +pub mod crypto; +pub mod error; +pub mod framing; +pub mod header; +pub mod messages; diff --git a/vendor/smb-server/src/server.rs b/vendor/smb-server/src/server.rs new file mode 100644 index 0000000..31f0fb5 --- /dev/null +++ b/vendor/smb-server/src/server.rs @@ -0,0 +1,566 @@ +//! Top-level `SmbServer` lifecycle: builder integration, accept loop, +//! graceful shutdown. + +use std::collections::HashMap; +use std::io; +use std::net::SocketAddr; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Weak}; + +use crate::proto::auth::ntlm::UserCreds; +use thiserror::Error; +use tokio::net::TcpListener; +use tokio::sync::{Notify, RwLock}; +use tracing::{Instrument, error, info, info_span}; +use uuid::Uuid; + +use crate::backend::ShareBackend; +use crate::builder::{Access, Share, SmbServerBuilder}; +use crate::conn::connection_loop; +use crate::conn::state::Connection; +use crate::utils::now_filetime; + +// --------------------------------------------------------------------------- +// ShareMode / ShareBindings +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ShareMode { + Public, + PublicReadOnly, + /// Default — closed share. Only users in the explicit `users` map allowed. + AuthenticatedOnly, +} + +#[derive(Clone)] +pub struct ShareAcl { + pub mode: ShareMode, + pub users: HashMap, +} + +/// Compiled binding for a single share — the per-server-state form of `Share`. +pub struct ShareBindings { + pub name: String, + pub backend: Arc, + pub acl: RwLock, + /// `IPC$` synthetic share. Accepted at TREE_CONNECT for client compatibility + /// (Windows always probes IPC$ before mounting an actual share). All + /// downstream ops on an IPC$ tree return `STATUS_NOT_SUPPORTED`. + pub is_ipc: bool, +} + +impl ShareBindings { + pub(crate) fn new( + name: String, + backend: Arc, + mode: ShareMode, + users: HashMap, + is_ipc: bool, + ) -> Arc { + Arc::new(Self { + name, + backend, + acl: RwLock::new(ShareAcl { mode, users }), + is_ipc, + }) + } + + /// Synthetic IPC$ share. The backend is a no-op; clients that try to + /// CREATE on it get `STATUS_NOT_SUPPORTED` from the CREATE handler. + pub fn ipc() -> Arc { + Self::new( + "IPC$".to_string(), + Arc::new(crate::backend::NotSupportedBackend), + ShareMode::PublicReadOnly, + HashMap::new(), + true, + ) + } +} + +// --------------------------------------------------------------------------- +// ServerConfig / ServerUsers / ServerState +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone)] +pub struct ServerConfig { + pub listen_addr: SocketAddr, + pub netbios_name: String, + pub max_read_size: u32, + pub max_write_size: u32, + pub server_guid: Uuid, +} + +pub struct ServerUsers { + /// Username → precomputed NT hash record. + pub table: RwLock>, +} + +pub struct ServerShares { + by_name: RwLock>>, +} + +impl ServerShares { + pub fn new(shares: Vec>) -> Self { + let mut by_name = HashMap::with_capacity(shares.len()); + for share in shares { + by_name.insert(share.name.to_ascii_lowercase(), share); + } + Self { + by_name: RwLock::new(by_name), + } + } + + pub async fn find(&self, name: &str) -> Option> { + self.by_name + .read() + .await + .get(&name.to_ascii_lowercase()) + .cloned() + } + + pub async fn insert(&self, share: Arc) -> Result<(), ConfigError> { + let key = share.name.to_ascii_lowercase(); + let mut by_name = self.by_name.write().await; + if by_name.contains_key(&key) { + return Err(ConfigError::DuplicateShare(share.name.clone())); + } + by_name.insert(key, share); + Ok(()) + } + + pub async fn remove(&self, name: &str) -> Option> { + self.by_name + .write() + .await + .remove(&name.to_ascii_lowercase()) + } + + pub async fn all(&self) -> Vec> { + self.by_name.read().await.values().cloned().collect() + } +} + +pub struct ActiveConnections { + next_id: AtomicU64, + conns: RwLock>>, +} + +impl ActiveConnections { + pub fn new() -> Self { + Self { + next_id: AtomicU64::new(1), + conns: RwLock::new(HashMap::new()), + } + } + + pub async fn register(&self, conn: &Arc) -> u64 { + let id = self.next_id.fetch_add(1, Ordering::Relaxed); + self.conns.write().await.insert(id, Arc::downgrade(conn)); + id + } + + pub async fn unregister(&self, id: u64) { + self.conns.write().await.remove(&id); + } + + pub async fn live(&self) -> Vec> { + let mut live = Vec::new(); + let mut conns = self.conns.write().await; + conns.retain(|_, weak| { + if let Some(conn) = weak.upgrade() { + live.push(conn); + true + } else { + false + } + }); + live + } +} + +impl Default for ActiveConnections { + fn default() -> Self { + Self::new() + } +} + +/// Top-level immutable-ish state shared across connections. +pub struct ServerState { + pub config: ServerConfig, + pub users: ServerUsers, + pub shares: ServerShares, + pub active_connections: ActiveConnections, + pub server_start_filetime: u64, + /// Set when `shutdown()` is invoked; the accept loop stops on the next + /// iteration and connection loops abandon their next read. + pub shutdown: Arc, + pub shutting_down: Arc, +} + +impl ServerState { + pub fn new(config: ServerConfig, users: ServerUsers, shares: Vec>) -> Self { + Self { + config, + users, + shares: ServerShares::new(shares), + active_connections: ActiveConnections::new(), + server_start_filetime: now_filetime(), + shutdown: Arc::new(Notify::new()), + shutting_down: Arc::new(AtomicBool::new(false)), + } + } + + /// Find a share by case-insensitive name. + pub async fn find_share(&self, name: &str) -> Option> { + self.shares.find(name).await + } + + /// Look up a user's NT hash by name. + pub async fn lookup_user(&self, name: &str) -> Option { + self.users.table.read().await.get(name).cloned() + } + + /// Whether anonymous logon is permitted (i.e. at least one share is public). + pub async fn anonymous_allowed(&self) -> bool { + for share in self.shares.all().await { + let acl = share.acl.read().await; + if matches!(acl.mode, ShareMode::Public | ShareMode::PublicReadOnly) { + return true; + } + } + false + } +} + +#[derive(Debug, Error, PartialEq, Eq)] +pub enum ConfigError { + #[error("user `{0}` does not exist")] + UnknownUser(String), + #[error("share `{0}` does not exist")] + UnknownShare(String), + #[error("duplicate share `{0}`")] + DuplicateShare(String), + #[error("share `{0}` mixes public mode with explicit users")] + PublicMixedWithUsers(String), + #[error("user name `{0}` is reserved")] + ReservedUserName(String), + #[error("user name must be non-empty")] + EmptyUserName, + #[error("share name `{0}` is reserved")] + ReservedShareName(String), +} + +#[derive(Clone)] +pub struct ConfigHandle { + state: Arc, +} + +impl ConfigHandle { + pub async fn add_user( + &self, + name: impl Into, + password: impl AsRef, + ) -> Result<(), ConfigError> { + let name = name.into(); + validate_user_name(&name)?; + let creds = UserCreds::from_password(password.as_ref()); + self.state.users.table.write().await.insert(name, creds); + Ok(()) + } + + pub async fn remove_user(&self, name: &str) -> Result<(), ConfigError> { + validate_user_name(name)?; + let removed = self.state.users.table.write().await.remove(name); + if removed.is_none() { + return Err(ConfigError::UnknownUser(name.to_string())); + } + + for share in self.state.shares.all().await { + share.acl.write().await.users.remove(name); + } + + for conn in self.state.active_connections.live().await { + conn.close_sessions_for_user(name).await; + } + Ok(()) + } + + pub async fn add_share(&self, share: Share) -> Result<(), ConfigError> { + validate_share_name(&share.name)?; + let is_public = matches!(share.mode, ShareMode::Public | ShareMode::PublicReadOnly); + if is_public && !share.users.is_empty() { + return Err(ConfigError::PublicMixedWithUsers(share.name)); + } + let users = self.state.users.table.read().await; + for user in share.users.keys() { + if !users.contains_key(user) { + return Err(ConfigError::UnknownUser(user.clone())); + } + } + + let binding = ShareBindings::new(share.name, share.backend, share.mode, share.users, false); + self.state.shares.insert(binding).await + } + + pub async fn remove_share(&self, name: &str) -> Result<(), ConfigError> { + validate_share_name(name)?; + let removed = self.state.shares.remove(name).await; + if removed.is_none() { + return Err(ConfigError::UnknownShare(name.to_string())); + } + + for conn in self.state.active_connections.live().await { + conn.close_trees_for_share(name).await; + } + Ok(()) + } + + pub async fn grant_share_user( + &self, + share_name: &str, + user: &str, + access: Access, + ) -> Result<(), ConfigError> { + validate_user_name(user)?; + validate_share_name(share_name)?; + let users = self.state.users.table.read().await; + if !users.contains_key(user) { + return Err(ConfigError::UnknownUser(user.to_string())); + } + let share = self + .state + .find_share(share_name) + .await + .ok_or_else(|| ConfigError::UnknownShare(share_name.to_string()))?; + let mut acl = share.acl.write().await; + if matches!(acl.mode, ShareMode::Public | ShareMode::PublicReadOnly) { + return Err(ConfigError::PublicMixedWithUsers(share.name.clone())); + } + acl.users.insert(user.to_string(), access); + Ok(()) + } + + pub async fn revoke_share_user(&self, share_name: &str, user: &str) -> Result<(), ConfigError> { + validate_user_name(user)?; + validate_share_name(share_name)?; + let share = self + .state + .find_share(share_name) + .await + .ok_or_else(|| ConfigError::UnknownShare(share_name.to_string()))?; + share.acl.write().await.users.remove(user); + + for conn in self.state.active_connections.live().await { + conn.close_trees_for_user_share(user, share_name).await; + } + Ok(()) + } + + pub async fn set_share_mode( + &self, + share_name: &str, + mode: ShareMode, + ) -> Result<(), ConfigError> { + validate_share_name(share_name)?; + let share = self + .state + .find_share(share_name) + .await + .ok_or_else(|| ConfigError::UnknownShare(share_name.to_string()))?; + let mut acl = share.acl.write().await; + if matches!(mode, ShareMode::Public | ShareMode::PublicReadOnly) && !acl.users.is_empty() { + return Err(ConfigError::PublicMixedWithUsers(share.name.clone())); + } + if acl.mode == mode { + return Ok(()); + } + acl.mode = mode; + drop(acl); + + for conn in self.state.active_connections.live().await { + conn.close_trees_for_share(share_name).await; + } + Ok(()) + } +} + +fn validate_user_name(name: &str) -> Result<(), ConfigError> { + if name.is_empty() { + return Err(ConfigError::EmptyUserName); + } + if name.eq_ignore_ascii_case("anonymous") { + return Err(ConfigError::ReservedUserName(name.to_string())); + } + Ok(()) +} + +fn validate_share_name(name: &str) -> Result<(), ConfigError> { + if name.eq_ignore_ascii_case("IPC$") { + return Err(ConfigError::ReservedShareName(name.to_string())); + } + Ok(()) +} + +// --------------------------------------------------------------------------- +// SmbServer +// --------------------------------------------------------------------------- + +/// A built but not-yet-running SMB server. +/// +/// Use `serve()` to bind the configured listener and run until shutdown. +pub struct SmbServer { + state: Arc, + /// The listener is bound lazily inside `serve()` so we can return a + /// useful `local_addr` only after binding. Pre-bind helpers: `serve` is + /// the only path that opens the socket. + bound: tokio::sync::Mutex>, + /// Resolved local address once `bind_local()` has been called. Tests + /// expect to ask for the address before serving (port 0 case). + local_addr: tokio::sync::Mutex>, +} + +impl SmbServer { + pub fn builder() -> SmbServerBuilder { + SmbServerBuilder::default() + } + + pub(crate) fn from_state(state: ServerState) -> Self { + Self { + state: Arc::new(state), + bound: tokio::sync::Mutex::new(None), + local_addr: tokio::sync::Mutex::new(None), + } + } + + pub fn config_handle(&self) -> ConfigHandle { + ConfigHandle { + state: self.state.clone(), + } + } + + /// Bind the configured listen address without yet entering the accept + /// loop. Required for tests that need the actual port (e.g. when the + /// builder used port 0). + pub async fn bind(&self) -> io::Result { + let mut bound = self.bound.lock().await; + if let Some(l) = bound.as_ref() { + return l.local_addr(); + } + let listener = TcpListener::bind(self.state.config.listen_addr).await?; + let addr = listener.local_addr()?; + *bound = Some(listener); + *self.local_addr.lock().await = Some(addr); + Ok(addr) + } + + /// Returns the actual bound address. `None` if `bind()`/`serve()` have + /// not yet been called. + pub async fn local_addr(&self) -> Option { + *self.local_addr.lock().await + } + + /// Configured listen address (the *intended* address; may be `0.0.0.0:0` + /// before binding). + pub fn configured_addr(&self) -> SocketAddr { + self.state.config.listen_addr + } + + /// Initiate a graceful shutdown. Stops the accept loop and lets in-flight + /// connection tasks complete. + pub fn shutdown(&self) { + self.state.shutting_down.store(true, Ordering::Release); + self.state.shutdown.notify_waiters(); + } + + /// Returns a clonable handle that can request shutdown after `serve()` + /// has consumed the `SmbServer` value. + pub fn shutdown_handle(&self) -> ShutdownHandle { + ShutdownHandle { + shutdown: self.state.shutdown.clone(), + shutting_down: self.state.shutting_down.clone(), + } + } + + /// Run the accept loop until `shutdown()` is called. + pub async fn serve(self) -> io::Result<()> { + // Ensure the listener is bound. (The user may also have called + // `bind()` to pre-extract `local_addr()` for a test.) + if self.bound.lock().await.is_none() { + self.bind().await?; + } + let listener = self + .bound + .lock() + .await + .take() + .expect("listener bound above"); + let local = listener.local_addr().ok(); + let span = info_span!("smb_server", listen = ?local); + async move { + info!("server starting"); + let state = self.state.clone(); + let shutdown = state.shutdown.clone(); + let shutting_down = state.shutting_down.clone(); + + loop { + tokio::select! { + biased; + _ = shutdown.notified() => { + info!("shutdown requested; stopping accept loop"); + break; + } + accept = listener.accept() => { + match accept { + Ok((stream, peer)) => { + if shutting_down.load(Ordering::Acquire) { + drop(stream); + break; + } + let server_state = state.clone(); + let span = info_span!("conn", peer = %peer); + tokio::spawn(async move { + if let Err(e) = connection_loop(stream, server_state).await { + error!(error = %e, "connection loop exited with error"); + } + }.instrument(span)); + } + Err(e) => { + error!(error = %e, "accept failed"); + if shutting_down.load(Ordering::Acquire) { + break; + } + } + } + } + } + } + info!("server stopped"); + Ok::<(), io::Error>(()) + } + .instrument(span) + .await + } + + /// Access shared state for in-crate tests/integrations. + #[doc(hidden)] + pub fn state(&self) -> Arc { + self.state.clone() + } +} + +/// Cheaply-clonable shutdown handle. Outlives `SmbServer::serve` consuming +/// the server. +#[derive(Clone)] +pub struct ShutdownHandle { + shutdown: Arc, + shutting_down: Arc, +} + +impl ShutdownHandle { + /// Request a graceful shutdown. + pub fn shutdown(&self) { + self.shutting_down.store(true, Ordering::Release); + self.shutdown.notify_waiters(); + } +} diff --git a/vendor/smb-server/src/tests/dynamic_config.rs b/vendor/smb-server/src/tests/dynamic_config.rs new file mode 100644 index 0000000..0ed452c --- /dev/null +++ b/vendor/smb-server/src/tests/dynamic_config.rs @@ -0,0 +1,173 @@ +use std::sync::Arc; + +use super::memfs::MemFsBackend; +use crate::conn::state::{Connection, Session, TreeConnect}; +use crate::server::ConfigError; +use crate::{Access, Identity, Share, ShareMode, SmbServer}; + +fn test_server() -> SmbServer { + SmbServer::builder() + .listen("127.0.0.1:0".parse().unwrap()) + .user("alice", "password") + .share( + Share::new("home", MemFsBackend::new().with_file("seed.txt", b"")) + .user("alice", Access::ReadWrite), + ) + .build() + .expect("build") +} + +fn public_server() -> SmbServer { + SmbServer::builder() + .listen("127.0.0.1:0".parse().unwrap()) + .share(Share::new("public", MemFsBackend::new()).public()) + .build() + .expect("build") +} + +async fn register_session( + server: &SmbServer, + identity: Identity, + share_name: &str, +) -> Arc { + let state = server.state(); + let conn = Arc::new(Connection::new( + state.config.server_guid, + state.config.max_read_size, + state.config.max_write_size, + )); + state.active_connections.register(&conn).await; + + let session = Session::new(1, identity, [0; 16], [0; 16], false, None); + let session = Arc::new(tokio::sync::RwLock::new(session)); + let share = state.find_share(share_name).await.expect("share"); + let tree = Arc::new(tokio::sync::RwLock::new(TreeConnect::new( + 1, + share, + Access::ReadWrite, + ))); + { + let sess = session.read().await; + sess.trees.write().await.insert(1, tree); + } + conn.sessions.write().await.insert(1, session); + conn +} + +async fn register_alice_session(server: &SmbServer) -> Arc { + register_session( + server, + Identity::User { + user: "alice".to_string(), + domain: String::new(), + }, + "home", + ) + .await +} + +#[tokio::test] +async fn config_handle_adds_users_and_shares() { + let server = SmbServer::builder() + .listen("127.0.0.1:0".parse().unwrap()) + .build() + .expect("build"); + let config = server.config_handle(); + + config.add_user("bob", "password").await.expect("add user"); + config + .add_share(Share::new("media", MemFsBackend::new()).user("bob", Access::Read)) + .await + .expect("add share"); + + let state = server.state(); + assert!(state.lookup_user("bob").await.is_some()); + assert!(state.find_share("media").await.is_some()); +} + +#[tokio::test] +async fn removing_user_revokes_active_sessions() { + let server = test_server(); + let conn = register_alice_session(&server).await; + + server + .config_handle() + .remove_user("alice") + .await + .expect("remove user"); + + assert!(server.state().lookup_user("alice").await.is_none()); + assert!(conn.sessions.read().await.is_empty()); +} + +#[tokio::test] +async fn removing_share_revokes_active_trees() { + let server = test_server(); + let conn = register_alice_session(&server).await; + + server + .config_handle() + .remove_share("home") + .await + .expect("remove share"); + + assert!(server.state().find_share("home").await.is_none()); + let sessions = conn.sessions.read().await; + let session = sessions.get(&1).expect("session remains").read().await; + assert!(session.trees.read().await.is_empty()); +} + +#[tokio::test] +async fn revoking_user_from_share_revokes_only_that_tree() { + let server = test_server(); + let conn = register_alice_session(&server).await; + + server + .config_handle() + .revoke_share_user("home", "alice") + .await + .expect("revoke user share"); + + assert!(conn.sessions.read().await.contains_key(&1)); + let sessions = conn.sessions.read().await; + let session = sessions.get(&1).expect("session remains").read().await; + assert!(session.trees.read().await.is_empty()); +} + +#[tokio::test] +async fn changing_share_mode_revokes_active_trees() { + let server = public_server(); + let conn = register_session(&server, Identity::Anonymous, "public").await; + + server + .config_handle() + .set_share_mode("public", ShareMode::PublicReadOnly) + .await + .expect("set mode"); + + let sessions = conn.sessions.read().await; + let session = sessions.get(&1).expect("session remains").read().await; + assert!(session.trees.read().await.is_empty()); +} + +#[tokio::test] +async fn public_share_cannot_mix_explicit_users() { + let server = SmbServer::builder() + .listen("127.0.0.1:0".parse().unwrap()) + .share(Share::new("public", MemFsBackend::new()).public()) + .build() + .expect("build"); + + let config = server.config_handle(); + config + .add_user("alice", "password") + .await + .expect("add user"); + + let err = config + .grant_share_user("public", "alice", Access::Read) + .await + .expect_err("grant should fail"); + + assert_eq!(err, ConfigError::PublicMixedWithUsers("public".to_string())); +} diff --git a/vendor/smb-server/src/tests/memfs.rs b/vendor/smb-server/src/tests/memfs.rs new file mode 100644 index 0000000..bc65f81 --- /dev/null +++ b/vendor/smb-server/src/tests/memfs.rs @@ -0,0 +1,300 @@ +use std::collections::HashMap; +use std::sync::Mutex; + +use crate::backend::{ + BackendCapabilities, DirEntry, FileInfo, FileTimes, Handle, OpenIntent, OpenOptions, + ShareBackend, +}; +use crate::error::{SmbError, SmbResult}; +use crate::path::SmbPath; +use async_trait::async_trait; +use bytes::Bytes; + +/// Minimal in-memory FS used by integration tests. Files are byte vectors, +/// directories are sets of names. Not threadsafe across workers — only used +/// within one test. +pub struct MemFsBackend { + inner: std::sync::Arc>, +} + +#[derive(Default)] +struct MemInner { + files: HashMap>, + /// All directories present (always includes "" for the root). Each + /// directory is keyed by canonical path string. + dirs: HashMap, +} + +impl Default for MemFsBackend { + fn default() -> Self { + Self::new() + } +} + +impl MemFsBackend { + pub fn new() -> Self { + let mut inner = MemInner::default(); + inner.dirs.insert(String::new(), ()); + Self { + inner: std::sync::Arc::new(Mutex::new(inner)), + } + } + + pub fn with_file(self, path: &str, contents: &[u8]) -> Self { + { + let mut g = self.inner.lock().unwrap(); + g.files.insert(path.to_string(), contents.to_vec()); + } + self + } +} + +fn key(path: &SmbPath) -> String { + path.display_backslash() +} + +#[async_trait] +impl ShareBackend for MemFsBackend { + async fn open(&self, path: &SmbPath, opts: OpenOptions) -> SmbResult> { + let k = key(path); + let mut g = self.inner.lock().unwrap(); + let exists_file = g.files.contains_key(&k); + let exists_dir = g.dirs.contains_key(&k); + + if opts.directory { + if exists_file { + return Err(SmbError::NotADirectory); + } + if !exists_dir { + if matches!(opts.intent, OpenIntent::Create | OpenIntent::OpenOrCreate) { + g.dirs.insert(k.clone(), ()); + } else { + return Err(SmbError::NotFound); + } + } + return Ok(Box::new(MemHandle::dir(self.inner.clone(), k))); + } + + if exists_dir { + return Err(SmbError::IsDirectory); + } + match opts.intent { + OpenIntent::Open => { + if !exists_file { + return Err(SmbError::NotFound); + } + } + OpenIntent::Create => { + if exists_file { + return Err(SmbError::Exists); + } + g.files.insert(k.clone(), Vec::new()); + } + OpenIntent::OpenOrCreate => { + g.files.entry(k.clone()).or_default(); + } + OpenIntent::Truncate => { + if !exists_file { + return Err(SmbError::NotFound); + } + g.files.insert(k.clone(), Vec::new()); + } + OpenIntent::OverwriteOrCreate => { + g.files.insert(k.clone(), Vec::new()); + } + } + Ok(Box::new(MemHandle::file(self.inner.clone(), k))) + } + + async fn unlink(&self, path: &SmbPath) -> SmbResult<()> { + let k = key(path); + let mut g = self.inner.lock().unwrap(); + if g.files.remove(&k).is_some() { + return Ok(()); + } + if g.dirs.remove(&k).is_some() { + return Ok(()); + } + Err(SmbError::NotFound) + } + + async fn rename(&self, from: &SmbPath, to: &SmbPath) -> SmbResult<()> { + let kf = key(from); + let kt = key(to); + let mut g = self.inner.lock().unwrap(); + if g.files.contains_key(&kt) || g.dirs.contains_key(&kt) { + return Err(SmbError::Exists); + } + if let Some(data) = g.files.remove(&kf) { + g.files.insert(kt, data); + return Ok(()); + } + if g.dirs.remove(&kf).is_some() { + g.dirs.insert(kt, ()); + return Ok(()); + } + Err(SmbError::NotFound) + } + + fn capabilities(&self) -> BackendCapabilities { + BackendCapabilities { + is_read_only: false, + case_sensitive: false, + } + } +} + +pub struct MemHandle { + inner: std::sync::Arc>, + key: String, + is_dir: bool, +} + +impl MemHandle { + fn file(inner: std::sync::Arc>, key: String) -> Self { + Self { + inner, + key, + is_dir: false, + } + } + + fn dir(inner: std::sync::Arc>, key: String) -> Self { + Self { + inner, + key, + is_dir: true, + } + } +} + +#[async_trait] +impl Handle for MemHandle { + async fn read(&self, offset: u64, len: u32) -> SmbResult { + if self.is_dir { + return Err(SmbError::IsDirectory); + } + let g = self.inner.lock().unwrap(); + let data = g.files.get(&self.key).ok_or(SmbError::NotFound)?; + let start = offset as usize; + if start >= data.len() { + return Ok(Bytes::new()); + } + let end = (start + len as usize).min(data.len()); + Ok(Bytes::copy_from_slice(&data[start..end])) + } + + async fn write(&self, offset: u64, data: &[u8]) -> SmbResult { + if self.is_dir { + return Err(SmbError::IsDirectory); + } + let mut g = self.inner.lock().unwrap(); + let buf = g.files.get_mut(&self.key).ok_or(SmbError::NotFound)?; + let needed = (offset as usize) + data.len(); + if buf.len() < needed { + buf.resize(needed, 0); + } + buf[offset as usize..offset as usize + data.len()].copy_from_slice(data); + Ok(data.len() as u32) + } + + async fn flush(&self) -> SmbResult<()> { + Ok(()) + } + + async fn stat(&self) -> SmbResult { + let g = self.inner.lock().unwrap(); + let size = if self.is_dir { + 0 + } else { + g.files.get(&self.key).ok_or(SmbError::NotFound)?.len() as u64 + }; + let name = self + .key + .rsplit_once('\\') + .map(|(_, n)| n.to_string()) + .unwrap_or_else(|| self.key.clone()); + Ok(FileInfo { + name, + end_of_file: size, + allocation_size: size, + creation_time: 0x01D9_0000_0000_0000, + last_access_time: 0x01D9_0000_0000_0000, + last_write_time: 0x01D9_0000_0000_0000, + change_time: 0x01D9_0000_0000_0000, + is_directory: self.is_dir, + file_index: 0, + }) + } + + async fn set_times(&self, _times: FileTimes) -> SmbResult<()> { + Ok(()) + } + + async fn truncate(&self, len: u64) -> SmbResult<()> { + if self.is_dir { + return Err(SmbError::IsDirectory); + } + let mut g = self.inner.lock().unwrap(); + let buf = g.files.get_mut(&self.key).ok_or(SmbError::NotFound)?; + buf.resize(len as usize, 0); + Ok(()) + } + + async fn list_dir(&self, _pattern: Option<&str>) -> SmbResult> { + if !self.is_dir { + return Err(SmbError::NotADirectory); + } + let g = self.inner.lock().unwrap(); + let prefix = if self.key.is_empty() { + String::new() + } else { + format!("{}\\", self.key) + }; + let mut entries = Vec::new(); + for (k, v) in g.files.iter() { + if let Some(rest) = k.strip_prefix(&prefix) + && !rest.contains('\\') + { + entries.push(DirEntry { + info: FileInfo { + name: rest.to_string(), + end_of_file: v.len() as u64, + allocation_size: v.len() as u64, + creation_time: 0x01D9_0000_0000_0000, + last_access_time: 0x01D9_0000_0000_0000, + last_write_time: 0x01D9_0000_0000_0000, + change_time: 0x01D9_0000_0000_0000, + is_directory: false, + file_index: 0, + }, + }); + } + } + for k in g.dirs.keys() { + if let Some(rest) = k.strip_prefix(&prefix) + && !rest.is_empty() + && !rest.contains('\\') + { + entries.push(DirEntry { + info: FileInfo { + name: rest.to_string(), + end_of_file: 0, + allocation_size: 0, + creation_time: 0x01D9_0000_0000_0000, + last_access_time: 0x01D9_0000_0000_0000, + last_write_time: 0x01D9_0000_0000_0000, + change_time: 0x01D9_0000_0000_0000, + is_directory: true, + file_index: 0, + }, + }); + } + } + Ok(entries) + } + + async fn close(self: Box) -> SmbResult<()> { + Ok(()) + } +} diff --git a/vendor/smb-server/src/utils.rs b/vendor/smb-server/src/utils.rs new file mode 100644 index 0000000..b52f8b1 --- /dev/null +++ b/vendor/smb-server/src/utils.rs @@ -0,0 +1,69 @@ +//! Small helpers shared across modules. + +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Number of 100-nanosecond intervals between 1601-01-01 (Windows FILETIME +/// epoch) and 1970-01-01 (UNIX epoch). 369 years. +const FILETIME_OFFSET: u64 = 116_444_736_000_000_000; + +/// Convert a `SystemTime` to a Windows FILETIME (100ns ticks since 1601). +pub fn system_time_to_filetime(t: SystemTime) -> u64 { + match t.duration_since(UNIX_EPOCH) { + Ok(d) => FILETIME_OFFSET + (d.as_secs() * 10_000_000) + (d.subsec_nanos() as u64 / 100), + // Pre-1970 — clamp to the FILETIME epoch. + Err(_) => 0, + } +} + +/// Convert "now" to FILETIME. +pub fn now_filetime() -> u64 { + system_time_to_filetime(SystemTime::now()) +} + +/// Encode a `&str` to little-endian UTF-16 bytes. +pub fn utf16le(s: &str) -> Vec { + let mut out = Vec::with_capacity(s.len() * 2); + for unit in s.encode_utf16() { + out.extend_from_slice(&unit.to_le_bytes()); + } + out +} + +/// Decode a UTF-16LE byte slice. Returns an empty string if the buffer is not +/// 2-byte aligned (caller decides what to do); replacement characters on +/// invalid surrogates. +pub fn utf16le_to_string(bytes: &[u8]) -> String { + if !bytes.len().is_multiple_of(2) { + return String::new(); + } + let units: Vec = bytes + .chunks_exact(2) + .map(|c| u16::from_le_bytes([c[0], c[1]])) + .collect(); + String::from_utf16_lossy(&units) +} + +/// Decode a UTF-16LE byte slice into a `Vec`, returning `None` on a +/// non-aligned buffer. +pub fn utf16le_to_units(bytes: &[u8]) -> Option> { + if !bytes.len().is_multiple_of(2) { + return None; + } + Some( + bytes + .chunks_exact(2) + .map(|c| u16::from_le_bytes([c[0], c[1]])) + .collect(), + ) +} + +/// Fill `out` with cryptographically-strong random bytes via `getrandom`. +/// Falls back to zeros if the OS RNG fails — the caller should treat this as +/// fatal, but we never panic. +pub fn fill_random(out: &mut [u8]) { + if getrandom::fill(out).is_err() { + for b in out.iter_mut() { + *b = 0; + } + } +} diff --git a/vendor/smb2/.cargo/audit.toml b/vendor/smb2/.cargo/audit.toml new file mode 100644 index 0000000..835dd31 --- /dev/null +++ b/vendor/smb2/.cargo/audit.toml @@ -0,0 +1,8 @@ +# Ignored advisories for cargo-audit + +[advisories] +ignore = [ + # Marvin Attack timing sidechannel in `rsa` crate. No fix available. + # Only affects benchmarks/smb/ (via sspi -> rsa), not the smb2 crate itself. + "RUSTSEC-2023-0071", +] diff --git a/vendor/smb2/.claude/rules/docs-maintenance.md b/vendor/smb2/.claude/rules/docs-maintenance.md new file mode 100644 index 0000000..47aad9e --- /dev/null +++ b/vendor/smb2/.claude/rules/docs-maintenance.md @@ -0,0 +1,12 @@ +When modifying code in a directory that contains a `CLAUDE.md` file, check whether your changes affect the documented +architecture, key decisions, or gotchas. If they do, update the `CLAUDE.md` to stay in sync. If you notice a `CLAUDE.md` +missing in a directory where there should be one, add it. Skip this for trivial changes (bug fixes, formatting, small +refactors that don't change the architecture). + +If something failed due to a wrong assumption, add a `Gotcha/Why` entry to the nearest `CLAUDE.md`. + +Add `Decision/Why` entries to the nearest colocated `CLAUDE.md` for key decisions. If the decision has rich evidence +(benchmarks, detailed analysis), put the evidence in `docs/notes/` and link from the CLAUDE.md. + +When writing guides, see [this diff](https://github.com/vdavid/cmdr/commit/13ad8f3#diff-795210f) for the formatting +standard. (Before: AI-written. After: matching our standards for conciseness and clarity.) diff --git a/vendor/smb2/.claude/rules/git-conventions.md b/vendor/smb2/.claude/rules/git-conventions.md new file mode 100644 index 0000000..7a4cc44 --- /dev/null +++ b/vendor/smb2/.claude/rules/git-conventions.md @@ -0,0 +1,14 @@ +## Commit messages + +- Use conventional commit messages. +- Title: Capture the IMPACT of the change, not the tech details. From the title, we need to understand WHY we did this, + what we ACHIEVED with the commit. Length-wise, aim for about 50 chars max. +- Body: Use bullets primarily. No word wrap. Don't hard-wrap body lines at 72 chars or any other width. Let the + terminal/viewer wrap naturally. Enclose entities in ``. No co-author! + +## PRs + +- Use the PR title to summarize the changes in a casual/informal tone. Be information dense and concise. +- In the desc., write a thorough, organized, but concise, often bulleted list of the changes. Use no headings. +- At the bottom of the PR description, use a single "## Test plan" heading, in which, explain how the changes were + tested. Assume that the changes were also tested manually if it makes sense for the type of changes. diff --git a/vendor/smb2/.codegraph/.gitignore b/vendor/smb2/.codegraph/.gitignore new file mode 100644 index 0000000..9de0f16 --- /dev/null +++ b/vendor/smb2/.codegraph/.gitignore @@ -0,0 +1,16 @@ +# CodeGraph data files +# These are local to each machine and should not be committed + +# Database +*.db +*.db-wal +*.db-shm + +# Cache +cache/ + +# Logs +*.log + +# Hook markers +.dirty diff --git a/vendor/smb2/.codegraph/config.json b/vendor/smb2/.codegraph/config.json new file mode 100644 index 0000000..f613aa7 --- /dev/null +++ b/vendor/smb2/.codegraph/config.json @@ -0,0 +1,140 @@ +{ + "version": 1, + "include": [ + "**/*.ts", + "**/*.tsx", + "**/*.js", + "**/*.jsx", + "**/*.py", + "**/*.go", + "**/*.rs", + "**/*.java", + "**/*.c", + "**/*.h", + "**/*.cpp", + "**/*.hpp", + "**/*.cc", + "**/*.cxx", + "**/*.cs", + "**/*.php", + "**/*.rb", + "**/*.swift", + "**/*.kt", + "**/*.kts", + "**/*.dart", + "**/*.svelte", + "**/*.liquid", + "**/*.pas", + "**/*.dpr", + "**/*.dpk", + "**/*.lpr", + "**/*.dfm", + "**/*.fmx" + ], + "exclude": [ + "**/.git/**", + "**/node_modules/**", + "**/vendor/**", + "**/Pods/**", + "**/dist/**", + "**/build/**", + "**/out/**", + "**/bin/**", + "**/obj/**", + "**/target/**", + "**/*.min.js", + "**/*.bundle.js", + "**/.next/**", + "**/.nuxt/**", + "**/.svelte-kit/**", + "**/.output/**", + "**/.turbo/**", + "**/.cache/**", + "**/.parcel-cache/**", + "**/.vite/**", + "**/.astro/**", + "**/.docusaurus/**", + "**/.gatsby/**", + "**/.webpack/**", + "**/.nx/**", + "**/.yarn/cache/**", + "**/.pnpm-store/**", + "**/storybook-static/**", + "**/.expo/**", + "**/web-build/**", + "**/ios/Pods/**", + "**/ios/build/**", + "**/android/build/**", + "**/android/.gradle/**", + "**/__pycache__/**", + "**/.venv/**", + "**/venv/**", + "**/site-packages/**", + "**/dist-packages/**", + "**/.pytest_cache/**", + "**/.mypy_cache/**", + "**/.ruff_cache/**", + "**/.tox/**", + "**/.nox/**", + "**/*.egg-info/**", + "**/.eggs/**", + "**/go/pkg/mod/**", + "**/target/debug/**", + "**/target/release/**", + "**/.gradle/**", + "**/.m2/**", + "**/generated-sources/**", + "**/.kotlin/**", + "**/.dart_tool/**", + "**/.vs/**", + "**/.nuget/**", + "**/artifacts/**", + "**/publish/**", + "**/cmake-build-*/**", + "**/CMakeFiles/**", + "**/bazel-*/**", + "**/vcpkg_installed/**", + "**/.conan/**", + "**/Debug/**", + "**/Release/**", + "**/x64/**", + "**/.pio/**", + "**/release/**", + "**/*.app/**", + "**/*.asar", + "**/DerivedData/**", + "**/.build/**", + "**/.swiftpm/**", + "**/xcuserdata/**", + "**/Carthage/Build/**", + "**/SourcePackages/**", + "**/__history/**", + "**/__recovery/**", + "**/*.dcu", + "**/.composer/**", + "**/storage/framework/**", + "**/bootstrap/cache/**", + "**/.bundle/**", + "**/tmp/cache/**", + "**/public/assets/**", + "**/public/packs/**", + "**/.yardoc/**", + "**/coverage/**", + "**/htmlcov/**", + "**/.nyc_output/**", + "**/test-results/**", + "**/.coverage/**", + "**/.idea/**", + "**/logs/**", + "**/tmp/**", + "**/temp/**", + "**/_build/**", + "**/docs/_build/**", + "**/site/**" + ], + "languages": [], + "frameworks": [], + "maxFileSize": 1048576, + "extractDocstrings": true, + "trackCallSites": true +} \ No newline at end of file diff --git a/vendor/smb2/.env.example b/vendor/smb2/.env.example new file mode 100644 index 0000000..3825808 --- /dev/null +++ b/vendor/smb2/.env.example @@ -0,0 +1,4 @@ +# Copy this to .env and fill in your values. +# .env is gitignored and never committed. + +SMB2_TEST_NAS_PASSWORD=your_nas_password_here diff --git a/vendor/smb2/.gitattributes b/vendor/smb2/.gitattributes new file mode 100644 index 0000000..524f56f --- /dev/null +++ b/vendor/smb2/.gitattributes @@ -0,0 +1,2 @@ +# Force LF line endings for all text files (consistent with rustfmt.toml newline_style = "Unix") +* text=auto eol=lf diff --git a/vendor/smb2/.github/workflows/ci.yml b/vendor/smb2/.github/workflows/ci.yml new file mode 100644 index 0000000..f383a3c --- /dev/null +++ b/vendor/smb2/.github/workflows/ci.yml @@ -0,0 +1,136 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +env: + CARGO_TERM_COLOR: always + +jobs: + check: + name: Check (${{ matrix.os }}, rust ${{ matrix.rust }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-2025] + rust: ["1.85", stable] + + steps: + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + components: rustfmt, clippy + + - name: Cache cargo registry and target + uses: Swatinem/rust-cache@v2 + + - name: Check formatting + run: cargo fmt --check + + - name: Run clippy lints + run: cargo clippy --all-targets -- -D warnings + + - name: Run tests + run: cargo test + + - name: Build documentation + run: cargo doc --no-deps + + docker-tests: + name: Docker integration tests + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry and target + uses: Swatinem/rust-cache@v2 + + - name: Start SMB containers + run: ./tests/docker/start.sh internal + + - name: Run Docker integration tests + run: cargo test --test docker_integration -- --ignored + env: + RUST_LOG: smb2=info + + - name: Stop containers + if: always() + run: ./tests/docker/stop.sh + + consumer-tests: + name: Consumer integration tests + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry and target + uses: Swatinem/rust-cache@v2 + + - name: Start consumer containers + run: ./tests/docker/start.sh consumer + + - name: Run consumer integration tests + run: cargo test --features testing --test consumer_integration -- --ignored + env: + RUST_LOG: smb2=info + + - name: Stop containers + if: always() + run: ./tests/docker/stop.sh + + msrv: + name: Verify MSRV (1.85) + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + + - name: Install Rust toolchain (MSRV) + uses: dtolnay/rust-toolchain@master + with: + toolchain: "1.85" + + - name: Cache cargo registry and target + uses: Swatinem/rust-cache@v2 + + - name: Check compilation on MSRV + run: cargo check + env: + RUSTFLAGS: "-D warnings" + + ci-ok: + name: CI OK + runs-on: ubuntu-latest + needs: [check, docker-tests, consumer-tests, msrv] + if: always() + steps: + - name: Check all jobs passed + run: | + if [[ "${{ contains(needs.*.result, 'failure') }}" == "true" ]]; then + echo "Some jobs failed" + exit 1 + fi + if [[ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]]; then + echo "Some jobs were cancelled" + exit 1 + fi + echo "All jobs passed" diff --git a/vendor/smb2/.github/workflows/fuzz.yml b/vendor/smb2/.github/workflows/fuzz.yml new file mode 100644 index 0000000..05f27b4 --- /dev/null +++ b/vendor/smb2/.github/workflows/fuzz.yml @@ -0,0 +1,74 @@ +name: Fuzz + +# Short-duration fuzz run: weekly schedule + manual dispatch. Each target +# runs for 5 minutes with the committed seed corpus. For longer hunts, run +# locally: `cargo +nightly fuzz run -- -max_total_time=1800`. +# +# We deliberately do NOT fuzz on every push -- runs are too long for that. + +on: + schedule: + # Mondays 04:15 UTC. + - cron: "15 4 * * 1" + workflow_dispatch: + inputs: + duration_seconds: + description: "Per-target fuzz time (seconds)" + required: false + default: "300" + +env: + CARGO_TERM_COLOR: always + +jobs: + fuzz: + name: Fuzz ${{ matrix.target }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + target: + - fuzz_header_parse + - fuzz_transform_header_parse + - fuzz_compression_transform_header_parse + - fuzz_compound_split + - fuzz_frame_parse + - fuzz_sub_frame_parse + - fuzz_negotiate_request_parse + - fuzz_negotiate_response_parse + - fuzz_create_request_parse + - fuzz_create_response_parse + - fuzz_query_info_response_parse + - fuzz_dfs_referral_response_parse + + steps: + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + + - name: Install Rust nightly + uses: dtolnay/rust-toolchain@nightly + + - name: Cache cargo registry and target + uses: Swatinem/rust-cache@v2 + with: + workspaces: | + . + fuzz + + - name: Install cargo-fuzz + run: cargo install cargo-fuzz + + - name: Run fuzz target + env: + DURATION: ${{ github.event.inputs.duration_seconds || '300' }} + run: | + cargo +nightly fuzz run "${{ matrix.target }}" \ + -- -max_total_time="${DURATION}" -print_final_stats=1 + + - name: Upload crash artifacts (if any) + if: failure() + uses: actions/upload-artifact@v7 + with: + name: fuzz-crash-${{ matrix.target }} + path: fuzz/artifacts/${{ matrix.target }}/ + if-no-files-found: ignore diff --git a/vendor/smb2/.gitignore b/vendor/smb2/.gitignore new file mode 100644 index 0000000..8db06dc --- /dev/null +++ b/vendor/smb2/.gitignore @@ -0,0 +1,7 @@ +.idea/ +.claude/projects/ +.claude/worktrees/ +related-repos/ +target/ +.DS_Store +.env diff --git a/vendor/smb2/Cargo.toml b/vendor/smb2/Cargo.toml new file mode 100644 index 0000000..96b6c36 --- /dev/null +++ b/vendor/smb2/Cargo.toml @@ -0,0 +1,81 @@ +[package] +name = "smb2" +version = "0.11.3" +edition = "2021" +rust-version = "1.85" +license = "MIT OR Apache-2.0" +description = "Pure-Rust SMB2/3 client library with pipelined I/O" +repository = "https://github.com/vdavid/smb2" +keywords = ["smb", "smb2", "smb3", "cifs", "network"] +categories = ["network-programming", "filesystem"] +readme = "README.md" +documentation = "https://docs.rs/smb2" +exclude = [ + ".github/", + "AGENTS.md", + "docs/", + "justfile", + "deny.toml", + "clippy.toml", + "rustfmt.toml", + "related-repos/", +] + +[package.metadata.docs.rs] +all-features = true + +[dependencies] +# Logging facade -- application picks the backend (env_logger, tracing, etc.) +log = "0.4" + +# Async runtime agnostic +async-trait = "0.1" + +# `FuturesUnordered` for pipelined concurrent `execute` calls. +futures-util = { version = "0.3", default-features = false, features = ["std", "async-await"] } + +# Error handling +thiserror = "2" + +# Enum conversion derives +num_enum = "0.7" + +# Async runtime -- transport layer needs net, io-util, time, sync +tokio = { version = "1", features = ["net", "io-util", "time", "sync", "rt"] } + +# Crypto -- signing, encryption, key derivation +hmac = "0.13" +sha2 = "0.11" +aes = "0.9" +aes-gcm = "=0.11.0-rc.4" +ccm = "=0.6.0-rc.3" +cmac = "=0.8.0-rc.5" +digest = "0.11" + +# NTLM authentication (MS-NLMP) +md-5 = "0.11" +md4 = "0.11" + +# Kerberos key derivation (AES string-to-key) +pbkdf2 = "0.13" +sha1 = "0.11" + +# Cryptographically secure random +getrandom = "0.4" + +# Compression +lz4_flex = "0.13" + +# Optional: `Serialize` derives on diagnostics types. Off by default. +serde = { version = "1", optional = true, features = ["derive"] } + +[features] +testing = [] # Enables smb2::testing module for Docker-based test servers +fuzzing = [] # Exposes parser entry points for `fuzz/` targets; not for applications +serde = ["dep:serde"] # `Serialize` impls on `Diagnostics` types and the protocol enums they embed. + +[dev-dependencies] +tokio = { version = "1", features = ["rt-multi-thread", "macros", "time", "net", "io-util"] } +proptest = "1" +env_logger = "0.11" +serde_json = "1" # JSON round-trip tests for the `serde` feature diff --git a/vendor/smb2/src/auth/CLAUDE.md b/vendor/smb2/src/auth/CLAUDE.md new file mode 100644 index 0000000..ea26aba --- /dev/null +++ b/vendor/smb2/src/auth/CLAUDE.md @@ -0,0 +1,120 @@ +# Auth -- NTLM and Kerberos authentication + +NTLMv2 and Kerberos authentication for SMB2 session setup. + +## Key files + +| File | Purpose | +|---|---| +| `mod.rs` | Module exports | +| `der.rs` | Shared ASN.1/DER primitives (TLV encode/decode) | +| `ntlm.rs` | `NtlmAuthenticator` -- 3-message NTLM exchange | +| `spnego.rs` | SPNEGO NegTokenInit/NegTokenResp wrapping | +| `kerberos/mod.rs` | Kerberos module root, re-exports authenticator | +| `kerberos/authenticator.rs` | `KerberosAuthenticator` -- full AS + TGS + AP-REQ flow | +| `kerberos/crypto.rs` | AES-CTS, RC4-HMAC, string-to-key, key derivation | +| `kerberos/messages.rs` | ASN.1/DER encoding/decoding for Kerberos messages | +| `kerberos/kdc.rs` | KDC transport client (UDP/TCP with fallback) | + +## NTLM exchange + +1. `negotiate()` -- builds NEGOTIATE_MESSAGE (Type 1) with default flags +2. Server sends CHALLENGE_MESSAGE (Type 2) with server challenge and target info +3. `authenticate(&challenge_bytes)` -- builds AUTHENTICATE_MESSAGE (Type 3) with NTLMv2 response + +Only NTLMv2 is supported. NTLMv1 is insecure and not implemented. + +## Kerberos exchange + +`KerberosAuthenticator` performs the full Kerberos flow in three steps: + +1. **AS exchange** (client -> KDC): derive user key from password, build PA-ENC-TIMESTAMP + PA-PAC-REQUEST, send AS-REQ, parse AS-REP, decrypt enc-part with user key to get TGT + AS session key. +2. **TGS exchange** (client -> KDC): build AP-REQ wrapping TGT + authenticator (encrypted with AS session key), send TGS-REQ for `cifs/hostname`, parse TGS-REP, decrypt enc-part with AS session key to get service ticket + TGS session key. +3. **AP-REQ construction**: build Authenticator with subkey, encrypt with TGS session key, build AP-REQ with service ticket, wrap in SPNEGO NegTokenInit. + +The flow differs from NTLM: Kerberos contacts the KDC directly (async, network I/O), then produces a single token for SESSION_SETUP (usually 1 round-trip with the SMB server). + +### Key usage numbers (RFC 4120 section 7.5.1) + +- 1: PA-ENC-TIMESTAMP encryption +- 3: AS-REP EncKDCRepPart decryption +- 6: TGS-REQ PA-TGS-REQ Authenticator cksum (body checksum) +- 7: AP-REQ Authenticator encryption +- 8: TGS-REP EncKDCRepPart decryption (tries 8 first, falls back to 9) + +### Encryption types supported + +- AES-256-CTS-HMAC-SHA1-96 (etype 18) -- preferred +- AES-128-CTS-HMAC-SHA1-96 (etype 17) +- RC4-HMAC (etype 23) -- legacy + +### Key derivation constants (RFC 3961) + +Three subkeys are derived from each base key + usage number: +- **Ke** = DK(key, usage || 0xAA) -- encryption subkey, used for AES-CTS +- **Ki** = DK(key, usage || 0x55) -- integrity subkey, used for HMAC inside encrypt/decrypt +- **Kc** = DK(key, usage || 0x99) -- checksum subkey, used for standalone checksum/MIC + +Ki and Kc are NOT the same key. Ki is for the HMAC that's appended to ciphertext in the encrypt() function. Kc is for standalone operations like the body checksum in the TGS-REQ Authenticator. + +### Kerberos wire encryption format (AES) + +1. Derive Ke (with 0xAA) and Ki (with 0x55) from base key + usage +2. Generate 16-byte random confounder +3. plaintext' = confounder || plaintext +4. ciphertext = AES-CTS(Ke, iv=0, plaintext') +5. hmac = HMAC-SHA1-96(Ki, plaintext') -- 12 bytes +6. output = ciphertext || hmac + +## NTLM key derivation flow + +1. `NTOWFv2`: `HMAC-MD5(MD4(password_utf16), uppercase(username) + domain)` +2. `NTProofStr`: `HMAC-MD5(NTOWFv2, server_challenge + client_blob)` +3. `SessionBaseKey`: `HMAC-MD5(NTOWFv2, NTProofStr)` +4. If KEY_EXCH flag: generate random session key, RC4-encrypt with SessionBaseKey +5. `ExportedSessionKey` feeds into SP800-108 KDF (in `crypto/kdf.rs`) + +## MIC computation + +Modern servers include `MsvAvTimestamp` in the challenge target info, which triggers MIC validation. When present: +1. Add `MsvAvFlags` with bit 0x2 (MIC present) to the target info +2. Build the AUTHENTICATE_MESSAGE with a zeroed 16-byte MIC field at offset 72 +3. Compute `HMAC-MD5(ExportedSessionKey, negotiate_msg || challenge_msg || authenticate_msg)` +4. Patch the MIC into bytes 72..88 + +The authenticator retains raw bytes of NEGOTIATE and CHALLENGE messages for this computation. + +## Key decisions + +- **`getrandom` for random values**: Client challenge, random session key, nonces, and confounders use `getrandom` (OS CSPRNG). +- **`test_random_session_key` override**: Tests can inject a deterministic session key for reproducibility. Never used in production. +- **Subkey in AP-REQ Authenticator**: The Kerberos authenticator includes a random subkey, which becomes the SMB session key. This provides forward secrecy. +- **No full `authenticate()` unit tests**: The full flow requires a real KDC. Unit tests cover individual steps (encrypt/decrypt roundtrip, message encoding, etype parsing). Integration tests with Docker are planned. + +## Gotchas + +- **Retain raw challenge bytes for MIC (NTLM)**: The MIC is computed over the exact wire bytes of all three messages. +- **RC4 for key exchange is inline (NTLM)**: ~15 lines of RC4 implementation. +- **MsvAvTimestamp presence changes behavior (NTLM)**: Without it, no MIC is computed. With it, MIC is mandatory. +- **NTLMv1 not supported**: No fallback. +- **Target info modification (NTLM)**: The client modifies the server's target info before including it in the client blob. +- **TGS-REP key usage ambiguity (Kerberos)**: RFC 4120 says key usage 8 for TGS-REP encrypted with session key, but some KDCs use 9. The authenticator tries 8 first, falls back to 9. +- **KDC_ERR_PREAUTH_REQUIRED handling (Kerberos)**: First AS-REQ without pre-auth gets error 25. The authenticator extracts supported etypes from the e-data (ETYPE-INFO2) and retries with pre-authentication. +- **DER primitives in `auth::der`**: Core DER encoding/decoding helpers (`der_length`, `der_tlv`, `parse_der_length`, `parse_der_tlv`) live in `auth/der.rs` and are shared by `spnego.rs` and `kerberos/messages.rs`. Type-specific helpers (INTEGER, GeneralString, etc.) stay in their respective modules. + +## Kerberos key design decisions (from end-to-end testing) + +- **MS Kerberos OID (`1.2.840.48018.1.2.2`)**: Windows AD requires the Microsoft Kerberos OID in SPNEGO NegTokenInit, not the standard RFC 4120 OID. Both are included in mechTypes, with MS OID first. +- **Key usage 11 for SPNEGO AP-REQ Authenticator**: Standard RFC 4120 uses key usage 7 for AP-REQ Authenticator encryption. Windows expects key usage 11 when the AP-REQ is wrapped in SPNEGO (per MS-KILE). Using 7 causes `KRB_AP_ERR_MODIFIED`. +- **Session key etype detection**: The TGS-REQ requests AES-256, AES-128, and RC4 (preference order). The KDC picks the session key type from this list — it may differ from the ticket encryption type. The authenticator detects the actual etype from the TGS-REP `EncKDCRepPart.key.keytype` and uses the matching cipher for Authenticator encryption. +- **Raw ticket pass-through**: The service ticket bytes must be sent to the SMB server exactly as received from the KDC. Re-encoding the ticket from parsed fields produces different DER and causes `KRB_AP_ERR_MODIFIED`. The `Ticket` struct carries `raw_bytes` for this. +- **GSS-API wrapping**: The AP-REQ in SPNEGO NegTokenInit must include the GSS-API OID header (`0x60 len OID ap-req`), not just the raw AP-REQ bytes. +- **Mutual authentication**: AP-REQ sets the mutual-required flag. The server returns an AP-REP (in SPNEGO NegTokenResp) containing a server sub-session key. The client decrypts the AP-REP (key usage 12) to extract this subkey, which becomes the SMB session key. This provides cryptographic proof that the server possesses the service key. The AP-REP may arrive in a `STATUS_SUCCESS` response (not always `STATUS_MORE_PROCESSING_REQUIRED`). + +- **Credential cache (ccache) support**: `kerberos/ccache.rs` parses MIT Kerberos ccache files (v3 and v4). Supports loading cached TGTs (skip AS exchange, do TGS) and cached service tickets (skip both AS and TGS). Integrates via `Session::setup_kerberos_from_ccache()` and `KerberosAuthenticator::authenticate_from_ccache()`. `load_ccache()` reads from a path or `$KRB5CCNAME`. + +## Known tech debt (Kerberos) + +- ~~DER helpers duplicated between `spnego.rs` and `kerberos/messages.rs`~~ (resolved: shared `auth/der.rs`) +- ~~`kerberos/authenticator.rs` mixes crypto wrappers with protocol flow~~ (resolved: `kerberos_encrypt`, `kerberos_decrypt`, `etype_from_i32`, and `generate_random_key` moved to `kerberos/crypto.rs`) +- ~~`#![allow(rustdoc::broken_intra_doc_links)]` hack in `kerberos/messages.rs`~~ (resolved: ASN.1 context tags in doc comments wrapped in backticks) diff --git a/vendor/smb2/src/auth/der.rs b/vendor/smb2/src/auth/der.rs new file mode 100644 index 0000000..a3e471e --- /dev/null +++ b/vendor/smb2/src/auth/der.rs @@ -0,0 +1,196 @@ +//! Shared ASN.1/DER encoding and decoding primitives. +//! +//! These low-level helpers are used by both `spnego.rs` and `kerberos/messages.rs` +//! to build and parse DER-encoded structures. Only the core TLV operations live +//! here; type-specific helpers (INTEGER, GeneralString, etc.) stay in their +//! respective modules. + +use crate::Error; + +/// Encode a DER length field. +/// +/// - Lengths < 128 are encoded as a single byte. +/// - Lengths < 256 are encoded as `0x81` followed by one byte. +/// - Lengths < 65536 are encoded as `0x82` followed by two bytes (big-endian). +pub(crate) fn der_length(len: usize) -> Vec { + if len < 128 { + vec![len as u8] + } else if len < 256 { + vec![0x81, len as u8] + } else { + vec![0x82, (len >> 8) as u8, (len & 0xff) as u8] + } +} + +/// Wrap data in a DER TLV (tag-length-value). +pub(crate) fn der_tlv(tag: u8, data: &[u8]) -> Vec { + let mut out = vec![tag]; + out.extend_from_slice(&der_length(data.len())); + out.extend_from_slice(data); + out +} + +/// Parse a DER length field, returning `(length, bytes_consumed)`. +pub(crate) fn parse_der_length(data: &[u8]) -> Result<(usize, usize), Error> { + if data.is_empty() { + return Err(Error::invalid_data("DER: truncated length")); + } + let first = data[0]; + if first < 128 { + Ok((first as usize, 1)) + } else if first == 0x81 { + if data.len() < 2 { + return Err(Error::invalid_data("DER: truncated length (0x81)")); + } + Ok((data[1] as usize, 2)) + } else if first == 0x82 { + if data.len() < 3 { + return Err(Error::invalid_data("DER: truncated length (0x82)")); + } + let len = ((data[1] as usize) << 8) | (data[2] as usize); + Ok((len, 3)) + } else if first == 0x83 { + if data.len() < 4 { + return Err(Error::invalid_data("DER: truncated length (0x83)")); + } + let len = ((data[1] as usize) << 16) | ((data[2] as usize) << 8) | (data[3] as usize); + Ok((len, 4)) + } else { + Err(Error::invalid_data(format!( + "DER: unsupported length encoding: 0x{first:02x}" + ))) + } +} + +/// Parse a DER TLV, returning `(tag, value_slice, total_bytes_consumed)`. +pub(crate) fn parse_der_tlv(data: &[u8]) -> Result<(u8, &[u8], usize), Error> { + if data.is_empty() { + return Err(Error::invalid_data("DER: truncated TLV")); + } + let tag = data[0]; + let (len, len_bytes) = parse_der_length(&data[1..])?; + let header_len = 1 + len_bytes; + let total = header_len + len; + if data.len() < total { + return Err(Error::invalid_data(format!( + "DER: TLV truncated: need {total} bytes, have {}", + data.len() + ))); + } + Ok((tag, &data[header_len..total], total)) +} + +#[cfg(test)] +mod tests { + use super::*; + + // ======================================================================= + // DER length encoding + // ======================================================================= + + #[test] + fn length_single_byte() { + assert_eq!(der_length(0), vec![0x00]); + assert_eq!(der_length(1), vec![0x01]); + assert_eq!(der_length(127), vec![0x7f]); + } + + #[test] + fn length_two_byte() { + assert_eq!(der_length(128), vec![0x81, 0x80]); + assert_eq!(der_length(255), vec![0x81, 0xff]); + } + + #[test] + fn length_three_byte() { + assert_eq!(der_length(256), vec![0x82, 0x01, 0x00]); + assert_eq!(der_length(65535), vec![0x82, 0xff, 0xff]); + assert_eq!(der_length(1000), vec![0x82, 0x03, 0xe8]); + } + + // ======================================================================= + // DER TLV encoding + // ======================================================================= + + #[test] + fn tlv_simple() { + let result = der_tlv(0x04, &[0x01, 0x02]); + assert_eq!(result, vec![0x04, 0x02, 0x01, 0x02]); + } + + #[test] + fn tlv_empty() { + let result = der_tlv(0x30, &[]); + assert_eq!(result, vec![0x30, 0x00]); + } + + #[test] + fn tlv_long_content() { + let data = vec![0xaa; 200]; + let result = der_tlv(0x04, &data); + assert_eq!(result[0], 0x04); + assert_eq!(result[1], 0x81); + assert_eq!(result[2], 200); + assert_eq!(result.len(), 3 + 200); + } + + // ======================================================================= + // DER length parsing + // ======================================================================= + + #[test] + fn parse_length_single_byte() { + let (len, consumed) = parse_der_length(&[0x05]).unwrap(); + assert_eq!(len, 5); + assert_eq!(consumed, 1); + } + + #[test] + fn parse_length_two_byte() { + let (len, consumed) = parse_der_length(&[0x81, 0x80]).unwrap(); + assert_eq!(len, 128); + assert_eq!(consumed, 2); + } + + #[test] + fn parse_length_three_byte() { + let (len, consumed) = parse_der_length(&[0x82, 0x01, 0x00]).unwrap(); + assert_eq!(len, 256); + assert_eq!(consumed, 3); + } + + #[test] + fn parse_length_four_byte() { + let (len, consumed) = parse_der_length(&[0x83, 0x01, 0x00, 0x00]).unwrap(); + assert_eq!(len, 65536); + assert_eq!(consumed, 4); + } + + #[test] + fn parse_length_truncated() { + assert!(parse_der_length(&[]).is_err()); + assert!(parse_der_length(&[0x81]).is_err()); + assert!(parse_der_length(&[0x82, 0x01]).is_err()); + assert!(parse_der_length(&[0x83, 0x01, 0x00]).is_err()); + } + + // ======================================================================= + // DER TLV parsing + // ======================================================================= + + #[test] + fn parse_tlv_roundtrip() { + let original = der_tlv(0x04, &[0xde, 0xad, 0xbe, 0xef]); + let (tag, value, total) = parse_der_tlv(&original).unwrap(); + assert_eq!(tag, 0x04); + assert_eq!(value, &[0xde, 0xad, 0xbe, 0xef]); + assert_eq!(total, original.len()); + } + + #[test] + fn parse_tlv_truncated() { + assert!(parse_der_tlv(&[]).is_err()); + // Tag present, length says 10 bytes but only 2 available + assert!(parse_der_tlv(&[0x04, 0x0a, 0x01, 0x02]).is_err()); + } +} diff --git a/vendor/smb2/src/auth/kerberos/authenticator.rs b/vendor/smb2/src/auth/kerberos/authenticator.rs new file mode 100644 index 0000000..30bd598 --- /dev/null +++ b/vendor/smb2/src/auth/kerberos/authenticator.rs @@ -0,0 +1,1637 @@ +//! Stateful Kerberos authenticator for SMB2 session setup. +//! +//! Performs the full Kerberos authentication exchange: +//! 1. AS exchange (client -> KDC): get a TGT +//! 2. TGS exchange (client -> KDC): get a service ticket for `cifs/hostname` +//! 3. AP-REQ construction: wrap the service ticket for SESSION_SETUP +//! +//! After [`KerberosAuthenticator::authenticate`] succeeds, call +//! [`token()`](KerberosAuthenticator::token) for the SPNEGO-wrapped AP-REQ +//! and [`session_key()`](KerberosAuthenticator::session_key) for the SMB +//! session key. + +use log::{debug, trace}; +use std::time::Duration; + +use crate::auth::kerberos::crypto::{ + compute_checksum, etype_from_i32, kerberos_decrypt, kerberos_encrypt, string_to_key_aes, + string_to_key_rc4, EncryptionType, +}; +use crate::auth::kerberos::kdc::{send_to_kdc, KdcConfig}; +use crate::auth::kerberos::messages::{ + encode_ap_req, encode_as_req, encode_authenticator, encode_pa_enc_timestamp, encode_tgs_req, + encode_tgs_req_body, parse_enc_kdc_rep_part, parse_kdc_rep, parse_krb_error, EncryptedData, + PaData, PrincipalName, Ticket, +}; +use crate::auth::spnego::{wrap_neg_token_init, OID_KERBEROS, OID_MS_KERBEROS}; +use crate::error::{Error, Result}; + +// --------------------------------------------------------------------------- +// Key usage numbers (RFC 4120 section 7.5.1) +// --------------------------------------------------------------------------- + +/// Key usage for PA-ENC-TIMESTAMP encryption. +const KEY_USAGE_PA_ENC_TIMESTAMP: u32 = 1; + +/// Key usage for AS-REP EncKDCRepPart decryption. +const KEY_USAGE_AS_REP_ENC_PART: u32 = 3; + +/// Key usage for AP-REQ Authenticator encryption (standard, RFC 4120). +/// +/// Used for the PA-TGS-REQ authenticator in TGS exchanges. +const KEY_USAGE_AP_REQ_AUTHENTICATOR: u32 = 7; + +/// Key usage for AP-REQ Authenticator encryption (MS-KILE/SPNEGO). +/// +/// Windows servers expect key usage 11 for the AP-REQ Authenticator +/// in SPNEGO-wrapped SMB SESSION_SETUP exchanges. Impacket uses this. +const KEY_USAGE_AP_REQ_AUTHENTICATOR_SPNEGO: u32 = 11; + +/// Key usage for TGS-REP EncKDCRepPart decryption (sub-session key). +/// +/// Per RFC 4120 section 7.5.1 and MS-KILE, the TGS-REP enc-part is +/// encrypted with key usage 8 when using the TGT session key. +/// However, some implementations use key usage 9. We try 8 first, +/// then fall back to 9 if decryption fails. +const KEY_USAGE_TGS_REP_ENC_PART_SESSION_KEY: u32 = 8; + +/// Fallback key usage for TGS-REP (some KDCs use 9). +const KEY_USAGE_TGS_REP_ENC_PART_SUBKEY: u32 = 9; + +// --------------------------------------------------------------------------- +// KDC error codes (RFC 4120 section 7.5.9) +// --------------------------------------------------------------------------- + +/// KDC_ERR_PREAUTH_REQUIRED: pre-authentication information was needed but +/// not found in the request. +const KDC_ERR_PREAUTH_REQUIRED: i32 = 25; + +// --------------------------------------------------------------------------- +// PA-DATA type constants +// --------------------------------------------------------------------------- + +/// PA-ENC-TIMESTAMP (padata type 2). +const PA_ENC_TIMESTAMP: i32 = 2; + +/// PA-ETYPE-INFO2 (padata type 19). +const PA_ETYPE_INFO2: i32 = 19; + +/// PA-PAC-REQUEST (padata type 128). +const PA_PAC_REQUEST: i32 = 128; + +// --------------------------------------------------------------------------- +// Public types +// --------------------------------------------------------------------------- + +/// Credentials for Kerberos authentication. +#[derive(Debug, Clone)] +pub struct KerberosCredentials { + /// Username (without realm). + pub username: String, + /// Password. + pub password: String, + /// Kerberos realm (uppercase, for example, "CORP.EXAMPLE.COM"). + pub realm: String, + /// KDC address (host:port or host, port defaults to 88). + pub kdc_address: String, +} + +/// Stateful Kerberos authenticator. +/// +/// Performs the full Kerberos exchange: AS -> TGS -> AP. +/// After completion, [`session_key()`](Self::session_key) returns the session +/// key for SMB signing/encryption. +pub struct KerberosAuthenticator { + credentials: KerberosCredentials, + /// TGT obtained from the AS exchange. + tgt: Option, + /// Session key from the AS exchange (used to authenticate to the TGS). + as_session_key: Option>, + /// Service ticket obtained from the TGS exchange. + service_ticket: Option, + /// Session key from the TGS exchange (the SMB session key). + tgs_session_key: Option>, + /// SPNEGO-wrapped AP-REQ bytes for SESSION_SETUP. + ap_req_bytes: Option>, + /// Final session key for SMB (same as tgs_session_key). + session_key: Option>, + /// Negotiated encryption type. + etype: EncryptionType, +} + +impl KerberosAuthenticator { + /// Create a new authenticator with the given credentials. + pub fn new(credentials: KerberosCredentials) -> Self { + Self { + credentials, + tgt: None, + as_session_key: None, + service_ticket: None, + tgs_session_key: None, + ap_req_bytes: None, + session_key: None, + etype: EncryptionType::Aes256CtsHmacSha196, + } + } + + /// Perform the full Kerberos exchange (AS + TGS + build AP-REQ). + /// + /// After this returns `Ok(())`, call [`token()`](Self::token) to get the + /// SPNEGO-wrapped AP-REQ for SESSION_SETUP, and + /// [`session_key()`](Self::session_key) for the session key. + /// + /// This is async because it contacts the KDC over the network. + pub async fn authenticate(&mut self, server_hostname: &str) -> Result<()> { + let kdc_config = KdcConfig { + address: self.credentials.kdc_address.clone(), + timeout: Duration::from_secs(10), + }; + + // ── Step 1: AS exchange ── + debug!("kerberos: starting AS exchange"); + self.as_exchange(&kdc_config).await?; + + // ── Step 2: TGS exchange ── + debug!( + "kerberos: starting TGS exchange for cifs/{}", + server_hostname + ); + self.tgs_exchange(&kdc_config, server_hostname).await?; + + // ── Step 3: Build AP-REQ ── + debug!("kerberos: building AP-REQ"); + self.build_ap_req()?; + + debug!("kerberos: authentication complete"); + Ok(()) + } + + /// Authenticate using a cached credential from a ccache file. + /// + /// If the ccache has a service ticket for `cifs/`, uses it + /// directly (no KDC contact needed). If only a TGT is cached, performs a + /// TGS exchange to get the service ticket. + /// + /// After this returns `Ok(())`, call [`token()`](Self::token) and + /// [`session_key()`](Self::session_key) as usual. + pub async fn authenticate_from_ccache( + &mut self, + ccache: &crate::auth::kerberos::ccache::CCache, + server_hostname: &str, + ) -> Result<()> { + let realm = &self.credentials.realm; + + // Try cached service ticket first (no KDC needed). + if let Some(svc) = ccache.find_service_ticket("cifs", server_hostname, realm) { + debug!( + "kerberos: using cached service ticket for cifs/{}", + server_hostname + ); + self.load_service_ticket_from_ccache(svc)?; + self.build_ap_req()?; + debug!("kerberos: authentication complete (from cached service ticket)"); + return Ok(()); + } + + // Fall back to cached TGT + TGS exchange. + if let Some(tgt_cred) = ccache.find_tgt(realm) { + debug!( + "kerberos: using cached TGT, doing TGS exchange for cifs/{}", + server_hostname + ); + self.load_tgt_from_ccache(tgt_cred)?; + + let kdc_config = KdcConfig { + address: self.credentials.kdc_address.clone(), + timeout: Duration::from_secs(10), + }; + self.tgs_exchange(&kdc_config, server_hostname).await?; + self.build_ap_req()?; + debug!("kerberos: authentication complete (TGT from cache + TGS exchange)"); + return Ok(()); + } + + Err(Error::Auth { + message: format!("ccache has no TGT or service ticket for realm {realm}"), + }) + } + + /// Load a service ticket from a ccache credential entry. + fn load_service_ticket_from_ccache( + &mut self, + cred: &crate::auth::kerberos::ccache::CcacheCredential, + ) -> Result<()> { + // Parse the ticket from the raw DER bytes. + let ticket = crate::auth::kerberos::messages::parse_ticket(&cred.ticket)?; + + // Determine the etype from the session key. + let etype = etype_from_code(cred.key_etype as i32)?; + self.etype = etype; + + self.service_ticket = Some(ticket); + self.tgs_session_key = Some(cred.key_data.clone()); + self.session_key = Some(cred.key_data.clone()); + + Ok(()) + } + + /// Load a TGT from a ccache credential entry. + fn load_tgt_from_ccache( + &mut self, + cred: &crate::auth::kerberos::ccache::CcacheCredential, + ) -> Result<()> { + let ticket = crate::auth::kerberos::messages::parse_ticket(&cred.ticket)?; + + let etype = etype_from_code(cred.key_etype as i32)?; + self.etype = etype; + + self.tgt = Some(ticket); + self.as_session_key = Some(cred.key_data.clone()); + + Ok(()) + } + + /// Get the SPNEGO-wrapped AP-REQ token for SESSION_SETUP. + /// + /// Available after [`authenticate()`](Self::authenticate) succeeds. + pub fn token(&self) -> Option<&[u8]> { + self.ap_req_bytes.as_deref() + } + + /// Get the session key for SMB signing/encryption. + /// + /// Available after [`authenticate()`](Self::authenticate) succeeds. + pub fn session_key(&self) -> Option<&[u8]> { + self.session_key.as_deref() + } + + // ===================================================================== + // AS exchange + // ===================================================================== + + /// Perform the AS exchange to get a TGT. + async fn as_exchange(&mut self, kdc_config: &KdcConfig) -> Result<()> { + let realm = &self.credentials.realm; + let username = &self.credentials.username; + + // Client principal: username@REALM + let cname = PrincipalName { + name_type: 1, // KRB_NT_PRINCIPAL + name_string: vec![username.clone()], + }; + + // Service principal for TGT: krbtgt/REALM + let sname = PrincipalName { + name_type: 2, // KRB_NT_SRV_INST + name_string: vec!["krbtgt".to_string(), realm.clone()], + }; + + // Generate a random nonce. + let nonce = generate_nonce(); + + // Requested etypes: prefer AES-256, then AES-128, then RC4. + let etypes = [ + EncryptionType::Aes256CtsHmacSha196, + EncryptionType::Aes128CtsHmacSha196, + EncryptionType::Rc4Hmac, + ]; + + // First attempt: send AS-REQ without pre-authentication. + // Most KDCs will respond with KDC_ERR_PREAUTH_REQUIRED. + let as_req = encode_as_req(&cname, realm, &sname, nonce, &etypes, &[]); + let response = send_to_kdc(kdc_config, &as_req).await?; + + // Check if we got a KRB-ERROR (APPLICATION [30] = 0x7e). + trace!( + "kerberos: AS response first 32 bytes: {:02x?}", + &response[..response.len().min(32)] + ); + let response = if !response.is_empty() && response[0] == 0x7e { + let krb_error = parse_krb_error(&response)?; + + if krb_error.error_code == KDC_ERR_PREAUTH_REQUIRED { + debug!("kerberos: got KDC_ERR_PREAUTH_REQUIRED, retrying with pre-authentication"); + + // Extract supported etypes from e-data if available. + let chosen_etype = if let Some(ref e_data) = krb_error.e_data { + self.extract_best_etype(e_data).unwrap_or(self.etype) + } else { + self.etype + }; + self.etype = chosen_etype; + + // Derive the user's long-term key from the password. + let user_key = self.derive_user_key(); + + // Build PA-ENC-TIMESTAMP. + let (ctime, cusec) = current_kerberos_time(); + let timestamp_plaintext = encode_pa_enc_timestamp(&ctime, cusec); + let encrypted_timestamp = kerberos_encrypt( + &user_key, + KEY_USAGE_PA_ENC_TIMESTAMP, + ×tamp_plaintext, + self.etype, + ); + + let enc_timestamp_data = EncryptedData { + etype: self.etype as i32, + kvno: None, + cipher: encrypted_timestamp, + }; + let pa_enc_ts_value = encode_encrypted_data_raw(&enc_timestamp_data); + + // Build PA-PAC-REQUEST (request the PAC). + let pa_pac_value = encode_pa_pac_request(true); + + let padata = vec![ + PaData { + padata_type: PA_ENC_TIMESTAMP, + padata_value: pa_enc_ts_value, + }, + PaData { + padata_type: PA_PAC_REQUEST, + padata_value: pa_pac_value, + }, + ]; + + // Retry AS-REQ with pre-authentication. + let as_req = encode_as_req(&cname, realm, &sname, nonce, &etypes, &padata); + send_to_kdc(kdc_config, &as_req).await? + } else { + return Err(Error::Auth { + message: format!( + "Kerberos AS exchange failed: KRB-ERROR code {} ({})", + krb_error.error_code, + krb_error.e_text.unwrap_or_default() + ), + }); + } + } else { + response + }; + + // Check for error in the response to the pre-auth attempt. + if !response.is_empty() && response[0] == 0x7e { + let krb_error = parse_krb_error(&response)?; + return Err(Error::Auth { + message: format!( + "Kerberos AS exchange failed: KRB-ERROR code {} ({})", + krb_error.error_code, + krb_error.e_text.unwrap_or_default() + ), + }); + } + + // Parse AS-REP (APPLICATION [11] = 0x6b). + let as_rep = parse_kdc_rep(&response)?; + if as_rep.msg_type != 11 { + return Err(Error::invalid_data(format!( + "Kerberos: expected AS-REP (msg_type 11), got {}", + as_rep.msg_type + ))); + } + + // Update etype from what the KDC actually chose. + self.etype = etype_from_i32(as_rep.enc_part.etype)?; + debug!( + "kerberos: AS-REP etype={}, kvno={:?}, cipher_len={}, crealm={}, cname={:?}", + as_rep.enc_part.etype, + as_rep.enc_part.kvno, + as_rep.enc_part.cipher.len(), + as_rep.crealm, + as_rep.cname.name_string, + ); + + // Derive the user's long-term key (may have been derived already, + // but etype might have changed based on the KDC response). + let user_key = self.derive_user_key(); + debug!( + "kerberos: user_key len={}, etype={:?}, salt={}{}, key_prefix={:02x?}", + user_key.len(), + self.etype, + &self.credentials.realm, + &self.credentials.username, + &user_key[..user_key.len().min(8)], + ); + + // Decrypt the enc-part to get the session key. + let enc_part_plain = kerberos_decrypt( + &user_key, + KEY_USAGE_AS_REP_ENC_PART, + &as_rep.enc_part.cipher, + self.etype, + )?; + + let enc_kdc_rep = parse_enc_kdc_rep_part(&enc_part_plain)?; + + trace!( + "kerberos: AS session key type={}, len={}", + enc_kdc_rep.key.keytype, + enc_kdc_rep.key.keyvalue.len() + ); + + self.tgt = Some(as_rep.ticket); + self.as_session_key = Some(enc_kdc_rep.key.keyvalue); + + Ok(()) + } + + // ===================================================================== + // TGS exchange + // ===================================================================== + + /// Perform the TGS exchange to get a service ticket. + async fn tgs_exchange(&mut self, kdc_config: &KdcConfig, server_hostname: &str) -> Result<()> { + let tgt = self + .tgt + .as_ref() + .ok_or_else(|| Error::Auth { + message: "TGS exchange requires a TGT (run AS exchange first)".to_string(), + })? + .clone(); + let as_session_key = self + .as_session_key + .as_ref() + .ok_or_else(|| Error::Auth { + message: "TGS exchange requires AS session key".to_string(), + })? + .clone(); + + let realm = &self.credentials.realm; + let username = &self.credentials.username; + + // Service principal: cifs/server_hostname + let sname = PrincipalName { + name_type: 2, // KRB_NT_SRV_INST + name_string: vec!["cifs".to_string(), server_hostname.to_string()], + }; + + // Build an AP-REQ wrapping the TGT for the TGS (PA-TGS-REQ). + let cname = PrincipalName { + name_type: 1, + name_string: vec![username.clone()], + }; + + let nonce = generate_nonce(); + // Request etypes in preference order. The KDC picks the session key + // type from this list. AES-256 preferred, with AES-128 and RC4 fallback. + let etypes = [ + EncryptionType::Aes256CtsHmacSha196, + EncryptionType::Aes128CtsHmacSha196, + EncryptionType::Rc4Hmac, + ]; + + // Build the KDC-REQ-BODY first, so we can compute a checksum + // over it for the Authenticator (required per RFC 4120 section 7.2.2). + let req_body = encode_tgs_req_body(realm, &sname, nonce, &etypes); + + // Compute checksum over KDC-REQ-BODY using key usage 6 + // (PA-TGS-REQ padata AP-REQ Authenticator cksum). + let body_checksum = compute_checksum(&as_session_key, 6, &req_body, self.etype); + let checksum_type: i32 = match self.etype { + EncryptionType::Aes256CtsHmacSha196 => 16, // hmac-sha1-96-aes256 + EncryptionType::Aes128CtsHmacSha196 => 15, // hmac-sha1-96-aes128 + EncryptionType::Rc4Hmac => -138, // HMAC_MD5 (KERB_CHECKSUM_HMAC_MD5) + }; + + let (ctime, cusec) = current_kerberos_time(); + let authenticator_plain = encode_authenticator( + realm, + &cname, + &ctime, + cusec, + None, + None, + Some((&body_checksum, checksum_type)), + ); + + debug!( + "kerberos: TGS authenticator plain ({} bytes), session key prefix={:02x?}", + authenticator_plain.len(), + &as_session_key[..as_session_key.len().min(8)] + ); + + let encrypted_authenticator = kerberos_encrypt( + &as_session_key, + KEY_USAGE_AP_REQ_AUTHENTICATOR, + &authenticator_plain, + self.etype, + ); + + let authenticator_enc_data = EncryptedData { + etype: self.etype as i32, + kvno: None, + cipher: encrypted_authenticator, + }; + + let tgt_ap_req = encode_ap_req(&tgt, &authenticator_enc_data, false); + + let tgs_req = encode_tgs_req(realm, &sname, nonce, &etypes, &tgt_ap_req); + let response = send_to_kdc(kdc_config, &tgs_req).await?; + + // Check for KRB-ERROR. + if !response.is_empty() && response[0] == 0x7e { + let krb_error = parse_krb_error(&response)?; + return Err(Error::Auth { + message: format!( + "Kerberos TGS exchange failed: KRB-ERROR code {} ({})", + krb_error.error_code, + krb_error.e_text.unwrap_or_default() + ), + }); + } + + // Parse TGS-REP (APPLICATION [13] = 0x6d). + let tgs_rep = parse_kdc_rep(&response)?; + debug!( + "kerberos: TGS-REP ticket etype={}, kvno={:?}, cipher_len={}", + tgs_rep.ticket.enc_part.etype, + tgs_rep.ticket.enc_part.kvno, + tgs_rep.ticket.enc_part.cipher.len() + ); + debug!( + "kerberos: TGS-REP enc-part etype={}, kvno={:?}", + tgs_rep.enc_part.etype, tgs_rep.enc_part.kvno + ); + if tgs_rep.msg_type != 13 { + return Err(Error::invalid_data(format!( + "Kerberos: expected TGS-REP (msg_type 13), got {}", + tgs_rep.msg_type + ))); + } + + // Decrypt the enc-part with the AS session key. + // Try key usage 8 first (session key), fall back to 9 (subkey). + let enc_part_plain = match kerberos_decrypt( + &as_session_key, + KEY_USAGE_TGS_REP_ENC_PART_SESSION_KEY, + &tgs_rep.enc_part.cipher, + self.etype, + ) { + Ok(plain) => plain, + Err(_) => { + debug!("kerberos: TGS-REP decryption with key usage 8 failed, trying 9"); + kerberos_decrypt( + &as_session_key, + KEY_USAGE_TGS_REP_ENC_PART_SUBKEY, + &tgs_rep.enc_part.cipher, + self.etype, + )? + } + }; + + let enc_kdc_rep = parse_enc_kdc_rep_part(&enc_part_plain)?; + + trace!( + "kerberos: TGS session key type={}, len={}", + enc_kdc_rep.key.keytype, + enc_kdc_rep.key.keyvalue.len() + ); + + // Log ticket raw bytes info. + debug!( + "kerberos: service ticket has raw_bytes={}, raw_len={:?}", + tgs_rep.ticket.raw_bytes.is_some(), + tgs_rep.ticket.raw_bytes.as_ref().map(|b| b.len()) + ); + + // Use the session key's actual etype for Authenticator encryption. + let tgs_key_etype = match enc_kdc_rep.key.keytype { + 18 => EncryptionType::Aes256CtsHmacSha196, + 17 => EncryptionType::Aes128CtsHmacSha196, + 23 => EncryptionType::Rc4Hmac, + other => { + return Err(Error::Auth { + message: format!("TGS session key has unsupported etype {other}"), + }); + } + }; + self.etype = tgs_key_etype; + + self.service_ticket = Some(tgs_rep.ticket); + self.tgs_session_key = Some(enc_kdc_rep.key.keyvalue.clone()); + self.session_key = Some(enc_kdc_rep.key.keyvalue); + + Ok(()) + } + + // ===================================================================== + // AP-REQ construction + // ===================================================================== + + /// Build the AP-REQ and wrap it in SPNEGO for SESSION_SETUP. + fn build_ap_req(&mut self) -> Result<()> { + let service_ticket = self + .service_ticket + .as_ref() + .ok_or_else(|| Error::Auth { + message: "AP-REQ requires a service ticket (run TGS exchange first)".to_string(), + })? + .clone(); + let tgs_session_key = self + .tgs_session_key + .as_ref() + .ok_or_else(|| Error::Auth { + message: "AP-REQ requires TGS session key".to_string(), + })? + .clone(); + + let realm = &self.credentials.realm; + let username = &self.credentials.username; + + let cname = PrincipalName { + name_type: 1, + name_string: vec![username.clone()], + }; + + // Build and encrypt the Authenticator. + let (ctime, cusec) = current_kerberos_time(); + + // Minimal Authenticator: no subkey, no seq-number, no checksum. + // This matches impacket's working implementation. Windows accepts + // this minimal format for SMB Kerberos authentication. + let authenticator_plain = encode_authenticator( + realm, &cname, &ctime, cusec, None, // no subkey + None, // no seq-number + None, // no checksum + ); + + let encrypted_authenticator = kerberos_encrypt( + &tgs_session_key, + KEY_USAGE_AP_REQ_AUTHENTICATOR_SPNEGO, + &authenticator_plain, + self.etype, + ); + + let authenticator_enc_data = EncryptedData { + etype: self.etype as i32, + kvno: None, + cipher: encrypted_authenticator, + }; + + let ap_req = encode_ap_req(&service_ticket, &authenticator_enc_data, true); + + // Wrap the AP-REQ in a Kerberos GSS-API initial context token + // (RFC 1964): APPLICATION [0] { OID, 0x0100, AP-REQ }. + // Windows SPNEGO expects this wrapping in the NegTokenInit mechToken. + let gss_mech_token = { + // Standard Kerberos OID 1.2.840.113554.1.2.2 (for GSS inner token) + let oid_bytes: &[u8] = &OID_KERBEROS[2..]; // skip tag+length + let mut inner = Vec::new(); + inner.push(0x06); // OID tag + inner.push(oid_bytes.len() as u8); + inner.extend_from_slice(oid_bytes); + inner.extend_from_slice(&[0x01, 0x00]); // KRB_AP_REQ token ID + inner.extend_from_slice(&ap_req); + + let mut token = Vec::new(); + token.push(0x60); // APPLICATION [0] + if inner.len() < 128 { + token.push(inner.len() as u8); + } else if inner.len() < 256 { + token.push(0x81); + token.push(inner.len() as u8); + } else { + token.push(0x82); + token.push((inner.len() >> 8) as u8); + token.push((inner.len() & 0xff) as u8); + } + token.extend_from_slice(&inner); + token + }; + + // Wrap in SPNEGO NegTokenInit with MS Kerberos OID. + let spnego_token = wrap_neg_token_init(&[OID_MS_KERBEROS], &gss_mech_token); + + // The SMB session key is the TGS session key. + self.session_key = Some(tgs_session_key); + self.ap_req_bytes = Some(spnego_token); + + Ok(()) + } + + /// Process the server's mutual authentication token from SPNEGO. + /// + /// The token may be GSS-API wrapped. After unwrapping, the 2-byte token ID + /// tells us what it is: + /// - `02 00`: AP-REP — contains optional server subkey + /// - `03 00`: KRB-ERROR — logged but not fatal (session may still be valid) + pub fn process_mutual_auth_token(&mut self, token_bytes: &[u8]) -> Result<()> { + use crate::auth::kerberos::messages::{ + parse_ap_rep, parse_enc_ap_rep_part, parse_krb_error, + }; + + // Unwrap GSS-API APPLICATION [0] wrapper if present. + let inner = if !token_bytes.is_empty() && token_bytes[0] == 0x60 { + // Skip APPLICATION [0] header + OID + let (_, gss_inner, _) = + crate::auth::kerberos::messages::parse_gss_api_wrapper(token_bytes)?; + gss_inner + } else { + token_bytes.to_vec() + }; + + if inner.len() < 2 { + return Err(Error::invalid_data("Kerberos: mutual auth token too short")); + } + + let token_id = [inner[0], inner[1]]; + let krb_data = &inner[2..]; + + match token_id { + [0x02, 0x00] => { + // AP-REP + debug!("kerberos: processing AP-REP from server"); + let ap_rep = parse_ap_rep(krb_data)?; + + const KEY_USAGE_AP_REP_ENC_PART: u32 = 12; + let current_key = self.session_key.as_ref().ok_or_else(|| Error::Auth { + message: "No session key available to decrypt AP-REP".to_string(), + })?; + + let etype = match ap_rep.enc_part.etype { + 18 => EncryptionType::Aes256CtsHmacSha196, + 17 => EncryptionType::Aes128CtsHmacSha196, + 23 => EncryptionType::Rc4Hmac, + other => { + return Err(Error::Auth { + message: format!("AP-REP: unsupported etype {other}"), + }) + } + }; + + let plain = kerberos_decrypt( + current_key, + KEY_USAGE_AP_REP_ENC_PART, + &ap_rep.enc_part.cipher, + etype, + )?; + + let enc_part = parse_enc_ap_rep_part(&plain)?; + + if let Some(server_subkey) = enc_part.subkey { + debug!( + "kerberos: AP-REP server subkey, etype={}, len={}", + server_subkey.keytype, + server_subkey.keyvalue.len() + ); + self.session_key = Some(server_subkey.keyvalue); + } else { + debug!("kerberos: AP-REP has no server subkey"); + } + } + [0x03, 0x00] => { + // KRB-ERROR — the server's Kerberos layer reported an error, + // but the SMB session may still be valid. Log and continue. + match parse_krb_error(krb_data) { + Ok(err) => { + debug!( + "kerberos: mutual auth KRB-ERROR code={}, realm={}, sname={:?}, e_text={:?}, e_data={:02x?}", + err.error_code, err.realm, err.sname, err.e_text, + err.e_data.as_deref().unwrap_or(&[]) + ); + } + Err(e) => { + debug!("kerberos: failed to parse KRB-ERROR in mutual auth: {}", e); + } + } + } + _ => { + debug!( + "kerberos: unexpected mutual auth token ID: {:02x} {:02x}", + token_id[0], token_id[1] + ); + } + } + + Ok(()) + } + + // ===================================================================== + // Helpers + // ===================================================================== + + /// Derive the user's long-term key from the password. + fn derive_user_key(&self) -> Vec { + let salt = format!("{}{}", self.credentials.realm, self.credentials.username); + match self.etype { + EncryptionType::Aes256CtsHmacSha196 => { + string_to_key_aes(&self.credentials.password, &salt, 32) + } + EncryptionType::Aes128CtsHmacSha196 => { + string_to_key_aes(&self.credentials.password, &salt, 16) + } + EncryptionType::Rc4Hmac => string_to_key_rc4(&self.credentials.password), + } + } + + /// Extract the best supported etype from ETYPE-INFO2 in the KRB-ERROR e-data. + /// + /// The e-data for KDC_ERR_PREAUTH_REQUIRED contains a METHOD-DATA + /// (SEQUENCE OF PA-DATA). We look for PA-ETYPE-INFO2 (type 19) which + /// contains a SEQUENCE OF ETYPE-INFO2-ENTRY. + fn extract_best_etype(&self, e_data: &[u8]) -> Option { + // Parse METHOD-DATA: SEQUENCE OF PA-DATA. + // Each PA-DATA is SEQUENCE { [1] padata-type INTEGER, [2] padata-value OCTET STRING }. + // We look for padata-type 19 (PA-ETYPE-INFO2). + let entries = parse_method_data(e_data).ok()?; + + for entry in &entries { + if entry.padata_type == PA_ETYPE_INFO2 { + // Parse ETYPE-INFO2: SEQUENCE OF ETYPE-INFO2-ENTRY + // Each entry: SEQUENCE { [0] etype INTEGER, [1] salt GeneralString OPTIONAL, ... } + if let Some(etype) = parse_etype_info2_best(&entry.padata_value) { + return Some(etype); + } + } + } + + None + } +} + +// ========================================================================= +// DER encoding helpers for PA-DATA values +// ========================================================================= + +/// Encode an EncryptedData as raw DER (for embedding in PA-DATA values). +fn encode_encrypted_data_raw(ed: &EncryptedData) -> Vec { + // EncryptedData ::= SEQUENCE { + // etype [0] Int32, + // kvno [1] UInt32 OPTIONAL, + // cipher [2] OCTET STRING + // } + let etype = der_context(0, &der_integer(ed.etype)); + let cipher = der_context(2, &der_octet_string(&ed.cipher)); + if let Some(kvno) = ed.kvno { + let kvno_enc = der_context(1, &der_integer(kvno)); + der_sequence(&[&etype, &kvno_enc, &cipher]) + } else { + der_sequence(&[&etype, &cipher]) + } +} + +/// Encode a PA-PAC-REQUEST value. +/// +/// KERB-PA-PAC-REQUEST ::= SEQUENCE { +/// include-pac [0] BOOLEAN +/// } +fn encode_pa_pac_request(include_pac: bool) -> Vec { + let bool_val: &[u8] = if include_pac { + &[0x01, 0x01, 0xff] + } else { + &[0x01, 0x01, 0x00] + }; + let include = der_context(0, bool_val); + der_sequence(&[&include]) +} + +// ========================================================================= +// Parsing helpers for PREAUTH_REQUIRED e-data +// ========================================================================= + +/// Parse METHOD-DATA (SEQUENCE OF PA-DATA) from a KRB-ERROR's e-data. +fn parse_method_data(data: &[u8]) -> Result> { + let (tag, seq_data, _) = parse_der_tlv_local(data)?; + if tag != 0x30 { + return Err(Error::invalid_data(format!( + "Kerberos: expected SEQUENCE for METHOD-DATA, got 0x{tag:02x}" + ))); + } + + let mut entries = Vec::new(); + let mut pos = 0; + while pos < seq_data.len() { + let (entry_tag, entry_data, consumed) = parse_der_tlv_local(&seq_data[pos..])?; + if entry_tag == 0x30 { + // PA-DATA SEQUENCE + let fields = parse_sequence_fields_local(entry_data)?; + let mut padata_type = None; + let mut padata_value = None; + for (ftag, fvalue) in &fields { + match ftag { + 0xa1 => padata_type = Some(parse_der_integer_local(fvalue)?), + 0xa2 => padata_value = Some(parse_der_octet_string_local(fvalue)?), + _ => {} + } + } + if let (Some(pt), Some(pv)) = (padata_type, padata_value) { + entries.push(PaData { + padata_type: pt, + padata_value: pv, + }); + } + } + pos += consumed; + } + + Ok(entries) +} + +/// Parse the best etype from an ETYPE-INFO2 value. +/// +/// Returns the first etype we support, preferring AES-256 > AES-128 > RC4. +fn parse_etype_info2_best(data: &[u8]) -> Option { + let (tag, seq_data, _) = parse_der_tlv_local(data).ok()?; + if tag != 0x30 { + return None; + } + + let mut best: Option = None; + + let mut pos = 0; + while pos < seq_data.len() { + let (entry_tag, entry_data, consumed) = parse_der_tlv_local(&seq_data[pos..]).ok()?; + if entry_tag == 0x30 { + let fields = parse_sequence_fields_local(entry_data).ok()?; + for (ftag, fvalue) in &fields { + if *ftag == 0xa0 { + if let Ok(etype_val) = parse_der_integer_local(fvalue) { + if let Ok(et) = etype_from_i32(etype_val) { + match (&best, et) { + (None, _) => best = Some(et), + (Some(EncryptionType::Rc4Hmac), _) + if et != EncryptionType::Rc4Hmac => + { + best = Some(et); + } + ( + Some(EncryptionType::Aes128CtsHmacSha196), + EncryptionType::Aes256CtsHmacSha196, + ) => { + best = Some(et); + } + _ => {} + } + } + } + } + } + } + pos += consumed; + } + + best +} + +// ========================================================================= +// Minimal DER helpers (local, to avoid depending on messages.rs internals) +// ========================================================================= + +/// Parse a DER TLV, returning `(tag, value_slice, total_bytes_consumed)`. +fn parse_der_tlv_local(data: &[u8]) -> Result<(u8, &[u8], usize)> { + if data.is_empty() { + return Err(Error::invalid_data("Kerberos: truncated DER TLV")); + } + let tag = data[0]; + let (len, len_bytes) = parse_der_length_local(&data[1..])?; + let header_len = 1 + len_bytes; + let total = header_len + len; + if data.len() < total { + return Err(Error::invalid_data(format!( + "Kerberos: DER TLV truncated: need {total} bytes, have {}", + data.len() + ))); + } + Ok((tag, &data[header_len..total], total)) +} + +/// Parse a DER length field. +fn parse_der_length_local(data: &[u8]) -> Result<(usize, usize)> { + if data.is_empty() { + return Err(Error::invalid_data("Kerberos: truncated DER length")); + } + let first = data[0]; + if first < 128 { + Ok((first as usize, 1)) + } else if first == 0x81 { + if data.len() < 2 { + return Err(Error::invalid_data("Kerberos: truncated DER length (0x81)")); + } + Ok((data[1] as usize, 2)) + } else if first == 0x82 { + if data.len() < 3 { + return Err(Error::invalid_data("Kerberos: truncated DER length (0x82)")); + } + let len = ((data[1] as usize) << 8) | (data[2] as usize); + Ok((len, 3)) + } else { + Err(Error::invalid_data(format!( + "Kerberos: unsupported DER length encoding: 0x{first:02x}" + ))) + } +} + +/// Parse all TLV elements in a SEQUENCE body. +fn parse_sequence_fields_local(data: &[u8]) -> Result)>> { + let mut fields = Vec::new(); + let mut pos = 0; + while pos < data.len() { + let (tag, value, consumed) = parse_der_tlv_local(&data[pos..])?; + fields.push((tag, value.to_vec())); + pos += consumed; + } + Ok(fields) +} + +/// Parse a DER INTEGER TLV, returning i32. +fn parse_der_integer_local(data: &[u8]) -> Result { + let (tag, value, _) = parse_der_tlv_local(data)?; + if tag != 0x02 { + return Err(Error::invalid_data(format!( + "Kerberos: expected INTEGER (0x02), got 0x{tag:02x}" + ))); + } + if value.is_empty() { + return Err(Error::invalid_data("Kerberos: empty INTEGER")); + } + let negative = value[0] & 0x80 != 0; + let mut val: i64 = if negative { -1 } else { 0 }; + for &b in value { + val = (val << 8) | (b as i64); + } + Ok(val as i32) +} + +/// Parse a DER OCTET STRING TLV, returning the raw bytes. +fn parse_der_octet_string_local(data: &[u8]) -> Result> { + let (tag, value, _) = parse_der_tlv_local(data)?; + if tag != 0x04 { + return Err(Error::invalid_data(format!( + "Kerberos: expected OCTET STRING (0x04), got 0x{tag:02x}" + ))); + } + Ok(value.to_vec()) +} + +// ========================================================================= +// DER encoding helpers +// ========================================================================= + +/// Encode a DER length field. +fn der_length(len: usize) -> Vec { + if len < 128 { + vec![len as u8] + } else if len < 256 { + vec![0x81, len as u8] + } else { + vec![0x82, (len >> 8) as u8, (len & 0xff) as u8] + } +} + +/// Wrap data in a DER TLV. +fn der_tlv(tag: u8, data: &[u8]) -> Vec { + let mut out = vec![tag]; + out.extend_from_slice(&der_length(data.len())); + out.extend_from_slice(data); + out +} + +/// Encode a context-specific constructed tag. +fn der_context(tag_num: u8, data: &[u8]) -> Vec { + der_tlv(0xa0 | tag_num, data) +} + +/// Encode an ASN.1 INTEGER. +fn der_integer(val: i32) -> Vec { + let bytes = val.to_be_bytes(); + let mut start = 0; + if val >= 0 { + while start < 3 && bytes[start] == 0x00 && bytes[start + 1] & 0x80 == 0 { + start += 1; + } + } else { + while start < 3 && bytes[start] == 0xff && bytes[start + 1] & 0x80 != 0 { + start += 1; + } + } + der_tlv(0x02, &bytes[start..]) +} + +/// Encode a DER OCTET STRING. +fn der_octet_string(data: &[u8]) -> Vec { + der_tlv(0x04, data) +} + +/// Encode a DER SEQUENCE from pre-encoded items. +fn der_sequence(items: &[&[u8]]) -> Vec { + let mut contents = Vec::new(); + for item in items { + contents.extend_from_slice(item); + } + der_tlv(0x30, &contents) +} + +// ========================================================================= +// Time and random helpers +// ========================================================================= + +/// Get the current time in Kerberos GeneralizedTime format and microseconds. +/// +/// Format: "YYYYMMDDHHmmssZ" (UTC). +fn current_kerberos_time() -> (String, u32) { + use std::time::SystemTime; + + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("system clock before epoch"); + + let total_secs = now.as_secs(); + let usec = now.subsec_micros(); + + // Convert seconds since epoch to date/time components. + // This is a simplified UTC calculation (no leap seconds, which is fine + // for Kerberos timestamps). + let (year, month, day, hour, minute, second) = secs_to_datetime(total_secs); + + let time_str = format!( + "{:04}{:02}{:02}{:02}{:02}{:02}Z", + year, month, day, hour, minute, second + ); + + (time_str, usec) +} + +/// Convert seconds since Unix epoch to (year, month, day, hour, minute, second). +fn secs_to_datetime(secs: u64) -> (u32, u32, u32, u32, u32, u32) { + // Days since epoch. + let days = secs / 86400; + let time_of_day = secs % 86400; + + let hour = (time_of_day / 3600) as u32; + let minute = ((time_of_day % 3600) / 60) as u32; + let second = (time_of_day % 60) as u32; + + // Civil date from days since 1970-01-01 (algorithm from Howard Hinnant). + let z = days as i64 + 719468; + let era = if z >= 0 { z } else { z - 146096 } / 146097; + let doe = (z - era * 146097) as u64; // [0, 146096] + let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; // [0, 399] + let y = (yoe as i64) + era * 400; + let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); // [0, 365] + let mp = (5 * doy + 2) / 153; // [0, 11] + let d = doy - (153 * mp + 2) / 5 + 1; // [1, 31] + let m = if mp < 10 { mp + 3 } else { mp - 9 }; // [1, 12] + let y = if m <= 2 { y + 1 } else { y }; + + (y as u32, m as u32, d as u32, hour, minute, second) +} + +/// Convert an etype integer code to an [`EncryptionType`] enum value. +fn etype_from_code(code: i32) -> Result { + match code { + 18 => Ok(EncryptionType::Aes256CtsHmacSha196), + 17 => Ok(EncryptionType::Aes128CtsHmacSha196), + 23 => Ok(EncryptionType::Rc4Hmac), + other => Err(Error::Auth { + message: format!("unsupported etype {other}"), + }), + } +} + +/// Generate a random 32-bit nonce. +fn generate_nonce() -> u32 { + let mut buf = [0u8; 4]; + getrandom::fill(&mut buf).expect("CSPRNG failed"); + u32::from_ne_bytes(buf) & 0x7FFF_FFFF // Ensure positive (Kerberos nonce is UInt32 but some KDCs treat it as signed) +} + +// ========================================================================= +// Tests +// ========================================================================= + +#[cfg(test)] +mod tests { + use super::*; + use crate::auth::kerberos::crypto::{ + generate_random_key, kerberos_decrypt, kerberos_encrypt, string_to_key_aes, + }; + use crate::auth::kerberos::messages::{ + encode_ap_req, encode_as_req, encode_authenticator, encode_pa_enc_timestamp, EncryptedData, + PrincipalName, Ticket, + }; + use crate::auth::spnego::OID_NTLMSSP; + + // ── Time formatting tests ──────────────────────────────────────── + + #[test] + fn secs_to_datetime_epoch() { + let (y, m, d, h, mi, s) = secs_to_datetime(0); + assert_eq!((y, m, d, h, mi, s), (1970, 1, 1, 0, 0, 0)); + } + + #[test] + fn secs_to_datetime_known_date() { + // 2026-04-08 12:00:00 UTC + // Unix timestamp: 1775649600 + let (y, m, d, h, mi, s) = secs_to_datetime(1775649600); + assert_eq!((y, m, d, h, mi, s), (2026, 4, 8, 12, 0, 0)); + } + + #[test] + fn secs_to_datetime_leap_year() { + // 2024-02-29 00:00:00 UTC + // Unix timestamp: 1709164800 + let (y, m, d, _, _, _) = secs_to_datetime(1709164800); + assert_eq!((y, m, d), (2024, 2, 29)); + } + + #[test] + fn current_kerberos_time_format() { + let (time_str, _cusec) = current_kerberos_time(); + assert_eq!( + time_str.len(), + 15, + "GeneralizedTime should be 15 chars: {time_str}" + ); + assert!(time_str.ends_with('Z'), "should end with Z: {time_str}"); + // Should be parseable: YYYYMMDDHHMMSSZ + assert!(time_str[..4].parse::().is_ok(), "year: {time_str}"); + } + + // ── Nonce generation ───────────────────────────────────────────── + + #[test] + fn generate_nonce_is_positive() { + for _ in 0..100 { + let n = generate_nonce(); + assert!(n <= 0x7FFF_FFFF, "nonce should be positive: {n}"); + } + } + + #[test] + fn generate_nonce_not_constant() { + let n1 = generate_nonce(); + let n2 = generate_nonce(); + // With 31 bits, collision probability is ~2^-31, negligible. + // But allow it just in case. + if n1 == n2 { + let n3 = generate_nonce(); + assert!( + n1 != n3 || n2 != n3, + "three consecutive identical nonces is suspicious" + ); + } + } + + // ── PA-PAC-REQUEST encoding ────────────────────────────────────── + + #[test] + fn encode_pa_pac_request_true() { + let encoded = encode_pa_pac_request(true); + // SEQUENCE { [0] BOOLEAN TRUE } + assert_eq!(encoded[0], 0x30); // SEQUENCE + // Should contain 0xff for TRUE + assert!(encoded.windows(3).any(|w| w == [0x01, 0x01, 0xff])); + } + + #[test] + fn encode_pa_pac_request_false() { + let encoded = encode_pa_pac_request(false); + assert_eq!(encoded[0], 0x30); + assert!(encoded.windows(3).any(|w| w == [0x01, 0x01, 0x00])); + } + + // ── PA-ENC-TIMESTAMP encrypt ───────────────────────────────────── + + #[test] + fn pa_enc_timestamp_produces_valid_encrypted_data() { + let key = string_to_key_aes("password", "EXAMPLE.COMuser", 32); + let timestamp_plain = encode_pa_enc_timestamp("20260408120000Z", 123456); + + let ciphertext = kerberos_encrypt( + &key, + KEY_USAGE_PA_ENC_TIMESTAMP, + ×tamp_plain, + EncryptionType::Aes256CtsHmacSha196, + ); + + // Should be non-empty and longer than just the HMAC. + assert!( + ciphertext.len() > 12, + "ciphertext too short: {}", + ciphertext.len() + ); + + // Should decrypt successfully. + let decrypted = kerberos_decrypt( + &key, + KEY_USAGE_PA_ENC_TIMESTAMP, + &ciphertext, + EncryptionType::Aes256CtsHmacSha196, + ) + .unwrap(); + + assert_eq!(decrypted, timestamp_plain); + } + + // ── Authenticator encrypt ──────────────────────────────────────── + + #[test] + fn authenticator_encrypt_decrypt_roundtrip() { + let key = generate_random_key(EncryptionType::Aes256CtsHmacSha196); + + let cname = PrincipalName { + name_type: 1, + name_string: vec!["user".to_string()], + }; + let authenticator_plain = encode_authenticator( + "EXAMPLE.COM", + &cname, + "20260408120000Z", + 0, + None, + None, + None, + ); + + let encrypted = kerberos_encrypt( + &key, + KEY_USAGE_AP_REQ_AUTHENTICATOR, + &authenticator_plain, + EncryptionType::Aes256CtsHmacSha196, + ); + + let decrypted = kerberos_decrypt( + &key, + KEY_USAGE_AP_REQ_AUTHENTICATOR, + &encrypted, + EncryptionType::Aes256CtsHmacSha196, + ) + .unwrap(); + + assert_eq!(decrypted, authenticator_plain); + } + + // ── AP-REQ construction ────────────────────────────────────────── + + #[test] + fn build_ap_req_produces_spnego_wrapped_token() { + // Build a fake service ticket. + let ticket = Ticket { + tkt_vno: 5, + realm: "EXAMPLE.COM".to_string(), + sname: PrincipalName { + name_type: 2, + name_string: vec!["cifs".to_string(), "server.example.com".to_string()], + }, + enc_part: EncryptedData { + etype: 18, + kvno: Some(1), + cipher: vec![0xDE, 0xAD, 0xBE, 0xEF], + }, + raw_bytes: None, + }; + + let session_key = generate_random_key(EncryptionType::Aes256CtsHmacSha196); + + let cname = PrincipalName { + name_type: 1, + name_string: vec!["user".to_string()], + }; + + let authenticator_plain = encode_authenticator( + "EXAMPLE.COM", + &cname, + "20260408120000Z", + 0, + None, + None, + None, + ); + + let encrypted_auth = kerberos_encrypt( + &session_key, + KEY_USAGE_AP_REQ_AUTHENTICATOR, + &authenticator_plain, + EncryptionType::Aes256CtsHmacSha196, + ); + + let auth_enc_data = EncryptedData { + etype: 18, + kvno: None, + cipher: encrypted_auth, + }; + + let ap_req = encode_ap_req(&ticket, &auth_enc_data, false); + + // AP-REQ should start with APPLICATION [14] = 0x6e. + assert_eq!(ap_req[0], 0x6e, "AP-REQ should start with APPLICATION [14]"); + + // Wrap in SPNEGO. + let spnego = wrap_neg_token_init(&[OID_KERBEROS, OID_NTLMSSP], &ap_req); + + // SPNEGO NegTokenInit starts with APPLICATION [0] = 0x60. + assert_eq!( + spnego[0], 0x60, + "SPNEGO token should start with APPLICATION [0]" + ); + + // Should contain the SPNEGO OID. + assert!( + spnego + .windows(OID_KERBEROS.len()) + .any(|w| w == OID_KERBEROS), + "SPNEGO token should contain the Kerberos OID" + ); + } + + // ── AS-REQ construction ────────────────────────────────────────── + + #[test] + fn as_req_with_padata_contains_pa_types() { + let cname = PrincipalName { + name_type: 1, + name_string: vec!["user".to_string()], + }; + let sname = PrincipalName { + name_type: 2, + name_string: vec!["krbtgt".to_string(), "EXAMPLE.COM".to_string()], + }; + + let pa_pac = PaData { + padata_type: PA_PAC_REQUEST, + padata_value: encode_pa_pac_request(true), + }; + + let encoded = encode_as_req( + &cname, + "EXAMPLE.COM", + &sname, + 12345, + &[EncryptionType::Aes256CtsHmacSha196], + &[pa_pac], + ); + + // Should start with APPLICATION [10] = 0x6a. + assert_eq!(encoded[0], 0x6a); + + // Should be non-trivial size (with padata it's bigger). + assert!( + encoded.len() > 50, + "AS-REQ with padata should be substantial" + ); + } + + // ── EncryptedData encoding ─────────────────────────────────────── + + #[test] + fn encode_encrypted_data_raw_has_sequence_tag() { + let ed = EncryptedData { + etype: 18, + kvno: None, + cipher: vec![0x01, 0x02, 0x03], + }; + let encoded = encode_encrypted_data_raw(&ed); + assert_eq!(encoded[0], 0x30, "EncryptedData should be a SEQUENCE"); + } + + #[test] + fn encode_encrypted_data_raw_with_kvno() { + let ed = EncryptedData { + etype: 18, + kvno: Some(2), + cipher: vec![0x01, 0x02, 0x03], + }; + let encoded = encode_encrypted_data_raw(&ed); + // Should contain the kvno field (context tag [1]). + assert!( + encoded.windows(2).any(|w| w[0] == 0xa1), + "should contain kvno field [1]" + ); + } + + // ── ETYPE-INFO2 parsing ────────────────────────────────────────── + + #[test] + fn parse_etype_info2_best_aes256() { + // Build a minimal ETYPE-INFO2 with AES-256 and RC4. + // SEQUENCE { SEQUENCE { [0] INTEGER 18 }, SEQUENCE { [0] INTEGER 23 } } + let entry_18 = der_sequence(&[&der_context(0, &der_integer(18))]); + let entry_23 = der_sequence(&[&der_context(0, &der_integer(23))]); + let etype_info2 = der_sequence(&[&entry_18, &entry_23]); + + let best = parse_etype_info2_best(&etype_info2); + assert_eq!(best, Some(EncryptionType::Aes256CtsHmacSha196)); + } + + #[test] + fn parse_etype_info2_best_prefers_aes256_over_aes128() { + let entry_17 = der_sequence(&[&der_context(0, &der_integer(17))]); + let entry_18 = der_sequence(&[&der_context(0, &der_integer(18))]); + let etype_info2 = der_sequence(&[&entry_17, &entry_18]); + + let best = parse_etype_info2_best(&etype_info2); + assert_eq!(best, Some(EncryptionType::Aes256CtsHmacSha196)); + } + + #[test] + fn parse_etype_info2_best_rc4_only() { + let entry_23 = der_sequence(&[&der_context(0, &der_integer(23))]); + let etype_info2 = der_sequence(&[&entry_23]); + + let best = parse_etype_info2_best(&etype_info2); + assert_eq!(best, Some(EncryptionType::Rc4Hmac)); + } + + #[test] + fn parse_etype_info2_best_unknown_only() { + let entry_99 = der_sequence(&[&der_context(0, &der_integer(99))]); + let etype_info2 = der_sequence(&[&entry_99]); + + let best = parse_etype_info2_best(&etype_info2); + assert_eq!(best, None); + } + + // ── METHOD-DATA parsing ────────────────────────────────────────── + + #[test] + fn parse_method_data_extracts_padata() { + // Build METHOD-DATA: SEQUENCE { PA-DATA { type=19, value= } } + let pa_value = vec![0x01, 0x02, 0x03]; + let pa_type_enc = der_context(1, &der_integer(PA_ETYPE_INFO2)); + let pa_value_enc = der_context(2, &der_octet_string(&pa_value)); + let pa_data = der_sequence(&[&pa_type_enc, &pa_value_enc]); + let method_data = der_sequence(&[&pa_data]); + + let entries = parse_method_data(&method_data).unwrap(); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].padata_type, PA_ETYPE_INFO2); + assert_eq!(entries[0].padata_value, pa_value); + } + + // ── KerberosAuthenticator state ────────────────────────────────── + + #[test] + fn authenticator_initial_state() { + let auth = KerberosAuthenticator::new(KerberosCredentials { + username: "user".to_string(), + password: "pass".to_string(), + realm: "EXAMPLE.COM".to_string(), + kdc_address: "kdc.example.com".to_string(), + }); + + assert!(auth.token().is_none()); + assert!(auth.session_key().is_none()); + assert!(auth.tgt.is_none()); + assert!(auth.as_session_key.is_none()); + assert!(auth.service_ticket.is_none()); + assert!(auth.tgs_session_key.is_none()); + } + + // ── User key derivation ────────────────────────────────────────── + + #[test] + fn derive_user_key_aes256() { + let auth = KerberosAuthenticator { + credentials: KerberosCredentials { + username: "user".to_string(), + password: "password".to_string(), + realm: "EXAMPLE.COM".to_string(), + kdc_address: "kdc.example.com".to_string(), + }, + tgt: None, + as_session_key: None, + service_ticket: None, + tgs_session_key: None, + ap_req_bytes: None, + session_key: None, + etype: EncryptionType::Aes256CtsHmacSha196, + }; + + let key = auth.derive_user_key(); + assert_eq!(key.len(), 32, "AES-256 key should be 32 bytes"); + + // Should match direct call. + let expected = string_to_key_aes("password", "EXAMPLE.COMuser", 32); + assert_eq!(key, expected); + } + + #[test] + fn derive_user_key_aes128() { + let mut auth = KerberosAuthenticator::new(KerberosCredentials { + username: "user".to_string(), + password: "password".to_string(), + realm: "EXAMPLE.COM".to_string(), + kdc_address: "kdc.example.com".to_string(), + }); + auth.etype = EncryptionType::Aes128CtsHmacSha196; + + let key = auth.derive_user_key(); + assert_eq!(key.len(), 16, "AES-128 key should be 16 bytes"); + } + + #[test] + fn derive_user_key_rc4() { + let mut auth = KerberosAuthenticator::new(KerberosCredentials { + username: "user".to_string(), + password: "password".to_string(), + realm: "EXAMPLE.COM".to_string(), + kdc_address: "kdc.example.com".to_string(), + }); + auth.etype = EncryptionType::Rc4Hmac; + + let key = auth.derive_user_key(); + assert_eq!(key.len(), 16, "RC4 key (NT hash) should be 16 bytes"); + } +} diff --git a/vendor/smb2/src/auth/kerberos/ccache.rs b/vendor/smb2/src/auth/kerberos/ccache.rs new file mode 100644 index 0000000..c852c36 --- /dev/null +++ b/vendor/smb2/src/auth/kerberos/ccache.rs @@ -0,0 +1,447 @@ +//! MIT Kerberos credential cache (ccache) file parser. +//! +//! Reads ccache files (v3 and v4) to extract cached TGTs and service tickets, +//! enabling Kerberos authentication without a password when the user already +//! has a valid ticket (for example, from `kinit`). +//! +//! References: +//! - MIT Kerberos source: `lib/krb5/ccache/cc_file.c` +//! - Format: version(2) + [header(v4)] + default_principal + credentials* + +use crate::error::{Error, Result}; +use log::debug; + +// --------------------------------------------------------------------------- +// Public types +// --------------------------------------------------------------------------- + +/// A parsed Kerberos credential cache. +#[derive(Debug, Clone)] +pub struct CCache { + /// File format version (3 or 4). + pub version: u16, + /// Default principal (typically the user who ran `kinit`). + pub default_principal: CcachePrincipal, + /// Cached credentials (TGTs and service tickets). + pub credentials: Vec, +} + +/// A principal name in the ccache. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CcachePrincipal { + /// Name type (1 = KRB_NT_PRINCIPAL, 2 = KRB_NT_SRV_INST, etc.). + pub name_type: u32, + /// Kerberos realm. + pub realm: String, + /// Name components (for example, `["smbtest"]` or `["cifs", "server.domain.com"]`). + pub components: Vec, +} + +/// A single cached credential (ticket + metadata). +#[derive(Debug, Clone)] +pub struct CcacheCredential { + /// Client principal. + pub client: CcachePrincipal, + /// Server (service) principal. + pub server: CcachePrincipal, + /// Session key encryption type. + pub key_etype: u16, + /// Session key bytes. + pub key_data: Vec, + /// Time the ticket was issued (Unix timestamp). + pub authtime: u32, + /// Time the ticket becomes valid (Unix timestamp). + pub starttime: u32, + /// Time the ticket expires (Unix timestamp). + pub endtime: u32, + /// Time the ticket's renewable lifetime expires (Unix timestamp). + pub renew_till: u32, + /// Raw ticket bytes (DER-encoded Kerberos Ticket). + pub ticket: Vec, +} + +// --------------------------------------------------------------------------- +// Parsing +// --------------------------------------------------------------------------- + +/// Read and parse a ccache file from a filesystem path. +/// +/// Reads `$KRB5CCNAME` if `path` is `None`, falling back to +/// `/tmp/krb5cc_` on Unix. +pub fn load_ccache(path: Option<&std::path::Path>) -> Result { + let path = match path { + Some(p) => p.to_path_buf(), + None => { + if let Ok(env_path) = std::env::var("KRB5CCNAME") { + // Strip "FILE:" prefix if present. + let p = env_path.strip_prefix("FILE:").unwrap_or(&env_path); + std::path::PathBuf::from(p) + } else { + // Default: /tmp/krb5cc_ + return Err(Error::invalid_data( + "ccache: no path specified and $KRB5CCNAME not set", + )); + } + } + }; + + let data = std::fs::read(&path).map_err(|e| { + Error::invalid_data(format!("ccache: failed to read {}: {e}", path.display())) + })?; + + parse_ccache(&data) +} + +/// Parse a ccache file from raw bytes. +pub fn parse_ccache(data: &[u8]) -> Result { + let mut pos = 0; + + // Version: 2 bytes, big-endian. We support 0x0503 (v3) and 0x0504 (v4). + if data.len() < 2 { + return Err(Error::invalid_data("ccache: file too short for version")); + } + let version = read_u16(data, &mut pos)?; + if version != 0x0503 && version != 0x0504 { + return Err(Error::invalid_data(format!( + "ccache: unsupported version 0x{version:04x} (expected 0x0503 or 0x0504)" + ))); + } + + // V4 has a header section after the version. + if version == 0x0504 { + let header_len = read_u16(data, &mut pos)? as usize; + if pos + header_len > data.len() { + return Err(Error::invalid_data( + "ccache: header extends past end of file", + )); + } + // Skip header tags (we don't need them). + pos += header_len; + } + + // Default principal. + let default_principal = read_principal(data, &mut pos)?; + + // Credentials: read until EOF. + let mut credentials = Vec::new(); + while pos < data.len() { + match read_credential(data, &mut pos) { + Ok(cred) => credentials.push(cred), + Err(_) => break, // Treat parse errors at the end as EOF. + } + } + + debug!( + "ccache: parsed v{}, principal={}@{}, {} credentials", + version & 0xFF, + default_principal.components.join("/"), + default_principal.realm, + credentials.len() + ); + + Ok(CCache { + version, + default_principal, + credentials, + }) +} + +// --------------------------------------------------------------------------- +// Lookup +// --------------------------------------------------------------------------- + +impl CCache { + /// Find a cached service ticket for the given SPN and realm. + /// + /// Looks for a credential where the server principal matches + /// `service/hostname@realm` (case-insensitive hostname comparison). + pub fn find_service_ticket( + &self, + service: &str, + hostname: &str, + realm: &str, + ) -> Option<&CcacheCredential> { + self.credentials.iter().find(|c| { + c.server.realm.eq_ignore_ascii_case(realm) + && c.server.components.len() == 2 + && c.server.components[0].eq_ignore_ascii_case(service) + && c.server.components[1].eq_ignore_ascii_case(hostname) + }) + } + + /// Find a cached TGT for the given realm. + /// + /// Looks for a credential where the server principal is `krbtgt/REALM@REALM`. + pub fn find_tgt(&self, realm: &str) -> Option<&CcacheCredential> { + self.credentials.iter().find(|c| { + c.server.realm.eq_ignore_ascii_case(realm) + && c.server.components.len() == 2 + && c.server.components[0] == "krbtgt" + && c.server.components[1].eq_ignore_ascii_case(realm) + }) + } +} + +// --------------------------------------------------------------------------- +// Binary reading helpers +// --------------------------------------------------------------------------- + +fn read_u8(data: &[u8], pos: &mut usize) -> Result { + if *pos >= data.len() { + return Err(Error::invalid_data("ccache: unexpected end of data")); + } + let val = data[*pos]; + *pos += 1; + Ok(val) +} + +fn read_u16(data: &[u8], pos: &mut usize) -> Result { + if *pos + 2 > data.len() { + return Err(Error::invalid_data("ccache: unexpected end of data")); + } + let val = u16::from_be_bytes([data[*pos], data[*pos + 1]]); + *pos += 2; + Ok(val) +} + +fn read_u32(data: &[u8], pos: &mut usize) -> Result { + if *pos + 4 > data.len() { + return Err(Error::invalid_data("ccache: unexpected end of data")); + } + let val = u32::from_be_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]); + *pos += 4; + Ok(val) +} + +fn read_bytes(data: &[u8], pos: &mut usize, len: usize) -> Result> { + if *pos + len > data.len() { + return Err(Error::invalid_data("ccache: unexpected end of data")); + } + let val = data[*pos..*pos + len].to_vec(); + *pos += len; + Ok(val) +} + +fn read_string(data: &[u8], pos: &mut usize) -> Result { + let len = read_u32(data, pos)? as usize; + let bytes = read_bytes(data, pos, len)?; + String::from_utf8(bytes).map_err(|_| Error::invalid_data("ccache: invalid UTF-8 in string")) +} + +fn read_principal(data: &[u8], pos: &mut usize) -> Result { + let name_type = read_u32(data, pos)?; + let num_components = read_u32(data, pos)?; + let realm = read_string(data, pos)?; + let mut components = Vec::with_capacity(num_components as usize); + for _ in 0..num_components { + components.push(read_string(data, pos)?); + } + Ok(CcachePrincipal { + name_type, + realm, + components, + }) +} + +fn read_keyblock(data: &[u8], pos: &mut usize) -> Result<(u16, Vec)> { + let enctype = read_u16(data, pos)?; + let key_len = read_u32(data, pos)? as usize; + let key_data = read_bytes(data, pos, key_len)?; + Ok((enctype, key_data)) +} + +fn read_credential(data: &[u8], pos: &mut usize) -> Result { + let client = read_principal(data, pos)?; + let server = read_principal(data, pos)?; + let (key_etype, key_data) = read_keyblock(data, pos)?; + let authtime = read_u32(data, pos)?; + let starttime = read_u32(data, pos)?; + let endtime = read_u32(data, pos)?; + let renew_till = read_u32(data, pos)?; + let _is_skey = read_u8(data, pos)?; + let _ticket_flags = read_u32(data, pos)?; + + // Addresses (count + entries). + let addr_count = read_u32(data, pos)?; + for _ in 0..addr_count { + let _addr_type = read_u16(data, pos)?; + let addr_len = read_u32(data, pos)? as usize; + *pos += addr_len; // skip address data + } + + // Auth data (count + entries). + let authdata_count = read_u32(data, pos)?; + for _ in 0..authdata_count { + let _ad_type = read_u16(data, pos)?; + let ad_len = read_u32(data, pos)? as usize; + *pos += ad_len; // skip authdata + } + + // Ticket. + let ticket_len = read_u32(data, pos)? as usize; + let ticket = read_bytes(data, pos, ticket_len)?; + + // Second ticket. + let second_ticket_len = read_u32(data, pos)? as usize; + let _second_ticket = read_bytes(data, pos, second_ticket_len)?; + + Ok(CcacheCredential { + client, + server, + key_etype, + key_data, + authtime, + starttime, + endtime, + renew_till, + ticket, + }) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_v4_ccache_from_fixture() { + let data = include_bytes!("../../../tests/fixtures/test.ccache"); + let ccache = parse_ccache(data).expect("failed to parse v4 ccache"); + + assert_eq!(ccache.version, 0x0504); + assert_eq!(ccache.default_principal.realm, "TEST.LOCAL"); + assert_eq!(ccache.default_principal.components, vec!["smbtest"]); + assert_eq!(ccache.credentials.len(), 2); + } + + #[test] + fn parse_v3_ccache_from_fixture() { + let data = include_bytes!("../../../tests/fixtures/test_v3.ccache"); + let ccache = parse_ccache(data).expect("failed to parse v3 ccache"); + + assert_eq!(ccache.version, 0x0503); + assert_eq!(ccache.default_principal.realm, "EXAMPLE.COM"); + assert_eq!(ccache.default_principal.components, vec!["user"]); + assert_eq!(ccache.credentials.len(), 1); + } + + #[test] + fn tgt_credential_has_correct_fields() { + let data = include_bytes!("../../../tests/fixtures/test.ccache"); + let ccache = parse_ccache(data).unwrap(); + + let tgt = &ccache.credentials[0]; + assert_eq!(tgt.client.realm, "TEST.LOCAL"); + assert_eq!(tgt.client.components, vec!["smbtest"]); + assert_eq!(tgt.server.realm, "TEST.LOCAL"); + assert_eq!(tgt.server.components, vec!["krbtgt", "TEST.LOCAL"]); + assert_eq!(tgt.key_etype, 23); // RC4-HMAC + assert_eq!(tgt.key_data.len(), 16); + assert_eq!(tgt.authtime, 1744100000); + assert_eq!(tgt.endtime, 1744200000); + } + + #[test] + fn service_ticket_has_correct_fields() { + let data = include_bytes!("../../../tests/fixtures/test.ccache"); + let ccache = parse_ccache(data).unwrap(); + + let svc = &ccache.credentials[1]; + assert_eq!(svc.server.components, vec!["cifs", "server.test.local"]); + assert_eq!(svc.key_etype, 23); + assert_eq!(svc.key_data, (16u8..32).collect::>()); + } + + #[test] + fn find_tgt_by_realm() { + let data = include_bytes!("../../../tests/fixtures/test.ccache"); + let ccache = parse_ccache(data).unwrap(); + + let tgt = ccache.find_tgt("TEST.LOCAL"); + assert!(tgt.is_some()); + assert_eq!(tgt.unwrap().server.components[0], "krbtgt"); + + assert!(ccache.find_tgt("OTHER.REALM").is_none()); + } + + #[test] + fn find_service_ticket_by_spn() { + let data = include_bytes!("../../../tests/fixtures/test.ccache"); + let ccache = parse_ccache(data).unwrap(); + + let svc = ccache.find_service_ticket("cifs", "server.test.local", "TEST.LOCAL"); + assert!(svc.is_some()); + assert_eq!(svc.unwrap().key_data, (16u8..32).collect::>()); + + // Case-insensitive hostname. + assert!(ccache + .find_service_ticket("cifs", "SERVER.TEST.LOCAL", "TEST.LOCAL") + .is_some()); + + // Wrong hostname. + assert!(ccache + .find_service_ticket("cifs", "other.test.local", "TEST.LOCAL") + .is_none()); + + // Wrong service. + assert!(ccache + .find_service_ticket("ldap", "server.test.local", "TEST.LOCAL") + .is_none()); + } + + #[test] + fn find_tgt_case_insensitive() { + let data = include_bytes!("../../../tests/fixtures/test.ccache"); + let ccache = parse_ccache(data).unwrap(); + + assert!(ccache.find_tgt("test.local").is_some()); + } + + #[test] + fn v3_ccache_tgt_has_aes256_key() { + let data = include_bytes!("../../../tests/fixtures/test_v3.ccache"); + let ccache = parse_ccache(data).unwrap(); + + let tgt = ccache.find_tgt("EXAMPLE.COM").unwrap(); + assert_eq!(tgt.key_etype, 18); // AES-256 + assert_eq!(tgt.key_data.len(), 32); + } + + #[test] + fn reject_unsupported_version() { + let data = [0x05, 0x02]; // v2 + let result = parse_ccache(&data); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("unsupported version")); + } + + #[test] + fn reject_truncated_file() { + let result = parse_ccache(&[0x05]); + assert!(result.is_err()); + } + + #[test] + fn empty_credentials_list() { + // A valid ccache with just a version + principal + no credentials + let mut data = vec![0x05, 0x04, 0x00, 0x00]; // v4, no header + // Principal: type=1, components=1, realm="R", component="u" + data.extend_from_slice(&[0, 0, 0, 1]); // name_type + data.extend_from_slice(&[0, 0, 0, 1]); // num_components + data.extend_from_slice(&[0, 0, 0, 1]); // realm length + data.push(b'R'); + data.extend_from_slice(&[0, 0, 0, 1]); // component length + data.push(b'u'); + + let ccache = parse_ccache(&data).unwrap(); + assert_eq!(ccache.credentials.len(), 0); + assert_eq!(ccache.default_principal.realm, "R"); + assert_eq!(ccache.default_principal.components, vec!["u"]); + } +} diff --git a/vendor/smb2/src/auth/kerberos/crypto.rs b/vendor/smb2/src/auth/kerberos/crypto.rs new file mode 100644 index 0000000..99e259d --- /dev/null +++ b/vendor/smb2/src/auth/kerberos/crypto.rs @@ -0,0 +1,1329 @@ +//! Kerberos cryptographic operations. +//! +//! Supports three encryption types (etypes): +//! - **AES256-CTS-HMAC-SHA1-96** (etype 18): AES-256 with CTS mode and HMAC-SHA1 checksums. +//! - **AES128-CTS-HMAC-SHA1-96** (etype 17): AES-128 with CTS mode and HMAC-SHA1 checksums. +//! - **RC4-HMAC** (etype 23): RC4 stream cipher with HMAC-MD5 checksums. +//! +//! References: +//! - RFC 3961: Encryption and Checksum Specifications for Kerberos 5 +//! - RFC 3962: AES Encryption for Kerberos 5 +//! - RFC 4757: RC4-HMAC Kerberos Encryption Types +//! - MS-KILE: Kerberos Protocol Extensions + +use crate::Error; +use digest::KeyInit; + +// --------------------------------------------------------------------------- +// Encryption type enum +// --------------------------------------------------------------------------- + +/// Kerberos encryption type identifiers. +/// +/// Each variant's numeric value matches the IANA-assigned etype number +/// from RFC 3961 and RFC 4757. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EncryptionType { + /// AES-256 with CTS mode and HMAC-SHA1-96 checksum (etype 18). + Aes256CtsHmacSha196 = 18, + /// AES-128 with CTS mode and HMAC-SHA1-96 checksum (etype 17). + Aes128CtsHmacSha196 = 17, + /// RC4 with HMAC-MD5 checksum (etype 23). + Rc4Hmac = 23, +} + +// --------------------------------------------------------------------------- +// String-to-Key: password → encryption key +// --------------------------------------------------------------------------- + +/// Derive an AES encryption key from a password (RFC 3962 section 4). +/// +/// Uses PBKDF2-HMAC-SHA1 with 4096 iterations, then applies the +/// DK(key, "kerberos") random-to-key folding per RFC 3961. +/// +/// Salt is typically `REALM` + `username` (concatenated, case-sensitive). +/// `key_size` is 16 for AES-128 (etype 17) or 32 for AES-256 (etype 18). +pub fn string_to_key_aes(password: &str, salt: &str, key_size: usize) -> Vec { + use sha1::Sha1; + + assert!( + key_size == 16 || key_size == 32, + "key_size must be 16 or 32" + ); + + // Step 1: PBKDF2-HMAC-SHA1 with 4096 iterations. + let mut raw_key = vec![0u8; key_size]; + pbkdf2::pbkdf2_hmac::(password.as_bytes(), salt.as_bytes(), 4096, &mut raw_key); + + // Step 2: DK(raw_key, "kerberos") per RFC 3961. + // This applies the derive-key function with the well-known constant "kerberos". + dk_derive(&raw_key, b"kerberos") +} + +/// Derive an RC4-HMAC key from a password (RFC 4757). +/// +/// This is the NT hash: `MD4(UTF-16LE(password))`. Identical to the +/// NTLM NT hash computation. +pub fn string_to_key_rc4(password: &str) -> Vec { + use digest::Digest; + + let unicode_password: Vec = password + .encode_utf16() + .flat_map(|u| u.to_le_bytes()) + .collect(); + let mut hasher = md4::Md4::new(); + hasher.update(&unicode_password); + hasher.finalize().to_vec() +} + +// --------------------------------------------------------------------------- +// Key derivation (RFC 3961) +// --------------------------------------------------------------------------- + +/// Derive a usage-specific key from a base key (RFC 3961). +/// +/// Uses the `random-to-key(DR(base_key, usage))` construction. +/// The `usage` is a well-known constant (for example, `"signaturekey"`) or a +/// key usage number encoded as bytes with a type suffix: +/// - For encryption: `[usage_be32, 0xAA]` +/// - For checksum: `[usage_be32, 0x99]` +/// - For key derivation: `[usage_be32, 0x55]` +pub fn derive_key_aes(base_key: &[u8], usage: &[u8]) -> Vec { + dk_derive(base_key, usage) +} + +/// Build the 5-byte key usage constant for AES encryption keys. +/// +/// Format: 4-byte big-endian usage number + `0xAA` (encryption). +pub fn usage_enc(usage: u32) -> [u8; 5] { + let mut out = [0u8; 5]; + out[0..4].copy_from_slice(&usage.to_be_bytes()); + out[4] = 0xAA; + out +} + +/// Build the 5-byte key usage constant for AES integrity (Ki) keys. +/// +/// Format: 4-byte big-endian usage number + `0x55`. +/// +/// Per RFC 3961 section 3, the integrity subkey Ki is derived with +/// `0x55` and used for the HMAC inside `encrypt()`/`decrypt()`. +pub fn usage_int(usage: u32) -> [u8; 5] { + let mut out = [0u8; 5]; + out[0..4].copy_from_slice(&usage.to_be_bytes()); + out[4] = 0x55; + out +} + +/// Build the 5-byte key usage constant for AES checksum (Kc) keys. +/// +/// Format: 4-byte big-endian usage number + `0x99`. +/// +/// Per RFC 3961 section 5.4, the checksum subkey Kc is derived with +/// `0x99` and used for standalone `get_mic()` / checksum operations. +pub fn usage_chk(usage: u32) -> [u8; 5] { + let mut out = [0u8; 5]; + out[0..4].copy_from_slice(&usage.to_be_bytes()); + out[4] = 0x99; + out +} + +// --------------------------------------------------------------------------- +// AES-CTS encryption/decryption (RFC 3962 section 3) +// --------------------------------------------------------------------------- + +/// Encrypt data using AES-CTS (Cipher Text Stealing) mode. +/// +/// AES-CTS is AES-CBC with the last two ciphertext blocks swapped +/// and the final block potentially truncated to the plaintext size. +/// For a single block (16 bytes or fewer), uses AES-CBC with zero-padding. +pub fn encrypt_aes_cts(key: &[u8], iv: &[u8], plaintext: &[u8]) -> Vec { + if plaintext.is_empty() { + return Vec::new(); + } + + let block_size = 16; + + // For single-block or less: pad to one full block and encrypt with AES-CBC. + // Per RFC 3962: "If the data [...] has only a single block, that block is + // simply encrypted with AES." The ciphertext is always a full 16-byte block. + if plaintext.len() <= block_size { + let mut padded = [0u8; 16]; + padded[..plaintext.len()].copy_from_slice(plaintext); + // XOR with IV, then ECB encrypt. + for i in 0..16 { + padded[i] ^= iv[i]; + } + let ct = aes_ecb_encrypt(key, &padded); + return ct.to_vec(); + } + + // Multi-block: encrypt with standard CBC, then apply CTS. + // Pad the plaintext to a multiple of block_size. + let n_blocks = plaintext.len().div_ceil(block_size); + let padded_len = n_blocks * block_size; + let mut padded = vec![0u8; padded_len]; + padded[..plaintext.len()].copy_from_slice(plaintext); + + // Encrypt with AES-CBC (no padding -- we padded ourselves). + let cbc_out = aes_cbc_encrypt(key, iv, &padded); + + // CTS: swap the last two ciphertext blocks. + let mut result = cbc_out; + let second_last_start = (n_blocks - 2) * block_size; + let last_start = (n_blocks - 1) * block_size; + + // Swap blocks. + let mut second_last_block = [0u8; 16]; + let mut last_block = [0u8; 16]; + second_last_block.copy_from_slice(&result[second_last_start..second_last_start + block_size]); + last_block.copy_from_slice(&result[last_start..last_start + block_size]); + result[second_last_start..second_last_start + block_size].copy_from_slice(&last_block); + result[last_start..last_start + block_size].copy_from_slice(&second_last_block); + + // Truncate the final block to the original plaintext length. + result.truncate(plaintext.len()); + result +} + +/// Decrypt data using AES-CTS mode. +/// +/// Reverses the CTS transformation: un-swap the last two blocks, +/// then decrypt with AES-CBC. +pub fn decrypt_aes_cts(key: &[u8], iv: &[u8], ciphertext: &[u8]) -> Result, Error> { + if ciphertext.is_empty() { + return Ok(Vec::new()); + } + + let block_size = 16; + + // Single block (16 bytes): ECB decrypt then XOR with IV. + // Per RFC 3962, single-block ciphertext is always exactly 16 bytes. + if ciphertext.len() <= block_size { + if ciphertext.len() != block_size { + return Err(Error::invalid_data(format!( + "AES-CTS single-block ciphertext must be exactly 16 bytes, got {}", + ciphertext.len() + ))); + } + let mut pt = aes_ecb_decrypt(key, ciphertext); + for i in 0..16 { + pt[i] ^= iv[i]; + } + return Ok(pt.to_vec()); + } + + // Multi-block CTS decryption. + let orig_len = ciphertext.len(); + let n_blocks = orig_len.div_ceil(block_size); + let padded_len = n_blocks * block_size; + + // Pad the ciphertext to a full number of blocks. + let mut padded_ct = vec![0u8; padded_len]; + padded_ct[..orig_len].copy_from_slice(ciphertext); + + let second_last_start = (n_blocks - 2) * block_size; + let last_start = (n_blocks - 1) * block_size; + + if orig_len % block_size != 0 { + let tail_len = orig_len - (n_blocks - 1) * block_size; + + // c_{n-1} is the swapped full block (at second_last_start). + let mut c_n_minus_1 = [0u8; 16]; + c_n_minus_1.copy_from_slice(&padded_ct[second_last_start..second_last_start + block_size]); + + // Decrypt c_{n-1} with ECB to get intermediate. + let intermediate = aes_ecb_decrypt(key, &c_n_minus_1); + + // c_n is the partial block (tail_len bytes at last_start). + let mut reconstructed_last = [0u8; 16]; + reconstructed_last[..tail_len] + .copy_from_slice(&padded_ct[last_start..last_start + tail_len]); + // Pad with tail of the intermediate. + reconstructed_last[tail_len..].copy_from_slice(&intermediate[tail_len..]); + + // Now put them back in the right order for CBC decryption. + padded_ct[second_last_start..second_last_start + block_size] + .copy_from_slice(&reconstructed_last); + padded_ct[last_start..last_start + block_size].copy_from_slice(&c_n_minus_1); + } else { + // Block-aligned: swap back. + let mut second_last_block = [0u8; 16]; + let mut last_block = [0u8; 16]; + second_last_block + .copy_from_slice(&padded_ct[second_last_start..second_last_start + block_size]); + last_block.copy_from_slice(&padded_ct[last_start..last_start + block_size]); + padded_ct[second_last_start..second_last_start + block_size].copy_from_slice(&last_block); + padded_ct[last_start..last_start + block_size].copy_from_slice(&second_last_block); + } + + // Decrypt with standard CBC. + let plaintext = aes_cbc_decrypt(key, iv, &padded_ct); + Ok(plaintext[..orig_len].to_vec()) +} + +// --------------------------------------------------------------------------- +// RC4-HMAC encryption/decryption (RFC 4757) +// --------------------------------------------------------------------------- + +/// Encrypt data using RC4-HMAC (etype 23). +/// +/// 1. K1 = HMAC-MD5(key, usage as little-endian i32) +/// 2. Generate random 8-byte confounder +/// 3. Compute HMAC-MD5(K1, confounder + plaintext) → checksum (16 bytes) +/// 4. K3 = HMAC-MD5(K1, checksum) +/// 5. RC4-encrypt (confounder + plaintext) using K3 +/// 6. Output = checksum (16 bytes) + encrypted_data +pub fn encrypt_rc4_hmac(key: &[u8], usage: u32, plaintext: &[u8]) -> Vec { + use hmac::{Hmac, Mac}; + type HmacMd5 = Hmac; + + // K1 = HMAC-MD5(key, usage_le) + // Note: RFC 4757 uses the usage as a signed 32-bit little-endian value. + let usage_bytes = (usage as i32).to_le_bytes(); + let mut mac = HmacMd5::new_from_slice(key).expect("HMAC accepts any key length"); + mac.update(&usage_bytes); + let k1 = mac.finalize().into_bytes(); + + // Generate random 8-byte confounder. + let mut confounder = [0u8; 8]; + getrandom::fill(&mut confounder).expect("CSPRNG failed"); + + // Build confounder + plaintext. + let mut payload = Vec::with_capacity(8 + plaintext.len()); + payload.extend_from_slice(&confounder); + payload.extend_from_slice(plaintext); + + // Checksum = HMAC-MD5(K1, confounder + plaintext) + let mut mac = HmacMd5::new_from_slice(&k1).expect("HMAC accepts any key length"); + mac.update(&payload); + let checksum = mac.finalize().into_bytes(); + + // K3 = HMAC-MD5(K1, checksum) + let mut mac = HmacMd5::new_from_slice(&k1).expect("HMAC accepts any key length"); + mac.update(&checksum); + let k3 = mac.finalize().into_bytes(); + + // Encrypt payload with RC4 using K3. + let encrypted = rc4_transform(&k3, &payload); + + // Output = checksum (16 bytes) + encrypted_data + let mut output = Vec::with_capacity(16 + encrypted.len()); + output.extend_from_slice(&checksum); + output.extend_from_slice(&encrypted); + output +} + +/// Decrypt data using RC4-HMAC (etype 23). +/// +/// Reverses the `encrypt_rc4_hmac` process and verifies the checksum. +pub fn decrypt_rc4_hmac(key: &[u8], usage: u32, ciphertext: &[u8]) -> Result, Error> { + use hmac::{Hmac, Mac}; + type HmacMd5 = Hmac; + + if ciphertext.len() < 24 { + return Err(Error::invalid_data( + "RC4-HMAC ciphertext too short (need at least 16-byte checksum + 8-byte confounder)", + )); + } + + let checksum = &ciphertext[..16]; + let encrypted_data = &ciphertext[16..]; + + // K1 = HMAC-MD5(key, usage_le) + let usage_bytes = (usage as i32).to_le_bytes(); + let mut mac = HmacMd5::new_from_slice(key).expect("HMAC accepts any key length"); + mac.update(&usage_bytes); + let k1 = mac.finalize().into_bytes(); + + // K3 = HMAC-MD5(K1, checksum) + let mut mac = HmacMd5::new_from_slice(&k1).expect("HMAC accepts any key length"); + mac.update(checksum); + let k3 = mac.finalize().into_bytes(); + + // Decrypt payload with RC4 using K3. + let payload = rc4_transform(&k3, encrypted_data); + + // Verify: HMAC-MD5(K1, decrypted_payload) must equal the checksum. + let mut mac = HmacMd5::new_from_slice(&k1).expect("HMAC accepts any key length"); + mac.update(&payload); + let computed_checksum = mac.finalize().into_bytes(); + + if computed_checksum.as_slice() != checksum { + return Err(Error::invalid_data("RC4-HMAC checksum verification failed")); + } + + // Strip the 8-byte confounder. + if payload.len() < 8 { + return Err(Error::invalid_data("RC4-HMAC decrypted payload too short")); + } + Ok(payload[8..].to_vec()) +} + +// --------------------------------------------------------------------------- +// Checksum computation +// --------------------------------------------------------------------------- + +/// Compute a standalone Kerberos checksum (MIC) for the given data. +/// +/// Uses the checksum subkey Kc (derived with `0x99`) per RFC 3961 section 5.4. +/// This is for standalone checksum operations (for example, the body checksum +/// in the TGS-REQ Authenticator), NOT for the HMAC inside encrypt/decrypt +/// (which uses Ki derived with `0x55`). +/// +/// - For AES (etypes 17, 18): HMAC-SHA1 truncated to 12 bytes (96 bits). +/// - For RC4 (etype 23): HMAC-MD5, producing 16 bytes. +pub fn compute_checksum(key: &[u8], usage: u32, data: &[u8], etype: EncryptionType) -> Vec { + match etype { + EncryptionType::Aes128CtsHmacSha196 | EncryptionType::Aes256CtsHmacSha196 => { + // Derive the checksum key Kc for this usage. + let kc = derive_key_aes(key, &usage_chk(usage)); + hmac_sha1_96(&kc, data) + } + EncryptionType::Rc4Hmac => { + use hmac::{Hmac, Mac}; + type HmacMd5 = Hmac; + + // K1 = HMAC-MD5(key, usage_le) + let usage_bytes = (usage as i32).to_le_bytes(); + let mut mac = HmacMd5::new_from_slice(key).expect("HMAC accepts any key length"); + mac.update(&usage_bytes); + let k1 = mac.finalize().into_bytes(); + + // Checksum = HMAC-MD5(K1, data) + let mut mac = HmacMd5::new_from_slice(&k1).expect("HMAC accepts any key length"); + mac.update(data); + mac.finalize().into_bytes().to_vec() + } + } +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +/// HMAC-SHA1 truncated to 12 bytes (96 bits), as used by AES Kerberos checksums. +fn hmac_sha1_96(key: &[u8], data: &[u8]) -> Vec { + use hmac::{Hmac, Mac}; + use sha1::Sha1; + type HmacSha1 = Hmac; + + let mut mac = HmacSha1::new_from_slice(key).expect("HMAC accepts any key length"); + mac.update(data); + let result = mac.finalize().into_bytes(); + result[..12].to_vec() +} + +/// DK(base_key, constant) per RFC 3961 section 5.1. +/// +/// DK = random-to-key(DR(base_key, constant)) +/// DR = k-truncate(E(base_key, n-fold(constant, block_size))) +/// +/// For AES, random-to-key is the identity function, so DK = DR. +fn dk_derive(base_key: &[u8], constant: &[u8]) -> Vec { + let block_size = 16; // AES block size is always 16. + let key_size = base_key.len(); + + // n-fold the constant to the cipher's block size. + let folded = nfold(constant, block_size); + + // DR: repeatedly encrypt to produce enough key material. + let mut result = Vec::with_capacity(key_size); + let mut input = [0u8; 16]; + input.copy_from_slice(&folded); + + while result.len() < key_size { + // Encrypt the input block with AES-ECB (single block, no IV needed). + let encrypted = aes_ecb_encrypt(base_key, &input); + result.extend_from_slice(&encrypted); + input = encrypted; + } + + result.truncate(key_size); + result +} + +/// N-fold operation per RFC 3961 section 5.1. +/// +/// Takes an input byte string and produces an output of `output_len` bytes. +/// The algorithm rotates the input by 13 bits for each successive copy and +/// sums them with one's-complement-like carry propagation. +fn nfold(input: &[u8], output_len: usize) -> Vec { + let in_len = input.len(); + + // Helper: get a single byte from `input` RIGHT-rotated by `rot` bits. + // Right rotation by `rot`: bit `j` of the result comes from + // bit `(j - rot) mod in_bits` of the original. Equivalently, + // bit `(j + in_bits - rot) mod in_bits`. + let rotated_byte = |rot: usize, byte_idx: usize| -> u8 { + let in_bits = in_len * 8; + let rot_mod = rot % in_bits; + let bit = (byte_idx * 8 + in_bits - rot_mod) % in_bits; + let b = bit / 8; + let s = bit % 8; + if s == 0 { + input[b] + } else { + (((input[b] as u16) << s) | ((input[(b + 1) % in_len] as u16) >> (8 - s))) as u8 + } + }; + + let in_bits = in_len * 8; + let out_bits = output_len * 8; + let lcm_bits = lcm(in_bits, out_bits); + + // Total bytes to iterate over all copies laid end-to-end. + let lcm_bytes = lcm_bits / 8; + + // Accumulator (u32 to handle carries). + let mut result = vec![0u32; output_len]; + + // Walk through lcm_bytes bytes, each one coming from a specific + // rotated copy. The output byte it maps to wraps modulo output_len. + for i in 0..lcm_bytes { + // Which copy is this byte from? + let copy = i / in_len; + // Which byte within that copy? + let byte_in_copy = i % in_len; + // Each copy is rotated 13 bits further than the previous. + let rotation = copy * 13; + let val = rotated_byte(rotation, byte_in_copy); + // Map to output position, wrapping. + let out_idx = i % output_len; + result[out_idx] += val as u32; + } + + // Propagate carries from right to left (big-endian addition). + // The carry wraps around from the most-significant byte to the + // least-significant, like one's-complement addition. + loop { + let mut carry = 0u32; + for i in (0..output_len).rev() { + result[i] += carry; + carry = result[i] >> 8; + result[i] &= 0xFF; + } + if carry == 0 { + break; + } + // Wrap carry to LSB. + result[output_len - 1] += carry; + } + + result.iter().map(|&v| v as u8).collect() +} + +/// Least common multiple. +fn lcm(a: usize, b: usize) -> usize { + a / gcd(a, b) * b +} + +/// Greatest common divisor (Euclidean algorithm). +fn gcd(mut a: usize, mut b: usize) -> usize { + while b != 0 { + let t = b; + b = a % b; + a = t; + } + a +} + +/// AES-ECB encrypt a single 16-byte block. +fn aes_ecb_encrypt(key: &[u8], block: &[u8]) -> [u8; 16] { + use aes::cipher::{BlockCipherEncrypt, KeyInit}; + + let mut output = [0u8; 16]; + output.copy_from_slice(block); + + match key.len() { + 16 => { + let cipher = aes::Aes128::new_from_slice(key).expect("valid key"); + cipher.encrypt_block((&mut output).into()); + } + 32 => { + let cipher = aes::Aes256::new_from_slice(key).expect("valid key"); + cipher.encrypt_block((&mut output).into()); + } + _ => panic!("AES key must be 16 or 32 bytes, got {}", key.len()), + } + output +} + +/// AES-ECB decrypt a single 16-byte block. +fn aes_ecb_decrypt(key: &[u8], block: &[u8]) -> [u8; 16] { + use aes::cipher::{BlockCipherDecrypt, KeyInit}; + + let mut output = [0u8; 16]; + output.copy_from_slice(block); + + match key.len() { + 16 => { + let cipher = aes::Aes128::new_from_slice(key).expect("valid key"); + cipher.decrypt_block((&mut output).into()); + } + 32 => { + let cipher = aes::Aes256::new_from_slice(key).expect("valid key"); + cipher.decrypt_block((&mut output).into()); + } + _ => panic!("AES key must be 16 or 32 bytes, got {}", key.len()), + } + output +} + +/// AES-CBC encrypt (no padding -- input must be a multiple of 16 bytes). +/// Implemented manually using AES-ECB to avoid cbc crate API complexity. +fn aes_cbc_encrypt(key: &[u8], iv: &[u8], data: &[u8]) -> Vec { + assert!( + data.len() % 16 == 0, + "AES-CBC input must be a multiple of 16 bytes" + ); + + let n_blocks = data.len() / 16; + let mut output = vec![0u8; data.len()]; + let mut prev = [0u8; 16]; + prev.copy_from_slice(iv); + + for i in 0..n_blocks { + let start = i * 16; + let mut block = [0u8; 16]; + block.copy_from_slice(&data[start..start + 16]); + // XOR with previous ciphertext block (or IV for first block). + for j in 0..16 { + block[j] ^= prev[j]; + } + let encrypted = aes_ecb_encrypt(key, &block); + output[start..start + 16].copy_from_slice(&encrypted); + prev = encrypted; + } + output +} + +/// AES-CBC decrypt (no padding -- input must be a multiple of 16 bytes). +/// Implemented manually using AES-ECB to avoid cbc crate API complexity. +fn aes_cbc_decrypt(key: &[u8], iv: &[u8], data: &[u8]) -> Vec { + assert!( + data.len() % 16 == 0, + "AES-CBC input must be a multiple of 16 bytes" + ); + + let n_blocks = data.len() / 16; + let mut output = vec![0u8; data.len()]; + let mut prev = [0u8; 16]; + prev.copy_from_slice(iv); + + for i in 0..n_blocks { + let start = i * 16; + let mut ct_block = [0u8; 16]; + ct_block.copy_from_slice(&data[start..start + 16]); + let mut decrypted = aes_ecb_decrypt(key, &ct_block); + // XOR with previous ciphertext block (or IV for first block). + for j in 0..16 { + decrypted[j] ^= prev[j]; + } + output[start..start + 16].copy_from_slice(&decrypted); + prev = ct_block; + } + output +} + +/// RC4 stream cipher (symmetric -- encrypt and decrypt are the same operation). +fn rc4_transform(key: &[u8], data: &[u8]) -> Vec { + let mut s: Vec = (0..=255).collect(); + let mut j: u8 = 0; + for i in 0..256 { + j = j.wrapping_add(s[i]).wrapping_add(key[i % key.len()]); + s.swap(i, j as usize); + } + let mut i: u8 = 0; + j = 0; + data.iter() + .map(|&byte| { + i = i.wrapping_add(1); + j = j.wrapping_add(s[i as usize]); + s.swap(i as usize, j as usize); + byte ^ s[s[i as usize].wrapping_add(s[j as usize]) as usize] + }) + .collect() +} + +// --------------------------------------------------------------------------- +// Kerberos encrypt/decrypt (RFC 3961 section 5.3) +// --------------------------------------------------------------------------- +// +// For AES (etypes 17, 18): +// 1. Derive encryption key: Ke = DK(base_key, usage || 0xAA) +// 2. Derive integrity key: Ki = DK(base_key, usage || 0x55) +// 3. Generate random 16-byte confounder +// 4. Plaintext' = confounder || plaintext +// 5. Ciphertext = AES-CTS(Ke, iv=0, plaintext') +// 6. HMAC = HMAC-SHA1-96(Ki, plaintext') +// 7. Output = ciphertext || HMAC (12 bytes) +// +// For RC4-HMAC (etype 23): +// Uses the encrypt_rc4_hmac function directly (it handles confounder +// and checksum internally). + +/// Encrypt data using the Kerberos profile for the given etype and key usage. +pub(crate) fn kerberos_encrypt( + base_key: &[u8], + usage: u32, + plaintext: &[u8], + etype: EncryptionType, +) -> Vec { + match etype { + EncryptionType::Aes128CtsHmacSha196 | EncryptionType::Aes256CtsHmacSha196 => { + // Derive Ke (encryption key) and Ki (integrity key). + let ke = derive_key_aes(base_key, &usage_enc(usage)); + let ki = derive_key_aes(base_key, &usage_int(usage)); + + // Generate 16-byte random confounder. + let mut confounder = [0u8; 16]; + getrandom::fill(&mut confounder).expect("CSPRNG failed"); + + // Build plaintext' = confounder || plaintext. + let mut full_plain = Vec::with_capacity(16 + plaintext.len()); + full_plain.extend_from_slice(&confounder); + full_plain.extend_from_slice(plaintext); + + // Compute HMAC-SHA1-96 over plaintext' using Ki. + let hmac = hmac_sha1_96(&ki, &full_plain); + + // Encrypt plaintext' with AES-CTS using Ke and IV=0. + let iv = [0u8; 16]; + let ciphertext = encrypt_aes_cts(&ke, &iv, &full_plain); + + // Output = ciphertext || HMAC (12 bytes). + let mut output = ciphertext; + output.extend_from_slice(&hmac); + output + } + EncryptionType::Rc4Hmac => encrypt_rc4_hmac(base_key, usage, plaintext), + } +} + +/// Decrypt data using the Kerberos profile for the given etype and key usage. +pub(crate) fn kerberos_decrypt( + base_key: &[u8], + usage: u32, + ciphertext: &[u8], + etype: EncryptionType, +) -> Result, Error> { + match etype { + EncryptionType::Aes128CtsHmacSha196 | EncryptionType::Aes256CtsHmacSha196 => { + // HMAC-SHA1-96 is 12 bytes, appended to the ciphertext. + if ciphertext.len() < 12 + 16 { + return Err(Error::invalid_data( + "Kerberos AES ciphertext too short (need at least confounder + HMAC)", + )); + } + + let hmac_offset = ciphertext.len() - 12; + let enc_data = &ciphertext[..hmac_offset]; + let expected_hmac = &ciphertext[hmac_offset..]; + + // Derive Ke (encryption key) and Ki (integrity key). + let ke = derive_key_aes(base_key, &usage_enc(usage)); + let ki = derive_key_aes(base_key, &usage_int(usage)); + + // Decrypt with AES-CTS using Ke and IV=0. + let iv = [0u8; 16]; + let full_plain = decrypt_aes_cts(&ke, &iv, enc_data)?; + + // Verify HMAC-SHA1-96 using Ki. + let computed_hmac = hmac_sha1_96(&ki, &full_plain); + if computed_hmac != expected_hmac { + return Err(Error::Auth { + message: "Kerberos AES HMAC verification failed".to_string(), + }); + } + + // Strip the 16-byte confounder. + if full_plain.len() < 16 { + return Err(Error::invalid_data( + "Kerberos AES decrypted data too short for confounder", + )); + } + Ok(full_plain[16..].to_vec()) + } + EncryptionType::Rc4Hmac => decrypt_rc4_hmac(base_key, usage, ciphertext), + } +} + +// --------------------------------------------------------------------------- +// Etype conversion +// --------------------------------------------------------------------------- + +/// Convert an etype integer value to our enum. +pub(crate) fn etype_from_i32(val: i32) -> Result { + match val { + 18 => Ok(EncryptionType::Aes256CtsHmacSha196), + 17 => Ok(EncryptionType::Aes128CtsHmacSha196), + 23 => Ok(EncryptionType::Rc4Hmac), + _ => Err(Error::Auth { + message: format!("unsupported Kerberos encryption type: {val}"), + }), + } +} + +// --------------------------------------------------------------------------- +// Random key generation (test support) +// --------------------------------------------------------------------------- + +/// Generate a random key of the appropriate size for the given etype. +#[cfg(test)] +pub(crate) fn generate_random_key(etype: EncryptionType) -> Vec { + let key_size = match etype { + EncryptionType::Aes256CtsHmacSha196 => 32, + EncryptionType::Aes128CtsHmacSha196 => 16, + EncryptionType::Rc4Hmac => 16, + }; + let mut key = vec![0u8; key_size]; + getrandom::fill(&mut key).expect("CSPRNG failed"); + key +} + +// =========================================================================== +// Tests +// =========================================================================== + +#[cfg(test)] +mod tests { + use super::*; + + // ── EncryptionType ──────────────────────────────────────────────── + + #[test] + fn encryption_type_values() { + assert_eq!(EncryptionType::Aes256CtsHmacSha196 as u32, 18); + assert_eq!(EncryptionType::Aes128CtsHmacSha196 as u32, 17); + assert_eq!(EncryptionType::Rc4Hmac as u32, 23); + } + + // ── n-fold ──────────────────────────────────────────────────────── + + #[test] + fn nfold_rfc3961_test_vectors() { + // RFC 3961 section 5.1 test vectors. + // 64-fold("012345") = 0xBE072631276B1955 + let result = nfold(b"012345", 8); + assert_eq!(result, hex("be072631276b1955")); + + // 56-fold("password") = 0x78A07B6CAF85FA + let result = nfold(b"password", 7); + assert_eq!(result, hex("78a07b6caf85fa")); + + // 64-fold("Rough Consensus, and Running Code") + let result = nfold(b"Rough Consensus, and Running Code", 8); + assert_eq!(result, hex("bb6ed30870b7f0e0")); + + // 168-fold("password") + let result = nfold(b"password", 21); + assert_eq!(result, hex("59e4a8ca7c0385c3c37b3f6d2000247cb6e6bd5b3e")); + + // 128-fold("kerberos") + let result = nfold(b"kerberos", 16); + assert_eq!(result, hex("6b65726265726f737b9b5b2b93132b93")); + + // 168-fold("kerberos") + let result = nfold(b"kerberos", 21); + assert_eq!(result, hex("8372c236344e5f1550cd0747e15d62ca7a5a3bcea4")); + + // 256-fold("kerberos") + let result = nfold(b"kerberos", 32); + assert_eq!( + result, + hex("6b65726265726f737b9b5b2b93132b935c9bdcdad95c9899c4cae4dee6d6cae4") + ); + } + + // ── String-to-Key (RC4) ─────────────────────────────────────────── + + #[test] + fn string_to_key_rc4_produces_nt_hash() { + // MS-NLMP test vector: password "Password" + // NT hash = MD4(UTF-16LE("Password")) + // = a4f49c406510bdcab6824ee7c30fd852 + let key = string_to_key_rc4("Password"); + assert_eq!(key, hex("a4f49c406510bdcab6824ee7c30fd852")); + } + + #[test] + fn string_to_key_rc4_empty_password() { + // Empty password still produces a valid 16-byte hash. + let key = string_to_key_rc4(""); + assert_eq!(key.len(), 16); + // MD4 of empty UTF-16LE is: 31d6cfe0d16ae931b73c59d7e0c089c0 + assert_eq!(key, hex("31d6cfe0d16ae931b73c59d7e0c089c0")); + } + + // ── String-to-Key (AES) ────────────────────────────────────────── + + #[test] + fn string_to_key_aes256_rfc3962_test_vector() { + // RFC 3962 Appendix B, Test Vector 4 (iterations = 4096): + // password = "password", salt = "ATHENA.MIT.EDUraeburn" + // Verified with Python hashlib.pbkdf2_hmac + AES-ECB DK derivation. + let key = string_to_key_aes("password", "ATHENA.MIT.EDUraeburn", 32); + assert_eq!( + key, + hex("01b897121d933ab44b47eb5494db15e50eb74530dbdae9b634d65020ff5d88c1") + ); + } + + #[test] + fn string_to_key_aes128_rfc3962_test_vector() { + // RFC 3962 Appendix B, Test Vector 4 (iterations = 4096): + // password = "password", salt = "ATHENA.MIT.EDUraeburn" + // Verified with Python hashlib.pbkdf2_hmac + AES-ECB DK derivation. + let key = string_to_key_aes("password", "ATHENA.MIT.EDUraeburn", 16); + assert_eq!(key, hex("fca822951813fb252154c883f5ee1cf4")); + } + + #[test] + fn string_to_key_aes256_produces_32_bytes() { + let key = string_to_key_aes("test", "EXAMPLE.COMtest", 32); + assert_eq!(key.len(), 32); + } + + #[test] + fn string_to_key_aes128_produces_16_bytes() { + let key = string_to_key_aes("test", "EXAMPLE.COMtest", 16); + assert_eq!(key.len(), 16); + } + + // ── Key Derivation (AES) ───────────────────────────────────────── + + #[test] + fn derive_key_aes_deterministic() { + let base_key = [0xAA; 16]; + let usage = usage_enc(7); + let k1 = derive_key_aes(&base_key, &usage); + let k2 = derive_key_aes(&base_key, &usage); + assert_eq!(k1, k2, "same inputs must produce same output"); + } + + #[test] + fn derive_key_aes_different_usages_produce_different_keys() { + let base_key = [0xBB; 16]; + let k_enc = derive_key_aes(&base_key, &usage_enc(7)); + let k_int = derive_key_aes(&base_key, &usage_int(7)); + assert_ne!( + k_enc, k_int, + "different usage types must produce different keys" + ); + } + + #[test] + fn derive_key_aes_different_usage_numbers_produce_different_keys() { + let base_key = [0xCC; 32]; + let k1 = derive_key_aes(&base_key, &usage_enc(1)); + let k7 = derive_key_aes(&base_key, &usage_enc(7)); + assert_ne!( + k1, k7, + "different usage numbers must produce different keys" + ); + } + + #[test] + fn derive_key_aes128_preserves_key_length() { + let base_key = [0xDD; 16]; + let derived = derive_key_aes(&base_key, &usage_enc(1)); + assert_eq!(derived.len(), 16); + } + + #[test] + fn derive_key_aes256_preserves_key_length() { + let base_key = [0xEE; 32]; + let derived = derive_key_aes(&base_key, &usage_enc(1)); + assert_eq!(derived.len(), 32); + } + + // ── AES-CTS encryption/decryption ──────────────────────────────── + + #[test] + fn aes_cts_empty_input() { + let key = [0x11; 16]; + let iv = [0u8; 16]; + let ct = encrypt_aes_cts(&key, &iv, &[]); + assert!(ct.is_empty()); + let pt = decrypt_aes_cts(&key, &iv, &ct).unwrap(); + assert!(pt.is_empty()); + } + + #[test] + fn aes_cts_single_block_roundtrip() { + let key = [0x22; 16]; + let iv = [0u8; 16]; + let plaintext = b"sixteen bytes!!!"; + assert_eq!(plaintext.len(), 16); + + let ct = encrypt_aes_cts(&key, &iv, plaintext); + assert_eq!(ct.len(), 16); + let pt = decrypt_aes_cts(&key, &iv, &ct).unwrap(); + assert_eq!(pt, plaintext); + } + + #[test] + fn aes_cts_two_blocks_roundtrip() { + let key = [0x33; 16]; + let iv = [0u8; 16]; + let plaintext = [0x42u8; 32]; // Exactly 2 blocks. + + let ct = encrypt_aes_cts(&key, &iv, &plaintext); + assert_eq!(ct.len(), 32); + let pt = decrypt_aes_cts(&key, &iv, &ct).unwrap(); + assert_eq!(pt, plaintext); + } + + #[test] + fn aes_cts_non_block_aligned_roundtrip() { + let key = [0x44; 16]; + let iv = [0u8; 16]; + let plaintext = [0x55u8; 30]; // Not a multiple of 16. + + let ct = encrypt_aes_cts(&key, &iv, &plaintext); + assert_eq!( + ct.len(), + 30, + "CTS ciphertext length equals plaintext length" + ); + let pt = decrypt_aes_cts(&key, &iv, &ct).unwrap(); + assert_eq!(pt, plaintext); + } + + #[test] + fn aes_cts_three_blocks_roundtrip() { + let key = [0x55; 32]; // AES-256 + let iv = [0u8; 16]; + let plaintext = [0x66u8; 48]; // Exactly 3 blocks. + + let ct = encrypt_aes_cts(&key, &iv, &plaintext); + assert_eq!(ct.len(), 48); + let pt = decrypt_aes_cts(&key, &iv, &ct).unwrap(); + assert_eq!(pt, plaintext); + } + + #[test] + fn aes_cts_non_aligned_aes256_roundtrip() { + let key = [0x77; 32]; // AES-256 + let iv = [0u8; 16]; + let plaintext: Vec = (0..50).collect(); // 50 bytes, not block-aligned. + + let ct = encrypt_aes_cts(&key, &iv, &plaintext); + assert_eq!(ct.len(), 50); + let pt = decrypt_aes_cts(&key, &iv, &ct).unwrap(); + assert_eq!(pt, plaintext); + } + + #[test] + fn aes_cts_sub_block_pads_to_full_block() { + // Per RFC 3962, a single block (even if plaintext < 16 bytes) produces + // a full 16-byte ciphertext. The plaintext is zero-padded to 16 bytes + // before encryption. + let key = [0x88; 16]; + let iv = [0u8; 16]; + let plaintext = b"short"; // Less than one block. + + let ct = encrypt_aes_cts(&key, &iv, plaintext); + assert_eq!(ct.len(), 16, "single-block ciphertext is always 16 bytes"); + + // Decrypting gives back the zero-padded 16-byte block. + let pt = decrypt_aes_cts(&key, &iv, &ct).unwrap(); + assert_eq!(pt.len(), 16); + assert_eq!(&pt[..5], plaintext.as_slice()); + assert_eq!(&pt[5..], &[0u8; 11]); // Zero padding. + } + + #[test] + fn aes_cts_ciphertext_differs_from_plaintext() { + let key = [0x99; 16]; + let iv = [0u8; 16]; + let plaintext = [0xAA; 32]; + + let ct = encrypt_aes_cts(&key, &iv, &plaintext); + assert_ne!(ct, plaintext, "ciphertext must differ from plaintext"); + } + + // ── RC4-HMAC encryption/decryption ─────────────────────────────── + + #[test] + fn rc4_hmac_roundtrip() { + let key = hex("a4f49c406510bdcab6824ee7c30fd852"); + let plaintext = b"Hello, Kerberos!"; + let usage = 7u32; + + let ct = encrypt_rc4_hmac(&key, usage, plaintext); + // Ciphertext should be 16-byte checksum + 8-byte confounder + plaintext. + assert_eq!(ct.len(), 16 + 8 + plaintext.len()); + + let pt = decrypt_rc4_hmac(&key, usage, &ct).unwrap(); + assert_eq!(pt, plaintext); + } + + #[test] + fn rc4_hmac_empty_plaintext_roundtrip() { + let key = [0xBB; 16]; + let ct = encrypt_rc4_hmac(&key, 1, &[]); + // 16-byte checksum + 8-byte confounder + 0-byte plaintext. + assert_eq!(ct.len(), 24); + let pt = decrypt_rc4_hmac(&key, 1, &ct).unwrap(); + assert!(pt.is_empty()); + } + + #[test] + fn rc4_hmac_wrong_key_fails() { + let key = [0xCC; 16]; + let ct = encrypt_rc4_hmac(&key, 1, b"secret data"); + + let wrong_key = [0xDD; 16]; + let result = decrypt_rc4_hmac(&wrong_key, 1, &ct); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("checksum verification failed")); + } + + #[test] + fn rc4_hmac_wrong_usage_fails() { + let key = [0xEE; 16]; + let ct = encrypt_rc4_hmac(&key, 1, b"usage test"); + + let result = decrypt_rc4_hmac(&key, 2, &ct); + assert!(result.is_err()); + } + + #[test] + fn rc4_hmac_ciphertext_too_short() { + let key = [0xFF; 16]; + let result = decrypt_rc4_hmac(&key, 1, &[0u8; 23]); // Need at least 24. + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("too short")); + } + + #[test] + fn rc4_hmac_tampered_ciphertext_fails() { + let key = [0x11; 16]; + let mut ct = encrypt_rc4_hmac(&key, 1, b"tamper test"); + + // Flip a byte in the encrypted data (after the 16-byte checksum). + let last = ct.len() - 1; + ct[last] ^= 0xFF; + + let result = decrypt_rc4_hmac(&key, 1, &ct); + assert!(result.is_err()); + } + + // ── Checksum ───────────────────────────────────────────────────── + + #[test] + fn checksum_aes_produces_12_bytes() { + let key = [0x11; 16]; + let data = b"checksum test data"; + let checksum = compute_checksum(&key, 7, data, EncryptionType::Aes128CtsHmacSha196); + assert_eq!(checksum.len(), 12, "HMAC-SHA1-96 produces 12 bytes"); + } + + #[test] + fn checksum_aes256_produces_12_bytes() { + let key = [0x22; 32]; + let data = b"checksum test data"; + let checksum = compute_checksum(&key, 7, data, EncryptionType::Aes256CtsHmacSha196); + assert_eq!(checksum.len(), 12); + } + + #[test] + fn checksum_rc4_produces_16_bytes() { + let key = [0x33; 16]; + let data = b"checksum test data"; + let checksum = compute_checksum(&key, 7, data, EncryptionType::Rc4Hmac); + assert_eq!(checksum.len(), 16, "HMAC-MD5 produces 16 bytes"); + } + + #[test] + fn checksum_aes_deterministic() { + let key = [0x44; 16]; + let data = b"determinism test"; + let c1 = compute_checksum(&key, 7, data, EncryptionType::Aes128CtsHmacSha196); + let c2 = compute_checksum(&key, 7, data, EncryptionType::Aes128CtsHmacSha196); + assert_eq!(c1, c2); + } + + #[test] + fn checksum_different_usage_produces_different_result() { + let key = [0x55; 16]; + let data = b"usage test"; + let c1 = compute_checksum(&key, 1, data, EncryptionType::Aes128CtsHmacSha196); + let c2 = compute_checksum(&key, 2, data, EncryptionType::Aes128CtsHmacSha196); + assert_ne!(c1, c2); + } + + #[test] + fn checksum_rc4_deterministic() { + let key = [0x66; 16]; + let data = b"rc4 checksum test"; + let c1 = compute_checksum(&key, 7, data, EncryptionType::Rc4Hmac); + let c2 = compute_checksum(&key, 7, data, EncryptionType::Rc4Hmac); + assert_eq!(c1, c2); + } + + // ── Usage constant helpers ─────────────────────────────────────── + + #[test] + fn usage_enc_format() { + let u = usage_enc(7); + assert_eq!(u, [0, 0, 0, 7, 0xAA]); + } + + #[test] + fn usage_int_format() { + let u = usage_int(7); + assert_eq!(u, [0, 0, 0, 7, 0x55]); + } + + // ── Helper ─────────────────────────────────────────────────────── + + /// Parse a hex string into bytes (ignores spaces). + fn hex(s: &str) -> Vec { + let s: String = s.chars().filter(|c| !c.is_whitespace()).collect(); + (0..s.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap()) + .collect() + } + + #[test] + fn string_to_key_aes256_matches_mit_kdc_keytab() { + // Key from MIT KDC keytab for testuser@TEST.LOCAL with password "testpass" + // Salt = "TEST.LOCALtestuser" + let key = string_to_key_aes("testpass", "TEST.LOCALtestuser", 32); + let expected = hex("7964c7e6f475912def26f886f2683da03f58257a987bca47e461daddb18cb336"); + assert_eq!(key, expected, "key must match MIT KDC keytab"); + } + + #[test] + fn aes_cts_known_vectors() { + // AES-CTS test vectors. Key: "chicken teriyaki", IV: all zeros. + // Plaintext: "I would like the General Gau's Chicken, please, and wonton soup." + let key = hex("636869636b656e207465726979616b69"); + let iv = [0u8; 16]; + let full_plain = b"I would like the General Gau's Chicken, please, and wonton soup."; + + // 17 bytes: verified against minikerberos (Python Kerberos reference). + let ct_17 = encrypt_aes_cts(&key, &iv, &full_plain[..17]); + assert_eq!( + ct_17, + hex("c6353568f2bf8cb4d8a580362da7ff7f97"), + "17-byte CTS failed" + ); + + // All CTS vectors must roundtrip correctly. + for len in [17, 31, 32, 47, 48, 64] { + let ct = encrypt_aes_cts(&key, &iv, &full_plain[..len]); + assert_eq!(ct.len(), len, "CTS ciphertext length for {len} bytes"); + let pt = decrypt_aes_cts(&key, &iv, &ct).unwrap(); + assert_eq!(&pt[..], &full_plain[..len], "CTS roundtrip for {len} bytes"); + } + } + + // ── Kerberos encrypt/decrypt roundtrip ─────────────────────────── + + #[test] + fn kerberos_encrypt_decrypt_aes256() { + let key = string_to_key_aes("password", "EXAMPLE.COMuser", 32); + let plaintext = b"Hello, Kerberos!"; + + let ciphertext = kerberos_encrypt(&key, 7, plaintext, EncryptionType::Aes256CtsHmacSha196); + let decrypted = + kerberos_decrypt(&key, 7, &ciphertext, EncryptionType::Aes256CtsHmacSha196).unwrap(); + + assert_eq!(decrypted, plaintext); + } + + #[test] + fn kerberos_encrypt_decrypt_aes128() { + let key = string_to_key_aes("password", "EXAMPLE.COMuser", 16); + let plaintext = b"Hello, Kerberos AES-128!"; + + let ciphertext = kerberos_encrypt(&key, 3, plaintext, EncryptionType::Aes128CtsHmacSha196); + let decrypted = + kerberos_decrypt(&key, 3, &ciphertext, EncryptionType::Aes128CtsHmacSha196).unwrap(); + + assert_eq!(decrypted, plaintext); + } + + #[test] + fn kerberos_encrypt_decrypt_rc4() { + let key = string_to_key_rc4("password"); + let plaintext = b"Hello, RC4!"; + + let ciphertext = kerberos_encrypt(&key, 7, plaintext, EncryptionType::Rc4Hmac); + let decrypted = kerberos_decrypt(&key, 7, &ciphertext, EncryptionType::Rc4Hmac).unwrap(); + + assert_eq!(decrypted, plaintext); + } + + #[test] + fn kerberos_decrypt_wrong_key_fails() { + let key = string_to_key_aes("password", "EXAMPLE.COMuser", 32); + let wrong_key = string_to_key_aes("wrong", "EXAMPLE.COMuser", 32); + let plaintext = b"secret data"; + + let ciphertext = kerberos_encrypt(&key, 1, plaintext, EncryptionType::Aes256CtsHmacSha196); + let result = kerberos_decrypt( + &wrong_key, + 1, + &ciphertext, + EncryptionType::Aes256CtsHmacSha196, + ); + + assert!(result.is_err(), "decryption with wrong key should fail"); + } + + #[test] + fn kerberos_decrypt_wrong_usage_fails() { + let key = string_to_key_aes("password", "EXAMPLE.COMuser", 32); + let plaintext = b"secret data"; + + let ciphertext = kerberos_encrypt(&key, 1, plaintext, EncryptionType::Aes256CtsHmacSha196); + let result = kerberos_decrypt(&key, 7, &ciphertext, EncryptionType::Aes256CtsHmacSha196); + + assert!(result.is_err(), "decryption with wrong usage should fail"); + } + + // ── Etype conversion ───────────────────────────────────────────── + + #[test] + fn etype_from_i32_valid() { + assert_eq!( + etype_from_i32(18).unwrap(), + EncryptionType::Aes256CtsHmacSha196 + ); + assert_eq!( + etype_from_i32(17).unwrap(), + EncryptionType::Aes128CtsHmacSha196 + ); + assert_eq!(etype_from_i32(23).unwrap(), EncryptionType::Rc4Hmac); + } + + #[test] + fn etype_from_i32_unsupported() { + assert!(etype_from_i32(99).is_err()); + assert!(etype_from_i32(0).is_err()); + } + + // ── Random key generation ──────────────────────────────────────── + + #[test] + fn generate_random_key_sizes() { + assert_eq!( + generate_random_key(EncryptionType::Aes256CtsHmacSha196).len(), + 32 + ); + assert_eq!( + generate_random_key(EncryptionType::Aes128CtsHmacSha196).len(), + 16 + ); + assert_eq!(generate_random_key(EncryptionType::Rc4Hmac).len(), 16); + } +} diff --git a/vendor/smb2/src/auth/kerberos/kdc.rs b/vendor/smb2/src/auth/kerberos/kdc.rs new file mode 100644 index 0000000..72352a1 --- /dev/null +++ b/vendor/smb2/src/auth/kerberos/kdc.rs @@ -0,0 +1,890 @@ +//! KDC (Key Distribution Center) transport client. +//! +//! Sends AS-REQ and TGS-REQ messages to a Kerberos KDC on port 88. +//! Tries UDP first (no framing), falls back to TCP (4-byte big-endian +//! length prefix) when the response indicates KRB_ERR_RESPONSE_TOO_BIG +//! (error code 52). +//! +//! Transport details per RFC 4120 section 7.2 and MS-KILE section 2.1: +//! - UDP: raw DER bytes, no length prefix, max 65535 bytes +//! - TCP: 4-byte big-endian length prefix, then DER bytes +//! - Retry: up to 3 attempts with exponential backoff (1s, 2s, 4s) +//! +//! The functions here are transport-only: they send raw bytes and return +//! raw bytes. No ASN.1 parsing beyond detecting error code 52 in the +//! UDP-to-TCP fallback path. + +use log::{debug, trace, warn}; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpStream, UdpSocket}; + +use crate::error::{Error, Result}; + +/// Default Kerberos port (RFC 4120). +const KERBEROS_PORT: u16 = 88; + +/// Maximum UDP receive buffer size. +const UDP_MAX_SIZE: usize = 65535; + +/// KRB_ERR_RESPONSE_TOO_BIG error code (RFC 4120 section 7.2.1). +const KRB_ERR_RESPONSE_TOO_BIG: u32 = 52; + +/// Maximum TCP frame size we accept (1 MB, generous for Kerberos). +const MAX_KDC_FRAME_SIZE: usize = 1024 * 1024; + +/// Number of retry attempts per transport. +const MAX_RETRIES: u32 = 3; + +/// Base retry delay (doubles each attempt). +const RETRY_BASE_DELAY: Duration = Duration::from_secs(1); + +/// Configuration for connecting to a KDC. +#[derive(Debug, Clone)] +pub struct KdcConfig { + /// KDC address (host:port or just host, defaults to port 88). + pub address: String, + /// Connection/request timeout. + pub timeout: Duration, +} + +/// Resolve the KDC address to include a port if not specified. +fn resolve_address(address: &str) -> String { + if address.contains(':') { + address.to_string() + } else { + format!("{}:{}", address, KERBEROS_PORT) + } +} + +/// Send a Kerberos message to the KDC and receive the response. +/// +/// Tries UDP first. If the response indicates the message was too +/// large for UDP (KRB_ERR_RESPONSE_TOO_BIG), retries with TCP. +/// +/// UDP framing: raw DER bytes, no length prefix. +/// TCP framing: 4-byte big-endian length prefix, then DER bytes. +pub async fn send_to_kdc(config: &KdcConfig, message: &[u8]) -> Result> { + let addr = resolve_address(&config.address); + debug!("kdc: sending {} bytes to {}", message.len(), addr); + + // Try UDP first. + match send_udp(&addr, message, config.timeout).await { + Ok(response) => { + if is_response_too_big(&response) { + debug!("kdc: got KRB_ERR_RESPONSE_TOO_BIG, retrying with TCP"); + send_tcp(&addr, message, config.timeout).await + } else { + Ok(response) + } + } + Err(e) => { + warn!("kdc: UDP failed ({}), falling back to TCP", e); + send_tcp(&addr, message, config.timeout).await + } + } +} + +/// Send a Kerberos message via UDP. +async fn send_udp(addr: &str, message: &[u8], timeout: Duration) -> Result> { + let socket = UdpSocket::bind("0.0.0.0:0").await.map_err(Error::Io)?; + + let mut last_err = None; + + for attempt in 0..MAX_RETRIES { + if attempt > 0 { + let delay = RETRY_BASE_DELAY * 2u32.pow(attempt - 1); + debug!("kdc: UDP retry {} after {:?}", attempt, delay); + tokio::time::sleep(delay).await; + } + + // Send the raw DER bytes (no framing for UDP). + match tokio::time::timeout(timeout, socket.send_to(message, addr)).await { + Ok(Ok(n)) => { + trace!("kdc: UDP sent {} bytes", n); + } + Ok(Err(e)) => { + last_err = Some(Error::Io(e)); + continue; + } + Err(_) => { + last_err = Some(Error::Timeout); + continue; + } + } + + // Receive the response. + let mut buf = vec![0u8; UDP_MAX_SIZE]; + match tokio::time::timeout(timeout, socket.recv_from(&mut buf)).await { + Ok(Ok((n, _src))) => { + trace!("kdc: UDP received {} bytes", n); + buf.truncate(n); + return Ok(buf); + } + Ok(Err(e)) => { + last_err = Some(Error::Io(e)); + } + Err(_) => { + last_err = Some(Error::Timeout); + } + } + } + + Err(last_err.unwrap_or(Error::Timeout)) +} + +/// Send a Kerberos message via TCP. +async fn send_tcp(addr: &str, message: &[u8], timeout: Duration) -> Result> { + let mut last_err = None; + + for attempt in 0..MAX_RETRIES { + if attempt > 0 { + let delay = RETRY_BASE_DELAY * 2u32.pow(attempt - 1); + debug!("kdc: TCP retry {} after {:?}", attempt, delay); + tokio::time::sleep(delay).await; + } + + match send_tcp_once(addr, message, timeout).await { + Ok(response) => return Ok(response), + Err(e) => { + last_err = Some(e); + } + } + } + + Err(last_err.unwrap_or(Error::Timeout)) +} + +/// Single TCP send/receive attempt. +async fn send_tcp_once(addr: &str, message: &[u8], timeout: Duration) -> Result> { + // Connect with timeout. + let mut stream = tokio::time::timeout(timeout, TcpStream::connect(addr)) + .await + .map_err(|_| Error::Timeout)? + .map_err(Error::Io)?; + + // Disable Nagle for lower latency. + stream.set_nodelay(true).map_err(Error::Io)?; + + // Send: 4-byte big-endian length prefix + DER bytes. + let len = message.len() as u32; + let len_bytes = len.to_be_bytes(); + + tokio::time::timeout(timeout, async { + stream.write_all(&len_bytes).await.map_err(Error::Io)?; + stream.write_all(message).await.map_err(Error::Io)?; + stream.flush().await.map_err(Error::Io)?; + trace!("kdc: TCP sent {} bytes", message.len()); + Ok::<(), Error>(()) + }) + .await + .map_err(|_| Error::Timeout)??; + + // Receive: 4-byte big-endian length prefix. + let mut len_buf = [0u8; 4]; + tokio::time::timeout(timeout, stream.read_exact(&mut len_buf)) + .await + .map_err(|_| Error::Timeout)? + .map_err(|e| { + if e.kind() == std::io::ErrorKind::UnexpectedEof { + Error::Disconnected + } else { + Error::Io(e) + } + })?; + + let resp_len = u32::from_be_bytes(len_buf) as usize; + if resp_len > MAX_KDC_FRAME_SIZE { + return Err(Error::invalid_data(format!( + "KDC TCP response length {} exceeds maximum {}", + resp_len, MAX_KDC_FRAME_SIZE + ))); + } + + // Read the response body. + let mut buf = vec![0u8; resp_len]; + tokio::time::timeout(timeout, stream.read_exact(&mut buf)) + .await + .map_err(|_| Error::Timeout)? + .map_err(|e| { + if e.kind() == std::io::ErrorKind::UnexpectedEof { + Error::Disconnected + } else { + Error::Io(e) + } + })?; + + trace!("kdc: TCP received {} bytes", resp_len); + Ok(buf) +} + +/// Detect KRB_ERR_RESPONSE_TOO_BIG (error code 52) in a KRB-ERROR response. +/// +/// KRB-ERROR is APPLICATION [30] (tag 0x7e). We parse just enough DER +/// to extract the error-code field (context tag [6]) without a full +/// ASN.1 parser. +fn is_response_too_big(response: &[u8]) -> bool { + // KRB-ERROR starts with APPLICATION [30] = 0x7e. + if response.is_empty() || response[0] != 0x7e { + return false; + } + + match extract_krb_error_code(response) { + Some(code) => code == KRB_ERR_RESPONSE_TOO_BIG, + None => false, + } +} + +/// Extract the error-code from a KRB-ERROR message. +/// +/// KRB-ERROR structure (simplified DER): +/// ```text +/// APPLICATION [30] { +/// SEQUENCE { +/// [0] pvno INTEGER, +/// [1] msg-type INTEGER, +/// [2] ctime (optional), +/// [3] cusec (optional), +/// [4] stime, +/// [5] susec, +/// [6] error-code INTEGER, <-- we want this +/// ... +/// } +/// } +/// ``` +fn extract_krb_error_code(data: &[u8]) -> Option { + let mut pos = 0; + + // Skip APPLICATION [30] tag. + if pos >= data.len() || data[pos] != 0x7e { + return None; + } + pos += 1; + pos = skip_der_length(data, pos)?; + + // Skip SEQUENCE tag (0x30). + if pos >= data.len() || data[pos] != 0x30 { + return None; + } + pos += 1; + pos = skip_der_length(data, pos)?; + + // Now iterate through context-tagged fields until we find [6]. + loop { + if pos >= data.len() { + return None; + } + + let tag = data[pos]; + // Context tags are 0xa0..0xbf for constructed. + if tag & 0xe0 != 0xa0 { + return None; + } + let tag_num = tag & 0x1f; + pos += 1; + + let (field_len, new_pos) = read_der_length(data, pos)?; + let field_end = new_pos + field_len; + + if tag_num == 6 { + // This field contains an INTEGER with the error code. + return parse_der_integer(data, new_pos); + } + + pos = field_end; + } +} + +/// Skip a DER length field and return the position after it. +fn skip_der_length(data: &[u8], pos: usize) -> Option { + let (_len, new_pos) = read_der_length(data, pos)?; + Some(new_pos) +} + +/// Read a DER length field, returning (length, position_after_length). +fn read_der_length(data: &[u8], pos: usize) -> Option<(usize, usize)> { + if pos >= data.len() { + return None; + } + + let first = data[pos]; + match first.cmp(&0x80) { + std::cmp::Ordering::Less => { + // Short form: length is the byte itself. + Some((first as usize, pos + 1)) + } + std::cmp::Ordering::Equal => { + // Indefinite length, not used in DER. + None + } + std::cmp::Ordering::Greater => { + // Long form: first byte & 0x7f = number of subsequent length bytes. + let num_bytes = (first & 0x7f) as usize; + if num_bytes > 4 || pos + 1 + num_bytes > data.len() { + return None; + } + let mut length: usize = 0; + for i in 0..num_bytes { + length = (length << 8) | (data[pos + 1 + i] as usize); + } + Some((length, pos + 1 + num_bytes)) + } + } +} + +/// Parse a DER INTEGER at the given position, returning its value as u32. +fn parse_der_integer(data: &[u8], pos: usize) -> Option { + if pos >= data.len() || data[pos] != 0x02 { + return None; + } + let (len, val_pos) = read_der_length(data, pos + 1)?; + if val_pos + len > data.len() || len == 0 || len > 4 { + return None; + } + + let mut value: u32 = 0; + for i in 0..len { + value = (value << 8) | (data[val_pos + i] as u32); + } + Some(value) +} + +/// Discover KDC addresses for a realm via DNS SRV records. +/// +/// Looks up `_kerberos._udp.{realm}` and `_kerberos._tcp.{realm}`. +/// Returns addresses sorted by priority. +/// +/// For now, this is a placeholder -- initial implementation uses +/// the hardcoded address from KdcConfig. DNS SRV discovery will +/// be added in a future version. +pub async fn discover_kdc(_realm: &str) -> Vec { + // Placeholder: DNS SRV lookup not yet implemented. + // Callers should use KdcConfig.address directly. + debug!("kdc: DNS SRV discovery not yet implemented, returning empty list"); + Vec::new() +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::AsyncReadExt; + use tokio::net::TcpListener; + + // ── DER parsing tests ────────────────────────────────────────── + + #[test] + fn read_der_length_short_form() { + assert_eq!(read_der_length(&[0x05], 0), Some((5, 1))); + assert_eq!(read_der_length(&[0x7f], 0), Some((127, 1))); + assert_eq!(read_der_length(&[0x00], 0), Some((0, 1))); + } + + #[test] + fn read_der_length_long_form_one_byte() { + // 0x81, 0x80 = 128 bytes + assert_eq!(read_der_length(&[0x81, 0x80], 0), Some((128, 2))); + } + + #[test] + fn read_der_length_long_form_two_bytes() { + // 0x82, 0x01, 0x00 = 256 bytes + assert_eq!(read_der_length(&[0x82, 0x01, 0x00], 0), Some((256, 3))); + } + + #[test] + fn read_der_length_indefinite_returns_none() { + assert_eq!(read_der_length(&[0x80], 0), None); + } + + #[test] + fn read_der_length_truncated_returns_none() { + // Says 2 length bytes follow but only 1 is present. + assert_eq!(read_der_length(&[0x82, 0x01], 0), None); + } + + #[test] + fn parse_der_integer_single_byte() { + // INTEGER tag 0x02, length 1, value 52. + assert_eq!(parse_der_integer(&[0x02, 0x01, 0x34], 0), Some(52)); + } + + #[test] + fn parse_der_integer_two_bytes() { + // INTEGER tag 0x02, length 2, value 0x0100 = 256. + assert_eq!(parse_der_integer(&[0x02, 0x02, 0x01, 0x00], 0), Some(256)); + } + + #[test] + fn parse_der_integer_not_integer_tag() { + assert_eq!(parse_der_integer(&[0x03, 0x01, 0x34], 0), None); + } + + // ── KRB-ERROR detection tests ────────────────────────────────── + + /// Build a minimal KRB-ERROR with the given error code. + /// + /// This constructs a valid DER-encoded KRB-ERROR with fields: + /// [0] pvno = 5, [1] msg-type = 30, [4] stime, [5] susec = 0, + /// [6] error-code = the given code. + fn build_krb_error(error_code: u32) -> Vec { + // Helper: wrap value in context tag. + fn context_tag(tag_num: u8, contents: &[u8]) -> Vec { + let mut out = vec![0xa0 | tag_num]; + push_der_length(&mut out, contents.len()); + out.extend_from_slice(contents); + out + } + + // Helper: encode a DER INTEGER. + fn der_integer(value: u32) -> Vec { + // Encode as minimal bytes. + let bytes = if value == 0 { + vec![0x00] + } else if value < 0x80 { + vec![value as u8] + } else if value < 0x8000 { + vec![(value >> 8) as u8, (value & 0xff) as u8] + } else if value < 0x800000 { + vec![ + (value >> 16) as u8, + (value >> 8) as u8, + (value & 0xff) as u8, + ] + } else { + vec![ + (value >> 24) as u8, + (value >> 16) as u8, + (value >> 8) as u8, + (value & 0xff) as u8, + ] + }; + let mut out = vec![0x02]; + push_der_length(&mut out, bytes.len()); + out.extend_from_slice(&bytes); + out + } + + fn push_der_length(out: &mut Vec, len: usize) { + if len < 0x80 { + out.push(len as u8); + } else if len < 0x100 { + out.push(0x81); + out.push(len as u8); + } else { + out.push(0x82); + out.push((len >> 8) as u8); + out.push((len & 0xff) as u8); + } + } + + // Build the SEQUENCE contents. + let pvno = context_tag(0, &der_integer(5)); + let msg_type = context_tag(1, &der_integer(30)); + // Skip [2] ctime and [3] cusec (optional). + // [4] stime: GeneralizedTime "20250101000000Z" + let stime_val = b"20250101000000Z"; + let mut stime_der = vec![0x18]; // GeneralizedTime tag + push_der_length(&mut stime_der, stime_val.len()); + stime_der.extend_from_slice(stime_val); + let stime = context_tag(4, &stime_der); + let susec = context_tag(5, &der_integer(0)); + let error_code_field = context_tag(6, &der_integer(error_code)); + + let mut seq_contents = Vec::new(); + seq_contents.extend_from_slice(&pvno); + seq_contents.extend_from_slice(&msg_type); + seq_contents.extend_from_slice(&stime); + seq_contents.extend_from_slice(&susec); + seq_contents.extend_from_slice(&error_code_field); + + // Wrap in SEQUENCE. + let mut seq = vec![0x30]; + push_der_length(&mut seq, seq_contents.len()); + seq.extend_from_slice(&seq_contents); + + // Wrap in APPLICATION [30]. + let mut msg = vec![0x7e]; + push_der_length(&mut msg, seq.len()); + msg.extend_from_slice(&seq); + + msg + } + + #[test] + fn is_response_too_big_detects_error_52() { + let error = build_krb_error(KRB_ERR_RESPONSE_TOO_BIG); + assert!(is_response_too_big(&error)); + } + + #[test] + fn is_response_too_big_ignores_other_errors() { + // Error code 6 = KDC_ERR_C_PRINCIPAL_UNKNOWN + let error = build_krb_error(6); + assert!(!is_response_too_big(&error)); + } + + #[test] + fn is_response_too_big_ignores_non_error_messages() { + // AS-REP starts with APPLICATION [11] = 0x6b + assert!(!is_response_too_big(&[0x6b, 0x03, 0x30, 0x01, 0x00])); + } + + #[test] + fn is_response_too_big_handles_empty_response() { + assert!(!is_response_too_big(&[])); + } + + #[test] + fn is_response_too_big_handles_truncated_response() { + // Just the APPLICATION tag and nothing else. + assert!(!is_response_too_big(&[0x7e])); + assert!(!is_response_too_big(&[0x7e, 0x00])); + } + + #[test] + fn extract_error_code_from_valid_krb_error() { + let error = build_krb_error(25); + assert_eq!(extract_krb_error_code(&error), Some(25)); + } + + #[test] + fn extract_error_code_returns_none_for_non_error() { + assert_eq!( + extract_krb_error_code(&[0x6b, 0x03, 0x30, 0x01, 0x00]), + None + ); + } + + // ── Address resolution tests ─────────────────────────────────── + + #[test] + fn resolve_address_adds_default_port() { + assert_eq!(resolve_address("kdc.example.com"), "kdc.example.com:88"); + } + + #[test] + fn resolve_address_preserves_explicit_port() { + assert_eq!( + resolve_address("kdc.example.com:8888"), + "kdc.example.com:8888" + ); + } + + #[test] + fn resolve_address_ip_no_port() { + assert_eq!(resolve_address("10.0.0.1"), "10.0.0.1:88"); + } + + #[test] + fn resolve_address_ip_with_port() { + assert_eq!(resolve_address("10.0.0.1:88"), "10.0.0.1:88"); + } + + // ── UDP transport tests ──────────────────────────────────────── + + #[tokio::test] + async fn udp_send_receive() { + // Set up a mock KDC that echoes the request back. + let server = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let server_addr = server.local_addr().unwrap(); + + let server_task = tokio::spawn(async move { + let mut buf = vec![0u8; UDP_MAX_SIZE]; + let (n, src) = server.recv_from(&mut buf).await.unwrap(); + // Echo back the message. + server.send_to(&buf[..n], src).await.unwrap(); + }); + + let message = b"test-kerberos-message"; + let result = send_udp(&server_addr.to_string(), message, Duration::from_secs(5)).await; + + assert!( + result.is_ok(), + "UDP send/receive failed: {:?}", + result.err() + ); + assert_eq!(result.unwrap(), message); + + server_task.await.unwrap(); + } + + #[tokio::test] + async fn udp_timeout_on_no_response() { + // Bind a server socket but never read from it. + let server = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let server_addr = server.local_addr().unwrap(); + + // Use very short timeout and only 1 retry attempt to keep test fast. + // We can't change MAX_RETRIES, but we use a very short timeout so + // all 3 retries finish quickly. + let result = send_udp( + &server_addr.to_string(), + b"hello", + Duration::from_millis(50), + ) + .await; + + assert!(result.is_err()); + assert!( + matches!(result.as_ref().unwrap_err(), Error::Timeout), + "expected Timeout, got: {:?}", + result.unwrap_err() + ); + + drop(server); + } + + // ── TCP transport tests ──────────────────────────────────────── + + #[tokio::test] + async fn tcp_send_receive() { + // Set up a mock KDC that reads a length-prefixed message and + // sends back a length-prefixed response. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + // Read 4-byte length prefix. + let mut len_buf = [0u8; 4]; + stream.read_exact(&mut len_buf).await.unwrap(); + let msg_len = u32::from_be_bytes(len_buf) as usize; + + // Read the message body. + let mut msg = vec![0u8; msg_len]; + stream.read_exact(&mut msg).await.unwrap(); + + // Echo back with length prefix. + let response = b"kdc-response"; + let resp_len = (response.len() as u32).to_be_bytes(); + stream.write_all(&resp_len).await.unwrap(); + stream.write_all(response).await.unwrap(); + stream.flush().await.unwrap(); + }); + + let result = send_tcp(&addr.to_string(), b"test-request", Duration::from_secs(5)).await; + + assert!( + result.is_ok(), + "TCP send/receive failed: {:?}", + result.err() + ); + assert_eq!(result.unwrap(), b"kdc-response"); + + server_task.await.unwrap(); + } + + #[tokio::test] + async fn tcp_timeout_on_no_response() { + // Set up a server that accepts but never responds. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + // Hold the connection open but never respond. + tokio::time::sleep(Duration::from_secs(10)).await; + drop(stream); + }); + + let result = send_tcp_once(&addr.to_string(), b"hello", Duration::from_millis(100)).await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + matches!(err, Error::Timeout), + "expected Timeout, got: {err}" + ); + + server_task.abort(); + } + + #[tokio::test] + async fn tcp_truncated_response() { + // Server sends a length prefix saying 100 bytes, then disconnects. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + // Read the request (don't care about contents). + let mut len_buf = [0u8; 4]; + let _ = stream.read_exact(&mut len_buf).await; + let msg_len = u32::from_be_bytes(len_buf) as usize; + let mut discard = vec![0u8; msg_len]; + let _ = stream.read_exact(&mut discard).await; + + // Send response with length 100 but only 5 bytes of data, then close. + let resp_len = 100u32.to_be_bytes(); + stream.write_all(&resp_len).await.unwrap(); + stream + .write_all(&[0x01, 0x02, 0x03, 0x04, 0x05]) + .await + .unwrap(); + stream.shutdown().await.unwrap(); + }); + + let result = send_tcp_once(&addr.to_string(), b"hello", Duration::from_secs(5)).await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + matches!(err, Error::Disconnected), + "expected Disconnected for truncated response, got: {err}" + ); + + server_task.await.unwrap(); + } + + #[tokio::test] + async fn tcp_oversized_length_rejected() { + // Server sends a length prefix larger than MAX_KDC_FRAME_SIZE. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + // Read request. + let mut len_buf = [0u8; 4]; + let _ = stream.read_exact(&mut len_buf).await; + let msg_len = u32::from_be_bytes(len_buf) as usize; + let mut discard = vec![0u8; msg_len]; + let _ = stream.read_exact(&mut discard).await; + + // Send absurdly large length. + let resp_len = (MAX_KDC_FRAME_SIZE as u32 + 1).to_be_bytes(); + stream.write_all(&resp_len).await.unwrap(); + stream.flush().await.unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; + }); + + let result = send_tcp_once(&addr.to_string(), b"hello", Duration::from_secs(5)).await; + + assert!(result.is_err()); + let err_str = result.unwrap_err().to_string(); + assert!( + err_str.contains("exceeds maximum"), + "expected 'exceeds maximum' error, got: {err_str}" + ); + + server_task.abort(); + } + + // ── send_to_kdc tests ────────────────────────────────────────── + + #[tokio::test] + async fn send_to_kdc_udp_success() { + // Set up a UDP mock KDC. + let server = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let server_addr = server.local_addr().unwrap(); + + let server_task = tokio::spawn(async move { + let mut buf = vec![0u8; UDP_MAX_SIZE]; + let (n, src) = server.recv_from(&mut buf).await.unwrap(); + // Respond with a fake AS-REP (not a KRB-ERROR). + let response = b"\x6b\x05\x30\x03\x02\x01\x05"; // Fake AS-REP-like + server.send_to(response, src).await.unwrap(); + drop(buf[..n].to_vec()); // acknowledge we received + }); + + let config = KdcConfig { + address: server_addr.to_string(), + timeout: Duration::from_secs(5), + }; + + let result = send_to_kdc(&config, b"as-req").await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), b"\x6b\x05\x30\x03\x02\x01\x05"); + + server_task.await.unwrap(); + } + + #[tokio::test] + async fn send_to_kdc_udp_too_big_falls_back_to_tcp() { + // Set up a UDP server that returns KRB_ERR_RESPONSE_TOO_BIG + // and a TCP server that returns a real response. The fallback + // path uses one `KdcConfig.address`, so both servers must share + // a port. + // + // Bind TCP first (more restrictive) and then UDP to its port. + // On Windows Server, the OS port allocator can hand out an + // ephemeral port that's in an excluded range for the other + // protocol (WSAEACCES / 10013). Retry a few times if so; + // a fresh `:0` lottery picks a different port each attempt. + let (udp_server, tcp_listener) = { + let mut last_err: Option = None; + let mut bound = None; + for _ in 0..10 { + let tcp = match TcpListener::bind("127.0.0.1:0").await { + Ok(l) => l, + Err(e) => { + last_err = Some(e); + continue; + } + }; + let port = tcp.local_addr().unwrap().port(); + match UdpSocket::bind(format!("127.0.0.1:{port}")).await { + Ok(udp) => { + bound = Some((udp, tcp)); + break; + } + Err(e) => { + last_err = Some(e); + // TCP listener drops here; try a new port. + } + } + } + bound.unwrap_or_else(|| { + panic!("could not co-bind UDP+TCP on a shared loopback port in 10 attempts: {last_err:?}") + }) + }; + let udp_addr = udp_server.local_addr().unwrap(); + + let udp_task = tokio::spawn(async move { + let mut buf = vec![0u8; UDP_MAX_SIZE]; + let (_, src) = udp_server.recv_from(&mut buf).await.unwrap(); + let error = build_krb_error(KRB_ERR_RESPONSE_TOO_BIG); + udp_server.send_to(&error, src).await.unwrap(); + }); + + let tcp_task = tokio::spawn(async move { + let (mut stream, _) = tcp_listener.accept().await.unwrap(); + // Read request. + let mut len_buf = [0u8; 4]; + stream.read_exact(&mut len_buf).await.unwrap(); + let msg_len = u32::from_be_bytes(len_buf) as usize; + let mut msg = vec![0u8; msg_len]; + stream.read_exact(&mut msg).await.unwrap(); + + // Send TCP response. + let response = b"tcp-kdc-response"; + let resp_len = (response.len() as u32).to_be_bytes(); + stream.write_all(&resp_len).await.unwrap(); + stream.write_all(response).await.unwrap(); + stream.flush().await.unwrap(); + }); + + let config = KdcConfig { + address: udp_addr.to_string(), + timeout: Duration::from_secs(5), + }; + + let result = send_to_kdc(&config, b"as-req-large").await; + assert!(result.is_ok(), "send_to_kdc failed: {:?}", result.err()); + assert_eq!(result.unwrap(), b"tcp-kdc-response"); + + udp_task.await.unwrap(); + tcp_task.await.unwrap(); + } + + // ── discover_kdc tests ───────────────────────────────────────── + + #[tokio::test] + async fn discover_kdc_returns_empty_placeholder() { + let result = discover_kdc("EXAMPLE.COM").await; + assert!(result.is_empty()); + } +} diff --git a/vendor/smb2/src/auth/kerberos/messages.rs b/vendor/smb2/src/auth/kerberos/messages.rs new file mode 100644 index 0000000..eec395e --- /dev/null +++ b/vendor/smb2/src/auth/kerberos/messages.rs @@ -0,0 +1,1631 @@ +//! Kerberos ASN.1/DER message encoding and decoding. +//! +//! Hand-rolled ASN.1/DER for the specific Kerberos message structures needed +//! by an SMB2 client. Follows the same pattern as `spnego.rs`. +//! +//! References: +//! - RFC 4120: The Kerberos Network Authentication Service (V5) +//! - MS-KILE: Kerberos Protocol Extensions + +use crate::auth::der::{der_tlv, parse_der_tlv}; +use crate::auth::kerberos::crypto::EncryptionType; +use crate::Error; + +// --------------------------------------------------------------------------- +// ASN.1 tag constants +// --------------------------------------------------------------------------- + +const TAG_INTEGER: u8 = 0x02; +const TAG_BIT_STRING: u8 = 0x03; +const TAG_OCTET_STRING: u8 = 0x04; +const TAG_GENERAL_STRING: u8 = 0x1b; +const TAG_GENERALIZED_TIME: u8 = 0x18; +const TAG_SEQUENCE: u8 = 0x30; + +// --------------------------------------------------------------------------- +// Core types +// --------------------------------------------------------------------------- + +/// Kerberos principal name. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PrincipalName { + /// Name type: KRB_NT_PRINCIPAL=1, KRB_NT_SRV_INST=2, etc. + pub name_type: i32, + /// Name components: for example, `["user"]` or `["cifs", "server.domain.com"]`. + pub name_string: Vec, +} + +/// Kerberos ticket (opaque to the client: we don't decrypt it). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Ticket { + /// Ticket version number (always 5). + pub tkt_vno: i32, + /// Realm of the ticket. + pub realm: String, + /// Service principal name. + pub sname: PrincipalName, + /// Encrypted part (opaque). + pub enc_part: EncryptedData, + /// Raw DER bytes of the ticket as received from the KDC. + /// Used to pass the ticket through to the AP-REQ verbatim, + /// avoiding re-encoding which could corrupt the encrypted data. + pub raw_bytes: Option>, +} + +/// Generic encrypted data envelope. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EncryptedData { + /// Encryption type identifier. + pub etype: i32, + /// Key version number (optional). + pub kvno: Option, + /// Ciphertext bytes. + pub cipher: Vec, +} + +/// Pre-authentication data element. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PaData { + /// Pre-authentication data type. + pub padata_type: i32, + /// Pre-authentication data value. + pub padata_value: Vec, +} + +/// Parsed KDC-REP (AS-REP or TGS-REP). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct KdcRep { + /// Message type: 11 = AS-REP, 13 = TGS-REP. + pub msg_type: i32, + /// Client realm. + pub crealm: String, + /// Client principal name. + pub cname: PrincipalName, + /// Ticket. + pub ticket: Ticket, + /// Encrypted part (to be decrypted by the client). + pub enc_part: EncryptedData, +} + +/// Parsed decrypted EncKDCRepPart. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EncKdcRepPart { + /// Session key. + pub key: EncryptionKey, + /// Nonce from the request. + pub nonce: u32, + /// Ticket flags as a bit field. + pub flags: u32, + /// Authentication time. + pub authtime: String, + /// Ticket end time. + pub endtime: String, + /// Service realm. + pub srealm: String, + /// Service principal name. + pub sname: PrincipalName, +} + +/// Encryption key (keytype + key value). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EncryptionKey { + /// Key type (etype number). + pub keytype: i32, + /// Key value bytes. + pub keyvalue: Vec, +} + +/// Parsed AP-REP (`APPLICATION [15]`). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ApRep { + /// Encrypted part (to be decrypted with the session key or subkey). + pub enc_part: EncryptedData, +} + +/// Parsed decrypted EncAPRepPart. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EncApRepPart { + /// Optional sub-session key from the server. If present, this overrides + /// the client's subkey as the session key for the application (SMB). + pub subkey: Option, + /// Optional sequence number. + pub seq_number: Option, +} + +/// Parsed KRB-ERROR message. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct KrbError { + /// Error code. + pub error_code: i32, + /// Client realm (optional). + pub crealm: Option, + /// Server realm. + pub realm: String, + /// Server principal name. + pub sname: PrincipalName, + /// Error text (optional). + pub e_text: Option, + /// Error data (optional). + pub e_data: Option>, +} + +// Core DER encoding/decoding helpers (der_length, der_tlv, parse_der_length, +// parse_der_tlv) are in `crate::auth::der`. Imported at the top. + +// --------------------------------------------------------------------------- +// DER encoding helpers (Kerberos-specific) +// --------------------------------------------------------------------------- + +/// Encode a context-specific constructed tag: `[tag_num]`. +fn der_context(tag_num: u8, data: &[u8]) -> Vec { + der_tlv(0xa0 | tag_num, data) +} + +/// Encode an APPLICATION constructed tag: `[APPLICATION tag_num]`. +fn der_application(tag_num: u8, data: &[u8]) -> Vec { + der_tlv(0x60 | tag_num, data) +} + +/// Encode an ASN.1 INTEGER (signed, big-endian, minimal bytes). +fn der_integer(val: i32) -> Vec { + let bytes = val.to_be_bytes(); + // Find the first significant byte, keeping sign correct. + let mut start = 0; + if val >= 0 { + // Skip leading 0x00 bytes, but keep one if the next byte has the high bit set. + while start < 3 && bytes[start] == 0x00 && bytes[start + 1] & 0x80 == 0 { + start += 1; + } + } else { + // Skip leading 0xff bytes, but keep one if the next byte doesn't have the high bit set. + while start < 3 && bytes[start] == 0xff && bytes[start + 1] & 0x80 != 0 { + start += 1; + } + } + der_tlv(TAG_INTEGER, &bytes[start..]) +} + +/// Encode an unsigned 32-bit value as ASN.1 INTEGER. +fn der_integer_u32(val: u32) -> Vec { + // Treat as i64 to handle the full u32 range without sign issues. + let val64 = val as i64; + let bytes = val64.to_be_bytes(); + let mut start = 0; + while start < 7 && bytes[start] == 0x00 && bytes[start + 1] & 0x80 == 0 { + start += 1; + } + der_tlv(TAG_INTEGER, &bytes[start..]) +} + +/// Encode a DER OCTET STRING. +fn der_octet_string(data: &[u8]) -> Vec { + der_tlv(TAG_OCTET_STRING, data) +} + +/// Encode a DER GeneralString. +fn der_general_string(s: &str) -> Vec { + der_tlv(TAG_GENERAL_STRING, s.as_bytes()) +} + +/// Encode a DER GeneralizedTime (for example, `"20260408120000Z"`). +fn der_generalized_time(time: &str) -> Vec { + der_tlv(TAG_GENERALIZED_TIME, time.as_bytes()) +} + +/// Encode a DER BIT STRING. `bits` is the raw bytes; `unused` is the number +/// of unused bits in the last byte (usually 0 for 32-bit flags). +fn der_bit_string(bits: &[u8], unused: u8) -> Vec { + let mut data = vec![unused]; + data.extend_from_slice(bits); + der_tlv(TAG_BIT_STRING, &data) +} + +/// Encode a DER SEQUENCE from pre-encoded items. +fn der_sequence(items: &[&[u8]]) -> Vec { + let mut contents = Vec::new(); + for item in items { + contents.extend_from_slice(item); + } + der_tlv(TAG_SEQUENCE, &contents) +} + +// --------------------------------------------------------------------------- +// DER parsing helpers (Kerberos-specific) +// --------------------------------------------------------------------------- + +/// Parse all TLV elements in a SEQUENCE body, returning `(tag, value)` pairs. +fn parse_sequence_fields(data: &[u8]) -> Result)>, Error> { + let mut fields = Vec::new(); + let mut pos = 0; + while pos < data.len() { + let (tag, value, consumed) = parse_der_tlv(&data[pos..])?; + fields.push((tag, value.to_vec())); + pos += consumed; + } + Ok(fields) +} + +/// Parse a DER INTEGER value (already unwrapped from TLV), returning i32. +fn parse_integer_value(data: &[u8]) -> Result { + if data.is_empty() { + return Err(Error::invalid_data("Kerberos: empty INTEGER")); + } + // Sign-extend from arbitrary-length big-endian. + let negative = data[0] & 0x80 != 0; + let mut val: i64 = if negative { -1 } else { 0 }; + for &b in data { + val = (val << 8) | (b as i64); + } + Ok(val as i32) +} + +/// Parse a DER INTEGER TLV, returning i32. +fn parse_der_integer(data: &[u8]) -> Result { + let (tag, value, _) = parse_der_tlv(data)?; + if tag != TAG_INTEGER { + return Err(Error::invalid_data(format!( + "Kerberos: expected INTEGER (0x02), got 0x{tag:02x}" + ))); + } + parse_integer_value(value) +} + +/// Parse a DER INTEGER TLV, returning u32. +fn parse_der_integer_u32(data: &[u8]) -> Result { + let val = parse_der_integer(data)?; + Ok(val as u32) +} + +/// Parse a DER OCTET STRING TLV, returning the raw bytes. +fn parse_der_octet_string(data: &[u8]) -> Result, Error> { + let (tag, value, _) = parse_der_tlv(data)?; + if tag != TAG_OCTET_STRING { + return Err(Error::invalid_data(format!( + "Kerberos: expected OCTET STRING (0x04), got 0x{tag:02x}" + ))); + } + Ok(value.to_vec()) +} + +/// Parse a DER GeneralString TLV, returning the string. +fn parse_der_general_string(data: &[u8]) -> Result { + let (tag, value, _) = parse_der_tlv(data)?; + if tag != TAG_GENERAL_STRING { + return Err(Error::invalid_data(format!( + "Kerberos: expected GeneralString (0x1b), got 0x{tag:02x}" + ))); + } + String::from_utf8(value.to_vec()) + .map_err(|_| Error::invalid_data("Kerberos: invalid UTF-8 in GeneralString")) +} + +/// Parse a DER GeneralizedTime TLV, returning the time string. +fn parse_der_generalized_time(data: &[u8]) -> Result { + let (tag, value, _) = parse_der_tlv(data)?; + if tag != TAG_GENERALIZED_TIME { + return Err(Error::invalid_data(format!( + "Kerberos: expected GeneralizedTime (0x18), got 0x{tag:02x}" + ))); + } + String::from_utf8(value.to_vec()) + .map_err(|_| Error::invalid_data("Kerberos: invalid UTF-8 in GeneralizedTime")) +} + +/// Parse a DER BIT STRING TLV, returning the raw bit bytes (without the +/// unused-bits prefix byte) and the number of unused bits. +fn parse_der_bit_string(data: &[u8]) -> Result<(Vec, u8), Error> { + let (tag, value, _) = parse_der_tlv(data)?; + if tag != TAG_BIT_STRING { + return Err(Error::invalid_data(format!( + "Kerberos: expected BIT STRING (0x03), got 0x{tag:02x}" + ))); + } + if value.is_empty() { + return Err(Error::invalid_data("Kerberos: empty BIT STRING")); + } + let unused = value[0]; + Ok((value[1..].to_vec(), unused)) +} + +// --------------------------------------------------------------------------- +// Encoding compound types +// --------------------------------------------------------------------------- + +/// Encode a PrincipalName as DER. +fn encode_principal_name(name: &PrincipalName) -> Vec { + // PrincipalName ::= SEQUENCE { + // name-type [0] Int32, + // name-string [1] SEQUENCE OF KerberosString (GeneralString) + // } + let name_type = der_context(0, &der_integer(name.name_type)); + let name_strings: Vec> = name + .name_string + .iter() + .map(|s| der_general_string(s)) + .collect(); + let name_refs: Vec<&[u8]> = name_strings.iter().map(|v| v.as_slice()).collect(); + let name_seq = der_sequence(&name_refs); + let name_string = der_context(1, &name_seq); + der_sequence(&[&name_type, &name_string]) +} + +/// Encode an EncryptedData as DER. +fn encode_encrypted_data(ed: &EncryptedData) -> Vec { + // EncryptedData ::= SEQUENCE { + // etype [0] Int32, + // kvno [1] UInt32 OPTIONAL, + // cipher [2] OCTET STRING + // } + let etype = der_context(0, &der_integer(ed.etype)); + let cipher = der_context(2, &der_octet_string(&ed.cipher)); + if let Some(kvno) = ed.kvno { + let kvno_enc = der_context(1, &der_integer(kvno)); + der_sequence(&[&etype, &kvno_enc, &cipher]) + } else { + der_sequence(&[&etype, &cipher]) + } +} + +/// Encode a Ticket as DER (`APPLICATION [1]`). +fn encode_ticket(ticket: &Ticket) -> Vec { + // Ticket ::= [APPLICATION 1] SEQUENCE { + // tkt-vno [0] INTEGER (5), + // realm [1] Realm (GeneralString), + // sname [2] PrincipalName, + // enc-part [3] EncryptedData + // } + let tkt_vno = der_context(0, &der_integer(ticket.tkt_vno)); + let realm = der_context(1, &der_general_string(&ticket.realm)); + let sname = der_context(2, &encode_principal_name(&ticket.sname)); + let enc_part = der_context(3, &encode_encrypted_data(&ticket.enc_part)); + let seq = der_sequence(&[&tkt_vno, &realm, &sname, &enc_part]); + der_application(1, &seq) +} + +/// Encode a PaData as DER. +fn encode_pa_data(pa: &PaData) -> Vec { + // PA-DATA ::= SEQUENCE { + // padata-type [1] Int32, + // padata-value [2] OCTET STRING + // } + let padata_type = der_context(1, &der_integer(pa.padata_type)); + let padata_value = der_context(2, &der_octet_string(&pa.padata_value)); + der_sequence(&[&padata_type, &padata_value]) +} + +// --------------------------------------------------------------------------- +// Parsing compound types +// --------------------------------------------------------------------------- + +/// Parse a PrincipalName from DER bytes. +fn parse_principal_name(data: &[u8]) -> Result { + let (tag, seq_data, _) = parse_der_tlv(data)?; + if tag != TAG_SEQUENCE { + return Err(Error::invalid_data(format!( + "Kerberos: expected SEQUENCE for PrincipalName, got 0x{tag:02x}" + ))); + } + let fields = parse_sequence_fields(seq_data)?; + let mut name_type = None; + let mut name_string = Vec::new(); + for (ftag, fvalue) in &fields { + match ftag { + 0xa0 => name_type = Some(parse_der_integer(fvalue)?), + 0xa1 => { + // SEQUENCE OF GeneralString + let (stag, sdata, _) = parse_der_tlv(fvalue)?; + if stag != TAG_SEQUENCE { + return Err(Error::invalid_data( + "Kerberos: expected SEQUENCE for name-string", + )); + } + let mut pos = 0; + while pos < sdata.len() { + let (_, sv, consumed) = parse_der_tlv(&sdata[pos..])?; + name_string.push(String::from_utf8(sv.to_vec()).map_err(|_| { + Error::invalid_data("Kerberos: invalid UTF-8 in name-string") + })?); + pos += consumed; + } + } + _ => {} // ignore unknown fields + } + } + Ok(PrincipalName { + name_type: name_type + .ok_or_else(|| Error::invalid_data("Kerberos: missing name-type in PrincipalName"))?, + name_string, + }) +} + +/// Parse an EncryptedData from DER bytes. +fn parse_encrypted_data(data: &[u8]) -> Result { + let (tag, seq_data, _) = parse_der_tlv(data)?; + if tag != TAG_SEQUENCE { + return Err(Error::invalid_data(format!( + "Kerberos: expected SEQUENCE for EncryptedData, got 0x{tag:02x}" + ))); + } + let fields = parse_sequence_fields(seq_data)?; + let mut etype = None; + let mut kvno = None; + let mut cipher = None; + for (ftag, fvalue) in &fields { + match ftag { + 0xa0 => etype = Some(parse_der_integer(fvalue)?), + 0xa1 => kvno = Some(parse_der_integer(fvalue)?), + 0xa2 => cipher = Some(parse_der_octet_string(fvalue)?), + _ => {} + } + } + Ok(EncryptedData { + etype: etype + .ok_or_else(|| Error::invalid_data("Kerberos: missing etype in EncryptedData"))?, + kvno, + cipher: cipher + .ok_or_else(|| Error::invalid_data("Kerberos: missing cipher in EncryptedData"))?, + }) +} + +/// Parse a Ticket from DER bytes (expects `APPLICATION [1]` wrapper). +/// +/// Stores the raw DER bytes so the ticket can be passed through to the +/// AP-REQ verbatim. Re-encoding the ticket from parsed fields can produce +/// different DER bytes (e.g., different length encoding, field order), which +/// corrupts the encrypted data and causes the server to fail decryption. +pub fn parse_ticket(data: &[u8]) -> Result { + let (tag, inner, total_consumed) = parse_der_tlv(data)?; + // APPLICATION [1] = 0x61 + if tag != 0x61 { + return Err(Error::invalid_data(format!( + "Kerberos: expected APPLICATION [1] (0x61) for Ticket, got 0x{tag:02x}" + ))); + } + + // Store raw bytes for verbatim pass-through. + let raw_bytes = data[..total_consumed].to_vec(); + + let (seq_tag, seq_data, _) = parse_der_tlv(inner)?; + if seq_tag != TAG_SEQUENCE { + return Err(Error::invalid_data(format!( + "Kerberos: expected SEQUENCE in Ticket, got 0x{seq_tag:02x}" + ))); + } + let fields = parse_sequence_fields(seq_data)?; + let mut tkt_vno = None; + let mut realm = None; + let mut sname = None; + let mut enc_part = None; + for (ftag, fvalue) in &fields { + match ftag { + 0xa0 => tkt_vno = Some(parse_der_integer(fvalue)?), + 0xa1 => realm = Some(parse_der_general_string(fvalue)?), + 0xa2 => sname = Some(parse_principal_name(fvalue)?), + 0xa3 => enc_part = Some(parse_encrypted_data(fvalue)?), + _ => {} + } + } + Ok(Ticket { + tkt_vno: tkt_vno + .ok_or_else(|| Error::invalid_data("Kerberos: missing tkt-vno in Ticket"))?, + realm: realm.ok_or_else(|| Error::invalid_data("Kerberos: missing realm in Ticket"))?, + sname: sname.ok_or_else(|| Error::invalid_data("Kerberos: missing sname in Ticket"))?, + enc_part: enc_part + .ok_or_else(|| Error::invalid_data("Kerberos: missing enc-part in Ticket"))?, + raw_bytes: Some(raw_bytes), + }) +} + +/// Parse an EncryptionKey from DER bytes. +fn parse_encryption_key(data: &[u8]) -> Result { + let (tag, seq_data, _) = parse_der_tlv(data)?; + if tag != TAG_SEQUENCE { + return Err(Error::invalid_data(format!( + "Kerberos: expected SEQUENCE for EncryptionKey, got 0x{tag:02x}" + ))); + } + let fields = parse_sequence_fields(seq_data)?; + let mut keytype = None; + let mut keyvalue = None; + for (ftag, fvalue) in &fields { + match ftag { + 0xa0 => keytype = Some(parse_der_integer(fvalue)?), + 0xa1 => keyvalue = Some(parse_der_octet_string(fvalue)?), + _ => {} + } + } + Ok(EncryptionKey { + keytype: keytype + .ok_or_else(|| Error::invalid_data("Kerberos: missing keytype in EncryptionKey"))?, + keyvalue: keyvalue + .ok_or_else(|| Error::invalid_data("Kerberos: missing keyvalue in EncryptionKey"))?, + }) +} + +// --------------------------------------------------------------------------- +// Public API: encoding +// --------------------------------------------------------------------------- + +/// Encode a KRB_AS_REQ message (`APPLICATION [10]`). +pub fn encode_as_req( + cname: &PrincipalName, + realm: &str, + sname: &PrincipalName, + nonce: u32, + etypes: &[EncryptionType], + padata: &[PaData], +) -> Vec { + encode_kdc_req(10, Some(cname), realm, sname, nonce, etypes, padata) +} + +/// Encode the KDC-REQ-BODY for a TGS-REQ. +/// +/// Returns the DER-encoded body, which is needed for computing the +/// checksum in the Authenticator (per RFC 4120 section 7.2.2). +pub fn encode_tgs_req_body( + realm: &str, + sname: &PrincipalName, + nonce: u32, + etypes: &[EncryptionType], +) -> Vec { + encode_kdc_req_body(None, realm, sname, nonce, etypes) +} + +/// Encode a KRB_TGS_REQ message (`APPLICATION [12]`). +/// +/// The `tgt_ap_req` is an AP-REQ wrapping the TGT, placed in PA-TGS-REQ (padata type 1). +/// The `req_body` must be the same bytes returned by `encode_tgs_req_body` (used for +/// the Authenticator checksum). +pub fn encode_tgs_req( + realm: &str, + sname: &PrincipalName, + nonce: u32, + etypes: &[EncryptionType], + tgt_ap_req: &[u8], +) -> Vec { + let padata = [PaData { + padata_type: 1, // PA-TGS-REQ + padata_value: tgt_ap_req.to_vec(), + }]; + encode_kdc_req(12, None, realm, sname, nonce, etypes, &padata) +} + +/// Encode a KRB_AP_REQ message (`APPLICATION [14]`). +/// +/// When `mutual_required` is true, sets the mutual-required bit (bit 2) in +/// AP-OPTIONS, requesting the server to prove its identity via an AP-REP. +pub fn encode_ap_req( + ticket: &Ticket, + encrypted_authenticator: &EncryptedData, + mutual_required: bool, +) -> Vec { + // AP-REQ ::= [APPLICATION 14] SEQUENCE { + // pvno [0] INTEGER (5), + // msg-type [1] INTEGER (14), + // ap-options [2] APOptions (BIT STRING, 32 bits), + // ticket [3] Ticket, + // authenticator [4] EncryptedData + // } + let pvno = der_context(0, &der_integer(5)); + let msg_type = der_context(1, &der_integer(14)); + // AP-OPTIONS: bit 2 = mutual-required (0x20 in the first byte). + let opts_byte0 = if mutual_required { 0x20 } else { 0x00 }; + let ap_options = der_context(2, &der_bit_string(&[opts_byte0, 0x00, 0x00, 0x00], 0)); + // Use raw ticket bytes if available (preserves exact DER encoding from KDC). + // Re-encoding can produce different bytes and corrupt the encrypted ticket. + let ticket_raw = ticket + .raw_bytes + .as_ref() + .map(|b| der_context(3, b)) + .unwrap_or_else(|| der_context(3, &encode_ticket(ticket))); + let authenticator = der_context(4, &encode_encrypted_data(encrypted_authenticator)); + let seq = der_sequence(&[&pvno, &msg_type, &ap_options, &ticket_raw, &authenticator]); + der_application(14, &seq) +} + +/// Encode an Authenticator (`APPLICATION [2]`), to be encrypted before embedding in AP-REQ. +/// +/// The optional `cksum` parameter adds a checksum field `[3]`, used in TGS-REQ +/// to authenticate the KDC-REQ-BODY (RFC 4120 section 7.2.2). +pub fn encode_authenticator( + crealm: &str, + cname: &PrincipalName, + ctime: &str, + cusec: u32, + subkey: Option<(&[u8], i32)>, + seq_number: Option, + cksum: Option<(&[u8], i32)>, +) -> Vec { + // Authenticator ::= [APPLICATION 2] SEQUENCE { + // authenticator-vno [0] INTEGER (5), + // crealm [1] Realm (GeneralString), + // cname [2] PrincipalName, + // cksum [3] Checksum OPTIONAL, + // cusec [4] Microseconds (INTEGER), + // ctime [5] KerberosTime (GeneralizedTime), + // subkey [6] EncryptionKey OPTIONAL, + // seq-number [7] UInt32 OPTIONAL, + // } + let auth_vno = der_context(0, &der_integer(5)); + let crealm_enc = der_context(1, &der_general_string(crealm)); + let cname_enc = der_context(2, &encode_principal_name(cname)); + + let mut items: Vec> = vec![auth_vno, crealm_enc, cname_enc]; + + if let Some((checksum_data, checksum_type)) = cksum { + // Checksum ::= SEQUENCE { cksumtype [0] Int32, checksum [1] OCTET STRING } + let cktype = der_context(0, &der_integer(checksum_type)); + let ckval = der_context(1, &der_octet_string(checksum_data)); + let ck = der_sequence(&[&cktype, &ckval]); + items.push(der_context(3, &ck)); + } + + let cusec_enc = der_context(4, &der_integer_u32(cusec)); + let ctime_enc = der_context(5, &der_generalized_time(ctime)); + items.push(cusec_enc); + items.push(ctime_enc); + + if let Some((key_value, key_type)) = subkey { + // EncryptionKey ::= SEQUENCE { keytype [0], keyvalue [1] } + let kt = der_context(0, &der_integer(key_type)); + let kv = der_context(1, &der_octet_string(key_value)); + let ek = der_sequence(&[&kt, &kv]); + items.push(der_context(6, &ek)); + } + + if let Some(seq) = seq_number { + items.push(der_context(7, &der_integer_u32(seq))); + } + + let item_refs: Vec<&[u8]> = items.iter().map(|v| v.as_slice()).collect(); + let seq = der_sequence(&item_refs); + der_application(2, &seq) +} + +/// Encode a PA-ENC-TIMESTAMP pre-authentication data (the plaintext to be encrypted). +/// +/// Returns the DER encoding of `PA-ENC-TS-ENC ::= SEQUENCE { patimestamp [0] GeneralizedTime, pausec [1] Microseconds }`. +pub fn encode_pa_enc_timestamp(ctime: &str, cusec: u32) -> Vec { + let patimestamp = der_context(0, &der_generalized_time(ctime)); + let pausec = der_context(1, &der_integer_u32(cusec)); + der_sequence(&[&patimestamp, &pausec]) +} + +// --------------------------------------------------------------------------- +// Public API: parsing +// --------------------------------------------------------------------------- + +/// Parse a KRB_AS_REP (`APPLICATION [11]`) or KRB_TGS_REP (`APPLICATION [13]`) message. +pub fn parse_kdc_rep(data: &[u8]) -> Result { + let (tag, inner, _) = parse_der_tlv(data)?; + // APPLICATION [11] = 0x6b, APPLICATION [13] = 0x6d + let expected_msg_type = match tag { + 0x6b => 11, + 0x6d => 13, + _ => { + return Err(Error::invalid_data(format!( + "Kerberos: expected APPLICATION [11] or [13] for KDC-REP, got 0x{tag:02x}" + ))); + } + }; + + let (seq_tag, seq_data, _) = parse_der_tlv(inner)?; + if seq_tag != TAG_SEQUENCE { + return Err(Error::invalid_data(format!( + "Kerberos: expected SEQUENCE in KDC-REP, got 0x{seq_tag:02x}" + ))); + } + let fields = parse_sequence_fields(seq_data)?; + let mut msg_type = None; + let mut crealm = None; + let mut cname = None; + let mut ticket = None; + let mut enc_part = None; + + for (ftag, fvalue) in &fields { + match ftag { + // RFC 4120 section 5.4.2: KDC-REP fields + 0xa0 => { + // pvno [0] — skip validation + } + 0xa1 => msg_type = Some(parse_der_integer(fvalue)?), + // [2] padata — skip + 0xa3 => crealm = Some(parse_der_general_string(fvalue)?), + 0xa4 => cname = Some(parse_principal_name(fvalue)?), + 0xa5 => ticket = Some(parse_ticket(fvalue)?), + 0xa6 => enc_part = Some(parse_encrypted_data(fvalue)?), + _ => {} + } + } + + let msg_type = + msg_type.ok_or_else(|| Error::invalid_data("Kerberos: missing msg-type in KDC-REP"))?; + if msg_type != expected_msg_type { + return Err(Error::invalid_data(format!( + "Kerberos: KDC-REP msg-type mismatch: tag says {expected_msg_type}, field says {msg_type}" + ))); + } + + Ok(KdcRep { + msg_type, + crealm: crealm.ok_or_else(|| Error::invalid_data("Kerberos: missing crealm in KDC-REP"))?, + cname: cname.ok_or_else(|| Error::invalid_data("Kerberos: missing cname in KDC-REP"))?, + ticket: ticket.ok_or_else(|| Error::invalid_data("Kerberos: missing ticket in KDC-REP"))?, + enc_part: enc_part + .ok_or_else(|| Error::invalid_data("Kerberos: missing enc-part in KDC-REP"))?, + }) +} + +/// Parse the decrypted EncKDCRepPart. +/// +/// This can be wrapped in `APPLICATION [25]` (EncASRepPart) or `APPLICATION [26]` (EncTGSRepPart), +/// or may appear as a bare SEQUENCE (some implementations). +pub fn parse_enc_kdc_rep_part(data: &[u8]) -> Result { + let (tag, inner, _) = parse_der_tlv(data)?; + + // APPLICATION [25] = 0x79, APPLICATION [26] = 0x7a, or bare SEQUENCE + let seq_data = if tag == 0x79 || tag == 0x7a { + let (seq_tag, sd, _) = parse_der_tlv(inner)?; + if seq_tag != TAG_SEQUENCE { + return Err(Error::invalid_data(format!( + "Kerberos: expected SEQUENCE in EncKDCRepPart, got 0x{seq_tag:02x}" + ))); + } + sd + } else if tag == TAG_SEQUENCE { + inner + } else { + return Err(Error::invalid_data(format!( + "Kerberos: expected APPLICATION [25/26] or SEQUENCE for EncKDCRepPart, got 0x{tag:02x}" + ))); + }; + + let fields = parse_sequence_fields(seq_data)?; + let mut key = None; + let mut nonce = None; + let mut flags = None; + let mut authtime = None; + let mut endtime = None; + let mut srealm = None; + let mut sname = None; + + for (ftag, fvalue) in &fields { + match ftag { + 0xa0 => key = Some(parse_encryption_key(fvalue)?), + // [1] last-req — skip + 0xa2 => nonce = Some(parse_der_integer_u32(fvalue)?), + // [3] key-expiration — skip + 0xa4 => { + let (bits, _unused) = parse_der_bit_string(fvalue)?; + if bits.len() >= 4 { + flags = Some(u32::from_be_bytes([bits[0], bits[1], bits[2], bits[3]])); + } + } + 0xa5 => authtime = Some(parse_der_generalized_time(fvalue)?), + // [6] starttime — skip + 0xa7 => endtime = Some(parse_der_generalized_time(fvalue)?), + // [8] renew-till — skip + 0xa9 => srealm = Some(parse_der_general_string(fvalue)?), + 0xaa => sname = Some(parse_principal_name(fvalue)?), + _ => {} + } + } + + Ok(EncKdcRepPart { + key: key.ok_or_else(|| Error::invalid_data("Kerberos: missing key in EncKDCRepPart"))?, + nonce: nonce + .ok_or_else(|| Error::invalid_data("Kerberos: missing nonce in EncKDCRepPart"))?, + flags: flags.unwrap_or(0), + authtime: authtime + .ok_or_else(|| Error::invalid_data("Kerberos: missing authtime in EncKDCRepPart"))?, + endtime: endtime + .ok_or_else(|| Error::invalid_data("Kerberos: missing endtime in EncKDCRepPart"))?, + srealm: srealm + .ok_or_else(|| Error::invalid_data("Kerberos: missing srealm in EncKDCRepPart"))?, + sname: sname + .ok_or_else(|| Error::invalid_data("Kerberos: missing sname in EncKDCRepPart"))?, + }) +} + +/// Parse a KRB-ERROR message (`APPLICATION [30]`). +pub fn parse_krb_error(data: &[u8]) -> Result { + let (tag, inner, _) = parse_der_tlv(data)?; + // APPLICATION [30] = 0x7e + if tag != 0x7e { + return Err(Error::invalid_data(format!( + "Kerberos: expected APPLICATION [30] (0x7e) for KRB-ERROR, got 0x{tag:02x}" + ))); + } + let (seq_tag, seq_data, _) = parse_der_tlv(inner)?; + if seq_tag != TAG_SEQUENCE { + return Err(Error::invalid_data(format!( + "Kerberos: expected SEQUENCE in KRB-ERROR, got 0x{seq_tag:02x}" + ))); + } + let fields = parse_sequence_fields(seq_data)?; + + let mut error_code = None; + let mut crealm = None; + let mut realm = None; + let mut sname = None; + let mut e_text = None; + let mut e_data = None; + + for (ftag, fvalue) in &fields { + match ftag { + // [0] pvno — skip + // [1] msg-type — skip + // [2] ctime — skip + // [3] cusec — skip + // [4] stime — skip + // [5] susec — skip + 0xa6 => error_code = Some(parse_der_integer(fvalue)?), + 0xa7 => crealm = Some(parse_der_general_string(fvalue)?), + 0xa8 => { + // cname — skip (we don't need it in the error struct, but parse to validate) + } + 0xa9 => realm = Some(parse_der_general_string(fvalue)?), + 0xaa => sname = Some(parse_principal_name(fvalue)?), + 0xab => e_text = Some(parse_der_general_string(fvalue)?), + 0xac => e_data = Some(parse_der_octet_string(fvalue)?), + _ => {} + } + } + + Ok(KrbError { + error_code: error_code + .ok_or_else(|| Error::invalid_data("Kerberos: missing error-code in KRB-ERROR"))?, + crealm, + realm: realm.ok_or_else(|| Error::invalid_data("Kerberos: missing realm in KRB-ERROR"))?, + sname: sname.ok_or_else(|| Error::invalid_data("Kerberos: missing sname in KRB-ERROR"))?, + e_text, + e_data, + }) +} + +/// Unwrap a GSS-API token: `APPLICATION [0] { OID, inner-data }`. +/// +/// Returns the inner data after the OID as a `Vec`. +pub fn parse_gss_api_wrapper(data: &[u8]) -> Result<(Vec, Vec, usize), Error> { + let (tag, inner, total) = parse_der_tlv(data)?; + if tag != 0x60 { + return Err(Error::invalid_data(format!( + "Kerberos: expected GSS-API wrapper (0x60), got 0x{tag:02x}" + ))); + } + // Skip the OID TLV. + let (_oid_tag, oid_data, oid_consumed) = parse_der_tlv(inner)?; + let oid = oid_data.to_vec(); + let rest = inner[oid_consumed..].to_vec(); + Ok((oid, rest, total)) +} + +/// Parse a KRB_AP_REP message (`APPLICATION [15]`). +/// +/// Handles both bare AP-REP and GSS-API wrapped tokens (`APPLICATION [0]` +/// containing an OID followed by the AP-REP). +pub fn parse_ap_rep(data: &[u8]) -> Result { + let (tag, inner, _) = parse_der_tlv(data)?; + + // If wrapped in GSS-API APPLICATION [0], unwrap first. + let inner = if tag == 0x60 { + // APPLICATION [0] { OID, AP-REP } + // Skip the OID TLV to get to the AP-REP. + let (_oid_tag, _oid_data, oid_consumed) = parse_der_tlv(inner)?; + let ap_rep_data = &inner[oid_consumed..]; + let (ap_tag, ap_inner, _) = parse_der_tlv(ap_rep_data)?; + if ap_tag != 0x6f { + return Err(Error::invalid_data(format!( + "Kerberos: expected AP-REP (0x6f) inside GSS wrapper, got 0x{ap_tag:02x}" + ))); + } + ap_inner + } else if tag == 0x6f { + inner + } else { + return Err(Error::invalid_data(format!( + "Kerberos: expected APPLICATION [15] (0x6f) or GSS wrapper (0x60) for AP-REP, got 0x{tag:02x}" + ))); + }; + let (seq_tag, seq_data, _) = parse_der_tlv(inner)?; + if seq_tag != TAG_SEQUENCE { + return Err(Error::invalid_data(format!( + "Kerberos: expected SEQUENCE in AP-REP, got 0x{seq_tag:02x}" + ))); + } + let fields = parse_sequence_fields(seq_data)?; + + let mut enc_part = None; + for (ftag, fvalue) in &fields { + // [0] pvno — skip, [1] msg-type — skip + if ftag == &0xa2 { + enc_part = Some(parse_encrypted_data(fvalue)?); + } + } + + Ok(ApRep { + enc_part: enc_part + .ok_or_else(|| Error::invalid_data("Kerberos: missing enc-part in AP-REP"))?, + }) +} + +/// Parse the decrypted EncAPRepPart (`APPLICATION [27]`). +pub fn parse_enc_ap_rep_part(data: &[u8]) -> Result { + let (tag, inner, _) = parse_der_tlv(data)?; + // APPLICATION [27] = 0x7b, or bare SEQUENCE + let seq_data = match tag { + 0x7b => { + let (seq_tag, seq_data, _) = parse_der_tlv(inner)?; + if seq_tag != TAG_SEQUENCE { + return Err(Error::invalid_data(format!( + "Kerberos: expected SEQUENCE in EncAPRepPart, got 0x{seq_tag:02x}" + ))); + } + seq_data + } + TAG_SEQUENCE => inner, + _ => { + return Err(Error::invalid_data(format!( + "Kerberos: expected APPLICATION [27] or SEQUENCE for EncAPRepPart, got 0x{tag:02x}" + ))); + } + }; + + let fields = parse_sequence_fields(seq_data)?; + + let mut subkey = None; + let mut seq_number = None; + for (ftag, fvalue) in &fields { + match ftag { + // [0] ctime — skip + // [1] cusec — skip + 0xa2 => subkey = Some(parse_encryption_key(fvalue)?), + 0xa3 => seq_number = Some(parse_der_integer_u32(fvalue)?), + _ => {} + } + } + + Ok(EncApRepPart { subkey, seq_number }) +} + +// --------------------------------------------------------------------------- +// Internal: KDC-REQ encoding (shared by AS-REQ and TGS-REQ) +// --------------------------------------------------------------------------- + +/// Encode just the KDC-REQ-BODY portion of a KDC-REQ. +fn encode_kdc_req_body( + cname: Option<&PrincipalName>, + realm: &str, + sname: &PrincipalName, + nonce: u32, + etypes: &[EncryptionType], +) -> Vec { + let kdc_options = der_context(0, &der_bit_string(&[0x40, 0x81, 0x00, 0x10], 0)); + let mut body_items: Vec> = vec![kdc_options]; + + if let Some(cn) = cname { + body_items.push(der_context(1, &encode_principal_name(cn))); + } + body_items.push(der_context(2, &der_general_string(realm))); + body_items.push(der_context(3, &encode_principal_name(sname))); + // till: set far in the future + body_items.push(der_context(5, &der_generalized_time("20370913024805Z"))); + body_items.push(der_context(7, &der_integer_u32(nonce))); + + // etype: SEQUENCE OF INTEGER + let etype_ints: Vec> = etypes.iter().map(|e| der_integer(*e as i32)).collect(); + let etype_refs: Vec<&[u8]> = etype_ints.iter().map(|v| v.as_slice()).collect(); + let etype_seq = der_sequence(&etype_refs); + body_items.push(der_context(8, &etype_seq)); + + let body_refs: Vec<&[u8]> = body_items.iter().map(|v| v.as_slice()).collect(); + der_sequence(&body_refs) +} + +fn encode_kdc_req( + msg_type_val: i32, + cname: Option<&PrincipalName>, + realm: &str, + sname: &PrincipalName, + nonce: u32, + etypes: &[EncryptionType], + padata: &[PaData], +) -> Vec { + let req_body = encode_kdc_req_body(cname, realm, sname, nonce, etypes); + + // KDC-REQ + let pvno = der_context(1, &der_integer(5)); + let msg_type = der_context(2, &der_integer(msg_type_val)); + + let mut kdc_req_items: Vec> = vec![pvno, msg_type]; + + if !padata.is_empty() { + let pa_items: Vec> = padata.iter().map(encode_pa_data).collect(); + let pa_refs: Vec<&[u8]> = pa_items.iter().map(|v| v.as_slice()).collect(); + let pa_seq = der_sequence(&pa_refs); + kdc_req_items.push(der_context(3, &pa_seq)); + } + + kdc_req_items.push(der_context(4, &req_body)); + + let kdc_req_refs: Vec<&[u8]> = kdc_req_items.iter().map(|v| v.as_slice()).collect(); + let kdc_req_seq = der_sequence(&kdc_req_refs); + + // APPLICATION tag for the message type + let app_tag = match msg_type_val { + 10 => 10, // AS-REQ + 12 => 12, // TGS-REQ + _ => msg_type_val as u8, + }; + der_application(app_tag, &kdc_req_seq) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + // ----------------------------------------------------------------------- + // DER helper tests + // ----------------------------------------------------------------------- + + #[test] + fn test_der_integer_positive() { + // 5 should encode as 02 01 05 + let encoded = der_integer(5); + assert_eq!(encoded, vec![0x02, 0x01, 0x05]); + } + + #[test] + fn test_der_integer_zero() { + // 0 should encode as 02 01 00 + let encoded = der_integer(0); + assert_eq!(encoded, vec![0x02, 0x01, 0x00]); + } + + #[test] + fn test_der_integer_negative() { + // -1 should encode as 02 01 ff + let encoded = der_integer(-1); + assert_eq!(encoded, vec![0x02, 0x01, 0xff]); + } + + #[test] + fn test_der_integer_128() { + // 128 needs leading 0x00: 02 02 00 80 + let encoded = der_integer(128); + assert_eq!(encoded, vec![0x02, 0x02, 0x00, 0x80]); + } + + #[test] + fn test_der_integer_256() { + // 256 = 0x0100: 02 02 01 00 + let encoded = der_integer(256); + assert_eq!(encoded, vec![0x02, 0x02, 0x01, 0x00]); + } + + #[test] + fn test_der_integer_large_positive() { + // 65536 = 0x10000: 02 03 01 00 00 + let encoded = der_integer(65536); + assert_eq!(encoded, vec![0x02, 0x03, 0x01, 0x00, 0x00]); + } + + #[test] + fn test_der_integer_u32_max() { + // u32::MAX = 0xFFFFFFFF: needs 02 05 00 ff ff ff ff + let encoded = der_integer_u32(u32::MAX); + assert_eq!(encoded, vec![0x02, 0x05, 0x00, 0xff, 0xff, 0xff, 0xff]); + } + + #[test] + fn test_der_generalized_time() { + let encoded = der_generalized_time("20260408120000Z"); + assert_eq!(encoded[0], TAG_GENERALIZED_TIME); + assert_eq!(encoded[1], 15); // length + assert_eq!(&encoded[2..], b"20260408120000Z"); + } + + #[test] + fn test_der_bit_string_32bit_flags() { + let encoded = der_bit_string(&[0x40, 0x81, 0x00, 0x10], 0); + assert_eq!(encoded[0], TAG_BIT_STRING); + assert_eq!(encoded[1], 5); // 1 unused-bits byte + 4 bytes + assert_eq!(encoded[2], 0); // 0 unused bits + assert_eq!(&encoded[3..], &[0x40, 0x81, 0x00, 0x10]); + } + + #[test] + fn test_der_general_string() { + let encoded = der_general_string("EXAMPLE.COM"); + assert_eq!(encoded[0], TAG_GENERAL_STRING); + assert_eq!(encoded[1], 11); + assert_eq!(&encoded[2..], b"EXAMPLE.COM"); + } + + // DER primitive tests (der_length, der_tlv, parse_der_length, parse_der_tlv) + // live in `auth::der::tests`. + + // ----------------------------------------------------------------------- + // Parse helper tests (Kerberos-specific) + // ----------------------------------------------------------------------- + + #[test] + fn test_parse_der_integer_roundtrip() { + for val in [0, 1, 5, 127, 128, 255, 256, 1000, -1, -128, -129] { + let encoded = der_integer(val); + let parsed = parse_der_integer(&encoded).unwrap(); + assert_eq!(parsed, val, "roundtrip failed for {val}"); + } + } + + #[test] + fn test_parse_der_octet_string_roundtrip() { + let data = vec![0x01, 0x02, 0x03, 0xff]; + let encoded = der_octet_string(&data); + let parsed = parse_der_octet_string(&encoded).unwrap(); + assert_eq!(parsed, data); + } + + #[test] + fn test_parse_der_general_string_roundtrip() { + let encoded = der_general_string("EXAMPLE.COM"); + let parsed = parse_der_general_string(&encoded).unwrap(); + assert_eq!(parsed, "EXAMPLE.COM"); + } + + #[test] + fn test_parse_der_generalized_time_roundtrip() { + let encoded = der_generalized_time("20260408120000Z"); + let parsed = parse_der_generalized_time(&encoded).unwrap(); + assert_eq!(parsed, "20260408120000Z"); + } + + #[test] + fn test_parse_der_bit_string_roundtrip() { + let bits = vec![0x40, 0x81, 0x00, 0x10]; + let encoded = der_bit_string(&bits, 0); + let (parsed_bits, unused) = parse_der_bit_string(&encoded).unwrap(); + assert_eq!(parsed_bits, bits); + assert_eq!(unused, 0); + } + + // ----------------------------------------------------------------------- + // Encoding tests + // ----------------------------------------------------------------------- + + #[test] + fn test_encode_as_req_application_tag() { + let cname = PrincipalName { + name_type: 1, + name_string: vec!["user".to_string()], + }; + let sname = PrincipalName { + name_type: 2, + name_string: vec!["krbtgt".to_string(), "EXAMPLE.COM".to_string()], + }; + let encoded = encode_as_req( + &cname, + "EXAMPLE.COM", + &sname, + 12345, + &[EncryptionType::Aes256CtsHmacSha196], + &[], + ); + // APPLICATION [10] = 0x6a + assert_eq!(encoded[0], 0x6a, "AS-REQ must start with APPLICATION [10]"); + } + + #[test] + fn test_encode_as_req_contains_pvno_and_msg_type() { + let cname = PrincipalName { + name_type: 1, + name_string: vec!["user".to_string()], + }; + let sname = PrincipalName { + name_type: 2, + name_string: vec!["krbtgt".to_string(), "EXAMPLE.COM".to_string()], + }; + let encoded = encode_as_req( + &cname, + "EXAMPLE.COM", + &sname, + 12345, + &[EncryptionType::Aes256CtsHmacSha196], + &[], + ); + // Should contain pvno=5 somewhere: a1 03 02 01 05 + let pvno_pattern = [0xa1, 0x03, 0x02, 0x01, 0x05]; + assert!( + contains_subsequence(&encoded, &pvno_pattern), + "AS-REQ must contain pvno=5" + ); + // Should contain msg-type=10: a2 03 02 01 0a + let msg_type_pattern = [0xa2, 0x03, 0x02, 0x01, 0x0a]; + assert!( + contains_subsequence(&encoded, &msg_type_pattern), + "AS-REQ must contain msg-type=10" + ); + } + + #[test] + fn test_encode_tgs_req_application_tag() { + let sname = PrincipalName { + name_type: 2, + name_string: vec!["cifs".to_string(), "server.example.com".to_string()], + }; + let fake_ap_req = vec![0x6e, 0x03, 0x01, 0x02, 0x03]; + let encoded = encode_tgs_req( + "EXAMPLE.COM", + &sname, + 54321, + &[EncryptionType::Aes256CtsHmacSha196], + &fake_ap_req, + ); + // APPLICATION [12] = 0x6c + assert_eq!(encoded[0], 0x6c, "TGS-REQ must start with APPLICATION [12]"); + } + + #[test] + fn test_encode_tgs_req_contains_msg_type_12() { + let sname = PrincipalName { + name_type: 2, + name_string: vec!["cifs".to_string(), "server.example.com".to_string()], + }; + let fake_ap_req = vec![0x6e, 0x03, 0x01, 0x02, 0x03]; + let encoded = encode_tgs_req( + "EXAMPLE.COM", + &sname, + 54321, + &[EncryptionType::Aes256CtsHmacSha196], + &fake_ap_req, + ); + // msg-type=12: a2 03 02 01 0c + let msg_type_pattern = [0xa2, 0x03, 0x02, 0x01, 0x0c]; + assert!( + contains_subsequence(&encoded, &msg_type_pattern), + "TGS-REQ must contain msg-type=12" + ); + } + + #[test] + fn test_encode_ap_req_application_tag() { + let ticket = make_test_ticket(); + let auth = EncryptedData { + etype: 18, + kvno: None, + cipher: vec![0xaa, 0xbb], + }; + let encoded = encode_ap_req(&ticket, &auth, false); + // APPLICATION [14] = 0x6e + assert_eq!(encoded[0], 0x6e, "AP-REQ must start with APPLICATION [14]"); + } + + #[test] + fn test_encode_ap_req_contains_pvno_and_msg_type() { + let ticket = make_test_ticket(); + let auth = EncryptedData { + etype: 18, + kvno: None, + cipher: vec![0xaa, 0xbb], + }; + let encoded = encode_ap_req(&ticket, &auth, false); + // pvno=5: a0 03 02 01 05 + let pvno_pattern = [0xa0, 0x03, 0x02, 0x01, 0x05]; + assert!( + contains_subsequence(&encoded, &pvno_pattern), + "AP-REQ must contain pvno=5" + ); + // msg-type=14: a1 03 02 01 0e + let msg_type_pattern = [0xa1, 0x03, 0x02, 0x01, 0x0e]; + assert!( + contains_subsequence(&encoded, &msg_type_pattern), + "AP-REQ must contain msg-type=14" + ); + } + + #[test] + fn test_encode_authenticator_application_tag() { + let cname = PrincipalName { + name_type: 1, + name_string: vec!["user".to_string()], + }; + let encoded = encode_authenticator( + "EXAMPLE.COM", + &cname, + "20260408120000Z", + 123456, + None, + None, + None, + ); + // APPLICATION [2] = 0x62 + assert_eq!( + encoded[0], 0x62, + "Authenticator must start with APPLICATION [2]" + ); + } + + #[test] + fn test_encode_authenticator_with_subkey() { + let cname = PrincipalName { + name_type: 1, + name_string: vec!["user".to_string()], + }; + let subkey_value = vec![0x01; 32]; + let encoded = encode_authenticator( + "EXAMPLE.COM", + &cname, + "20260408120000Z", + 0, + Some((&subkey_value, 18)), + Some(42), + None, + ); + assert_eq!(encoded[0], 0x62); + // Should contain the subkey context tag [6] = 0xa6 + assert!( + contains_subsequence(&encoded, &[0xa6]), + "Authenticator with subkey must contain [6]" + ); + // Should contain seq-number context tag [7] = 0xa7 + assert!( + contains_subsequence(&encoded, &[0xa7]), + "Authenticator with seq-number must contain [7]" + ); + } + + #[test] + fn test_encode_pa_enc_timestamp() { + let encoded = encode_pa_enc_timestamp("20260408120000Z", 123456); + // Should be a SEQUENCE starting with 0x30 + assert_eq!(encoded[0], TAG_SEQUENCE); + // Should contain [0] with GeneralizedTime + assert!(contains_subsequence(&encoded, &[0xa0])); + // Should contain [1] with INTEGER + assert!(contains_subsequence(&encoded, &[0xa1])); + } + + // ----------------------------------------------------------------------- + // Parsing tests + // ----------------------------------------------------------------------- + + #[test] + fn test_parse_kdc_rep_as_rep() { + let rep_bytes = build_test_kdc_rep(11); + let rep = parse_kdc_rep(&rep_bytes).unwrap(); + assert_eq!(rep.msg_type, 11); + assert_eq!(rep.crealm, "EXAMPLE.COM"); + assert_eq!(rep.cname.name_type, 1); + assert_eq!(rep.cname.name_string, vec!["user"]); + assert_eq!(rep.ticket.realm, "EXAMPLE.COM"); + assert_eq!(rep.enc_part.etype, 18); + } + + #[test] + fn test_parse_kdc_rep_tgs_rep() { + let rep_bytes = build_test_kdc_rep(13); + let rep = parse_kdc_rep(&rep_bytes).unwrap(); + assert_eq!(rep.msg_type, 13); + } + + #[test] + fn test_parse_krb_error() { + let err_bytes = build_test_krb_error(25); // KDC_ERR_PREAUTH_REQUIRED + let err = parse_krb_error(&err_bytes).unwrap(); + assert_eq!(err.error_code, 25); + assert_eq!(err.realm, "EXAMPLE.COM"); + assert_eq!(err.sname.name_type, 2); + assert_eq!(err.sname.name_string, vec!["krbtgt", "EXAMPLE.COM"]); + } + + #[test] + fn test_parse_ticket_roundtrip() { + let ticket = make_test_ticket(); + let encoded = encode_ticket(&ticket); + let parsed = parse_ticket(&encoded).unwrap(); + assert_eq!(parsed.tkt_vno, 5); + assert_eq!(parsed.realm, "EXAMPLE.COM"); + assert_eq!(parsed.sname.name_type, 2); + assert_eq!(parsed.sname.name_string, vec!["krbtgt", "EXAMPLE.COM"]); + assert_eq!(parsed.enc_part.etype, 18); + assert_eq!(parsed.enc_part.cipher, vec![0xde, 0xad, 0xbe, 0xef]); + } + + #[test] + fn test_parse_enc_kdc_rep_part() { + let part_bytes = build_test_enc_kdc_rep_part(); + let part = parse_enc_kdc_rep_part(&part_bytes).unwrap(); + assert_eq!(part.key.keytype, 18); + assert_eq!(part.key.keyvalue, vec![0x01; 32]); + assert_eq!(part.nonce, 12345); + assert_eq!(part.authtime, "20260408120000Z"); + assert_eq!(part.endtime, "20260409120000Z"); + assert_eq!(part.srealm, "EXAMPLE.COM"); + assert_eq!(part.sname.name_type, 2); + } + + // ----------------------------------------------------------------------- + // Roundtrip tests + // ----------------------------------------------------------------------- + + #[test] + fn test_principal_name_roundtrip() { + let name = PrincipalName { + name_type: 2, + name_string: vec!["cifs".to_string(), "server.example.com".to_string()], + }; + let encoded = encode_principal_name(&name); + let parsed = parse_principal_name(&encoded).unwrap(); + assert_eq!(parsed, name); + } + + #[test] + fn test_encrypted_data_roundtrip() { + let ed = EncryptedData { + etype: 17, + kvno: Some(3), + cipher: vec![0x01, 0x02, 0x03, 0x04], + }; + let encoded = encode_encrypted_data(&ed); + let parsed = parse_encrypted_data(&encoded).unwrap(); + assert_eq!(parsed, ed); + } + + #[test] + fn test_encrypted_data_no_kvno_roundtrip() { + let ed = EncryptedData { + etype: 23, + kvno: None, + cipher: vec![0xff; 16], + }; + let encoded = encode_encrypted_data(&ed); + let parsed = parse_encrypted_data(&encoded).unwrap(); + assert_eq!(parsed, ed); + } + + #[test] + fn test_ticket_roundtrip() { + let ticket = make_test_ticket(); + let encoded = encode_ticket(&ticket); + let parsed = parse_ticket(&encoded).unwrap(); + // Compare fields (raw_bytes differs: None vs Some). + assert_eq!(parsed.tkt_vno, ticket.tkt_vno); + assert_eq!(parsed.realm, ticket.realm); + assert_eq!(parsed.sname, ticket.sname); + assert_eq!(parsed.enc_part, ticket.enc_part); + // Parsed ticket should have raw_bytes. + assert!(parsed.raw_bytes.is_some()); + assert_eq!(parsed.raw_bytes.as_ref().unwrap(), &encoded); + } + + // ----------------------------------------------------------------------- + // Test helpers + // ----------------------------------------------------------------------- + + fn contains_subsequence(haystack: &[u8], needle: &[u8]) -> bool { + haystack + .windows(needle.len()) + .any(|window| window == needle) + } + + fn make_test_ticket() -> Ticket { + Ticket { + tkt_vno: 5, + realm: "EXAMPLE.COM".to_string(), + sname: PrincipalName { + name_type: 2, + name_string: vec!["krbtgt".to_string(), "EXAMPLE.COM".to_string()], + }, + enc_part: EncryptedData { + etype: 18, + kvno: Some(2), + cipher: vec![0xde, 0xad, 0xbe, 0xef], + }, + raw_bytes: None, + } + } + + /// Build a test KDC-REP (AS-REP or TGS-REP) in DER. + fn build_test_kdc_rep(msg_type_val: i32) -> Vec { + // RFC 4120 section 5.4.2: KDC-REP fields start at [0] + let pvno = der_context(0, &der_integer(5)); + let msg_type = der_context(1, &der_integer(msg_type_val)); + let crealm = der_context(3, &der_general_string("EXAMPLE.COM")); + + let cname_inner = encode_principal_name(&PrincipalName { + name_type: 1, + name_string: vec!["user".to_string()], + }); + let cname = der_context(4, &cname_inner); + + let ticket = der_context(5, &encode_ticket(&make_test_ticket())); + + let enc_part_inner = encode_encrypted_data(&EncryptedData { + etype: 18, + kvno: Some(1), + cipher: vec![0xca, 0xfe], + }); + let enc_part = der_context(6, &enc_part_inner); + + let seq = der_sequence(&[&pvno, &msg_type, &crealm, &cname, &ticket, &enc_part]); + + let app_tag = match msg_type_val { + 11 => 11, // AS-REP + 13 => 13, // TGS-REP + _ => panic!("unexpected msg_type_val"), + }; + der_application(app_tag, &seq) + } + + /// Build a test KRB-ERROR in DER. + fn build_test_krb_error(error_code_val: i32) -> Vec { + let pvno = der_context(0, &der_integer(5)); + let msg_type = der_context(1, &der_integer(30)); + let stime = der_context(4, &der_generalized_time("20260408120000Z")); + let susec = der_context(5, &der_integer(0)); + let error_code = der_context(6, &der_integer(error_code_val)); + let realm = der_context(9, &der_general_string("EXAMPLE.COM")); + let sname_inner = encode_principal_name(&PrincipalName { + name_type: 2, + name_string: vec!["krbtgt".to_string(), "EXAMPLE.COM".to_string()], + }); + let sname = der_context(10, &sname_inner); + + let seq = der_sequence(&[ + &pvno, + &msg_type, + &stime, + &susec, + &error_code, + &realm, + &sname, + ]); + // APPLICATION [30] = 0x7e + der_application(30, &seq) + } + + /// Build a test EncKDCRepPart in DER. + fn build_test_enc_kdc_rep_part() -> Vec { + // key [0]: EncryptionKey { keytype=18, keyvalue=0x01*32 } + let kt = der_context(0, &der_integer(18)); + let kv = der_context(1, &der_octet_string(&[0x01; 32])); + let key_seq = der_sequence(&[&kt, &kv]); + let key = der_context(0, &key_seq); + + // last-req [1]: minimal (empty sequence) + let last_req = der_context(1, &der_sequence(&[])); + + // nonce [2] + let nonce = der_context(2, &der_integer_u32(12345)); + + // flags [4]: BIT STRING + let flags = der_context(4, &der_bit_string(&[0x50, 0x80, 0x00, 0x00], 0)); + + // authtime [5] + let authtime = der_context(5, &der_generalized_time("20260408120000Z")); + + // endtime [7] + let endtime = der_context(7, &der_generalized_time("20260409120000Z")); + + // srealm [9] + let srealm = der_context(9, &der_general_string("EXAMPLE.COM")); + + // sname [10] + let sname_inner = encode_principal_name(&PrincipalName { + name_type: 2, + name_string: vec!["krbtgt".to_string(), "EXAMPLE.COM".to_string()], + }); + let sname = der_context(10, &sname_inner); + + let seq = der_sequence(&[ + &key, &last_req, &nonce, &flags, &authtime, &endtime, &srealm, &sname, + ]); + + // Wrap in APPLICATION [25] (EncASRepPart) + der_application(25, &seq) + } +} diff --git a/vendor/smb2/src/auth/kerberos/mod.rs b/vendor/smb2/src/auth/kerberos/mod.rs new file mode 100644 index 0000000..7140e30 --- /dev/null +++ b/vendor/smb2/src/auth/kerberos/mod.rs @@ -0,0 +1,21 @@ +//! Kerberos authentication support. +//! +//! Implements the cryptographic operations needed for Kerberos authentication +//! (etypes 17, 18, 23): string-to-key, key derivation, AES-CTS encryption, +//! RC4-HMAC encryption, and checksum computation. +//! +//! The [`KerberosAuthenticator`] wires all building blocks together into +//! a full Kerberos authentication flow: AS exchange, TGS exchange, and +//! AP-REQ construction for SMB2 SESSION_SETUP. +//! +//! The [`ccache`] module supports reading MIT Kerberos credential caches, +//! enabling authentication from cached TGTs or service tickets (for example, +//! from `kinit`) without requiring a password. + +pub mod ccache; +pub mod crypto; +pub mod kdc; +pub mod messages; + +mod authenticator; +pub use authenticator::{KerberosAuthenticator, KerberosCredentials}; diff --git a/vendor/smb2/src/auth/mod.rs b/vendor/smb2/src/auth/mod.rs new file mode 100644 index 0000000..159f797 --- /dev/null +++ b/vendor/smb2/src/auth/mod.rs @@ -0,0 +1,15 @@ +//! Authentication mechanisms for SMB2. +//! +//! Supports NTLM authentication (MS-NLMP) and Kerberos authentication +//! (RFC 4120, MS-KILE). +//! +//! Most users don't need this module directly -- [`SmbClient`](crate::SmbClient) +//! handles authentication during [`connect`](crate::connect). + +pub(crate) mod der; +pub mod kerberos; +pub mod ntlm; +pub mod spnego; + +pub use kerberos::{KerberosAuthenticator, KerberosCredentials}; +pub use ntlm::{NtlmAuthenticator, NtlmCredentials}; diff --git a/vendor/smb2/src/auth/ntlm.rs b/vendor/smb2/src/auth/ntlm.rs new file mode 100644 index 0000000..6355a71 --- /dev/null +++ b/vendor/smb2/src/auth/ntlm.rs @@ -0,0 +1,1410 @@ +//! NTLM authentication (MS-NLMP). +//! +//! Implements the 3-message NTLM exchange: +//! 1. Client sends NEGOTIATE_MESSAGE (Type 1) +//! 2. Server sends CHALLENGE_MESSAGE (Type 2) +//! 3. Client sends AUTHENTICATE_MESSAGE (Type 3) +//! +//! Only NTLMv2 is supported. NTLMv1 is insecure and not implemented. + +use log::{debug, trace}; + +use crate::Error; +use digest::{Digest, KeyInit}; +use hmac::{Hmac, Mac}; + +type HmacMd5 = Hmac; + +// --------------------------------------------------------------------------- +// NTLM signature and message types +// --------------------------------------------------------------------------- + +/// The 8-byte NTLM signature: `"NTLMSSP\0"`. +const NTLM_SIGNATURE: &[u8; 8] = b"NTLMSSP\0"; + +/// NEGOTIATE_MESSAGE type. +const MSG_TYPE_NEGOTIATE: u32 = 0x0000_0001; +/// CHALLENGE_MESSAGE type. +const MSG_TYPE_CHALLENGE: u32 = 0x0000_0002; +/// AUTHENTICATE_MESSAGE type. +const MSG_TYPE_AUTHENTICATE: u32 = 0x0000_0003; + +// --------------------------------------------------------------------------- +// Negotiate flags (section 2.2.2.5) +// --------------------------------------------------------------------------- + +const NTLMSSP_NEGOTIATE_UNICODE: u32 = 0x0000_0001; +const NTLMSSP_REQUEST_TARGET: u32 = 0x0000_0004; +const NTLMSSP_NEGOTIATE_SIGN: u32 = 0x0000_0010; +const NTLMSSP_NEGOTIATE_SEAL: u32 = 0x0000_0020; +const NTLMSSP_NEGOTIATE_NTLM: u32 = 0x0000_0200; +const NTLMSSP_NEGOTIATE_ALWAYS_SIGN: u32 = 0x0000_8000; +const NTLMSSP_NEGOTIATE_EXTENDED_SESSIONSECURITY: u32 = 0x0008_0000; +const NTLMSSP_NEGOTIATE_TARGET_INFO: u32 = 0x0080_0000; +const NTLMSSP_NEGOTIATE_128: u32 = 0x2000_0000; +const NTLMSSP_NEGOTIATE_KEY_EXCH: u32 = 0x4000_0000; +const NTLMSSP_NEGOTIATE_56: u32 = 0x8000_0000; + +/// Default flags the client sends in the NEGOTIATE_MESSAGE. +const DEFAULT_NEGOTIATE_FLAGS: u32 = NTLMSSP_NEGOTIATE_UNICODE + | NTLMSSP_REQUEST_TARGET + | NTLMSSP_NEGOTIATE_NTLM + | NTLMSSP_NEGOTIATE_ALWAYS_SIGN + | NTLMSSP_NEGOTIATE_EXTENDED_SESSIONSECURITY + | NTLMSSP_NEGOTIATE_TARGET_INFO + | NTLMSSP_NEGOTIATE_128 + | NTLMSSP_NEGOTIATE_KEY_EXCH + | NTLMSSP_NEGOTIATE_56 + | NTLMSSP_NEGOTIATE_SIGN + | NTLMSSP_NEGOTIATE_SEAL; + +// --------------------------------------------------------------------------- +// AV_PAIR types (section 2.2.2.1) +// --------------------------------------------------------------------------- + +/// End of AV_PAIR list. +const MSV_AV_EOL: u16 = 0x0000; +/// NetBIOS computer name. +#[cfg(test)] +const MSV_AV_NB_COMPUTER_NAME: u16 = 0x0001; +/// NetBIOS domain name. +#[cfg(test)] +const MSV_AV_NB_DOMAIN_NAME: u16 = 0x0002; +/// DNS computer name. +#[allow(dead_code)] +const MSV_AV_DNS_COMPUTER_NAME: u16 = 0x0003; +/// DNS domain name. +#[allow(dead_code)] +const MSV_AV_DNS_DOMAIN_NAME: u16 = 0x0004; +/// Flags. +const MSV_AV_FLAGS: u16 = 0x0006; +/// Timestamp (FILETIME). +const MSV_AV_TIMESTAMP: u16 = 0x0007; +/// Target name (SPN). +#[allow(dead_code)] +const MSV_AV_TARGET_NAME: u16 = 0x0009; + +// --------------------------------------------------------------------------- +// Public types +// --------------------------------------------------------------------------- + +/// Credentials for NTLM authentication. +pub struct NtlmCredentials { + /// The username. + pub username: String, + /// The password. + pub password: String, + /// The domain (can be empty for local accounts). + pub domain: String, +} + +/// Stateful NTLM authenticator that manages the 3-message exchange. +/// +/// Usage: +/// 1. Call [`negotiate()`](Self::negotiate) to get the NEGOTIATE_MESSAGE bytes. +/// 2. Send those bytes in SESSION_SETUP, receive the server's CHALLENGE_MESSAGE. +/// 3. Call [`authenticate()`](Self::authenticate) with the challenge bytes to +/// get the AUTHENTICATE_MESSAGE bytes. +/// 4. After authenticate succeeds, [`session_key()`](Self::session_key) returns +/// the exported session key for signing/encryption. +pub struct NtlmAuthenticator { + credentials: NtlmCredentials, + /// Retained for MIC computation. + negotiate_bytes: Option>, + /// Retained for MIC computation. + challenge_bytes: Option>, + /// The exported session key, available after authenticate(). + session_key: Option>, + /// Override for the client challenge (for testing with known values). + #[cfg(test)] + test_client_challenge: Option<[u8; 8]>, + /// Override for the random session key (for testing with known values). + #[cfg(test)] + test_random_session_key: Option<[u8; 16]>, + /// Override for the timestamp (for testing with known values). + #[cfg(test)] + test_timestamp: Option, +} + +impl NtlmAuthenticator { + /// Create a new authenticator with the given credentials. + pub fn new(credentials: NtlmCredentials) -> Self { + Self { + credentials, + negotiate_bytes: None, + challenge_bytes: None, + session_key: None, + #[cfg(test)] + test_client_challenge: None, + #[cfg(test)] + test_random_session_key: None, + #[cfg(test)] + test_timestamp: None, + } + } + + /// Build the NEGOTIATE_MESSAGE (Type 1). + /// + /// Returns the raw bytes to embed in SESSION_SETUP's security buffer. + pub fn negotiate(&mut self) -> Vec { + let mut buf = Vec::with_capacity(32); + + // Signature (8 bytes) + buf.extend_from_slice(NTLM_SIGNATURE); + // MessageType (4 bytes) + buf.extend_from_slice(&MSG_TYPE_NEGOTIATE.to_le_bytes()); + // NegotiateFlags (4 bytes) + buf.extend_from_slice(&DEFAULT_NEGOTIATE_FLAGS.to_le_bytes()); + // DomainNameFields: Len(2) + MaxLen(2) + Offset(4) = all zeros (no domain supplied in negotiate) + buf.extend_from_slice(&[0u8; 8]); + // WorkstationFields: Len(2) + MaxLen(2) + Offset(4) = all zeros + buf.extend_from_slice(&[0u8; 8]); + + debug!("ntlm: negotiate message built, len={}", buf.len()); + self.negotiate_bytes = Some(buf.clone()); + buf + } + + /// Process the CHALLENGE_MESSAGE (Type 2) from the server and build the + /// AUTHENTICATE_MESSAGE (Type 3). + /// + /// Returns the raw bytes for the next SESSION_SETUP. + pub fn authenticate(&mut self, challenge_bytes: &[u8]) -> Result, Error> { + debug!("ntlm: processing challenge, len={}", challenge_bytes.len()); + self.challenge_bytes = Some(challenge_bytes.to_vec()); + + // Parse the CHALLENGE_MESSAGE + let challenge = parse_challenge_message(challenge_bytes)?; + trace!( + "ntlm: challenge flags=0x{:08x}, target_info_len={}", + challenge.negotiate_flags, + challenge.target_info.len() + ); + + // Compute NTLMv2 response + let nt_hash = compute_nt_hash(&self.credentials.password); + let ntlmv2_hash = compute_ntlmv2_hash( + &nt_hash, + &self.credentials.username, + &self.credentials.domain, + ); + + // Get timestamp from challenge TargetInfo, or use current time + let timestamp = self.get_timestamp(&challenge); + + // Get client challenge + let client_challenge = self.get_client_challenge(); + + // Check if MsvAvTimestamp is present (determines if MIC is required) + let has_timestamp = find_av_pair(&challenge.target_info, MSV_AV_TIMESTAMP).is_some(); + + // Build the modified target info for the authenticate message + let auth_target_info = build_auth_target_info(&challenge.target_info, has_timestamp); + + // Build temp blob (section 3.3.2) + let temp = build_temp(timestamp, &client_challenge, &auth_target_info); + + // NTProofStr = HMAC_MD5(NTLMv2_Hash, server_challenge + temp) + let nt_proof_str = { + let mut mac = + HmacMd5::new_from_slice(&ntlmv2_hash).expect("HMAC accepts any key length"); + mac.update(&challenge.server_challenge); + mac.update(&temp); + mac.finalize().into_bytes().to_vec() + }; + + // NtChallengeResponse = NTProofStr + temp + let mut nt_challenge_response = nt_proof_str.clone(); + nt_challenge_response.extend_from_slice(&temp); + + // SessionBaseKey = HMAC_MD5(NTLMv2_Hash, NTProofStr) + let session_base_key = { + let mut mac = + HmacMd5::new_from_slice(&ntlmv2_hash).expect("HMAC accepts any key length"); + mac.update(&nt_proof_str); + mac.finalize().into_bytes().to_vec() + }; + + // Key exchange: if KEY_EXCH is negotiated, generate random session key + let negotiate_flags = challenge.negotiate_flags; + let key_exch = (negotiate_flags & NTLMSSP_NEGOTIATE_KEY_EXCH) != 0 + && ((negotiate_flags & NTLMSSP_NEGOTIATE_SIGN) != 0 + || (negotiate_flags & NTLMSSP_NEGOTIATE_SEAL) != 0); + + let (exported_session_key, encrypted_random_session_key) = if key_exch { + let random_key = self.get_random_session_key(); + let encrypted = rc4_encrypt(&session_base_key, &random_key); + (random_key.to_vec(), encrypted) + } else { + (session_base_key.clone(), Vec::new()) + }; + + // LmChallengeResponse: if timestamp present, send Z(24); otherwise compute LMv2 + let lm_challenge_response = if has_timestamp { + vec![0u8; 24] + } else { + // LMv2: HMAC_MD5(ntlmv2_hash, server_challenge + client_challenge) + client_challenge + let mut mac = + HmacMd5::new_from_slice(&ntlmv2_hash).expect("HMAC accepts any key length"); + mac.update(&challenge.server_challenge); + mac.update(&client_challenge); + let proof = mac.finalize().into_bytes(); + let mut resp = proof.to_vec(); + resp.extend_from_slice(&client_challenge); + resp + }; + + // Build the AUTHENTICATE_MESSAGE + let auth_msg = build_authenticate_message( + negotiate_flags, + &self.credentials.domain, + &self.credentials.username, + &lm_challenge_response, + &nt_challenge_response, + &encrypted_random_session_key, + has_timestamp, + ); + + // If MIC is required, compute it and patch it in + let final_msg = if has_timestamp { + let negotiate_bytes = self.negotiate_bytes.as_ref().ok_or_else(|| { + Error::invalid_data("negotiate() must be called before authenticate()") + })?; + + let mic = compute_mic( + &exported_session_key, + negotiate_bytes, + challenge_bytes, + &auth_msg, + ); + + let mut patched = auth_msg; + // MIC is at offset 72 (after signature(8) + type(4) + 6 fields * 8 + flags(4) + version(8)) + // = 8 + 4 + 48 + 4 + 8 = 72 + patched[72..88].copy_from_slice(&mic); + patched + } else { + auth_msg + }; + + self.session_key = Some(exported_session_key); + debug!( + "ntlm: authenticate message built, len={}, mic={}", + final_msg.len(), + has_timestamp + ); + Ok(final_msg) + } + + /// Get the session key (available after authenticate()). + pub fn session_key(&self) -> Option<&[u8]> { + self.session_key.as_deref() + } + + /// Get the timestamp to use. If the challenge contains MsvAvTimestamp, use it. + /// Otherwise use current time (or test override). + fn get_timestamp(&self, challenge: &ChallengeMessage) -> u64 { + #[cfg(test)] + if let Some(ts) = self.test_timestamp { + return ts; + } + + if let Some(ts_bytes) = find_av_pair(&challenge.target_info, MSV_AV_TIMESTAMP) { + if ts_bytes.len() == 8 { + return u64::from_le_bytes(ts_bytes.try_into().unwrap()); + } + } + + // Current time as Windows FILETIME (100-ns intervals since 1601-01-01) + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default(); + // UNIX epoch is 11644473600 seconds after FILETIME epoch + (now.as_secs() + 11_644_473_600) * 10_000_000 + u64::from(now.subsec_nanos()) / 100 + } + + /// Get the client challenge (random 8 bytes, or test override). + fn get_client_challenge(&self) -> [u8; 8] { + #[cfg(test)] + if let Some(cc) = self.test_client_challenge { + return cc; + } + + let mut challenge = [0u8; 8]; + getrandom::fill(&mut challenge).expect("system RNG failed"); + challenge + } + + /// Get the random session key (random 16 bytes, or test override). + /// + /// This MUST be cryptographically secure -- the ExportedSessionKey + /// is used for all subsequent signing and encryption. A predictable + /// key would let an attacker forge messages and decrypt traffic. + fn get_random_session_key(&self) -> [u8; 16] { + #[cfg(test)] + if let Some(rsk) = self.test_random_session_key { + return rsk; + } + + let mut key = [0u8; 16]; + getrandom::fill(&mut key).expect("system RNG failed"); + key + } +} + +// --------------------------------------------------------------------------- +// Parsed CHALLENGE_MESSAGE +// --------------------------------------------------------------------------- + +/// Parsed fields from a CHALLENGE_MESSAGE (Type 2). +struct ChallengeMessage { + /// The server's negotiate flags. + negotiate_flags: u32, + /// The 8-byte server challenge. + server_challenge: [u8; 8], + /// Raw TargetInfo bytes (sequence of AV_PAIRs). + target_info: Vec, +} + +/// Parse a CHALLENGE_MESSAGE from raw bytes. +fn parse_challenge_message(data: &[u8]) -> Result { + if data.len() < 32 { + return Err(Error::invalid_data("CHALLENGE_MESSAGE too short")); + } + + // Verify signature + if &data[0..8] != NTLM_SIGNATURE { + return Err(Error::invalid_data( + "invalid NTLM signature in CHALLENGE_MESSAGE", + )); + } + + // Verify message type + let msg_type = u32::from_le_bytes(data[8..12].try_into().unwrap()); + if msg_type != MSG_TYPE_CHALLENGE { + return Err(Error::invalid_data(format!( + "expected CHALLENGE_MESSAGE type 2, got {}", + msg_type + ))); + } + + // TargetNameFields at offset 12: Len(2) + MaxLen(2) + Offset(4) + // We don't need the target name for authentication, but we parse past it. + + // NegotiateFlags at offset 20 + let negotiate_flags = u32::from_le_bytes(data[20..24].try_into().unwrap()); + + // ServerChallenge at offset 24 (8 bytes) + let mut server_challenge = [0u8; 8]; + server_challenge.copy_from_slice(&data[24..32]); + + // Reserved at offset 32 (8 bytes) - skip + + // TargetInfoFields at offset 40: Len(2) + MaxLen(2) + Offset(4) + let target_info = if data.len() >= 48 { + let ti_len = u16::from_le_bytes(data[40..42].try_into().unwrap()) as usize; + let ti_offset = u32::from_le_bytes(data[44..48].try_into().unwrap()) as usize; + if ti_len > 0 && ti_offset + ti_len <= data.len() { + data[ti_offset..ti_offset + ti_len].to_vec() + } else { + Vec::new() + } + } else { + Vec::new() + }; + + Ok(ChallengeMessage { + negotiate_flags, + server_challenge, + target_info, + }) +} + +// --------------------------------------------------------------------------- +// AV_PAIR parsing and building +// --------------------------------------------------------------------------- + +/// Find an AV_PAIR with the given AvId in a TargetInfo byte sequence. +/// Returns the value bytes if found, or None. +fn find_av_pair(target_info: &[u8], av_id: u16) -> Option> { + let mut offset = 0; + while offset + 4 <= target_info.len() { + let id = u16::from_le_bytes(target_info[offset..offset + 2].try_into().unwrap()); + let len = + u16::from_le_bytes(target_info[offset + 2..offset + 4].try_into().unwrap()) as usize; + + if id == av_id { + if offset + 4 + len <= target_info.len() { + return Some(target_info[offset + 4..offset + 4 + len].to_vec()); + } + return None; + } + + if id == MSV_AV_EOL { + break; + } + + offset += 4 + len; + } + None +} + +/// Parse all AV_PAIRs from a TargetInfo byte sequence. +/// Returns a list of (AvId, Value) pairs. +fn parse_av_pairs(target_info: &[u8]) -> Vec<(u16, Vec)> { + let mut pairs = Vec::new(); + let mut offset = 0; + while offset + 4 <= target_info.len() { + let id = u16::from_le_bytes(target_info[offset..offset + 2].try_into().unwrap()); + let len = + u16::from_le_bytes(target_info[offset + 2..offset + 4].try_into().unwrap()) as usize; + + if id == MSV_AV_EOL { + pairs.push((id, Vec::new())); + break; + } + + if offset + 4 + len > target_info.len() { + break; + } + + pairs.push((id, target_info[offset + 4..offset + 4 + len].to_vec())); + offset += 4 + len; + } + pairs +} + +/// Build the TargetInfo for the AUTHENTICATE_MESSAGE. +/// +/// If `has_timestamp` is true, adds MsvAvFlags with bit 0x2 set (MIC present). +/// Removes the trailing MsvAvEOL, adds new pairs, then re-adds MsvAvEOL. +fn build_auth_target_info(challenge_target_info: &[u8], has_timestamp: bool) -> Vec { + let pairs = parse_av_pairs(challenge_target_info); + let mut result = Vec::new(); + + // Copy all existing pairs except MsvAvEOL and MsvAvFlags (we'll re-add flags if needed) + for (id, value) in &pairs { + if *id == MSV_AV_EOL { + continue; + } + if *id == MSV_AV_FLAGS && has_timestamp { + // We'll add our own flags entry + continue; + } + result.extend_from_slice(&id.to_le_bytes()); + result.extend_from_slice(&(value.len() as u16).to_le_bytes()); + result.extend_from_slice(value); + } + + // If MIC is required, add MsvAvFlags with bit 0x2 + if has_timestamp { + // Check if there was an existing flags value to preserve other bits + let existing_flags = pairs + .iter() + .find(|(id, _)| *id == MSV_AV_FLAGS) + .map(|(_, v)| { + if v.len() >= 4 { + u32::from_le_bytes(v[..4].try_into().unwrap()) + } else { + 0 + } + }) + .unwrap_or(0); + let flags = existing_flags | 0x0000_0002; // MIC present + result.extend_from_slice(&MSV_AV_FLAGS.to_le_bytes()); + result.extend_from_slice(&4u16.to_le_bytes()); + result.extend_from_slice(&flags.to_le_bytes()); + } + + // Terminate with MsvAvEOL + result.extend_from_slice(&MSV_AV_EOL.to_le_bytes()); + result.extend_from_slice(&0u16.to_le_bytes()); + + result +} + +// --------------------------------------------------------------------------- +// Crypto helpers +// --------------------------------------------------------------------------- + +/// Compute the NT hash: MD4(UTF-16LE(password)). +fn compute_nt_hash(password: &str) -> Vec { + let unicode_password: Vec = password + .encode_utf16() + .flat_map(|u| u.to_le_bytes()) + .collect(); + let mut hasher = md4::Md4::new(); + hasher.update(&unicode_password); + hasher.finalize().to_vec() +} + +/// Compute the NTLMv2 hash: HMAC_MD5(NT_Hash, uppercase(UTF-16LE(username)) + UTF-16LE(domain)). +fn compute_ntlmv2_hash(nt_hash: &[u8], username: &str, domain: &str) -> Vec { + let user_upper: Vec = username + .to_uppercase() + .encode_utf16() + .flat_map(|u| u.to_le_bytes()) + .collect(); + let domain_unicode: Vec = domain + .encode_utf16() + .flat_map(|u| u.to_le_bytes()) + .collect(); + + let mut mac = HmacMd5::new_from_slice(nt_hash).expect("HMAC accepts any key length"); + mac.update(&user_upper); + mac.update(&domain_unicode); + mac.finalize().into_bytes().to_vec() +} + +/// Build the temp blob for NTLMv2 (section 3.3.2). +/// +/// ```text +/// temp = 0x01 0x01 + Z(6) + Time(8) + ClientChallenge(8) + Z(4) + ServerName + Z(4) +/// ``` +/// +/// Here `ServerName` is the AV_PAIR sequence (target_info for the authenticate message). +fn build_temp(timestamp: u64, client_challenge: &[u8; 8], target_info: &[u8]) -> Vec { + let mut temp = Vec::new(); + temp.push(0x01); // Responserversion + temp.push(0x01); // HiResponserversion + temp.extend_from_slice(&[0u8; 6]); // Z(6) + temp.extend_from_slice(×tamp.to_le_bytes()); // Time (8 bytes) + temp.extend_from_slice(client_challenge); // ClientChallenge (8 bytes) + temp.extend_from_slice(&[0u8; 4]); // Z(4) + temp.extend_from_slice(target_info); // ServerName (AV_PAIRs) + temp.extend_from_slice(&[0u8; 4]); // Z(4) - trailing padding + temp +} + +/// RC4 encryption (symmetric -- encrypt and decrypt are the same operation). +fn rc4_encrypt(key: &[u8], data: &[u8]) -> Vec { + let mut s: Vec = (0..=255).collect(); + let mut j: u8 = 0; + for i in 0..256 { + j = j.wrapping_add(s[i]).wrapping_add(key[i % key.len()]); + s.swap(i, j as usize); + } + let mut i: u8 = 0; + j = 0; + data.iter() + .map(|&byte| { + i = i.wrapping_add(1); + j = j.wrapping_add(s[i as usize]); + s.swap(i as usize, j as usize); + byte ^ s[s[i as usize].wrapping_add(s[j as usize]) as usize] + }) + .collect() +} + +/// Compute the MIC: HMAC_MD5(ExportedSessionKey, negotiate || challenge || authenticate). +fn compute_mic( + exported_session_key: &[u8], + negotiate_bytes: &[u8], + challenge_bytes: &[u8], + authenticate_bytes: &[u8], +) -> Vec { + let mut mac = + HmacMd5::new_from_slice(exported_session_key).expect("HMAC accepts any key length"); + mac.update(negotiate_bytes); + mac.update(challenge_bytes); + mac.update(authenticate_bytes); + mac.finalize().into_bytes().to_vec() +} + +/// Encode a string as UTF-16LE bytes. +fn encode_utf16le(s: &str) -> Vec { + s.encode_utf16().flat_map(|u| u.to_le_bytes()).collect() +} + +// --------------------------------------------------------------------------- +// AUTHENTICATE_MESSAGE construction +// --------------------------------------------------------------------------- + +/// Build the AUTHENTICATE_MESSAGE (Type 3). +/// +/// The MIC field (16 bytes at offset 72) is initially zeroed. +/// The caller must patch it in if MIC is required. +fn build_authenticate_message( + negotiate_flags: u32, + domain: &str, + username: &str, + lm_challenge_response: &[u8], + nt_challenge_response: &[u8], + encrypted_random_session_key: &[u8], + include_mic: bool, +) -> Vec { + let domain_bytes = encode_utf16le(domain); + let user_bytes = encode_utf16le(username); + let workstation_bytes: Vec = Vec::new(); // Empty workstation + + // Fixed header size: + // Signature(8) + MessageType(4) + 6 * Fields(8 each) + NegotiateFlags(4) + Version(8) + // = 8 + 4 + 48 + 4 + 8 = 72 + // + MIC(16) if included = 88 + let header_size = if include_mic { 88 } else { 72 }; + + // Payload offsets (payload starts after the fixed header) + let domain_offset = header_size; + let user_offset = domain_offset + domain_bytes.len(); + let workstation_offset = user_offset + user_bytes.len(); + let lm_offset = workstation_offset + workstation_bytes.len(); + let nt_offset = lm_offset + lm_challenge_response.len(); + let session_key_offset = nt_offset + nt_challenge_response.len(); + + let mut buf = Vec::with_capacity(session_key_offset + encrypted_random_session_key.len()); + + // Signature (8 bytes) + buf.extend_from_slice(NTLM_SIGNATURE); + // MessageType (4 bytes) + buf.extend_from_slice(&MSG_TYPE_AUTHENTICATE.to_le_bytes()); + + // LmChallengeResponseFields (8 bytes) + buf.extend_from_slice(&(lm_challenge_response.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(lm_challenge_response.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(lm_offset as u32).to_le_bytes()); + + // NtChallengeResponseFields (8 bytes) + buf.extend_from_slice(&(nt_challenge_response.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(nt_challenge_response.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(nt_offset as u32).to_le_bytes()); + + // DomainNameFields (8 bytes) + buf.extend_from_slice(&(domain_bytes.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(domain_bytes.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(domain_offset as u32).to_le_bytes()); + + // UserNameFields (8 bytes) + buf.extend_from_slice(&(user_bytes.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(user_bytes.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(user_offset as u32).to_le_bytes()); + + // WorkstationFields (8 bytes) + buf.extend_from_slice(&(workstation_bytes.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(workstation_bytes.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(workstation_offset as u32).to_le_bytes()); + + // EncryptedRandomSessionKeyFields (8 bytes) + buf.extend_from_slice(&(encrypted_random_session_key.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(encrypted_random_session_key.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(session_key_offset as u32).to_le_bytes()); + + // NegotiateFlags (4 bytes) + buf.extend_from_slice(&negotiate_flags.to_le_bytes()); + + // Version (8 bytes) - zeros (no NTLMSSP_NEGOTIATE_VERSION flag set) + buf.extend_from_slice(&[0u8; 8]); + + // MIC (16 bytes) - zeroed, caller patches if needed + if include_mic { + buf.extend_from_slice(&[0u8; 16]); + } + + // Payload + buf.extend_from_slice(&domain_bytes); + buf.extend_from_slice(&user_bytes); + buf.extend_from_slice(&workstation_bytes); + buf.extend_from_slice(lm_challenge_response); + buf.extend_from_slice(nt_challenge_response); + buf.extend_from_slice(encrypted_random_session_key); + + buf +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + // ======================================================================= + // Test vectors from MS-NLMP section 4.2.1 (Common Values) + // ======================================================================= + + const TEST_USER: &str = "User"; + const TEST_PASSWORD: &str = "Password"; + const TEST_DOMAIN: &str = "Domain"; + const TEST_SERVER_CHALLENGE: [u8; 8] = [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef]; + const TEST_CLIENT_CHALLENGE: [u8; 8] = [0xaa; 8]; + const TEST_RANDOM_SESSION_KEY: [u8; 16] = [0x55; 16]; + const TEST_TIME: u64 = 0; // All zeros in the test vectors + + // ======================================================================= + // NT hash tests + // ======================================================================= + + #[test] + fn nt_hash_of_password() { + // From section 4.2.2.1.2: NTOWFv1("Password", ...) = + // a4 f4 9c 40 65 10 bd ca b6 82 4e e7 c3 0f d8 52 + let expected = [ + 0xa4, 0xf4, 0x9c, 0x40, 0x65, 0x10, 0xbd, 0xca, 0xb6, 0x82, 0x4e, 0xe7, 0xc3, 0x0f, + 0xd8, 0x52, + ]; + let hash = compute_nt_hash(TEST_PASSWORD); + assert_eq!(hash, expected); + } + + // ======================================================================= + // NTLMv2 hash tests + // ======================================================================= + + #[test] + fn ntlmv2_hash_computation() { + // From section 4.2.4.1.1: NTOWFv2("Password", "User", "Domain") = + // 0c 86 8a 40 3b fd 7a 93 a3 00 1e f2 2e f0 2e 3f + let expected = [ + 0x0c, 0x86, 0x8a, 0x40, 0x3b, 0xfd, 0x7a, 0x93, 0xa3, 0x00, 0x1e, 0xf2, 0x2e, 0xf0, + 0x2e, 0x3f, + ]; + let nt_hash = compute_nt_hash(TEST_PASSWORD); + let ntlmv2_hash = compute_ntlmv2_hash(&nt_hash, TEST_USER, TEST_DOMAIN); + assert_eq!(ntlmv2_hash, expected); + } + + // ======================================================================= + // NTProofStr and SessionBaseKey tests (section 4.2.4) + // ======================================================================= + + #[test] + fn nt_proof_str_computation() { + // From section 4.2.4.2.2: NTLMv2 Response starts with NTProofStr = + // 68 cd 0a b8 51 e5 1c 96 aa bc 92 7b eb ef 6a 1c + let expected_nt_proof_str = [ + 0x68, 0xcd, 0x0a, 0xb8, 0x51, 0xe5, 0x1c, 0x96, 0xaa, 0xbc, 0x92, 0x7b, 0xeb, 0xef, + 0x6a, 0x1c, + ]; + + let nt_hash = compute_nt_hash(TEST_PASSWORD); + let ntlmv2_hash = compute_ntlmv2_hash(&nt_hash, TEST_USER, TEST_DOMAIN); + + // Build the target info that matches the test vectors: + // AV_PAIR: MsvAvNbDomainName(2) = "Domain" + // AV_PAIR: MsvAvNbComputerName(1) = "Server" + // AV_PAIR: MsvAvEOL(0) + let target_info = build_test_target_info(); + let temp = build_temp(TEST_TIME, &TEST_CLIENT_CHALLENGE, &target_info); + + let mut mac = HmacMd5::new_from_slice(&ntlmv2_hash).expect("HMAC accepts any key length"); + mac.update(&TEST_SERVER_CHALLENGE); + mac.update(&temp); + let nt_proof_str = mac.finalize().into_bytes().to_vec(); + + assert_eq!(nt_proof_str, expected_nt_proof_str); + } + + #[test] + fn session_base_key_computation() { + // From section 4.2.4.1.2: SessionBaseKey = + // 8d e4 0c ca db c1 4a 82 f1 5c b0 ad 0d e9 5c a3 + let expected = [ + 0x8d, 0xe4, 0x0c, 0xca, 0xdb, 0xc1, 0x4a, 0x82, 0xf1, 0x5c, 0xb0, 0xad, 0x0d, 0xe9, + 0x5c, 0xa3, + ]; + + let nt_hash = compute_nt_hash(TEST_PASSWORD); + let ntlmv2_hash = compute_ntlmv2_hash(&nt_hash, TEST_USER, TEST_DOMAIN); + + let target_info = build_test_target_info(); + let temp = build_temp(TEST_TIME, &TEST_CLIENT_CHALLENGE, &target_info); + + // NTProofStr + let mut mac = HmacMd5::new_from_slice(&ntlmv2_hash).expect("HMAC accepts any key length"); + mac.update(&TEST_SERVER_CHALLENGE); + mac.update(&temp); + let nt_proof_str = mac.finalize().into_bytes().to_vec(); + + // SessionBaseKey = HMAC_MD5(ntlmv2_hash, NTProofStr) + let mut mac = HmacMd5::new_from_slice(&ntlmv2_hash).expect("HMAC accepts any key length"); + mac.update(&nt_proof_str); + let session_base_key = mac.finalize().into_bytes().to_vec(); + + assert_eq!(session_base_key, expected); + } + + // ======================================================================= + // RC4 / Encrypted Session Key tests + // ======================================================================= + + #[test] + fn rc4_encrypted_session_key() { + // From section 4.2.4.2.3: RC4(SessionBaseKey, RandomSessionKey) = + // c5 da d2 54 4f c9 79 90 94 ce 1c e9 0b c9 d0 3e + let expected = [ + 0xc5, 0xda, 0xd2, 0x54, 0x4f, 0xc9, 0x79, 0x90, 0x94, 0xce, 0x1c, 0xe9, 0x0b, 0xc9, + 0xd0, 0x3e, + ]; + + let session_base_key = [ + 0x8d, 0xe4, 0x0c, 0xca, 0xdb, 0xc1, 0x4a, 0x82, 0xf1, 0x5c, 0xb0, 0xad, 0x0d, 0xe9, + 0x5c, 0xa3, + ]; + + let result = rc4_encrypt(&session_base_key, &TEST_RANDOM_SESSION_KEY); + assert_eq!(result, expected); + } + + #[test] + fn rc4_roundtrip() { + let key = b"test key"; + let data = b"hello, world!"; + let encrypted = rc4_encrypt(key, data); + let decrypted = rc4_encrypt(key, &encrypted); + assert_eq!(decrypted, data); + } + + // ======================================================================= + // AV_PAIR tests + // ======================================================================= + + #[test] + fn parse_av_pairs_from_target_info() { + let target_info = build_test_target_info(); + let pairs = parse_av_pairs(&target_info); + + assert_eq!(pairs.len(), 3); // NbDomainName, NbComputerName, EOL + + // First pair: MsvAvNbDomainName = "Domain" + assert_eq!(pairs[0].0, MSV_AV_NB_DOMAIN_NAME); + let domain = String::from_utf16( + &pairs[0] + .1 + .chunks_exact(2) + .map(|c| u16::from_le_bytes([c[0], c[1]])) + .collect::>(), + ) + .unwrap(); + assert_eq!(domain, "Domain"); + + // Second pair: MsvAvNbComputerName = "Server" + assert_eq!(pairs[1].0, MSV_AV_NB_COMPUTER_NAME); + + // Last pair: MsvAvEOL + assert_eq!(pairs[2].0, MSV_AV_EOL); + } + + #[test] + fn find_av_pair_present() { + let target_info = build_test_target_info(); + let domain = find_av_pair(&target_info, MSV_AV_NB_DOMAIN_NAME); + assert!(domain.is_some()); + } + + #[test] + fn find_av_pair_absent() { + let target_info = build_test_target_info(); + let timestamp = find_av_pair(&target_info, MSV_AV_TIMESTAMP); + assert!(timestamp.is_none()); + } + + #[test] + fn detect_timestamp_in_target_info() { + // Build a target info with MsvAvTimestamp present + let mut target_info = Vec::new(); + // MsvAvNbDomainName = "Domain" + let domain_bytes = encode_utf16le("Domain"); + target_info.extend_from_slice(&MSV_AV_NB_DOMAIN_NAME.to_le_bytes()); + target_info.extend_from_slice(&(domain_bytes.len() as u16).to_le_bytes()); + target_info.extend_from_slice(&domain_bytes); + // MsvAvTimestamp + target_info.extend_from_slice(&MSV_AV_TIMESTAMP.to_le_bytes()); + target_info.extend_from_slice(&8u16.to_le_bytes()); + target_info.extend_from_slice(&0u64.to_le_bytes()); + // MsvAvEOL + target_info.extend_from_slice(&MSV_AV_EOL.to_le_bytes()); + target_info.extend_from_slice(&0u16.to_le_bytes()); + + assert!(find_av_pair(&target_info, MSV_AV_TIMESTAMP).is_some()); + } + + // ======================================================================= + // NEGOTIATE_MESSAGE tests + // ======================================================================= + + #[test] + fn negotiate_message_has_correct_signature() { + let mut auth = NtlmAuthenticator::new(NtlmCredentials { + username: TEST_USER.to_string(), + password: TEST_PASSWORD.to_string(), + domain: TEST_DOMAIN.to_string(), + }); + let msg = auth.negotiate(); + + // Signature + assert_eq!(&msg[0..8], NTLM_SIGNATURE); + } + + #[test] + fn negotiate_message_has_correct_type() { + let mut auth = NtlmAuthenticator::new(NtlmCredentials { + username: TEST_USER.to_string(), + password: TEST_PASSWORD.to_string(), + domain: TEST_DOMAIN.to_string(), + }); + let msg = auth.negotiate(); + + let msg_type = u32::from_le_bytes(msg[8..12].try_into().unwrap()); + assert_eq!(msg_type, MSG_TYPE_NEGOTIATE); + } + + #[test] + fn negotiate_message_has_expected_flags() { + let mut auth = NtlmAuthenticator::new(NtlmCredentials { + username: TEST_USER.to_string(), + password: TEST_PASSWORD.to_string(), + domain: TEST_DOMAIN.to_string(), + }); + let msg = auth.negotiate(); + + let flags = u32::from_le_bytes(msg[12..16].try_into().unwrap()); + // Check that key flags are set + assert_ne!(flags & NTLMSSP_NEGOTIATE_UNICODE, 0); + assert_ne!(flags & NTLMSSP_NEGOTIATE_NTLM, 0); + assert_ne!(flags & NTLMSSP_NEGOTIATE_KEY_EXCH, 0); + assert_ne!(flags & NTLMSSP_NEGOTIATE_128, 0); + } + + #[test] + fn negotiate_message_minimum_size() { + let mut auth = NtlmAuthenticator::new(NtlmCredentials { + username: String::new(), + password: String::new(), + domain: String::new(), + }); + let msg = auth.negotiate(); + + // Minimum: signature(8) + type(4) + flags(4) + domain fields(8) + workstation fields(8) + assert_eq!(msg.len(), 32); + } + + // ======================================================================= + // CHALLENGE_MESSAGE parsing tests + // ======================================================================= + + #[test] + fn parse_challenge_message_from_spec() { + // Challenge message from section 4.2.4.3 + let challenge_bytes = build_test_challenge_message(); + let challenge = parse_challenge_message(&challenge_bytes).unwrap(); + + assert_eq!(challenge.server_challenge, TEST_SERVER_CHALLENGE); + assert!(!challenge.target_info.is_empty()); + } + + #[test] + fn parse_challenge_message_rejects_wrong_signature() { + let mut bad = build_test_challenge_message(); + bad[0] = 0x00; // Corrupt signature + assert!(parse_challenge_message(&bad).is_err()); + } + + #[test] + fn parse_challenge_message_rejects_wrong_type() { + let mut bad = build_test_challenge_message(); + // Change message type from 2 to 1 + bad[8] = 0x01; + assert!(parse_challenge_message(&bad).is_err()); + } + + #[test] + fn parse_challenge_message_rejects_too_short() { + assert!(parse_challenge_message(&[0u8; 16]).is_err()); + } + + // ======================================================================= + // Full flow tests + // ======================================================================= + + #[test] + fn full_negotiate_authenticate_flow_no_timestamp() { + let mut auth = NtlmAuthenticator::new(NtlmCredentials { + username: TEST_USER.to_string(), + password: TEST_PASSWORD.to_string(), + domain: TEST_DOMAIN.to_string(), + }); + auth.test_client_challenge = Some(TEST_CLIENT_CHALLENGE); + auth.test_random_session_key = Some(TEST_RANDOM_SESSION_KEY); + auth.test_timestamp = Some(TEST_TIME); + + // Step 1: Negotiate + let _negotiate = auth.negotiate(); + + // Step 2: Build a challenge message (no timestamp = no MIC) + let challenge_bytes = build_test_challenge_message(); + + // Step 3: Authenticate + let authenticate = auth.authenticate(&challenge_bytes).unwrap(); + + // Verify the authenticate message + assert_eq!(&authenticate[0..8], NTLM_SIGNATURE); + let msg_type = u32::from_le_bytes(authenticate[8..12].try_into().unwrap()); + assert_eq!(msg_type, MSG_TYPE_AUTHENTICATE); + + // Session key should be available + assert!(auth.session_key().is_some()); + assert_eq!(auth.session_key().unwrap().len(), 16); + } + + #[test] + fn full_flow_with_timestamp_includes_mic() { + let mut auth = NtlmAuthenticator::new(NtlmCredentials { + username: TEST_USER.to_string(), + password: TEST_PASSWORD.to_string(), + domain: TEST_DOMAIN.to_string(), + }); + auth.test_client_challenge = Some(TEST_CLIENT_CHALLENGE); + auth.test_random_session_key = Some(TEST_RANDOM_SESSION_KEY); + auth.test_timestamp = Some(TEST_TIME); + + let _negotiate = auth.negotiate(); + + // Challenge with MsvAvTimestamp + let challenge_bytes = build_test_challenge_message_with_timestamp(); + + let authenticate = auth.authenticate(&challenge_bytes).unwrap(); + + // Verify signature and type + assert_eq!(&authenticate[0..8], NTLM_SIGNATURE); + let msg_type = u32::from_le_bytes(authenticate[8..12].try_into().unwrap()); + assert_eq!(msg_type, MSG_TYPE_AUTHENTICATE); + + // MIC field at offset 72 should NOT be all zeros (it was patched) + let mic = &authenticate[72..88]; + assert_ne!( + mic, &[0u8; 16], + "MIC should be non-zero when timestamp is present" + ); + + // Session key should be available + assert!(auth.session_key().is_some()); + } + + #[test] + fn session_key_not_available_before_authenticate() { + let auth = NtlmAuthenticator::new(NtlmCredentials { + username: TEST_USER.to_string(), + password: TEST_PASSWORD.to_string(), + domain: TEST_DOMAIN.to_string(), + }); + assert!(auth.session_key().is_none()); + } + + #[test] + fn authenticate_without_negotiate_and_timestamp_still_works() { + // If there's no timestamp, MIC isn't required, so negotiate_bytes not needed + let mut auth = NtlmAuthenticator::new(NtlmCredentials { + username: TEST_USER.to_string(), + password: TEST_PASSWORD.to_string(), + domain: TEST_DOMAIN.to_string(), + }); + auth.test_client_challenge = Some(TEST_CLIENT_CHALLENGE); + auth.test_random_session_key = Some(TEST_RANDOM_SESSION_KEY); + auth.test_timestamp = Some(TEST_TIME); + + // Skip negotiate, go straight to authenticate with no-timestamp challenge + let challenge_bytes = build_test_challenge_message(); + let result = auth.authenticate(&challenge_bytes); + assert!(result.is_ok()); + } + + #[test] + fn authenticate_with_timestamp_requires_negotiate() { + let mut auth = NtlmAuthenticator::new(NtlmCredentials { + username: TEST_USER.to_string(), + password: TEST_PASSWORD.to_string(), + domain: TEST_DOMAIN.to_string(), + }); + auth.test_client_challenge = Some(TEST_CLIENT_CHALLENGE); + auth.test_random_session_key = Some(TEST_RANDOM_SESSION_KEY); + auth.test_timestamp = Some(TEST_TIME); + + // Skip negotiate, try authenticate with timestamp challenge + let challenge_bytes = build_test_challenge_message_with_timestamp(); + let result = auth.authenticate(&challenge_bytes); + // Should fail because negotiate_bytes is needed for MIC + assert!(result.is_err()); + } + + // ======================================================================= + // Edge case tests + // ======================================================================= + + #[test] + fn empty_domain() { + let mut auth = NtlmAuthenticator::new(NtlmCredentials { + username: TEST_USER.to_string(), + password: TEST_PASSWORD.to_string(), + domain: String::new(), + }); + auth.test_client_challenge = Some(TEST_CLIENT_CHALLENGE); + auth.test_random_session_key = Some(TEST_RANDOM_SESSION_KEY); + auth.test_timestamp = Some(TEST_TIME); + + let _negotiate = auth.negotiate(); + let challenge_bytes = build_test_challenge_message(); + let result = auth.authenticate(&challenge_bytes); + assert!(result.is_ok()); + } + + #[test] + fn unicode_username_with_special_characters() { + let mut auth = NtlmAuthenticator::new(NtlmCredentials { + username: "Us\u{00e9}r".to_string(), // "User" with e-acute + password: TEST_PASSWORD.to_string(), + domain: TEST_DOMAIN.to_string(), + }); + auth.test_client_challenge = Some(TEST_CLIENT_CHALLENGE); + auth.test_random_session_key = Some(TEST_RANDOM_SESSION_KEY); + auth.test_timestamp = Some(TEST_TIME); + + let _negotiate = auth.negotiate(); + let challenge_bytes = build_test_challenge_message(); + let result = auth.authenticate(&challenge_bytes); + assert!(result.is_ok()); + } + + #[test] + fn build_auth_target_info_adds_flags_when_timestamp_present() { + // Target info with timestamp + let mut target_info = Vec::new(); + let domain_bytes = encode_utf16le("Domain"); + target_info.extend_from_slice(&MSV_AV_NB_DOMAIN_NAME.to_le_bytes()); + target_info.extend_from_slice(&(domain_bytes.len() as u16).to_le_bytes()); + target_info.extend_from_slice(&domain_bytes); + target_info.extend_from_slice(&MSV_AV_TIMESTAMP.to_le_bytes()); + target_info.extend_from_slice(&8u16.to_le_bytes()); + target_info.extend_from_slice(&0u64.to_le_bytes()); + target_info.extend_from_slice(&MSV_AV_EOL.to_le_bytes()); + target_info.extend_from_slice(&0u16.to_le_bytes()); + + let auth_info = build_auth_target_info(&target_info, true); + let pairs = parse_av_pairs(&auth_info); + + // Should contain MsvAvFlags with bit 0x2 set + let flags_pair = pairs.iter().find(|(id, _)| *id == MSV_AV_FLAGS); + assert!(flags_pair.is_some(), "MsvAvFlags should be present"); + let flags_value = u32::from_le_bytes(flags_pair.unwrap().1[..4].try_into().unwrap()); + assert_ne!(flags_value & 0x2, 0, "MIC bit should be set in MsvAvFlags"); + } + + #[test] + fn build_auth_target_info_no_flags_when_no_timestamp() { + let target_info = build_test_target_info(); + let auth_info = build_auth_target_info(&target_info, false); + let pairs = parse_av_pairs(&auth_info); + + // Should NOT contain MsvAvFlags + let flags_pair = pairs.iter().find(|(id, _)| *id == MSV_AV_FLAGS); + assert!(flags_pair.is_none()); + } + + #[test] + fn lm_challenge_response_is_zeroed_when_timestamp_present() { + let mut auth = NtlmAuthenticator::new(NtlmCredentials { + username: TEST_USER.to_string(), + password: TEST_PASSWORD.to_string(), + domain: TEST_DOMAIN.to_string(), + }); + auth.test_client_challenge = Some(TEST_CLIENT_CHALLENGE); + auth.test_random_session_key = Some(TEST_RANDOM_SESSION_KEY); + auth.test_timestamp = Some(TEST_TIME); + + let _negotiate = auth.negotiate(); + let challenge_bytes = build_test_challenge_message_with_timestamp(); + let authenticate = auth.authenticate(&challenge_bytes).unwrap(); + + // Parse LmChallengeResponseFields from the authenticate message + let lm_len = u16::from_le_bytes(authenticate[12..14].try_into().unwrap()) as usize; + let lm_offset = u32::from_le_bytes(authenticate[16..20].try_into().unwrap()) as usize; + + // LM response should be 24 bytes of zeros + assert_eq!(lm_len, 24); + let lm_data = &authenticate[lm_offset..lm_offset + lm_len]; + assert_eq!(lm_data, &[0u8; 24]); + } + + #[test] + fn authenticate_message_contains_correct_domain_and_user() { + let mut auth = NtlmAuthenticator::new(NtlmCredentials { + username: TEST_USER.to_string(), + password: TEST_PASSWORD.to_string(), + domain: TEST_DOMAIN.to_string(), + }); + auth.test_client_challenge = Some(TEST_CLIENT_CHALLENGE); + auth.test_random_session_key = Some(TEST_RANDOM_SESSION_KEY); + auth.test_timestamp = Some(TEST_TIME); + + let _negotiate = auth.negotiate(); + let challenge_bytes = build_test_challenge_message(); + let authenticate = auth.authenticate(&challenge_bytes).unwrap(); + + // DomainNameFields at offset 28 + let domain_len = u16::from_le_bytes(authenticate[28..30].try_into().unwrap()) as usize; + let domain_offset = u32::from_le_bytes(authenticate[32..36].try_into().unwrap()) as usize; + let domain_bytes = &authenticate[domain_offset..domain_offset + domain_len]; + let domain = String::from_utf16( + &domain_bytes + .chunks_exact(2) + .map(|c| u16::from_le_bytes([c[0], c[1]])) + .collect::>(), + ) + .unwrap(); + assert_eq!(domain, TEST_DOMAIN); + + // UserNameFields at offset 36 + let user_len = u16::from_le_bytes(authenticate[36..38].try_into().unwrap()) as usize; + let user_offset = u32::from_le_bytes(authenticate[40..44].try_into().unwrap()) as usize; + let user_bytes = &authenticate[user_offset..user_offset + user_len]; + let user = String::from_utf16( + &user_bytes + .chunks_exact(2) + .map(|c| u16::from_le_bytes([c[0], c[1]])) + .collect::>(), + ) + .unwrap(); + assert_eq!(user, TEST_USER); + } + + // ======================================================================= + // Known-answer test: full NTLMv2 from spec section 4.2.4 + // ======================================================================= + + #[test] + fn ntlmv2_full_known_answer_lm_response() { + // From section 4.2.4.2.1: LMv2 Response + let expected = [ + 0x86, 0xc3, 0x50, 0x97, 0xac, 0x9c, 0xec, 0x10, 0x25, 0x54, 0x76, 0x4a, 0x57, 0xcc, + 0xcc, 0x19, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, + ]; + + let nt_hash = compute_nt_hash(TEST_PASSWORD); + let ntlmv2_hash = compute_ntlmv2_hash(&nt_hash, TEST_USER, TEST_DOMAIN); + + // LMv2: HMAC_MD5(ntlmv2_hash, server_challenge + client_challenge) + client_challenge + let mut mac = HmacMd5::new_from_slice(&ntlmv2_hash).expect("HMAC accepts any key length"); + mac.update(&TEST_SERVER_CHALLENGE); + mac.update(&TEST_CLIENT_CHALLENGE); + let proof = mac.finalize().into_bytes(); + let mut resp = proof.to_vec(); + resp.extend_from_slice(&TEST_CLIENT_CHALLENGE); + + assert_eq!(resp, expected); + } + + // ======================================================================= + // Test helpers + // ======================================================================= + + /// Build a test target info matching the NTLMv2 test vectors from section 4.2.4. + /// Contains: MsvAvNbDomainName("Domain"), MsvAvNbComputerName("Server"), MsvAvEOL. + fn build_test_target_info() -> Vec { + let mut info = Vec::new(); + + // MsvAvNbDomainName = "Domain" + let domain_bytes = encode_utf16le("Domain"); + info.extend_from_slice(&MSV_AV_NB_DOMAIN_NAME.to_le_bytes()); + info.extend_from_slice(&(domain_bytes.len() as u16).to_le_bytes()); + info.extend_from_slice(&domain_bytes); + + // MsvAvNbComputerName = "Server" + let server_bytes = encode_utf16le("Server"); + info.extend_from_slice(&MSV_AV_NB_COMPUTER_NAME.to_le_bytes()); + info.extend_from_slice(&(server_bytes.len() as u16).to_le_bytes()); + info.extend_from_slice(&server_bytes); + + // MsvAvEOL + info.extend_from_slice(&MSV_AV_EOL.to_le_bytes()); + info.extend_from_slice(&0u16.to_le_bytes()); + + info + } + + /// Build a CHALLENGE_MESSAGE matching the NTLMv2 test vectors (no MsvAvTimestamp). + fn build_test_challenge_message() -> Vec { + let target_info = build_test_target_info(); + let target_name = encode_utf16le("Server"); + + // NTLMv2 challenge flags from section 4.2.4 + let flags: u32 = 0xe28a8233; + + build_challenge_message_bytes(flags, &target_name, &target_info) + } + + /// Build a CHALLENGE_MESSAGE with MsvAvTimestamp present (triggers MIC). + fn build_test_challenge_message_with_timestamp() -> Vec { + let mut target_info = Vec::new(); + + // MsvAvNbDomainName = "Domain" + let domain_bytes = encode_utf16le("Domain"); + target_info.extend_from_slice(&MSV_AV_NB_DOMAIN_NAME.to_le_bytes()); + target_info.extend_from_slice(&(domain_bytes.len() as u16).to_le_bytes()); + target_info.extend_from_slice(&domain_bytes); + + // MsvAvNbComputerName = "Server" + let server_bytes = encode_utf16le("Server"); + target_info.extend_from_slice(&MSV_AV_NB_COMPUTER_NAME.to_le_bytes()); + target_info.extend_from_slice(&(server_bytes.len() as u16).to_le_bytes()); + target_info.extend_from_slice(&server_bytes); + + // MsvAvTimestamp + target_info.extend_from_slice(&MSV_AV_TIMESTAMP.to_le_bytes()); + target_info.extend_from_slice(&8u16.to_le_bytes()); + target_info.extend_from_slice(&0u64.to_le_bytes()); // timestamp = 0 + + // MsvAvEOL + target_info.extend_from_slice(&MSV_AV_EOL.to_le_bytes()); + target_info.extend_from_slice(&0u16.to_le_bytes()); + + let target_name = encode_utf16le("Server"); + let flags: u32 = 0xe28a8233; + + build_challenge_message_bytes(flags, &target_name, &target_info) + } + + /// Helper to construct a raw CHALLENGE_MESSAGE. + fn build_challenge_message_bytes( + flags: u32, + target_name: &[u8], + target_info: &[u8], + ) -> Vec { + // Fixed header: 56 bytes (up to and including version) + // Payload starts at offset 56 (no VERSION in our simplified messages) + // Actually, the challenge message layout: + // Signature(8) + Type(4) + TargetNameFields(8) + Flags(4) + ServerChallenge(8) + // + Reserved(8) + TargetInfoFields(8) + Version(8) + // = 56 bytes header + let header_size = 56; + let target_name_offset = header_size; + let target_info_offset = target_name_offset + target_name.len(); + + let mut buf = Vec::with_capacity(target_info_offset + target_info.len()); + + // Signature + buf.extend_from_slice(NTLM_SIGNATURE); + // MessageType + buf.extend_from_slice(&MSG_TYPE_CHALLENGE.to_le_bytes()); + // TargetNameFields + buf.extend_from_slice(&(target_name.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(target_name.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(target_name_offset as u32).to_le_bytes()); + // NegotiateFlags + buf.extend_from_slice(&flags.to_le_bytes()); + // ServerChallenge + buf.extend_from_slice(&TEST_SERVER_CHALLENGE); + // Reserved + buf.extend_from_slice(&[0u8; 8]); + // TargetInfoFields + buf.extend_from_slice(&(target_info.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(target_info.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(target_info_offset as u32).to_le_bytes()); + // Version (8 bytes) + buf.extend_from_slice(&[0x06, 0x00, 0x70, 0x17, 0x00, 0x00, 0x00, 0x0f]); + + // Payload + buf.extend_from_slice(target_name); + buf.extend_from_slice(target_info); + + buf + } +} diff --git a/vendor/smb2/src/auth/spnego.rs b/vendor/smb2/src/auth/spnego.rs new file mode 100644 index 0000000..984daea --- /dev/null +++ b/vendor/smb2/src/auth/spnego.rs @@ -0,0 +1,808 @@ +//! SPNEGO (Simple and Protected GSS-API Negotiation Mechanism) token wrapping. +//! +//! Implements the thin ASN.1/DER wrapper that SMB2 requires around authentication +//! tokens (NTLM, Kerberos). The client sends a NegTokenInit with supported +//! mechanism OIDs and the first mechanism's token, the server responds with +//! NegTokenResp indicating the selected mechanism and its response token, and +//! subsequent client messages use NegTokenResp as well. +//! +//! References: +//! - RFC 4178 (SPNEGO) +//! - MS-SPNG (Microsoft SPNEGO Extension) + +use super::der::{der_tlv, parse_der_tlv}; +use crate::Error; + +// --------------------------------------------------------------------------- +// OID constants (DER-encoded, including tag and length bytes) +// --------------------------------------------------------------------------- + +/// SPNEGO OID: 1.3.6.1.5.5.2 +pub const OID_SPNEGO: &[u8] = &[0x06, 0x06, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x02]; + +/// NTLM (NTLMSSP) OID: 1.3.6.1.4.1.311.2.2.10 +pub const OID_NTLMSSP: &[u8] = &[ + 0x06, 0x0a, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x02, 0x02, 0x0a, +]; + +/// Kerberos OID: 1.2.840.113554.1.2.2 (standard, RFC 4121) +pub const OID_KERBEROS: &[u8] = &[ + 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x12, 0x01, 0x02, 0x02, +]; + +/// Microsoft Kerberos OID: 1.2.840.48018.1.2.2 (MS-KILE, used by Windows SPNEGO) +/// +/// Windows expects this OID as the primary mechanism in SPNEGO NegTokenInit. +/// Using the standard Kerberos OID causes Windows to reject the AP-REQ. +pub const OID_MS_KERBEROS: &[u8] = &[ + 0x06, 0x09, 0x2a, 0x86, 0x48, 0x82, 0xf7, 0x12, 0x01, 0x02, 0x02, +]; + +// --------------------------------------------------------------------------- +// ASN.1 DER tag constants +// --------------------------------------------------------------------------- + +/// SEQUENCE tag (constructed). +const TAG_SEQUENCE: u8 = 0x30; +/// OCTET STRING tag. +const TAG_OCTET_STRING: u8 = 0x04; +/// ENUMERATED tag. +const TAG_ENUMERATED: u8 = 0x0a; +/// APPLICATION [0] (constructed) -- wraps the initial NegotiationToken. +const TAG_APPLICATION_0: u8 = 0x60; +/// Context-specific [0] (constructed). +const TAG_CONTEXT_0: u8 = 0xa0; +/// Context-specific [1] (constructed). +const TAG_CONTEXT_1: u8 = 0xa1; +/// Context-specific [2] (constructed). +const TAG_CONTEXT_2: u8 = 0xa2; +/// Context-specific [3] (constructed). +const TAG_CONTEXT_3: u8 = 0xa3; + +// --------------------------------------------------------------------------- +// NegState enum +// --------------------------------------------------------------------------- + +/// SPNEGO negotiation state from NegTokenResp. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum NegState { + /// Authentication completed successfully. + AcceptCompleted, + /// Authentication is in progress (more tokens needed). + AcceptIncomplete, + /// Authentication was rejected. + Reject, +} + +impl NegState { + /// Parse from the DER enumerated value. + fn from_value(v: u8) -> Option { + match v { + 0 => Some(NegState::AcceptCompleted), + 1 => Some(NegState::AcceptIncomplete), + 2 => Some(NegState::Reject), + _ => None, + } + } +} + +// --------------------------------------------------------------------------- +// NegTokenResp struct +// --------------------------------------------------------------------------- + +/// Parsed SPNEGO NegTokenResp from the server. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NegTokenResp { + /// The negotiation state. + pub neg_state: Option, + /// The selected mechanism OID (raw DER-encoded OID TLV). + pub supported_mech: Option>, + /// The mechanism-specific response token. + pub response_token: Option>, + /// The mechanism list MIC. + pub mech_list_mic: Option>, +} + +// DER encoding/decoding helpers are in `super::der`. Imported at the top. + +// --------------------------------------------------------------------------- +// Public API: wrapping +// --------------------------------------------------------------------------- + +/// Wrap a mechanism token in a SPNEGO NegTokenInit. +/// +/// The initial token sent by the client. Wraps the raw NTLM or Kerberos +/// token with mechanism OID negotiation. +/// +/// Structure (RFC 4178 section 4.2): +/// ```text +/// APPLICATION [0] { +/// OID_SPNEGO, +/// [0] { -- NegTokenInit choice tag +/// SEQUENCE { +/// [0] { SEQUENCE { mechOID1, mechOID2, ... } }, -- mechTypes +/// [2] { OCTET STRING { mechToken } } -- mechToken +/// } +/// } +/// } +/// ``` +pub fn wrap_neg_token_init(mech_oids: &[&[u8]], mech_token: &[u8]) -> Vec { + // Build mechTypes: SEQUENCE OF OID + let mut mech_list_contents = Vec::new(); + for oid in mech_oids { + mech_list_contents.extend_from_slice(oid); + } + let mech_list_seq = der_tlv(TAG_SEQUENCE, &mech_list_contents); + let mech_types = der_tlv(TAG_CONTEXT_0, &mech_list_seq); + + // Build mechToken: [2] OCTET STRING + let mech_token_octet = der_tlv(TAG_OCTET_STRING, mech_token); + let mech_token_ctx = der_tlv(TAG_CONTEXT_2, &mech_token_octet); + + // NegTokenInit SEQUENCE + let mut init_contents = Vec::new(); + init_contents.extend_from_slice(&mech_types); + init_contents.extend_from_slice(&mech_token_ctx); + let init_seq = der_tlv(TAG_SEQUENCE, &init_contents); + + // Wrap in context [0] (NegotiationToken CHOICE for negTokenInit) + let choice = der_tlv(TAG_CONTEXT_0, &init_seq); + + // Wrap in APPLICATION [0] with SPNEGO OID + let mut app_contents = Vec::new(); + app_contents.extend_from_slice(OID_SPNEGO); + app_contents.extend_from_slice(&choice); + der_tlv(TAG_APPLICATION_0, &app_contents) +} + +/// Wrap a mechanism token in a SPNEGO NegTokenResp. +/// +/// Used by the client in the second round-trip (for example, the NTLM +/// AUTHENTICATE_MESSAGE). Only the responseToken field is set. +/// +/// Structure: +/// ```text +/// [1] { -- NegotiationToken CHOICE for negTokenResp +/// SEQUENCE { +/// [2] { OCTET STRING { mechToken } } -- responseToken +/// } +/// } +/// ``` +pub fn wrap_neg_token_resp(mech_token: &[u8]) -> Vec { + // Build responseToken: [2] OCTET STRING + let mech_token_octet = der_tlv(TAG_OCTET_STRING, mech_token); + let response_token_ctx = der_tlv(TAG_CONTEXT_2, &mech_token_octet); + + // NegTokenResp SEQUENCE + let resp_seq = der_tlv(TAG_SEQUENCE, &response_token_ctx); + + // Wrap in context [1] (NegotiationToken CHOICE for negTokenResp) + der_tlv(TAG_CONTEXT_1, &resp_seq) +} + +// --------------------------------------------------------------------------- +// Public API: parsing +// --------------------------------------------------------------------------- + +/// Parse a SPNEGO NegTokenResp from the server. +/// +/// The input can be either: +/// - A bare `[1] { SEQUENCE { ... } }` NegTokenResp +/// - An `APPLICATION [0] { OID, [0] { ... } }` wrapping a NegTokenInit2 +/// (server-initiated SPNEGO, which we parse the inner token from) +/// +/// Extracts the negotiation state, selected mechanism, and response token. +pub fn parse_neg_token_resp(data: &[u8]) -> Result { + if data.is_empty() { + return Err(Error::invalid_data("SPNEGO: empty token")); + } + + // Check if this is an APPLICATION [0] wrapper (server-initiated NegTokenInit2) + // or a NegTokenResp [1] wrapper. + let (tag, value, _) = parse_der_tlv(data)?; + + match tag { + TAG_CONTEXT_1 => { + // Standard NegTokenResp: [1] { SEQUENCE { ... } } + parse_neg_token_resp_inner(value) + } + TAG_APPLICATION_0 => { + // APPLICATION [0] { OID_SPNEGO, [0] { NegTokenInit2 } } + // or could contain a [1] { NegTokenResp } + // Skip the SPNEGO OID + let (oid_tag, _, oid_total) = parse_der_tlv(value)?; + if oid_tag != 0x06 { + return Err(Error::invalid_data(format!( + "SPNEGO: expected OID in APPLICATION [0], got tag 0x{oid_tag:02x}" + ))); + } + let remaining = &value[oid_total..]; + let (inner_tag, inner_value, _) = parse_der_tlv(remaining)?; + match inner_tag { + TAG_CONTEXT_0 => { + // NegTokenInit2 wrapped in [0]: parse as NegTokenInit2 + // to extract mechTypes (as supportedMech) and mechToken + parse_neg_token_init2_as_resp(inner_value) + } + TAG_CONTEXT_1 => { + // NegTokenResp wrapped inside APPLICATION [0] + parse_neg_token_resp_inner(inner_value) + } + _ => Err(Error::invalid_data(format!( + "SPNEGO: unexpected tag 0x{inner_tag:02x} inside APPLICATION [0]" + ))), + } + } + _ => Err(Error::invalid_data(format!( + "SPNEGO: expected NegTokenResp [1] or APPLICATION [0], got tag 0x{tag:02x}" + ))), + } +} + +/// Parse the inner SEQUENCE of a NegTokenResp. +fn parse_neg_token_resp_inner(data: &[u8]) -> Result { + // Expect SEQUENCE + let (tag, seq_data, _) = parse_der_tlv(data)?; + if tag != TAG_SEQUENCE { + return Err(Error::invalid_data(format!( + "SPNEGO: expected SEQUENCE in NegTokenResp, got tag 0x{tag:02x}" + ))); + } + + let mut neg_state = None; + let mut supported_mech = None; + let mut response_token = None; + let mut mech_list_mic = None; + + let mut pos = 0; + while pos < seq_data.len() { + let (ctx_tag, ctx_value, ctx_total) = parse_der_tlv(&seq_data[pos..])?; + match ctx_tag { + TAG_CONTEXT_0 => { + // negState: ENUMERATED + let (enum_tag, enum_value, _) = parse_der_tlv(ctx_value)?; + if enum_tag != TAG_ENUMERATED { + return Err(Error::invalid_data(format!( + "SPNEGO: expected ENUMERATED for negState, got tag 0x{enum_tag:02x}" + ))); + } + if enum_value.is_empty() { + return Err(Error::invalid_data("SPNEGO: empty ENUMERATED for negState")); + } + neg_state = NegState::from_value(enum_value[0]); + if neg_state.is_none() { + return Err(Error::invalid_data(format!( + "SPNEGO: unknown negState value: {}", + enum_value[0] + ))); + } + } + TAG_CONTEXT_1 => { + // supportedMech: OID (the full TLV) + supported_mech = Some(ctx_value.to_vec()); + } + TAG_CONTEXT_2 => { + // responseToken: OCTET STRING + let (oct_tag, oct_value, _) = parse_der_tlv(ctx_value)?; + if oct_tag != TAG_OCTET_STRING { + return Err(Error::invalid_data(format!( + "SPNEGO: expected OCTET STRING for responseToken, got tag 0x{oct_tag:02x}" + ))); + } + response_token = Some(oct_value.to_vec()); + } + TAG_CONTEXT_3 => { + // mechListMIC: OCTET STRING + let (oct_tag, oct_value, _) = parse_der_tlv(ctx_value)?; + if oct_tag != TAG_OCTET_STRING { + return Err(Error::invalid_data(format!( + "SPNEGO: expected OCTET STRING for mechListMIC, got tag 0x{oct_tag:02x}" + ))); + } + mech_list_mic = Some(oct_value.to_vec()); + } + _ => { + // Unknown context tag, skip it (forward compatibility). + } + } + pos += ctx_total; + } + + Ok(NegTokenResp { + neg_state, + supported_mech, + response_token, + mech_list_mic, + }) +} + +/// Parse a NegTokenInit2 (server-initiated) and return it as a NegTokenResp. +/// +/// NegTokenInit2 has mechTypes at [0] and mechToken at [2]. We map the +/// first mechType to supportedMech and mechToken to responseToken. +fn parse_neg_token_init2_as_resp(data: &[u8]) -> Result { + let (tag, seq_data, _) = parse_der_tlv(data)?; + if tag != TAG_SEQUENCE { + return Err(Error::invalid_data(format!( + "SPNEGO: expected SEQUENCE in NegTokenInit2, got tag 0x{tag:02x}" + ))); + } + + let mut supported_mech = None; + let mut response_token = None; + + let mut pos = 0; + while pos < seq_data.len() { + let (ctx_tag, ctx_value, ctx_total) = parse_der_tlv(&seq_data[pos..])?; + match ctx_tag { + TAG_CONTEXT_0 => { + // mechTypes: SEQUENCE OF OID -- take the first one + let (seq_tag, mech_list_data, _) = parse_der_tlv(ctx_value)?; + if seq_tag != TAG_SEQUENCE { + return Err(Error::invalid_data( + "SPNEGO: expected SEQUENCE for mechTypes", + )); + } + if !mech_list_data.is_empty() { + // Take the first OID TLV as the supported mech + let (oid_tag, _, oid_total) = parse_der_tlv(mech_list_data)?; + if oid_tag == 0x06 { + supported_mech = Some(mech_list_data[..oid_total].to_vec()); + } + } + } + TAG_CONTEXT_2 => { + // mechToken: OCTET STRING + let (oct_tag, oct_value, _) = parse_der_tlv(ctx_value)?; + if oct_tag == TAG_OCTET_STRING { + response_token = Some(oct_value.to_vec()); + } + } + _ => { + // Skip reqFlags [1], negHints [3], mechListMIC [4] + } + } + pos += ctx_total; + } + + Ok(NegTokenResp { + neg_state: None, + supported_mech, + response_token, + mech_list_mic: None, + }) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + // DER primitive tests (der_length, der_tlv, parse_der_length, parse_der_tlv) + // live in `auth::der::tests`. + + // ======================================================================= + // NegTokenInit wrapping tests + // ======================================================================= + + #[test] + fn neg_token_init_starts_with_application_tag() { + let token = wrap_neg_token_init(&[OID_NTLMSSP], b"NTLMSSP\0test"); + assert_eq!( + token[0], TAG_APPLICATION_0, + "must start with APPLICATION [0]" + ); + } + + #[test] + fn neg_token_init_contains_spnego_oid() { + let token = wrap_neg_token_init(&[OID_NTLMSSP], b"NTLMSSP\0test"); + // The SPNEGO OID value bytes (without the 0x06 tag and 0x06 length) + let oid_value = &OID_SPNEGO[2..]; // skip tag+length + assert!( + token.windows(oid_value.len()).any(|w| w == oid_value), + "token must contain SPNEGO OID" + ); + } + + #[test] + fn neg_token_init_contains_mech_oid() { + let token = wrap_neg_token_init(&[OID_NTLMSSP], b"test"); + // The NTLMSSP OID value bytes (without the 0x06 tag) + let oid_value = &OID_NTLMSSP[2..]; // skip tag+length + assert!( + token.windows(oid_value.len()).any(|w| w == oid_value), + "token must contain NTLMSSP OID" + ); + } + + #[test] + fn neg_token_init_contains_mech_token() { + let mech_token = b"NTLMSSP\0negotiate_payload_here"; + let token = wrap_neg_token_init(&[OID_NTLMSSP], mech_token); + assert!( + token.windows(mech_token.len()).any(|w| w == mech_token), + "token must contain the raw mech token" + ); + } + + #[test] + fn neg_token_init_multiple_mechs() { + let token = wrap_neg_token_init(&[OID_NTLMSSP, OID_KERBEROS], b"tok"); + // Both OIDs should be present + let ntlm_oid_value = &OID_NTLMSSP[2..]; + let kerb_oid_value = &OID_KERBEROS[2..]; + assert!( + token + .windows(ntlm_oid_value.len()) + .any(|w| w == ntlm_oid_value), + "must contain NTLMSSP OID" + ); + assert!( + token + .windows(kerb_oid_value.len()) + .any(|w| w == kerb_oid_value), + "must contain Kerberos OID" + ); + } + + #[test] + fn neg_token_init_structure_is_valid_der() { + let token = wrap_neg_token_init(&[OID_NTLMSSP], b"test_token"); + // Parse the outer APPLICATION [0] + let (tag, value, total) = parse_der_tlv(&token).unwrap(); + assert_eq!(tag, TAG_APPLICATION_0); + assert_eq!(total, token.len(), "entire token should be consumed"); + + // Inside: OID_SPNEGO followed by [0] { SEQUENCE { ... } } + let (oid_tag, _, oid_total) = parse_der_tlv(value).unwrap(); + assert_eq!(oid_tag, 0x06, "first element should be OID"); + + let (choice_tag, _, _) = parse_der_tlv(&value[oid_total..]).unwrap(); + assert_eq!(choice_tag, TAG_CONTEXT_0, "second element should be [0]"); + } + + #[test] + fn neg_token_init_parseable_structure() { + // Wrap a token and verify we can walk the entire structure + let mech_token = b"the_raw_ntlm_token"; + let token = wrap_neg_token_init(&[OID_NTLMSSP], mech_token); + + // APPLICATION [0] + let (_, app_value, _) = parse_der_tlv(&token).unwrap(); + // Skip SPNEGO OID + let (_, _, oid_total) = parse_der_tlv(app_value).unwrap(); + // [0] CHOICE + let (_, choice_value, _) = parse_der_tlv(&app_value[oid_total..]).unwrap(); + // SEQUENCE + let (_, seq_value, _) = parse_der_tlv(choice_value).unwrap(); + // [0] mechTypes + let (tag0, ctx0_value, ctx0_total) = parse_der_tlv(seq_value).unwrap(); + assert_eq!(tag0, TAG_CONTEXT_0); + // SEQUENCE OF OID inside mechTypes + let (_, mech_list, _) = parse_der_tlv(ctx0_value).unwrap(); + // First OID should be NTLMSSP + assert_eq!(&mech_list[..OID_NTLMSSP.len()], OID_NTLMSSP); + + // [2] mechToken + let (tag2, ctx2_value, _) = parse_der_tlv(&seq_value[ctx0_total..]).unwrap(); + assert_eq!(tag2, TAG_CONTEXT_2); + // OCTET STRING + let (_, oct_value, _) = parse_der_tlv(ctx2_value).unwrap(); + assert_eq!(oct_value, mech_token); + } + + // ======================================================================= + // NegTokenResp wrapping tests + // ======================================================================= + + #[test] + fn neg_token_resp_wrap_starts_with_context_1() { + let token = wrap_neg_token_resp(b"auth_token"); + assert_eq!(token[0], TAG_CONTEXT_1, "must start with [1]"); + } + + #[test] + fn neg_token_resp_wrap_contains_mech_token() { + let mech_token = b"NTLMSSP\0authenticate_payload"; + let token = wrap_neg_token_resp(mech_token); + assert!( + token.windows(mech_token.len()).any(|w| w == mech_token), + "wrapped token must contain the raw mech token" + ); + } + + #[test] + fn neg_token_resp_wrap_valid_structure() { + let mech_token = b"authenticate_me"; + let token = wrap_neg_token_resp(mech_token); + + // [1] + let (tag, ctx1_value, _) = parse_der_tlv(&token).unwrap(); + assert_eq!(tag, TAG_CONTEXT_1); + // SEQUENCE + let (tag, seq_value, _) = parse_der_tlv(ctx1_value).unwrap(); + assert_eq!(tag, TAG_SEQUENCE); + // [2] responseToken + let (tag, ctx2_value, _) = parse_der_tlv(seq_value).unwrap(); + assert_eq!(tag, TAG_CONTEXT_2); + // OCTET STRING + let (tag, oct_value, _) = parse_der_tlv(ctx2_value).unwrap(); + assert_eq!(tag, TAG_OCTET_STRING); + assert_eq!(oct_value, mech_token); + } + + // ======================================================================= + // NegTokenResp parsing tests + // ======================================================================= + + /// Build a NegTokenResp with known fields for testing. + fn build_test_neg_token_resp( + neg_state: Option, + supported_mech: Option<&[u8]>, + response_token: Option<&[u8]>, + mech_list_mic: Option<&[u8]>, + ) -> Vec { + let mut seq_contents = Vec::new(); + + if let Some(state) = neg_state { + let enumerated = der_tlv(TAG_ENUMERATED, &[state]); + seq_contents.extend_from_slice(&der_tlv(TAG_CONTEXT_0, &enumerated)); + } + + if let Some(oid) = supported_mech { + seq_contents.extend_from_slice(&der_tlv(TAG_CONTEXT_1, oid)); + } + + if let Some(tok) = response_token { + let octet = der_tlv(TAG_OCTET_STRING, tok); + seq_contents.extend_from_slice(&der_tlv(TAG_CONTEXT_2, &octet)); + } + + if let Some(mic) = mech_list_mic { + let octet = der_tlv(TAG_OCTET_STRING, mic); + seq_contents.extend_from_slice(&der_tlv(TAG_CONTEXT_3, &octet)); + } + + let seq = der_tlv(TAG_SEQUENCE, &seq_contents); + der_tlv(TAG_CONTEXT_1, &seq) + } + + #[test] + fn parse_neg_token_resp_accept_incomplete() { + let token = build_test_neg_token_resp( + Some(1), // accept-incomplete + Some(OID_NTLMSSP), + Some(b"challenge_token"), + None, + ); + + let resp = parse_neg_token_resp(&token).unwrap(); + assert_eq!(resp.neg_state, Some(NegState::AcceptIncomplete)); + assert_eq!(resp.supported_mech.as_deref(), Some(OID_NTLMSSP)); + assert_eq!( + resp.response_token.as_deref(), + Some(&b"challenge_token"[..]) + ); + assert!(resp.mech_list_mic.is_none()); + } + + #[test] + fn parse_neg_token_resp_accept_completed() { + let token = build_test_neg_token_resp(Some(0), None, None, None); + + let resp = parse_neg_token_resp(&token).unwrap(); + assert_eq!(resp.neg_state, Some(NegState::AcceptCompleted)); + assert!(resp.supported_mech.is_none()); + assert!(resp.response_token.is_none()); + } + + #[test] + fn parse_neg_token_resp_reject() { + let token = build_test_neg_token_resp(Some(2), None, None, None); + + let resp = parse_neg_token_resp(&token).unwrap(); + assert_eq!(resp.neg_state, Some(NegState::Reject)); + } + + #[test] + fn parse_neg_token_resp_all_fields() { + let token = build_test_neg_token_resp( + Some(1), + Some(OID_NTLMSSP), + Some(b"response_data"), + Some(b"mic_data"), + ); + + let resp = parse_neg_token_resp(&token).unwrap(); + assert_eq!(resp.neg_state, Some(NegState::AcceptIncomplete)); + assert_eq!(resp.supported_mech.as_deref(), Some(OID_NTLMSSP)); + assert_eq!(resp.response_token.as_deref(), Some(&b"response_data"[..])); + assert_eq!(resp.mech_list_mic.as_deref(), Some(&b"mic_data"[..])); + } + + #[test] + fn parse_neg_token_resp_no_fields() { + // All fields optional + let token = build_test_neg_token_resp(None, None, None, None); + + let resp = parse_neg_token_resp(&token).unwrap(); + assert!(resp.neg_state.is_none()); + assert!(resp.supported_mech.is_none()); + assert!(resp.response_token.is_none()); + assert!(resp.mech_list_mic.is_none()); + } + + #[test] + fn parse_neg_token_resp_empty_data_error() { + let result = parse_neg_token_resp(&[]); + assert!(result.is_err()); + } + + #[test] + fn parse_neg_token_resp_truncated_error() { + // Just a tag byte, no length + let result = parse_neg_token_resp(&[TAG_CONTEXT_1]); + assert!(result.is_err()); + } + + #[test] + fn parse_neg_token_resp_wrong_tag_error() { + // SEQUENCE tag instead of [1] + let data = der_tlv(TAG_SEQUENCE, &[0x00]); + let result = parse_neg_token_resp(&data); + assert!(result.is_err()); + } + + #[test] + fn parse_neg_token_resp_unknown_neg_state_error() { + let token = build_test_neg_token_resp(Some(99), None, None, None); + let result = parse_neg_token_resp(&token); + assert!(result.is_err()); + } + + // ======================================================================= + // Cross-validation: construct a realistic server response + // ======================================================================= + + #[test] + fn parse_realistic_server_challenge_response() { + // Simulate a typical Samba/Windows SPNEGO response to the first + // SESSION_SETUP: accept-incomplete with NTLMSSP OID and an NTLM + // challenge token. + let ntlm_challenge = b"NTLMSSP\0\x02\x00\x00\x00fake_challenge_data"; + + let token = build_test_neg_token_resp( + Some(1), // accept-incomplete + Some(OID_NTLMSSP), + Some(ntlm_challenge), + None, + ); + + let resp = parse_neg_token_resp(&token).unwrap(); + assert_eq!(resp.neg_state, Some(NegState::AcceptIncomplete)); + assert_eq!(resp.response_token.as_deref(), Some(&ntlm_challenge[..])); + } + + #[test] + fn parse_realistic_server_accept_with_mic() { + // Final server response: accept-completed with mechListMIC + let mic = [0xaa; 16]; + let token = build_test_neg_token_resp(Some(0), None, None, Some(&mic)); + + let resp = parse_neg_token_resp(&token).unwrap(); + assert_eq!(resp.neg_state, Some(NegState::AcceptCompleted)); + assert_eq!(resp.mech_list_mic.as_deref(), Some(&mic[..])); + } + + // ======================================================================= + // Roundtrip: wrap and parse NegTokenResp + // ======================================================================= + + #[test] + fn neg_token_resp_wrap_then_parse() { + let mech_token = b"roundtrip_test_token"; + let wrapped = wrap_neg_token_resp(mech_token); + let parsed = parse_neg_token_resp(&wrapped).unwrap(); + + // Wrapped with only responseToken, so: + assert!(parsed.neg_state.is_none()); + assert!(parsed.supported_mech.is_none()); + assert_eq!(parsed.response_token.as_deref(), Some(&mech_token[..])); + assert!(parsed.mech_list_mic.is_none()); + } + + // ======================================================================= + // Wire capture cross-validation + // ======================================================================= + + #[test] + fn parse_hand_constructed_wire_bytes() { + // Hand-constructed NegTokenResp matching what a Windows/Samba server + // sends after receiving NegTokenInit with NTLMSSP: + // + // a1 XX -- [1] NegTokenResp + // 30 XX -- SEQUENCE + // a0 03 -- [0] negState + // 0a 01 01 -- ENUMERATED accept-incomplete (1) + // a1 0c -- [1] supportedMech + // 06 0a 2b 06 01 04 01 82 37 02 02 0a -- NTLMSSP OID + // a2 XX -- [2] responseToken + // 04 XX -- OCTET STRING + // + let ntlm_challenge = b"NTLMSSP\0fake"; + + // Build by hand + let neg_state_enum = vec![0x0a, 0x01, 0x01]; // ENUMERATED 1 + let neg_state_ctx = der_tlv(TAG_CONTEXT_0, &neg_state_enum); + + let mech_ctx = der_tlv(TAG_CONTEXT_1, OID_NTLMSSP); + + let resp_octet = der_tlv(TAG_OCTET_STRING, ntlm_challenge); + let resp_ctx = der_tlv(TAG_CONTEXT_2, &resp_octet); + + let mut seq_content = Vec::new(); + seq_content.extend_from_slice(&neg_state_ctx); + seq_content.extend_from_slice(&mech_ctx); + seq_content.extend_from_slice(&resp_ctx); + let seq = der_tlv(TAG_SEQUENCE, &seq_content); + let wire_bytes = der_tlv(TAG_CONTEXT_1, &seq); + + let parsed = parse_neg_token_resp(&wire_bytes).unwrap(); + assert_eq!(parsed.neg_state, Some(NegState::AcceptIncomplete)); + assert_eq!(parsed.supported_mech.as_deref(), Some(OID_NTLMSSP)); + assert_eq!(parsed.response_token.as_deref(), Some(&ntlm_challenge[..])); + } + + // ======================================================================= + // OID constant verification + // ======================================================================= + + #[test] + fn oid_constants_are_valid_der() { + // Each OID constant should parse as a valid DER TLV with tag 0x06 + for (name, oid) in [ + ("SPNEGO", OID_SPNEGO), + ("NTLMSSP", OID_NTLMSSP), + ("Kerberos", OID_KERBEROS), + ] { + let (tag, _, total) = + parse_der_tlv(oid).unwrap_or_else(|e| panic!("{name} OID is not valid DER: {e}")); + assert_eq!(tag, 0x06, "{name} OID tag should be 0x06"); + assert_eq!(total, oid.len(), "{name} OID should be fully consumed"); + } + } + + // ======================================================================= + // Large token handling + // ======================================================================= + + #[test] + fn neg_token_init_with_large_mech_token() { + // Kerberos tokens can be several KB + let large_token = vec![0xab; 4096]; + let wrapped = wrap_neg_token_init(&[OID_KERBEROS], &large_token); + + // Should parse without error + let (tag, _, total) = parse_der_tlv(&wrapped).unwrap(); + assert_eq!(tag, TAG_APPLICATION_0); + assert_eq!(total, wrapped.len()); + + // The large token should be embedded + assert!( + wrapped.windows(100).any(|w| w == &large_token[..100]), + "large token content must be present" + ); + } + + #[test] + fn neg_token_resp_with_large_response_token() { + let large_token = vec![0xcd; 4096]; + let built = build_test_neg_token_resp(Some(1), None, Some(&large_token), None); + let parsed = parse_neg_token_resp(&built).unwrap(); + assert_eq!(parsed.response_token.as_deref(), Some(&large_token[..])); + } +} diff --git a/vendor/smb2/src/client/CLAUDE.md b/vendor/smb2/src/client/CLAUDE.md new file mode 100644 index 0000000..5fea22c --- /dev/null +++ b/vendor/smb2/src/client/CLAUDE.md @@ -0,0 +1,182 @@ +# Client -- high-level SMB2 API + +Entry point for most users. `SmbClient` wraps `Connection` + `Session` and provides convenience methods for file operations. + +## Key files + +| File | Purpose | +|---|---| +| `mod.rs` | `SmbClient`, `ClientConfig`, `connect()` shorthand | +| `connection.rs` | `Connection` -- credit tracking, message sequencing, signing, encryption, `execute` / `execute_compound` | +| `session.rs` | `Session::setup()` -- NTLM auth, key derivation, signing/encryption activation | +| `tree.rs` | `Tree` -- share connection, file CRUD, compound and pipelined I/O | +| `stream.rs` | `FileDownload` / `FileUpload` / `FileWriter` (owns `Connection` + `Arc`, `'static`) / `open_file_writer` -- streaming I/O with progress | +| `watcher.rs` | `Watcher` -- directory change notifications via CHANGE_NOTIFY long-poll | +| `pipeline.rs` | `Pipeline` / `Op` / `OpResult` -- batched concurrent operations (the core feature) | +| `shares.rs` | Share enumeration via IPC$ + srvsvc RPC | +| `dfs.rs` | DFS referral IOCTL helper, `DfsResolver` with TTL-based referral cache | + +## Layering + +``` +SmbClient (owns Connection + Session, stores credentials for reconnect) + Connection (TCP transport, credits, message IDs, signing, encryption) + Session (NTLM auth, key derivation -- setup mutates Connection) + Tree (share-level ops, borrows &mut Connection for each call) + extra_connections (HashMap for DFS cross-server) + dfs_resolver (DfsResolver with TTL-based referral cache) +``` + +All `Tree` methods take `&mut Connection` as a parameter. `SmbClient` convenience methods use `connection_for_tree(tree)` to route through the correct connection (primary or DFS extra connection) based on the tree's `server` field. + +## Connection and credits + +- Connection starts with 1 credit (from negotiate). Requests 256 credits in every message. +- Multi-credit requests (reads/writes > 64 KB) consume `ceil(payload_size / 65536)` credits and use that many consecutive `MessageId` values. Gaps in `MessageId` sequences cause the server to drop the connection. +- Credits flow back from responses via `CreditResponse` header field. The connection tracks available credits and blocks if exhausted. +- `STATUS_PENDING` interim responses carry credits but the request isn't done -- keep waiting. + +## Compound requests + +`Connection::execute_compound(&[CompoundOp])` packs multiple operations into a single transport frame. Each sub-request is 8-byte aligned, linked via `NextCommand`. Subsequent related operations use `FileId::SENTINEL` (the server substitutes the real handle from the first CREATE). + +- **Read compound**: CREATE + READ + CLOSE (3 ops, 1 round-trip). Default for `read_file`. +- **Write compound**: CREATE + WRITE + FLUSH + CLOSE (4 ops, 1 round-trip). Default for `write_file`. +- **Delete compound**: CREATE (DELETE_ON_CLOSE) + CLOSE (2 ops, 1 round-trip). Default for `delete_file` / `delete_directory`. +- **Rename compound**: CREATE + SET_INFO + CLOSE (3 ops, 1 round-trip). Default for `rename`. +- **Stat compound**: CREATE + QUERY_INFO (basic) + QUERY_INFO (standard) + CLOSE (4 ops, 1 round-trip). Default for `stat`. +- **Fs-info compound**: CREATE + QUERY_INFO (FileFsFullSizeInformation) + CLOSE (3 ops, 1 round-trip). Default for `fs_info`. +- If CREATE succeeds but a later op fails, the client issues a standalone CLOSE to avoid leaking the handle. + +### Receiving compound responses + +`execute_compound` returns `Result>>`. The outer `Result` is "did the compound hit the wire"; the inner one is per-sub-op (waiter-level: session expired, signature verify, connection dropped mid-await). Sub-op protocol status codes (`STATUS_OBJECT_NAME_NOT_FOUND` etc.) ride in the inner frame's `header.status`, not the inner `Result`. Per MS-SMB2 3.3.4.1.3 the server MAY split the compound response across multiple transport frames (Samba, QNAP, Windows Server in some cases); the receiver task routes each sub-response by `MessageId` so the per-waiter `oneshot::Receiver`s resolve independently and `execute_compound` reassembles the result vector in submission order. + +Most callers use a small `all_or_first_err` helper (see `tree.rs`) that propagates the first inner `Err` as the outer `Err` (matching the pre-Phase-3 shortcircuit behavior) and hands back a `Vec` indexable per sub-op. Tolerating partial failure (for example, CREATE ok, READ fails → issue standalone CLOSE with the create's returned `FileId`) keeps the individual inner `Result`s. + +## Batch operations + +`delete_files`, `rename_files`, and `stat_files` issue one `execute_compound` per file. Partial failures are independent — if 3 of 50 files fail, the other 47 still succeed. Each method returns `Vec>` in the same order as the input. + +Decision/Why — sequential execute vs parallel: pre-Phase-3 these methods did "phase 1 send all compounds, phase 2 receive all" for wire-level pipelining. With the new API a caller can re-create that shape by spawning `tokio::spawn` tasks over `conn.clone()`s, each calling `execute_compound`. For cmdr's "delete 50 files" flows the sequential-compound cost is small (one round-trip per file) so we chose simplicity. If a workload needs the extra parallelism later, the refactor is local to each batch method. + +## DFS (Distributed File System) resolution + +Reactive DFS resolution with multi-target failover. When a convenience method gets `STATUS_PATH_NOT_COVERED` (mapped to `ErrorKind::DfsReferral`), it: + +1. Calls `handle_dfs_redirect()` which resolves the referral via `DfsResolver` (cache or IOCTL) +2. Tries each target in the referral response (multi-target failover) +3. Creates a new connection + session for cross-server targets via `ensure_connection()` +4. Tree-connects to the target share via `ensure_tree()` +5. Updates the caller's `&mut Tree` in-place to point to the new server/share +6. Retries the operation with the resolved remaining path + +**Key design decisions:** +- Convenience methods take `&mut Tree` (not `&Tree`) so DFS can update the tree in-place +- `disconnect_share` stays as `&Tree` (no redirect on teardown) +- Streaming methods (`download`, `upload`) keep `&Tree` because they return handles that borrow the tree for their lifetime +- `watch` now returns an *owned* `Watcher` (no lifetime); see the [Watcher pipelining](#watcher-pipelining) section +- Batch methods (`delete_files`, `rename_files`, `stat_files`) don't retry per-file; the caller should trigger one single-file operation first to resolve the redirect +- `dfs_enabled` flag on `ClientConfig` (default `true`) gates all DFS resolution +- Borrow checker requires inlining the connection lookup in `handle_dfs_redirect` to avoid double `&mut self` borrows + +## Watcher pipelining + +`Watcher` keeps **one CHANGE_NOTIFY request pre-issued on the wire at all times** after the first `next_events()` call. The wire never sits idle between responses. This closes the response→re-arm loss window that strict servers (older Samba builds, NAS firmware) drop events through. + +Shape: `Watcher` owns a cloned `Connection` (cheap `Arc::clone`, all clones multiplex over the same SMB session) and a `Tree` clone — no lifetime parameter, no borrow against the caller's `Connection`. `next_events` dispatches the next request via `Connection::dispatch` (a sibling to `execute` that returns once `transport.send().await` completes, handing back the `oneshot::Receiver` for the response) *before* awaiting the previous response. So when control returns to the consumer, the server already has somewhere to put new events. + +Decision/Why — eager-send `dispatch` vs `tokio::spawn(conn.execute(...))`: the spawn-based approach defers the send to when the spawned task is polled, which under tokio's `current_thread` scheduler may not happen until the spawning task yields. That left a gap where the simulator-modeled strict server dropped events. `dispatch` awaits transport.send() inline, so the eager-send guarantee is "after `.await` returns, the request is on the wire" — independent of scheduler. + +Pinned by `client::watcher::loss_window_tests::watcher_does_not_lose_events_between_consecutive_requests`: a strict-server simulator drops events that arrive with no outstanding request. Pre-fix: 5/5 gap events dropped. Post-fix: 0/5 dropped. + +## Pipelined I/O + +For large files, `read_file_pipelined` / `write_file_pipelined` issue multiple `execute_with_credits` calls concurrently on cloned connections via `futures_util::stream::FuturesUnordered`. The sliding window stays at 32 in-flight requests, credits are checked per launch via `conn.credits()`. Chunk size is `min(512 KB, max_read_size)`. This is the core performance feature -- without it, throughput is ~10x worse. + +`FileWriter` owns its `Connection` (cheap `Arc::clone`) and `Arc` — no lifetime parameter, no borrow against the `SmbClient` that built it. It keeps an owned `FuturesUnordered` field — `launch_wire_chunk` pushes a boxed `execute_with_credits` future, `drain_one` awaits `in_flight.next()`, and the public `write_chunk` / `finish` / `abort` drive that state machine. + +FileWriter provides push-based pipelined writes. The consumer pushes chunks at their own pace via `write_chunk`, with the sliding window handling backpressure. Complement to FileDownload (read streaming). Build one via `open_file_writer(tree, conn, path)` (free function), `Tree::create_file_writer(&Arc, conn, path)`, or `SmbClient::create_file_writer(&self, tree, path)` — the last clones the client's primary connection internally for convenience. + +## Streaming download entry points + +Two symmetric ways to start a `FileDownload`: + +- `SmbClient::download(&mut self, &Tree, path)` — convenience wrapper that borrows the client's internal `Connection`. +- `Tree::download(&self, &mut Connection, path)` — takes the `Connection` directly. Use this when you hold a + `conn.clone()` and want to drive concurrent downloads on the same SMB session (each clone pairs with one outstanding + download; the receiver task multiplexes responses by `MessageId`). `SmbClient::download` delegates here. + +For full control, `Tree::open_file` (returns `(FileId, u64)`) plus `FileDownload::new` let callers build custom chunk +loops with non-default `chunk_size`. Most users shouldn't need this — `read_file_compound` (1 RTT) handles small files +and `Tree::download` / `SmbClient::download` handle the streaming case. + +FileWriter has two terminal operations: +- `finish()` — send all buffered data, drain in-flight WRITEs, FLUSH (fsync on the server), CLOSE. Use on normal completion. +- `abort()` — discard unsent data, drain in-flight WRITEs to keep credits/message-ids in sync, skip FLUSH, best-effort CLOSE. Use on cancellation or error paths where the partial remote file is going to be deleted anyway — `abort()` saves the fsync round-trip. The caller is responsible for deleting the partial remote file. + +Both consume `self` so write-after-close/abort is a compile error. `Drop` logs a debug warning if neither was called (handle leaks). + +## Session setup flow + +1. Send NTLM NEGOTIATE in SESSION_SETUP +2. Receive STATUS_MORE_PROCESSING_REQUIRED with challenge, update preauth hash +3. Send NTLM AUTHENTICATE in SESSION_SETUP, update preauth hash with request only +4. Receive STATUS_SUCCESS (do NOT include in preauth hash) +5. Derive signing/encryption keys via SP800-108 KDF +6. Activate signing on the connection +7. If session or share requires encryption, activate encryption (TRANSFORM_HEADER wrapping with AEAD) + +## Encryption + +Encryption is activated when the session flags include `ENCRYPT_DATA` or a share has `SMB2_SHAREFLAG_ENCRYPT_DATA`. When active: +- Outgoing messages are wrapped in TRANSFORM_HEADER (protocol ID 0xFD) with a monotonic nonce +- Incoming messages with 0xFD are decrypted before processing +- Signing is skipped (AEAD provides authentication) +- Compound chains are encrypted as one unit (pitfall #9) + +Tree-level encryption: `connect_share()` checks the share's encrypt flag and activates encryption on the connection if needed, even if the session didn't require it. + +## Reconnection + +`SmbClient::reconnect()` creates a fresh TCP connection, re-negotiates, and re-authenticates using stored credentials. All previous `Tree` handles and `FileId` values are invalidated. The caller must `connect_share` again. + +## Connection internals: receiver task + `oneshot` routing + +`Connection::execute` / `execute_compound` is the primary API. A background receiver task (spawned per `Connection` at `from_transport`) owns the transport's read half and routes each sub-frame to a per-request `oneshot::Sender` by `MessageId`. + +- `Connection` is `Clone` and holds just `Arc`. `Inner` owns `waiters: Mutex>>>`, `credits: AtomicU32`, `next_message_id: AtomicU64`, the transport send half (via `Arc`), the receiver task's `JoinHandle`, and crypto state. All state is behind atomics or short-critical-section `std::sync::Mutex`. +- `execute(command, body, tree_id)` allocates a `MessageId` (`AtomicU64::fetch_add(credit_charge)`), registers a `oneshot::Sender` in `waiters` atomically under the waiters lock (re-checks `disconnected` there to rule out a TOCTOU where the receiver task has already shut down and drained the map), packs the frame, signs/encrypts/compresses as needed, and writes through `TransportSend::send`. Then it awaits the local `oneshot::Receiver`. Returns `Result`. +- `execute_compound(&[CompoundOp])` does the same per sub-op, building one compound transport frame with `NextCommand` offsets, then awaits each per-sub-op receiver sequentially. Each receiver resolves independently (the receiver task splits the server's response by `NextCommand` and routes each sub-response by its `MessageId`). The outer `Result` is "did the compound hit the wire"; the inner `Vec>` has one entry per sub-op. +- **Cancellation-by-drop is safe by construction.** If a caller's future is aborted (`tokio::spawn` + `JoinHandle::abort()` is the common path in consumers), the locally-owned `oneshot::Receiver` drops; the receiver task's `Sender::send` then fails silently when the late frame arrives; the frame is discarded. Credits are still applied in the receiver task so dropped-caller frames don't starve throughput. +- **Transport drop** fans `Err(Disconnected)` to every pending `oneshot::Sender` and sets `disconnected=true` under the waiters lock. Subsequent `execute` / `execute_compound` sees `disconnected=true` and returns `Err(Disconnected)` without inserting (no leaked waiters). + +Gotcha/Why — pre-Phase-3 `send_request` / `receive_response` split API was removed in Phase 3 Stage A.3. The test-mode `set_orphan_filter_enabled(false)` escape hatch is gone too; tests that build mocks without going through `setup_connection` call `mock.enable_auto_rewrite_msg_id()` instead, which rewrites each queued response's zero-msg_id to match the next pending sent msg_id in FIFO order. + +Full design in [docs/specs/connection-actor.md](../../docs/specs/connection-actor.md). + +## Key decisions + +- **Owned `FileWriter`: N concurrent streamed writes over one Connection without external locking**: `FileWriter` owns its `Connection` (cheap `Arc::clone`) and `Arc` instead of borrowing `&'a mut Connection` from the `SmbClient`. Built via the free `open_file_writer(tree: Arc, conn: Connection, path: &str)` or one of the two convenience wrappers (`Tree::create_file_writer`, `SmbClient::create_file_writer`). Multiple writers built from clones of the same `Connection` pipeline their WRITEs over one SMB session — the receiver task multiplexes responses by `MessageId`. The borrowed variant was the root cause of a production-reproducing deadlock in the cmdr SMB volume's `write_from_stream` (Phase C QNAP test, 200 × 7 MB concurrent overwrites): the consumer had to hold its session mutex for the entire upload because the writer borrowed `&'a mut Connection`. Owning the connection removes the lock from the hot path entirely. +- **`execute` / `execute_compound` take `&self`**: `Connection: Clone` supports concurrent ops per connection — clone freely across tasks, the receiver task multiplexes responses by `MessageId`. `Tree::*` methods still take `&mut Connection` because session-setup mutators (`activate_signing`, `set_session_id`) keep `&mut self`; Tree code calls both, so `&mut` at that layer is the least-churn choice. +- **Sender work stays on the caller thread, only the receiver is a task**: The send path already uses an internal Mutex on the transport write half for ordering; adding a second task just to drive sends would add latency without correctness gain. The receiver bug (orphan/dropped-caller frames corrupting the wire) only existed on the receive side, so only the receive side needed a task. +- **Compound reads as default**: One round-trip for small files. Saves 2 RTTs vs sequential CREATE/READ/CLOSE. +- **512 KB pipeline chunks**: Balances between too many small requests (overhead) and too few large ones (credit starvation). Gives ~20 chunks per 10 MB file. +- **Password stored in `SmbClient`**: Enables reconnect without re-prompting. Not encrypted in memory. Drop when done. + +## Gotchas + +- **Preauth hash excludes the final success response**: Only STATUS_MORE_PROCESSING_REQUIRED responses are hashed. Including the success response produces wrong keys. (MS-SMB2 3.2.5.3.1) +- **Oplock break notifications arrive with MessageId 0xFFFFFFFFFFFFFFFF**: The receiver task detects these and skips them without invoking a waiter lookup. +- **Register-waiter must be atomic with `disconnected` check**: The waiters lock covers both reading `disconnected` and inserting the `oneshot::Sender`. If the check and insert were racy, a receiver-task failure mid-send could leave an orphan `Sender` in the map that never gets routed — caller would hang on `rx.await` forever. Same goes for `fan_error_to_waiters`: it sets `disconnected=true` UNDER the same waiters lock before draining, so new sends strictly either succeed-and-get-drained or fail at the insert check. +- **Unrecoverable frame errors tear down the connection** (Phase 3 P3.4): decrypt failure, decompress failure, or a malformed sub-frame header that survives `split_compound` all cause the receiver task to call `fan_error_to_waiters(Err(Disconnected))` and exit. The alternative — log-and-continue — would leave the matching waiter hanging forever, because the msg_id isn't recoverable from an unparseable frame. The connection is also out of sync after one bad frame, so reconnect is the right move anyway. Counted via `MetricsSnapshot::{decrypt_failures, decompress_failures, malformed_frames}`. +- **STATUS_PENDING loop**: CHANGE_NOTIFY and other long-poll operations get STATUS_PENDING first. The receiver task keeps the waiter registered on PENDING and does NOT forward the interim response. Credits from PENDING are still applied so the caller's `conn.credits()` reflects them. Counted via `MetricsSnapshot::status_pending_loops`. +- **Signing and encryption are mutually exclusive on the wire**: When encrypting, zero the signature field (AEAD provides integrity). On receive, skip signature verification if decryption succeeded. +- **Compound encryption wraps the entire chain**: One TRANSFORM_HEADER for all sub-requests concatenated, not per sub-request. +- **Share-level encryption**: If a share has `SMB2_SHAREFLAG_ENCRYPT_DATA`, encryption is activated even if the session didn't require it. +- **FileDownload/FileUpload can leak handles on drop**: Rust has no async drop. If not consumed fully, the file handle leaks. The types log a warning. +- **FileWriter can leak handles on drop**: Same as FileDownload/FileUpload. Rust has no async drop. If not consumed via `finish()` or `abort()`, the file handle leaks. The type logs a debug warning. +- **DFS paths must include server\share prefix**: When `SMB2_FLAGS_DFS_OPERATIONS` is set, the server expects the path to start with `server\share\` (MS-SMB2 3.2.4.3). `Tree::format_path()` handles this automatically for DFS shares. Without the prefix, Samba strips the first two path components, leading to wrong file opens. +- **DFS redirect changes the tree in-place**: After a DFS redirect, `tree.server`, `tree.share_name`, and `tree.tree_id` all change. Subsequent operations on the same tree use the target server directly -- they must use target-relative paths, not the original DFS paths. +- **tree.server stores addr:port**: The `server` field on `Tree` stores the full `addr:port` string (not just hostname) so `connection_for_tree` can distinguish servers that share the same hostname but use different ports. +- **Servers MAY split compound responses**: MS-SMB2 section 3.3.4.1.3 says the server SHOULD compound responses but is not required to. Samba (and QNAP firmware built on it) is known to split compound chains into separate frames in some scenarios; Windows Server does too under certain conditions. Compound-using methods (`read_file_compound`, `write_file_compound`, `fs_info`, `stat`, `rename`, `delete_file`, batch `*_files`) call `Connection::receive_compound_expected(n)` instead of `receive_compound()`, which transparently gathers additional frames if the server splits. Logged at DEBUG, not WARN -- it's a spec edge case, not a problem. diff --git a/vendor/smb2/src/client/connection.rs b/vendor/smb2/src/client/connection.rs new file mode 100644 index 0000000..e78f457 --- /dev/null +++ b/vendor/smb2/src/client/connection.rs @@ -0,0 +1,3413 @@ +//! Connection state and message exchange with actor-based routing. +//! +//! The [`Connection`] type manages a single TCP connection to an SMB server. +//! A background receiver task owns the transport's read half, demultiplexes +//! incoming frames by `MessageId`, and routes each response to the matching +//! per-request `oneshot::Sender`. The caller-thread path holds the write +//! half (guarded by its own Mutex via the transport trait) and pushes a +//! per-request `oneshot::Receiver` onto a FIFO that `receive_response` +//! pops from. +//! +//! See `docs/specs/connection-actor.md` for the full design (Phase 2). + +use std::collections::{HashMap, HashSet}; +use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex as StdMutex, OnceLock}; +use std::time::Duration; + +use log::{debug, info, trace, warn}; +use tokio::sync::oneshot; + +use crate::crypto::compression::{compress_message, decompress_message, CompressedMessage}; +use crate::crypto::encryption::{self, Cipher, NonceGenerator}; +use crate::crypto::kdf::PreauthHasher; +use crate::crypto::signing::{self, SigningAlgorithm}; +use crate::error::{Error, Result}; +use crate::msg::header::Header; +use crate::msg::negotiate::{ + NegotiateContext, NegotiateRequest, NegotiateResponse, CIPHER_AES_128_CCM, CIPHER_AES_128_GCM, + CIPHER_AES_256_CCM, CIPHER_AES_256_GCM, COMPRESSION_LZ4, HASH_ALGORITHM_SHA512, + SIGNING_AES_CMAC, SIGNING_AES_GMAC, SIGNING_HMAC_SHA256, +}; +use crate::msg::transform::{ + CompressionTransformHeader, TransformHeader, COMPRESSION_ALGORITHM_LZ4, + COMPRESSION_PROTOCOL_ID, SMB2_COMPRESSION_FLAG_NONE, TRANSFORM_PROTOCOL_ID, +}; +use crate::pack::{Guid, Pack, ReadCursor, Unpack, WriteCursor}; +use crate::transport::{TcpTransport, TransportReceive, TransportSend}; +use crate::types::flags::{Capabilities, HeaderFlags, SecurityMode}; +use crate::types::status::NtStatus; +use crate::types::{Command, CreditCharge, Dialect, MessageId, SessionId, TreeId}; + +/// Parameters established during negotiate. +#[derive(Debug, Clone)] +pub struct NegotiatedParams { + /// The dialect both sides agreed on. + pub dialect: Dialect, + /// Maximum read size the server supports. + pub max_read_size: u32, + /// Maximum write size the server supports. + pub max_write_size: u32, + /// Maximum transact size the server supports. + pub max_transact_size: u32, + /// The server's GUID. + pub server_guid: Guid, + /// Whether the server requires signing. + pub signing_required: bool, + /// Server capabilities. + pub capabilities: Capabilities, + /// Whether AES-GMAC signing was negotiated (SMB 3.1.1). + pub gmac_negotiated: bool, + /// The cipher negotiated for encryption (SMB 3.x). + pub cipher: Option, + /// Whether compression was negotiated with the server. + pub compression_supported: bool, +} + +/// A received SMB2 sub-response, post-decrypt / post-decompress / post-header-parse. +/// +/// This is what `Connection::execute` / `execute_with_credits` return on +/// success (and what each inner `Result` in `execute_compound`'s return +/// vector wraps). The three fields cover every downstream parse need: +/// +/// - `header`: the parsed SMB2 header. Includes `status`, `command`, +/// `message_id`, `credits`, `tree_id`, etc. +/// - `body`: the sub-frame bytes after the header (i.e. +/// `raw[Header::SIZE..]`). Most callers unpack this via `ReadCursor` + +/// `Unpack`. +/// - `raw`: the full sub-frame bytes, header included. Kept for preauth +/// hash updates and any caller that wants to re-verify signatures or +/// inspect the original wire bytes. +/// +/// Callers receive one `Frame` per matched `MessageId`. Frames are owned; +/// the receiver task allocates fresh `Vec`s for `body` / `raw` as it splits +/// compound frames, so you can store or mutate them freely. +#[derive(Debug)] +pub struct Frame { + /// Parsed SMB2 header of this sub-response. + pub header: Header, + /// Sub-frame bytes after the header (body portion only). + pub body: Vec, + /// Full sub-frame bytes including the header. + pub raw: Vec, +} + +/// One sub-operation in a compound request, as passed to +/// [`Connection::execute_compound`]. +/// +/// Each `CompoundOp` describes a single SMB2 operation (CREATE, READ, +/// CLOSE, etc.) that the receiver side pairs with a [`Frame`] response +/// by `MessageId`. The server MAY split compound responses into multiple +/// transport frames — the receiver task handles that transparently; each +/// sub-op still gets routed to its own waiter by msg_id. +/// +/// Field-by-field: +/// +/// - `command`: the SMB2 command code (`Create`, `Read`, `Write`, etc.). +/// - `body`: the packed request body as a `&dyn Pack`. Typical callers +/// pass `&MyRequest { ... }` — the trait object lets one compound +/// chain hold heterogeneous request types. +/// - `tree_id`: the `TreeId` to stamp into the header, or `None` when +/// the op predates tree connect (for example, SESSION_SETUP in a +/// compound). For ordinary file ops, pass `Some(tree.tree_id)`. +/// - `credit_charge`: the number of credits (and consecutive MessageIds) +/// this op consumes. Most ops use `CreditCharge(1)`. Large READ / WRITE +/// ops consume `ceil(payload_size / 65536)` — see the docs on +/// [`execute_with_credits`](Connection::execute_with_credits) for details. +pub struct CompoundOp<'a> { + /// The SMB2 command code. + pub command: Command, + /// The packed request body, as a `&dyn Pack`. + pub body: &'a dyn Pack, + /// `Some(tree_id)` for tree-scoped ops, `None` for connection-level ones. + pub tree_id: Option, + /// Credit charge (and consecutive-MessageId count) for this sub-op. + pub credit_charge: CreditCharge, +} + +impl<'a> CompoundOp<'a> { + /// Build a `CompoundOp` with the default single-credit charge. + /// + /// Equivalent to setting `credit_charge: CreditCharge(1)`. For reads + /// or writes larger than 64 KB, construct the struct directly with + /// the right charge. + pub fn new(command: Command, body: &'a dyn Pack, tree_id: Option) -> Self { + Self { + command, + body, + tree_id, + credit_charge: CreditCharge(1), + } + } +} + +/// Crypto state shared between the caller thread (sending) and receiver task +/// (verifying signatures, decrypting). +/// +/// Uses `std::sync::Mutex` because the critical sections are short and never +/// hold the lock across an `.await`. Mutation is rare (once at session setup), +/// reads happen once per frame on either side. +struct CryptoState { + signing_key: Option>, + signing_algorithm: Option, + should_sign: bool, + encryption_key: Option>, + decryption_key: Option>, + encryption_cipher: Option, + should_encrypt: bool, + nonce_gen: Option, + session_id: SessionId, +} + +impl CryptoState { + fn new() -> Self { + Self { + signing_key: None, + signing_algorithm: None, + should_sign: false, + encryption_key: None, + decryption_key: None, + encryption_cipher: None, + should_encrypt: false, + nonce_gen: None, + session_id: SessionId::NONE, + } + } +} + +/// Shared connection state held in an `Arc` by the caller-facing `Connection` +/// (including all its clones) and the spawned receiver task. +/// +/// Phase 3 Stage A.1 moved all connection-wide state here so `Connection` +/// can be `Clone`: each clone shares the same `Arc` and therefore +/// sees the same credits, session id, negotiated params, and crypto state. +/// Phase 3 Stage A.3 removed the legacy caller-local FIFO and orphan-filter +/// fallback channel; `execute` / `execute_compound` own their per-call +/// `oneshot::Receiver`s locally, so there is no per-clone bookkeeping at +/// all now — `Connection` is just a handle to `Arc`. +struct Inner { + /// Per-request routing: msg_id → oneshot sender waiting for its response. + waiters: StdMutex>>>, + /// Credits available to the caller. Updated by the receiver task on every + /// frame (orphans included), read by the caller thread for pre-send checks. + credits: AtomicU32, + /// Next message id to allocate. Incremented by caller on send. + next_message_id: AtomicU64, + /// Crypto state for signing / encryption. + crypto: StdMutex, + /// Set to true when the receiver task exits (transport error / EOF). + /// New `execute` / `execute_compound` calls short-circuit to + /// `Err(Disconnected)` once this flips so they don't register waiters + /// into a dead map. + disconnected: AtomicBool, + + /// Shared transport send handle. `TransportSend::send` takes `&self` so + /// this can be called from any clone without a wrapping mutex — the + /// transport's implementation already serializes writes internally. + sender: Arc, + /// Handle for the background receiver task. Aborted when the last clone + /// of `Connection` drops (via `Inner`'s `Drop`). The transport's read + /// half's EOF also stops the task; the abort is a safety net. + receiver_task: StdMutex>>, + + /// Server name (hostname or IP) used for UNC paths. Set at construction + /// and never mutated. + server_name: String, + /// Negotiated parameters, populated once by `negotiate`. + params: OnceLock, + /// Estimated round-trip time measured during negotiate. + estimated_rtt: StdMutex>, + /// Whether compression is active on this connection (negotiated). + compression_enabled: AtomicBool, + /// Whether the client wants compression (from config). + compression_requested: AtomicBool, + /// Preauth integrity hash (for SMB 3.1.1 key derivation). Mutated during + /// negotiate and session setup; both happen on one task before any clone + /// is expected to observe it. + preauth_hasher: StdMutex, + /// Tree IDs that have DFS capability (auto-set `SMB2_FLAGS_DFS_OPERATIONS`). + dfs_trees: StdMutex>, + /// Counters for diagnostics. Snapshotted via [`Inner::metrics_snapshot`]. + /// Survives connection teardown — counters are read off the still-alive + /// `Arc` after the receiver task has exited. + metrics: Metrics, +} + +impl Inner { + fn new(sender: Arc, server_name: String) -> Self { + Self { + waiters: StdMutex::new(HashMap::new()), + credits: AtomicU32::new(1), + next_message_id: AtomicU64::new(0), + crypto: StdMutex::new(CryptoState::new()), + disconnected: AtomicBool::new(false), + sender, + receiver_task: StdMutex::new(None), + server_name, + params: OnceLock::new(), + estimated_rtt: StdMutex::new(None), + compression_enabled: AtomicBool::new(false), + compression_requested: AtomicBool::new(true), + preauth_hasher: StdMutex::new(PreauthHasher::new()), + dfs_trees: StdMutex::new(HashSet::new()), + metrics: Metrics::default(), + } + } + + /// Send raw wire bytes through the transport and bump the + /// `wire_bytes_sent` counter. The single funnel for every outbound + /// frame — keeps `wire_bytes_sent` from drifting as new send sites + /// are added. + async fn send_and_count(&self, bytes: &[u8]) -> Result<()> { + self.metrics + .wire_bytes_sent + .fetch_add(bytes.len() as u64, Ordering::Relaxed); + self.sender.send(bytes).await + } + + /// Snapshot the counters into a plain-value `MetricsSnapshot`. + /// + /// M2 promotes this to `Connection::diagnostics()`'s caller — until + /// then it's crate-internal so M1 tests can assert counter ticks + /// without committing to the public snapshot API shape. + pub(crate) fn metrics_snapshot(&self) -> crate::client::diagnostics::MetricsSnapshot { + self.metrics.snapshot() + } +} + +/// Per-connection counters, all `AtomicU64`, all `Relaxed` reads/writes. +/// +/// Lives on [`Inner`] and outlives the receiver task; a snapshot taken +/// after the connection has torn down returns the final values at the +/// moment of death. +/// +/// See `docs/specs/diagnostics-plan.md` § Counters for the rationale +/// behind each field and the disjoint partition of the receive-side +/// routing branches. +#[derive(Default)] +pub(crate) struct Metrics { + // Send path + pub requests_sent: AtomicU64, + pub compound_requests_sent: AtomicU64, + pub wire_bytes_sent: AtomicU64, + pub explicit_cancels_sent: AtomicU64, + + // Receive path: four disjoint routing outcomes + pub responses_routed_ok: AtomicU64, + pub responses_routed_err: AtomicU64, + pub responses_late_after_drop: AtomicU64, + pub responses_stray: AtomicU64, + pub wire_bytes_received: AtomicU64, + + // Protocol events + pub status_pending_loops: AtomicU64, + pub unsolicited_notifications_received: AtomicU64, + pub signature_failures: AtomicU64, + pub decrypt_failures: AtomicU64, + pub decompress_failures: AtomicU64, + pub malformed_frames: AtomicU64, + pub session_expired_events: AtomicU64, + + // Caller-observed outcomes + pub requests_returned_err: AtomicU64, +} + +impl Metrics { + fn snapshot(&self) -> crate::client::diagnostics::MetricsSnapshot { + use std::sync::atomic::Ordering::Relaxed; + crate::client::diagnostics::MetricsSnapshot { + requests_sent: self.requests_sent.load(Relaxed), + compound_requests_sent: self.compound_requests_sent.load(Relaxed), + wire_bytes_sent: self.wire_bytes_sent.load(Relaxed), + explicit_cancels_sent: self.explicit_cancels_sent.load(Relaxed), + responses_routed_ok: self.responses_routed_ok.load(Relaxed), + responses_routed_err: self.responses_routed_err.load(Relaxed), + responses_late_after_drop: self.responses_late_after_drop.load(Relaxed), + responses_stray: self.responses_stray.load(Relaxed), + wire_bytes_received: self.wire_bytes_received.load(Relaxed), + status_pending_loops: self.status_pending_loops.load(Relaxed), + unsolicited_notifications_received: self + .unsolicited_notifications_received + .load(Relaxed), + signature_failures: self.signature_failures.load(Relaxed), + decrypt_failures: self.decrypt_failures.load(Relaxed), + decompress_failures: self.decompress_failures.load(Relaxed), + malformed_frames: self.malformed_frames.load(Relaxed), + session_expired_events: self.session_expired_events.load(Relaxed), + requests_returned_err: self.requests_returned_err.load(Relaxed), + } + } +} + +impl Drop for Inner { + fn drop(&mut self) { + // Last `Arc` dropping: abort the receiver task if still alive. + if let Some(handle) = self.receiver_task.lock().unwrap().take() { + handle.abort(); + } + } +} + +/// Low-level connection with actor-based response routing. +/// +/// Manages credit tracking, message ID sequencing, preauth integrity hash, +/// message signing, and encryption. A background receiver task owns the +/// transport's read half and routes each incoming frame to the +/// `oneshot::Sender` registered for its `MessageId`. Callers go through +/// [`execute`](Self::execute) / [`execute_compound`](Self::execute_compound) +/// which register the waiter, send the frame, and await the matching +/// `oneshot::Receiver` — all owned locally by the future, so dropping the +/// future mid-flight is safe (the late arrival is discarded on the receiver +/// task when the `Sender` fails to deliver). +/// +/// `Connection` is `Clone`; cloning is a cheap `Arc::clone` bump. All clones +/// share the same receiver task, credits, and waiters map, so concurrent +/// `execute` calls on different clones multiplex over the same SMB session. +#[derive(Clone)] +pub struct Connection { + /// Shared state (credits, waiters, crypto, transport sender, negotiated + /// params, receiver task) behind `Arc`. `clone()` bumps this. + inner: Arc, +} + +impl Connection { + /// Create a connection from an existing transport (for testing with mock). + pub fn from_transport( + sender: Box, + receiver: Box, + server_name: impl Into, + ) -> Self { + let sender: Arc = Arc::from(sender); + let inner = Arc::new(Inner::new(sender, server_name.into())); + let inner_for_task = Arc::clone(&inner); + let handle = tokio::spawn(async move { + receiver_loop(receiver, inner_for_task).await; + }); + *inner.receiver_task.lock().unwrap() = Some(handle); + Self { inner } + } + + /// Connect to an SMB server over TCP. + pub async fn connect(addr: &str, timeout: Duration) -> Result { + let server_name = addr.split(':').next().unwrap_or(addr).to_string(); + let transport = TcpTransport::connect(addr, timeout).await?; + info!("connection: connected to {}", addr); + let transport = Arc::new(transport); + Ok(Self::from_transport( + Box::new(Arc::clone(&transport)), + Box::new(transport), + server_name, + )) + } + + /// Perform the SMB2 NEGOTIATE exchange. + pub async fn negotiate(&mut self) -> Result<()> { + debug!("negotiate: sending request, dialects={:?}", Dialect::ALL); + let client_guid = generate_guid(); + + let mut negotiate_contexts = vec![ + NegotiateContext::PreauthIntegrity { + hash_algorithms: vec![HASH_ALGORITHM_SHA512], + salt: generate_salt(), + }, + NegotiateContext::Encryption { + ciphers: vec![ + CIPHER_AES_128_GCM, + CIPHER_AES_128_CCM, + CIPHER_AES_256_GCM, + CIPHER_AES_256_CCM, + ], + }, + NegotiateContext::Signing { + algorithms: vec![SIGNING_AES_GMAC, SIGNING_AES_CMAC, SIGNING_HMAC_SHA256], + }, + ]; + + if self.inner.compression_requested.load(Ordering::Acquire) { + negotiate_contexts.push(NegotiateContext::Compression { + flags: 0, + algorithms: vec![COMPRESSION_LZ4], + }); + } + + let request = NegotiateRequest { + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + capabilities: Capabilities::new( + Capabilities::DFS | Capabilities::LEASING | Capabilities::LARGE_MTU, + ), + client_guid, + dialects: Dialect::ALL.to_vec(), + negotiate_contexts, + }; + + // Register a waiter for msg_id=0 (negotiate is always first). + let mut header = Header::new_request(Command::Negotiate); + let msg_id = self.allocate_msg_id(1); + header.message_id = msg_id; + header.credits = 1; + let req_bytes = pack_message(&header, &request); + + // Update preauth hash with request bytes. + self.inner.preauth_hasher.lock().unwrap().update(&req_bytes); + + let rx = self.register_waiter(msg_id)?; + + let rtt_start = std::time::Instant::now(); + if let Err(e) = self.inner.send_and_count(&req_bytes).await { + self.remove_waiter(msg_id); + return Err(e); + } + + let frame = await_frame(rx).await?; + *self.inner.estimated_rtt.lock().unwrap() = Some(rtt_start.elapsed()); + + // Preauth hash update with response bytes. + self.inner.preauth_hasher.lock().unwrap().update(&frame.raw); + + let resp_header = &frame.header; + if !resp_header.is_response() { + return Err(Error::invalid_data("expected a response but got a request")); + } + if resp_header.command != Command::Negotiate { + return Err(Error::invalid_data(format!( + "expected Negotiate response, got {:?}", + resp_header.command + ))); + } + + if resp_header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: resp_header.status, + command: Command::Negotiate, + }); + } + + // Parse the body. + let mut cursor = ReadCursor::new(&frame.body); + let resp = NegotiateResponse::unpack(&mut cursor)?; + + if !Dialect::ALL.contains(&resp.dialect_revision) { + return Err(Error::invalid_data(format!( + "server selected dialect 0x{:04X} which we did not offer", + u16::from(resp.dialect_revision) + ))); + } + if resp.max_read_size < 65536 { + return Err(Error::invalid_data(format!( + "MaxReadSize {} is below minimum 65536", + resp.max_read_size + ))); + } + if resp.max_write_size < 65536 { + return Err(Error::invalid_data(format!( + "MaxWriteSize {} is below minimum 65536", + resp.max_write_size + ))); + } + + let mut gmac_negotiated = false; + let mut cipher = None; + let mut compression_supported = false; + + for ctx in &resp.negotiate_contexts { + match ctx { + NegotiateContext::Signing { algorithms } + if algorithms.contains(&SIGNING_AES_GMAC) => + { + gmac_negotiated = true; + } + NegotiateContext::Encryption { ciphers } => { + if let Some(&c) = ciphers.first() { + cipher = match c { + CIPHER_AES_128_CCM => Some(Cipher::Aes128Ccm), + CIPHER_AES_128_GCM => Some(Cipher::Aes128Gcm), + CIPHER_AES_256_CCM => Some(Cipher::Aes256Ccm), + CIPHER_AES_256_GCM => Some(Cipher::Aes256Gcm), + _ => None, + }; + } + } + NegotiateContext::Compression { algorithms, .. } + if algorithms.contains(&COMPRESSION_LZ4) => + { + compression_supported = true; + } + _ => {} + } + } + + let signing_required = resp.security_mode.signing_required(); + let compression_enabled = + self.inner.compression_requested.load(Ordering::Acquire) && compression_supported; + self.inner + .compression_enabled + .store(compression_enabled, Ordering::Release); + + // OnceLock: set is idempotent-first-writer-wins. Re-negotiation isn't + // a supported flow; if this ever fails it means negotiate was called + // twice on the same connection. + let _ = self.inner.params.set(NegotiatedParams { + dialect: resp.dialect_revision, + max_read_size: resp.max_read_size, + max_write_size: resp.max_write_size, + max_transact_size: resp.max_transact_size, + server_guid: resp.server_guid, + signing_required, + capabilities: resp.capabilities, + gmac_negotiated, + cipher, + compression_supported, + }); + + info!( + "negotiate: dialect={}, signing_required={}, capabilities={:?}", + resp.dialect_revision, signing_required, resp.capabilities + ); + debug!( + "negotiate: max_read={}, max_write={}, max_transact={}, server_guid={:?}, cipher={:?}, gmac={}, compression={}", + resp.max_read_size, resp.max_write_size, resp.max_transact_size, + resp.server_guid, cipher, gmac_negotiated, compression_enabled + ); + + Ok(()) + } + + /// Get the estimated round-trip time. + pub fn estimated_rtt(&self) -> Option { + *self.inner.estimated_rtt.lock().unwrap() + } + + /// Get the negotiated parameters. + pub fn params(&self) -> Option<&NegotiatedParams> { + self.inner.params.get() + } + + /// Get a clone of the preauth hasher's current state. + /// + /// The hasher lives behind a lock (shared across `Connection` clones + /// now that the type is `Clone`). Callers that want to derive per-session + /// keys — see `session.rs` — take a snapshot via this method and feed + /// their own session-specific updates into it without disturbing the + /// shared connection-level hasher. Returning an owned clone is ~a few + /// hundred bytes of SHA-512 state; cheaper than the actual KDF it feeds. + pub fn preauth_hasher(&self) -> PreauthHasher { + self.inner.preauth_hasher.lock().unwrap().clone() + } + + /// Run a closure with a mutable borrow of the preauth hasher. + /// + /// The hasher lives behind a lock now that `Connection` is `Clone`; a + /// naked `&mut PreauthHasher` can no longer be handed out. Closure-based + /// access keeps the lock scoped to the caller's update. + #[doc(hidden)] // unused outside the crate; kept for crate-internal parity. + pub fn with_preauth_hasher_mut(&self, f: impl FnOnce(&mut PreauthHasher) -> R) -> R { + let mut h = self.inner.preauth_hasher.lock().unwrap(); + f(&mut h) + } + + /// Set the session ID. + pub fn set_session_id(&mut self, id: SessionId) { + self.inner.crypto.lock().unwrap().session_id = id; + } + + /// Get the current session ID. + pub fn session_id(&self) -> SessionId { + self.inner.crypto.lock().unwrap().session_id + } + + /// Activate signing with the given key and algorithm. + pub fn activate_signing(&mut self, key: Vec, algorithm: SigningAlgorithm) { + debug!( + "signing: activated, algo={:?}, key_len={}", + algorithm, + key.len() + ); + let mut c = self.inner.crypto.lock().unwrap(); + c.signing_key = Some(key); + c.signing_algorithm = Some(algorithm); + c.should_sign = true; + } + + /// Activate encryption with the given keys and cipher. + pub fn activate_encryption(&mut self, enc_key: Vec, dec_key: Vec, cipher: Cipher) { + debug!( + "encryption: activated, cipher={:?}, enc_key_len={}, dec_key_len={}", + cipher, + enc_key.len(), + dec_key.len() + ); + let mut c = self.inner.crypto.lock().unwrap(); + c.encryption_key = Some(enc_key); + c.decryption_key = Some(dec_key); + c.encryption_cipher = Some(cipher); + c.nonce_gen = Some(NonceGenerator::new()); + c.should_encrypt = true; + } + + /// Whether encryption is active on this connection. + pub fn should_encrypt(&self) -> bool { + self.inner.crypto.lock().unwrap().should_encrypt + } + + /// Get the current number of available credits. + pub fn credits(&self) -> u16 { + self.inner.credits.load(Ordering::Acquire) as u16 + } + + /// The `MessageId` that will be assigned to the next request. + /// + /// Starts at 0 on a fresh connection and increments by `credit_charge` + /// per allocation. Pre-first-send this is `0`; after a single + /// single-credit `execute` it's `1`. + pub fn next_message_id(&self) -> u64 { + self.inner.next_message_id.load(Ordering::Acquire) + } + + /// Get the server name. + pub fn server_name(&self) -> &str { + &self.inner.server_name + } + + /// Set whether the client wants compression. + pub fn set_compression_requested(&mut self, requested: bool) { + self.inner + .compression_requested + .store(requested, Ordering::Release); + } + + /// Whether compression is active on this connection. + pub fn compression_enabled(&self) -> bool { + self.inner.compression_enabled.load(Ordering::Acquire) + } + + /// Send a single SMB2 request and wait for its response. + /// + /// Takes `&self` so multiple clones of a `Connection` can call `execute` + /// concurrently from different tasks — the receiver task routes each + /// response to its own `oneshot::Sender` by `MessageId`. Cancellation + /// by drop is safe by construction: if the caller's future is dropped + /// before the response arrives, the `oneshot::Receiver` drops, and + /// the receiver task discards the late frame silently on arrival + /// (credits still apply). + /// + /// Equivalent to `execute_with_credits(command, body, tree_id, CreditCharge(1))`. + /// For large READ / WRITE ops (> 64 KB payload), use `execute_with_credits` + /// with a charge of `ceil(payload_size / 65536)` — each credit consumed + /// also consumes one consecutive `MessageId`, and gaps in the id + /// sequence cause the server to drop the connection. + pub async fn execute( + &self, + command: Command, + body: &dyn Pack, + tree_id: Option, + ) -> Result { + self.execute_with_credits(command, body, tree_id, CreditCharge(1)) + .await + } + + /// Crate-internal variant of [`execute`] that also returns the plaintext + /// request bytes that were packed on the wire (before any encryption). + /// + /// Only `session.rs` needs this: its SESSION_SETUP rounds feed the + /// *request* bytes into the session-local preauth hasher for key + /// derivation, and the signed/encrypted wire form would break the + /// hash because preauth covers the plaintext. Rather than forcing + /// session.rs to re-pack messages with a predicted msg_id, we let + /// `execute_with_credits_capturing_request` hand them back. + pub(crate) async fn execute_capturing_request( + &self, + command: Command, + body: &dyn Pack, + tree_id: Option, + ) -> Result<(Frame, Vec)> { + self.execute_with_credits_capturing_request(command, body, tree_id, CreditCharge(1)) + .await + } + + /// See [`Self::execute_capturing_request`]. + pub(crate) async fn execute_with_credits_capturing_request( + &self, + command: Command, + body: &dyn Pack, + tree_id: Option, + credit_charge: CreditCharge, + ) -> Result<(Frame, Vec)> { + let result = self + .execute_with_credits_capturing_request_inner(command, body, tree_id, credit_charge) + .await; + if result.is_err() { + self.inner + .metrics + .requests_returned_err + .fetch_add(1, Ordering::Relaxed); + } + result + } + + async fn execute_with_credits_capturing_request_inner( + &self, + command: Command, + body: &dyn Pack, + tree_id: Option, + credit_charge: CreditCharge, + ) -> Result<(Frame, Vec)> { + if self.inner.disconnected.load(Ordering::Acquire) { + return Err(Error::Disconnected); + } + let charge = credit_charge.0.max(1); + let msg_id = self.allocate_msg_id(charge as u64); + + let mut header = Header::new_request(command); + header.message_id = msg_id; + header.credits = 256; + header.credit_charge = CreditCharge(charge); + header.session_id = self.session_id(); + if let Some(tid) = tree_id { + header.tree_id = Some(tid); + } + + let (should_sign, should_encrypt) = { + let c = self.inner.crypto.lock().unwrap(); + (c.should_sign, c.should_encrypt) + }; + + if should_sign && !should_encrypt { + header.flags.set_signed(); + } + if self.should_set_dfs_flag(tree_id) { + header.flags |= HeaderFlags::new(HeaderFlags::DFS_OPERATIONS); + } + + let mut msg_bytes = pack_message(&header, body); + let captured = msg_bytes.clone(); + + let rx = self.register_waiter(msg_id)?; + + let wire_bytes = if should_encrypt { + match self.encrypt_bytes(&msg_bytes) { + Ok(enc) => enc, + Err(e) => { + self.remove_waiter(msg_id); + return Err(e); + } + } + } else { + if should_sign { + let c = self.inner.crypto.lock().unwrap(); + if let (Some(key), Some(algo)) = (&c.signing_key, &c.signing_algorithm) { + if let Err(e) = + signing::sign_message(&mut msg_bytes, key, *algo, msg_id.0, false) + { + drop(c); + self.remove_waiter(msg_id); + return Err(e); + } + } + } + msg_bytes + }; + + if let Err(e) = self.inner.send_and_count(&wire_bytes).await { + self.remove_waiter(msg_id); + return Err(e); + } + debug!( + "execute_cap: cmd={:?}, msg_id={}, credit_charge={}, tree_id={:?}, signed={}, encrypted={}", + command, msg_id.0, charge, tree_id, should_sign, should_encrypt + ); + let frame = await_frame(rx).await?; + Ok((frame, captured)) + } + + /// Send a single SMB2 request with a caller-specified credit charge. + /// + /// Same semantics as [`execute`](Self::execute) — see that method's doc + /// for the concurrency / cancellation invariants — but lets the caller + /// set `credit_charge` directly. Use `CreditCharge(ceil(payload_size / + /// 65536))` for READ / WRITE ops larger than 64 KB. + /// + /// On the wire this is the same as `send_request_with_credits` + + /// `receive_response` — the difference is that this method owns its + /// `oneshot::Receiver` locally (not in a caller-shared FIFO), so + /// it's safe to call from multiple tasks on clones of the same + /// `Connection`. + pub async fn execute_with_credits( + &self, + command: Command, + body: &dyn Pack, + tree_id: Option, + credit_charge: CreditCharge, + ) -> Result { + let result = self + .execute_with_credits_inner(command, body, tree_id, credit_charge) + .await; + if result.is_err() { + self.inner + .metrics + .requests_returned_err + .fetch_add(1, Ordering::Relaxed); + } + result + } + + async fn execute_with_credits_inner( + &self, + command: Command, + body: &dyn Pack, + tree_id: Option, + credit_charge: CreditCharge, + ) -> Result { + if self.inner.disconnected.load(Ordering::Acquire) { + return Err(Error::Disconnected); + } + let charge = credit_charge.0.max(1); + let msg_id = self.allocate_msg_id(charge as u64); + + let mut header = Header::new_request(command); + header.message_id = msg_id; + header.credits = 256; + header.credit_charge = CreditCharge(charge); + header.session_id = self.session_id(); + if let Some(tid) = tree_id { + header.tree_id = Some(tid); + } + + let (should_sign, should_encrypt) = { + let c = self.inner.crypto.lock().unwrap(); + (c.should_sign, c.should_encrypt) + }; + + if should_sign && !should_encrypt { + header.flags.set_signed(); + } + if self.should_set_dfs_flag(tree_id) { + header.flags |= HeaderFlags::new(HeaderFlags::DFS_OPERATIONS); + } + + let mut msg_bytes = pack_message(&header, body); + + // Register waiter BEFORE send so the receiver task can match any + // fast-arriving response. `register_waiter` atomically rechecks + // `disconnected` under the waiters lock, so a receiver-task + // teardown between the early fast-path check above and this + // insertion returns `Err(Disconnected)` instead of leaving a + // ghost Sender that never gets routed. + let rx = self.register_waiter(msg_id)?; + + // Build the wire bytes with encryption / signing / compression. + let wire_bytes = if should_encrypt { + match self.encrypt_bytes(&msg_bytes) { + Ok(enc) => enc, + Err(e) => { + self.remove_waiter(msg_id); + return Err(e); + } + } + } else { + if should_sign { + let c = self.inner.crypto.lock().unwrap(); + if let (Some(key), Some(algo)) = (&c.signing_key, &c.signing_algorithm) { + if let Err(e) = + signing::sign_message(&mut msg_bytes, key, *algo, msg_id.0, false) + { + drop(c); + self.remove_waiter(msg_id); + return Err(e); + } + } + } + if self.compression_enabled() && msg_bytes.len() > Header::SIZE { + if let Some(compressed) = compress_message(&msg_bytes, Header::SIZE) { + let framed = build_compressed_frame(&compressed); + match self.inner.send_and_count(&framed).await { + Ok(()) => { + debug!( + "execute: cmd={:?}, msg_id={}, credit_charge={}, tree_id={:?}, signed={}, compressed {}->{} bytes", + command, msg_id.0, charge, tree_id, should_sign, + msg_bytes.len(), framed.len() + ); + return await_frame(rx).await; + } + Err(e) => { + self.remove_waiter(msg_id); + return Err(e); + } + } + } + } + msg_bytes + }; + + if let Err(e) = self.inner.send_and_count(&wire_bytes).await { + self.remove_waiter(msg_id); + return Err(e); + } + debug!( + "execute: cmd={:?}, msg_id={}, credit_charge={}, tree_id={:?}, signed={}, encrypted={}, len={}", + command, msg_id.0, charge, tree_id, should_sign, should_encrypt, wire_bytes.len() + ); + await_frame(rx).await + } + + /// Send a request and return its response receiver without awaiting it. + /// + /// Same wire-level work as [`execute`](Self::execute) — allocate + /// `MessageId`, register waiter, sign / encrypt, send bytes — but + /// stops as soon as `transport.send().await` returns and hands back + /// the `oneshot::Receiver` for the response. Use this for pipelining: + /// dispatch the next request before awaiting the previous response, + /// keeping the wire continuously armed. + /// + /// **Eager-send guarantee**: when this future resolves to `Ok(rx)`, + /// the request bytes have been handed to the transport. The caller + /// can rely on "after this `.await` completes, the request is on + /// the wire." + /// + /// The returned `Receiver` follows the same drop-safety contract as + /// [`execute`]'s internal one: dropping it without awaiting causes + /// the receiver task to discard the late response silently when it + /// arrives (credits still apply). + /// + /// Currently used by [`Watcher`](crate::Watcher) to pre-issue the + /// next CHANGE_NOTIFY before awaiting the current one. Other call + /// sites should prefer [`execute`] / [`execute_with_credits`] unless + /// they specifically need the pipelining shape. + pub(crate) async fn dispatch( + &self, + command: Command, + body: &dyn Pack, + tree_id: Option, + ) -> Result>> { + self.dispatch_with_credits(command, body, tree_id, CreditCharge(1)) + .await + } + + /// Variant of [`dispatch`](Self::dispatch) with a caller-specified credit charge. + pub(crate) async fn dispatch_with_credits( + &self, + command: Command, + body: &dyn Pack, + tree_id: Option, + credit_charge: CreditCharge, + ) -> Result>> { + if self.inner.disconnected.load(Ordering::Acquire) { + return Err(Error::Disconnected); + } + let charge = credit_charge.0.max(1); + let msg_id = self.allocate_msg_id(charge as u64); + + let mut header = Header::new_request(command); + header.message_id = msg_id; + header.credits = 256; + header.credit_charge = CreditCharge(charge); + header.session_id = self.session_id(); + if let Some(tid) = tree_id { + header.tree_id = Some(tid); + } + + let (should_sign, should_encrypt) = { + let c = self.inner.crypto.lock().unwrap(); + (c.should_sign, c.should_encrypt) + }; + + if should_sign && !should_encrypt { + header.flags.set_signed(); + } + if self.should_set_dfs_flag(tree_id) { + header.flags |= HeaderFlags::new(HeaderFlags::DFS_OPERATIONS); + } + + let mut msg_bytes = pack_message(&header, body); + + let rx = self.register_waiter(msg_id)?; + + let wire_bytes = if should_encrypt { + match self.encrypt_bytes(&msg_bytes) { + Ok(enc) => enc, + Err(e) => { + self.remove_waiter(msg_id); + return Err(e); + } + } + } else { + if should_sign { + let c = self.inner.crypto.lock().unwrap(); + if let (Some(key), Some(algo)) = (&c.signing_key, &c.signing_algorithm) { + if let Err(e) = + signing::sign_message(&mut msg_bytes, key, *algo, msg_id.0, false) + { + drop(c); + self.remove_waiter(msg_id); + return Err(e); + } + } + } + if self.compression_enabled() && msg_bytes.len() > Header::SIZE { + if let Some(compressed) = compress_message(&msg_bytes, Header::SIZE) { + let framed = build_compressed_frame(&compressed); + match self.inner.send_and_count(&framed).await { + Ok(()) => { + debug!( + "dispatch: cmd={:?}, msg_id={}, credit_charge={}, tree_id={:?}, signed={}, compressed {}->{} bytes", + command, msg_id.0, charge, tree_id, should_sign, + msg_bytes.len(), framed.len() + ); + return Ok(rx); + } + Err(e) => { + self.remove_waiter(msg_id); + return Err(e); + } + } + } + } + msg_bytes + }; + + if let Err(e) = self.inner.send_and_count(&wire_bytes).await { + self.remove_waiter(msg_id); + return Err(e); + } + debug!( + "dispatch: cmd={:?}, msg_id={}, credit_charge={}, tree_id={:?}, signed={}, encrypted={}, len={}", + command, msg_id.0, charge, tree_id, should_sign, should_encrypt, wire_bytes.len() + ); + Ok(rx) + } + + /// Send a compound SMB2 request (multiple operations in one transport + /// frame) and return the per-sub-op responses. + /// + /// Takes `&self`. Each [`CompoundOp`] is assigned its own `MessageId` + /// and its own `oneshot::Sender` registered in the waiters map. The + /// server MAY split the compound response into multiple transport + /// frames (MS-SMB2 § 3.3.4.1.3) — the receiver task's per-`MessageId` + /// routing handles that transparently; each sub-op's waiter resolves + /// independently. + /// + /// Return shape (per decision E3 in `docs/specs/connection-actor.md`): + /// + /// - Outer `Result`: `Err` if the compound didn't make it onto the wire + /// (encryption failed, signing failed, transport send failed, or the + /// connection was already disconnected). On this path no waiter + /// observes a response — we clean them up before returning. + /// - Inner `Vec>`: one entry per sub-op, in the same + /// order as `ops`. `Ok(frame)` with the server's response, including + /// non-success statuses encoded in `frame.header.status`. `Err` when + /// a sub-op hit a waiter-level error (session expired, signature + /// verify failure, connection dropped mid-await). Compound partial + /// failure is protocol-normal — for example, CREATE may succeed but + /// a later READ fail — so callers typically match on each inner + /// result individually. + pub async fn execute_compound(&self, ops: &[CompoundOp<'_>]) -> Result>> { + self.inner + .metrics + .compound_requests_sent + .fetch_add(1, Ordering::Relaxed); + let result = self.execute_compound_inner(ops).await; + if result.is_err() { + self.inner + .metrics + .requests_returned_err + .fetch_add(1, Ordering::Relaxed); + } + result + } + + async fn execute_compound_inner(&self, ops: &[CompoundOp<'_>]) -> Result>> { + if ops.is_empty() { + return Err(Error::invalid_data( + "compound request must have at least one operation", + )); + } + if self.inner.disconnected.load(Ordering::Acquire) { + return Err(Error::Disconnected); + } + + let (should_sign, should_encrypt) = { + let c = self.inner.crypto.lock().unwrap(); + (c.should_sign, c.should_encrypt) + }; + + let session_id = self.session_id(); + let mut message_ids: Vec = Vec::with_capacity(ops.len()); + let mut sub_requests: Vec> = Vec::with_capacity(ops.len()); + + for (i, op) in ops.iter().enumerate() { + let charge = op.credit_charge.0.max(1); + let msg_id = self.allocate_msg_id(charge as u64); + + let mut header = Header::new_request(op.command); + header.message_id = msg_id; + header.credits = 256; + header.credit_charge = CreditCharge(charge); + header.session_id = session_id; + header.tree_id = op.tree_id; + + if i > 0 { + header.flags.set_related(); + } + if should_sign && !should_encrypt { + header.flags.set_signed(); + } + if self.should_set_dfs_flag(op.tree_id) { + header.flags |= HeaderFlags::new(HeaderFlags::DFS_OPERATIONS); + } + + message_ids.push(msg_id); + sub_requests.push(pack_message(&header, op.body)); + } + + // 8-byte align all but the last sub-request, then wire up + // `NextCommand` offsets. + let last_idx = sub_requests.len() - 1; + for sub_req in sub_requests.iter_mut().take(last_idx) { + let rem = sub_req.len() % 8; + if rem != 0 { + let pad = 8 - rem; + let new_len = sub_req.len() + pad; + sub_req.resize(new_len, 0); + } + } + for sub_req in sub_requests.iter_mut().take(last_idx) { + let next_cmd = sub_req.len() as u32; + sub_req[20..24].copy_from_slice(&next_cmd.to_le_bytes()); + } + + if should_sign && !should_encrypt { + let c = self.inner.crypto.lock().unwrap(); + if let (Some(key), Some(algo)) = (&c.signing_key, &c.signing_algorithm) { + for (i, sub_req) in sub_requests.iter_mut().enumerate() { + signing::sign_message(sub_req, key, *algo, message_ids[i].0, false)?; + } + } + } + + // Register one oneshot::Receiver per sub-op BEFORE the send, + // collected in the same order as `ops` / `message_ids`. On any + // registration error, unregister the ones we already inserted. + let mut receivers: Vec>> = + Vec::with_capacity(message_ids.len()); + let mut registered: Vec = Vec::with_capacity(message_ids.len()); + for id in &message_ids { + match self.register_waiter(*id) { + Ok(rx) => { + receivers.push(rx); + registered.push(*id); + } + Err(e) => { + for done in ®istered { + self.remove_waiter(*done); + } + return Err(e); + } + } + } + + let total_len: usize = sub_requests.iter().map(|r| r.len()).sum(); + let mut compound_buf = Vec::with_capacity(total_len); + for sub_req in &sub_requests { + compound_buf.extend_from_slice(sub_req); + } + + let send_result = if should_encrypt { + match self.encrypt_bytes(&compound_buf) { + Ok(enc) => self.inner.send_and_count(&enc).await, + Err(e) => { + for id in ®istered { + self.remove_waiter(*id); + } + return Err(e); + } + } + } else { + self.inner.send_and_count(&compound_buf).await + }; + + if let Err(e) = send_result { + for id in ®istered { + self.remove_waiter(*id); + } + return Err(e); + } + + debug!( + "execute_compound: {} operations, total_len={}, msg_ids={:?}, signed={}, encrypted={}", + ops.len(), + compound_buf.len(), + message_ids.iter().map(|m| m.0).collect::>(), + should_sign, + should_encrypt, + ); + + // Collect per-sub-op results in submission order. Each `rx.await` + // resolves independently — the receiver task splits the response + // frame by `NextCommand` and routes each sub-response to its own + // waiter, so we can await them sequentially without blocking any + // of them (they may already all be resolved by the time we loop). + let mut results: Vec> = Vec::with_capacity(receivers.len()); + for rx in receivers { + results.push(await_frame(rx).await); + } + Ok(results) + } + + /// Send a CANCEL request for an outstanding operation. + pub async fn send_cancel( + &mut self, + original_msg_id: MessageId, + async_id: Option, + ) -> Result<()> { + use crate::msg::cancel::CancelRequest; + + self.inner + .metrics + .explicit_cancels_sent + .fetch_add(1, Ordering::Relaxed); + + let (should_sign, should_encrypt) = { + let c = self.inner.crypto.lock().unwrap(); + (c.should_sign, c.should_encrypt) + }; + let session_id = self.session_id(); + + let mut header = Header::new_request(Command::Cancel); + header.message_id = original_msg_id; + header.credit_charge = CreditCharge(0); + header.credits = 0; + header.session_id = session_id; + + if let Some(aid) = async_id { + header.flags.set_async(); + header.async_id = Some(aid); + header.tree_id = None; + } + if should_sign && !should_encrypt { + header.flags.set_signed(); + } + + let body = CancelRequest; + let mut msg_bytes = pack_message(&header, &body); + + if should_encrypt { + let encrypted = self.encrypt_bytes(&msg_bytes)?; + self.inner.send_and_count(&encrypted).await?; + debug!( + "send_cancel: msg_id={}, async_id={:?}, encrypted", + original_msg_id.0, async_id + ); + } else { + if should_sign { + let c = self.inner.crypto.lock().unwrap(); + if let (Some(key), Some(algo)) = (&c.signing_key, &c.signing_algorithm) { + signing::sign_message(&mut msg_bytes, key, *algo, original_msg_id.0, false)?; + } + } + self.inner.send_and_count(&msg_bytes).await?; + debug!( + "send_cancel: msg_id={}, async_id={:?}, signed={}", + original_msg_id.0, async_id, should_sign + ); + } + Ok(()) + } + + /// Encrypt plaintext into a TRANSFORM_HEADER + ciphertext frame. + fn encrypt_bytes(&self, plaintext: &[u8]) -> Result> { + let mut c = self.inner.crypto.lock().unwrap(); + let enc_key = c + .encryption_key + .as_ref() + .ok_or_else(|| Error::invalid_data("encryption active but no encryption key"))? + .clone(); + let cipher = c + .encryption_cipher + .ok_or_else(|| Error::invalid_data("encryption active but no cipher"))?; + let session_id = c.session_id.0; + let nonce = c + .nonce_gen + .as_mut() + .ok_or_else(|| Error::invalid_data("encryption active but no nonce generator"))? + .next(cipher); + drop(c); + + let (transform_header, ciphertext) = + encryption::encrypt_message(plaintext, &enc_key, cipher, &nonce, session_id)?; + + let mut encrypted = transform_header; + encrypted.extend_from_slice(&ciphertext); + + trace!( + "encrypt: plaintext={} bytes, encrypted={} bytes, nonce={:02X?}", + plaintext.len(), + encrypted.len(), + &nonce[..cipher.nonce_len()] + ); + + Ok(encrypted) + } + + /// Register a tree as DFS-enabled. + pub fn register_dfs_tree(&mut self, tree_id: TreeId) { + self.inner.dfs_trees.lock().unwrap().insert(tree_id); + } + + /// Deregister a tree from DFS tracking. + pub fn deregister_dfs_tree(&mut self, tree_id: TreeId) { + self.inner.dfs_trees.lock().unwrap().remove(&tree_id); + } + + fn should_set_dfs_flag(&self, tree_id: Option) -> bool { + tree_id.is_some_and(|id| self.inner.dfs_trees.lock().unwrap().contains(&id)) + } + + /// Allocate `charge` consecutive MessageIds and return the first. + /// + /// Also bumps the `requests_sent` metric — this is the single funnel + /// every send path (`negotiate`, `execute`, `execute_with_credits`, + /// `execute_capturing_request`, `dispatch`, `execute_compound`'s loop) + /// goes through, so counting here can't drift as new send sites land. + /// `send_cancel` reuses an existing msg_id and is counted separately by + /// `explicit_cancels_sent`. + fn allocate_msg_id(&self, charge: u64) -> MessageId { + let first = self + .inner + .next_message_id + .fetch_add(charge, Ordering::SeqCst); + self.inner + .metrics + .requests_sent + .fetch_add(1, Ordering::Relaxed); + MessageId(first) + } + + /// Register a waiter in the shared map and return the Receiver. + /// + /// Atomically checks `disconnected` under the waiters lock. If the + /// connection died between `send_request`'s fast-path check and + /// this call, returns `Err(Disconnected)` without inserting — + /// prevents a TOCTOU where the receiver task has already drained + /// the waiters map but we'd insert a new entry that no one will + /// ever route to, leaving the caller hanging on `rx.await`. + /// + /// `fan_error_to_waiters` sets `disconnected = true` under the + /// same lock, making the two paths strictly ordered. + fn register_waiter(&self, msg_id: MessageId) -> Result>> { + let mut waiters = self.inner.waiters.lock().unwrap(); + if self.inner.disconnected.load(Ordering::Acquire) { + return Err(Error::Disconnected); + } + let (tx, rx) = oneshot::channel(); + waiters.insert(msg_id, tx); + trace!("register_waiter: msg_id={}", msg_id.0); + Ok(rx) + } + + /// Remove a waiter from the map (used on send error). + fn remove_waiter(&self, msg_id: MessageId) { + self.inner.waiters.lock().unwrap().remove(&msg_id); + trace!("remove_waiter: msg_id={}", msg_id.0); + } + + #[cfg(test)] + pub(crate) fn set_test_params(&mut self, params: NegotiatedParams) { + // OnceLock: first setter wins. Tests sometimes stage params on a + // fresh connection; ignore any collision. + let _ = self.inner.params.set(params); + } + + #[cfg(test)] + pub(crate) fn set_credits(&mut self, credits: u16) { + self.inner.credits.store(credits as u32, Ordering::Release); + } + + #[cfg(test)] + pub(crate) fn set_next_message_id(&mut self, id: u64) { + self.inner.next_message_id.store(id, Ordering::Release); + } + + /// Snapshot the diagnostics counters on this connection. + /// + /// Crate-internal; the public surface is [`Self::diagnostics`]. + pub(crate) fn metrics(&self) -> crate::client::diagnostics::MetricsSnapshot { + self.inner.metrics_snapshot() + } + + /// Capture a snapshot of this connection's state and counters. + /// + /// **Eventually consistent.** Each field is loaded independently — + /// `credits.available` and `credits.in_flight` are sampled at slightly + /// different moments, so their sum is *not* invariant. Documented on + /// [`crate::client::diagnostics::CreditInfo`]. + /// + /// **Survives teardown.** Counters live on the `Arc` that + /// outlives the receiver task; calling this on a torn-down connection + /// (`disconnected: true`) returns final values. + /// + /// **Lock order.** Internally takes the `crypto`, `waiters`, + /// `dfs_trees`, and `estimated_rtt` locks one at a time, in that + /// order, and only as long as it takes to copy primitives out. No + /// lock is held across an `.await`. + pub fn diagnostics(&self) -> crate::client::diagnostics::ConnectionDiagnostics { + use crate::client::diagnostics::{ + CompressionInfo, ConnectionDiagnostics, CreditInfo, EncryptionInfo, NegotiatedSummary, + SigningInfo, + }; + + // ── 1. crypto lock: signing / encryption snapshot ──────────────── + let (signing, encryption) = { + let c = self.inner.crypto.lock().unwrap(); + ( + SigningInfo { + active: c.should_sign, + algorithm: c.signing_algorithm, + }, + EncryptionInfo { + active: c.should_encrypt, + cipher: c.encryption_cipher, + }, + ) + }; + + // ── 2. waiters lock: in-flight count ──────────────────────────── + let in_flight = self.inner.waiters.lock().unwrap().len(); + + // ── 3. dfs_trees lock: cloned snapshot ────────────────────────── + let dfs_trees: Vec = self + .inner + .dfs_trees + .lock() + .unwrap() + .iter() + .copied() + .collect(); + + // ── 4. estimated_rtt lock: cloned snapshot ────────────────────── + let rtt_estimate = *self.inner.estimated_rtt.lock().unwrap(); + + // Wait-free reads. + let credits = CreditInfo { + available: (self.inner.credits.load(Ordering::Acquire) & 0xFFFF) as u16, + in_flight, + next_message_id: self.inner.next_message_id.load(Ordering::Acquire), + }; + let disconnected = self.inner.disconnected.load(Ordering::Acquire); + let compression = CompressionInfo { + requested: self.inner.compression_requested.load(Ordering::Acquire), + negotiated: self.inner.compression_enabled.load(Ordering::Acquire), + }; + + let negotiated = self.inner.params.get().map(|p| NegotiatedSummary { + dialect: p.dialect, + max_read_size: p.max_read_size, + max_write_size: p.max_write_size, + max_transact_size: p.max_transact_size, + server_guid: p.server_guid, + signing_required: p.signing_required, + capabilities: p.capabilities, + gmac_negotiated: p.gmac_negotiated, + cipher: p.cipher, + compression_supported: p.compression_supported, + }); + + ConnectionDiagnostics { + server: self.inner.server_name.clone(), + negotiated, + credits, + signing, + encryption, + compression, + rtt_estimate, + disconnected, + dfs_trees, + session: None, // populated by SmbClient when assembling the full tree + metrics: self.metrics(), + } + } +} + +// `Connection`'s teardown lives on `Inner::drop`: the receiver task is +// aborted only when the last clone drops (the last `Arc` goes away). + +/// Receiver task loop: owns the transport receive half, routes each frame +/// to its waiter. +async fn receiver_loop(transport_recv: Box, inner: Arc) { + loop { + let raw = match transport_recv.receive().await { + Ok(bytes) => bytes, + Err(e) => { + debug!("receiver_loop: transport error: {}, shutting down", e); + let count = inner.waiters.lock().unwrap().len(); + fan_error_to_waiters(&inner, &e); + // Idle teardown (no in-flight requests) is routine: the server + // or OS reaps a session that's been quiet long enough. Real + // disconnects with pending waiters stay at WARN because they + // affect callers. The decrypt / decompress / malformed-frame + // teardowns below stay WARN regardless of waiter count — those + // are protocol corruption, always worth surfacing. + if count == 0 { + debug!("receiver_loop: idle teardown (no waiters)"); + } else { + warn!( + "receiver_loop: exiting after fan-error to {} waiters", + count + ); + } + return; + } + }; + inner + .metrics + .wire_bytes_received + .fetch_add(raw.len() as u64, Ordering::Relaxed); + trace!("receiver_loop: received {} bytes", raw.len()); + trace!( + "receiver_loop: tick, waiters={}", + inner.waiters.lock().unwrap().len() + ); + + // Decrypt if TRANSFORM_HEADER. Per P3.4 / decision E6: on an + // unrecoverable frame error (decrypt auth tag mismatch, decompress + // failure, malformed sub-frame structure) we tear the connection + // down instead of log-and-continue. The msg_id isn't recoverable + // from an unparseable frame, so there's no targeted waiter to + // notify; log-and-continue would leave the matching waiter + // hanging forever. Teardown fans Err(Disconnected) to every + // pending waiter; the caller reconnects. + let (decoded, was_encrypted) = if raw.len() >= 4 && raw[0..4] == TRANSFORM_PROTOCOL_ID { + match decrypt_frame(&raw, &inner) { + Ok(plain) => (plain, true), + Err(e) => { + inner + .metrics + .decrypt_failures + .fetch_add(1, Ordering::Relaxed); + warn!( + "receiver_loop: decrypt failed: {}; tearing down connection", + e + ); + let count = inner.waiters.lock().unwrap().len(); + fan_error_to_waiters(&inner, &e); + warn!( + "receiver_loop: exiting after fan-error to {} waiters", + count + ); + return; + } + } + } else { + (raw, false) + }; + + // Decompress if COMPRESSION_HEADER. + let decoded = if decoded.len() >= 4 && decoded[0..4] == COMPRESSION_PROTOCOL_ID { + match decompress_response(&decoded) { + Ok(plain) => plain, + Err(e) => { + inner + .metrics + .decompress_failures + .fetch_add(1, Ordering::Relaxed); + warn!( + "receiver_loop: decompress failed: {}; tearing down connection", + e + ); + let count = inner.waiters.lock().unwrap().len(); + fan_error_to_waiters(&inner, &e); + warn!( + "receiver_loop: exiting after fan-error to {} waiters", + count + ); + return; + } + } + } else { + decoded + }; + + // Split by NextCommand. + let sub_frames = match split_compound(&decoded) { + Ok(subs) => subs, + Err(e) => { + inner + .metrics + .malformed_frames + .fetch_add(1, Ordering::Relaxed); + warn!( + "receiver_loop: malformed frame: {}; tearing down connection", + e + ); + let count = inner.waiters.lock().unwrap().len(); + fan_error_to_waiters(&inner, &e); + warn!( + "receiver_loop: exiting after fan-error to {} waiters", + count + ); + return; + } + }; + + // Produce a list of routable entries for this transport frame. + // SubFrameAction::Skip frames (oplock break, STATUS_PENDING) are + // dropped silently. A parse error from prepare_sub_frame is fatal: + // the compound split succeeded (framing looked valid) but a header + // inside is corrupt — the connection is out of sync and we can't + // recover. Tear down so pending waiters see Err(Disconnected) + // rather than hanging forever. + let mut routable: Vec<(MessageId, Result)> = Vec::new(); + for sub in sub_frames { + match prepare_sub_frame(&sub, was_encrypted, &inner) { + Ok(SubFrameAction::Route(msg_id, result)) => routable.push((msg_id, result)), + Ok(SubFrameAction::Skip) => { /* oplock break / STATUS_PENDING */ } + Err(e) => { + inner + .metrics + .malformed_frames + .fetch_add(1, Ordering::Relaxed); + warn!( + "receiver_loop: sub-frame parse failed: {}; tearing down connection", + e + ); + let count = inner.waiters.lock().unwrap().len(); + fan_error_to_waiters(&inner, &e); + warn!( + "receiver_loop: exiting after fan-error to {} waiters", + count + ); + return; + } + } + } + + if routable.is_empty() { + continue; + } + + for (msg_id, result) in routable { + let maybe_tx = inner.waiters.lock().unwrap().remove(&msg_id); + match maybe_tx { + Some(tx) => { + let was_err = result.is_err(); + match &result { + Ok(frame) => debug!( + "recv: routed msg_id={}, status={:?}, cmd={:?}", + msg_id.0, frame.header.status, frame.header.command + ), + Err(e) => debug!("recv: routed error msg_id={}, err={}", msg_id.0, e), + } + // Bump the routing counter BEFORE handing the response to + // the caller. The caller's `await` resumes as soon as + // `tx.send` lands, so a fetch_add ordered after the send + // races: tests that snapshot metrics right after the + // awaited call would see the increment land late. If the + // send fails (caller dropped its Receiver), rebalance: + // subtract from ok/err and credit `late_after_drop`. The + // rebalance window is only observable to another thread + // snapshotting mid-send during a caller-drop, which is + // benign (eventually consistent counters by design). + let counter = if was_err { + &inner.metrics.responses_routed_err + } else { + &inner.metrics.responses_routed_ok + }; + counter.fetch_add(1, Ordering::Relaxed); + if tx.send(result).is_err() { + // Caller's oneshot::Receiver was dropped — typical + // spawn/abort pattern. Counted distinctly from + // stray frames (None branch below). + counter.fetch_sub(1, Ordering::Relaxed); + inner + .metrics + .responses_late_after_drop + .fetch_add(1, Ordering::Relaxed); + trace!("recv: late arrival for dropped waiter, msg_id={}", msg_id.0); + } + } + None => { + // True orphan: msg_id never registered (server sent + // something for an id we didn't allocate, or a + // send-error cleanup raced with arrival). + inner + .metrics + .responses_stray + .fetch_add(1, Ordering::Relaxed); + match &result { + Ok(frame) => debug!( + "recv: orphan dropped, msg_id={}, status={:?}, cmd={:?}", + msg_id.0, frame.header.status, frame.header.command + ), + Err(e) => debug!( + "recv: orphan dropped (error) msg_id={}, err={}", + msg_id.0, e + ), + } + } + } + } + } +} + +/// Outcome of preparing a single sub-frame. +pub(crate) enum SubFrameAction { + /// Route this response to the waiter for `msg_id`. + /// + /// The inner `Result` lets us deliver a per-sub-op error (signature + /// verification failure, session expired) targeted at its matching + /// waiter without disturbing others. + Route(MessageId, std::result::Result), + /// Skip silently — not forwarded to any waiter. + /// Used for oplock-break notifications (MessageId=UNSOLICITED) and + /// STATUS_PENDING interim responses (keep the waiter alive). + Skip, +} + +/// Prepare a routable sub-frame from raw bytes. +/// +/// Returns `Ok(SubFrameAction::Route(...))` for a normal response (possibly +/// carrying a sub-op error), `Ok(SubFrameAction::Skip)` for oplock/PENDING +/// frames that the caller should drop silently, and `Err(e)` for +/// unrecoverable errors where the connection is now out of sync +/// (header parse failure on a sub-frame the compound-splitter claimed was +/// valid — the receiver loop fans the error to all waiters and exits). +fn prepare_sub_frame(sub: &[u8], was_encrypted: bool, inner: &Inner) -> Result { + // Parse the header. A failure here means split_compound produced a + // chunk that doesn't start with a valid SMB2 header — the framing is + // corrupt and we can't know where the next sub-frame begins. Fatal + // to the connection. + let mut cursor = ReadCursor::new(sub); + let header = match Header::unpack(&mut cursor) { + Ok(h) => h, + Err(e) => { + return Err(Error::invalid_data(format!( + "sub-frame header parse failed: {}", + e + ))); + } + }; + + // Always update credits. + if header.credits > 0 { + let prev = inner.credits.load(Ordering::Relaxed) as u16; + let next = prev.saturating_add(header.credits); + inner.credits.store(next as u32, Ordering::Release); + } + + // Oplock break notification: MessageId=UNSOLICITED. Skip silently. + if header.message_id == MessageId::UNSOLICITED { + inner + .metrics + .unsolicited_notifications_received + .fetch_add(1, Ordering::Relaxed); + debug!( + "recv: skipping unsolicited oplock break notification, cmd={:?}", + header.command + ); + return Ok(SubFrameAction::Skip); + } + + // STATUS_PENDING is an interim response — don't forward, keep waiter. + if header.status.is_pending() { + inner + .metrics + .status_pending_loops + .fetch_add(1, Ordering::Relaxed); + debug!( + "recv: STATUS_PENDING (interim), cmd={:?}, msg_id={}", + header.command, header.message_id.0 + ); + return Ok(SubFrameAction::Skip); + } + + // Consume credit_charge (or 1 if zero). + let consume = header.credit_charge.0.max(1); + let prev = inner.credits.load(Ordering::Relaxed) as u16; + inner + .credits + .store(prev.saturating_sub(consume) as u32, Ordering::Release); + + // Verify signature if signing is active and not encrypted. + let (should_sign, signing_key, signing_algorithm) = { + let c = inner.crypto.lock().unwrap(); + (c.should_sign, c.signing_key.clone(), c.signing_algorithm) + }; + if should_sign && !was_encrypted && sub.len() >= Header::SIZE { + let flags = u32::from_le_bytes(sub[16..20].try_into().unwrap()); + let is_signed = (flags & HeaderFlags::SIGNED) != 0; + let status = u32::from_le_bytes(sub[8..12].try_into().unwrap()); + let is_pending = status == NtStatus::PENDING.0; + if is_signed && !is_pending { + if let (Some(key), Some(algo)) = (signing_key, signing_algorithm) { + if let Err(e) = + signing::verify_signature(sub, &key, algo, header.message_id.0, false) + { + inner + .metrics + .signature_failures + .fetch_add(1, Ordering::Relaxed); + warn!( + "recv: sub-frame produced error for msg_id={}, reason=signature verify failed: {}", + header.message_id.0, e + ); + return Ok(SubFrameAction::Route(header.message_id, Err(e))); + } + } + } + } + + // Special status handling: session expired → error. + if header.status == NtStatus::NETWORK_SESSION_EXPIRED { + inner + .metrics + .session_expired_events + .fetch_add(1, Ordering::Relaxed); + warn!( + "recv: session expired (STATUS_NETWORK_SESSION_EXPIRED), cmd={:?}, msg_id={}", + header.command, header.message_id.0 + ); + warn!( + "recv: sub-frame produced error for msg_id={}, reason=session expired", + header.message_id.0 + ); + return Ok(SubFrameAction::Route( + header.message_id, + Err(Error::SessionExpired), + )); + } + + let body = if sub.len() > Header::SIZE { + sub[Header::SIZE..].to_vec() + } else { + Vec::new() + }; + let raw = sub.to_vec(); + let msg_id = header.message_id; + Ok(SubFrameAction::Route( + msg_id, + Ok(Frame { header, body, raw }), + )) +} + +/// Fan the given error (as best we can clone it) to every pending waiter +/// and clear the waiters map. Marks the connection as disconnected so +/// new sends fail-fast. +/// +/// `disconnected` is set UNDER the waiters lock so `register_waiter` sees +/// either "still alive → insert succeeds" or "dead → insert rejected", +/// never "inserted but already drained" (which would leave the caller +/// hanging on `rx.await`). +fn fan_error_to_waiters(inner: &Inner, e: &Error) { + let drained: Vec<(MessageId, oneshot::Sender>)> = { + let mut waiters = inner.waiters.lock().unwrap(); + inner.disconnected.store(true, Ordering::Release); + waiters.drain().collect() + }; + for (_id, tx) in drained { + let _ = tx.send(Err(clone_err_as_disconnected(e))); + } +} + +/// Best-effort error clone: `Error` isn't `Clone` (Io holds std::io::Error). +/// Everything maps to `Error::Disconnected` for waiter-fan-out purposes — +/// waiters only need to know "the connection died". +fn clone_err_as_disconnected(_e: &Error) -> Error { + Error::Disconnected +} + +fn decrypt_frame(data: &[u8], inner: &Inner) -> Result> { + let c = inner.crypto.lock().unwrap(); + let dec_key = c + .decryption_key + .as_ref() + .ok_or_else(|| Error::invalid_data("received encrypted message but no decryption key"))? + .clone(); + let cipher = c + .encryption_cipher + .ok_or_else(|| Error::invalid_data("received encrypted message but no cipher"))?; + drop(c); + + if data.len() < TransformHeader::SIZE { + return Err(Error::invalid_data( + "encrypted message too short for TransformHeader", + )); + } + + let transform_header = &data[..TransformHeader::SIZE]; + let ciphertext = &data[TransformHeader::SIZE..]; + let plaintext = encryption::decrypt_message(transform_header, ciphertext, &dec_key, cipher)?; + Ok(plaintext) +} + +/// Split a preprocessed frame into sub-frames by `NextCommand` offsets. +/// Returns the raw byte slices (as owned Vec) for each sub-frame. +pub(crate) fn split_compound(data: &[u8]) -> Result>> { + let mut results = Vec::new(); + let mut offset = 0usize; + + loop { + if offset + Header::SIZE > data.len() { + return Err(Error::invalid_data(format!( + "compound response truncated at offset {}: need {} bytes for header, but only {} remain", + offset, + Header::SIZE, + data.len() - offset, + ))); + } + + if !results.is_empty() && offset % 8 != 0 { + return Err(Error::invalid_data(format!( + "compound response at offset {} is not 8-byte aligned -- must disconnect", + offset, + ))); + } + + // Parse NextCommand directly from header bytes 20..24. + let next_cmd = u32::from_le_bytes(data[offset + 20..offset + 24].try_into().unwrap()); + let sub_end = if next_cmd > 0 { + offset + next_cmd as usize + } else { + data.len() + }; + + if sub_end > data.len() { + return Err(Error::invalid_data(format!( + "compound NextCommand offset {} at position {} exceeds response length {}", + next_cmd, + offset, + data.len(), + ))); + } + + results.push(data[offset..sub_end].to_vec()); + if next_cmd == 0 { + break; + } + offset += next_cmd as usize; + } + Ok(results) +} + +/// Await a per-request `oneshot::Receiver` and translate the three +/// outcomes into a `Result`: +/// +/// - `Ok(Ok(frame))` — the receiver task routed a successful response. +/// - `Ok(Err(e))` — the receiver task delivered a targeted error for +/// this `MessageId` (signature-verify failure, session expired, etc.). +/// - `Err(_)` on the outer await means the `oneshot::Sender` was dropped +/// without sending, which happens on connection teardown (see +/// `fan_error_to_waiters` — it calls `send(Err(Disconnected))` for +/// every pending waiter, so we only see a raw canceled channel if +/// the whole map was dropped without that call, i.e. Arc teardown). +/// Map it to `Error::Disconnected`. +pub(crate) async fn await_frame(rx: oneshot::Receiver>) -> Result { + match rx.await { + Ok(Ok(frame)) => Ok(frame), + Ok(Err(e)) => Err(e), + Err(_canceled) => Err(Error::Disconnected), + } +} + +/// Pack a header + body into raw SMB2 message bytes. +pub(crate) fn pack_message(header: &Header, body: &dyn Pack) -> Vec { + let mut cursor = WriteCursor::new(); + header.pack(&mut cursor); + body.pack(&mut cursor); + cursor.into_inner() +} + +fn generate_guid() -> Guid { + let mut bytes = [0u8; 16]; + getrandom::fill(&mut bytes).expect("failed to generate random GUID"); + Guid { + data1: u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]), + data2: u16::from_le_bytes([bytes[4], bytes[5]]), + data3: u16::from_le_bytes([bytes[6], bytes[7]]), + data4: [ + bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15], + ], + } +} + +fn generate_salt() -> Vec { + let mut salt = vec![0u8; 32]; + getrandom::fill(&mut salt).expect("failed to generate random salt"); + salt +} + +fn build_compressed_frame(compressed: &CompressedMessage) -> Vec { + let header = CompressionTransformHeader { + original_compressed_segment_size: compressed.original_size, + compression_algorithm: COMPRESSION_ALGORITHM_LZ4, + flags: SMB2_COMPRESSION_FLAG_NONE, + offset_or_length: compressed.offset, + }; + let mut cursor = WriteCursor::new(); + header.pack(&mut cursor); + let mut frame = cursor.into_inner(); + frame.extend_from_slice(&compressed.uncompressed_prefix); + frame.extend_from_slice(&compressed.compressed_data); + frame +} + +fn decompress_response(data: &[u8]) -> Result> { + if data.len() < CompressionTransformHeader::SIZE { + return Err(Error::invalid_data( + "compressed response too short for CompressionTransformHeader", + )); + } + let mut cursor = ReadCursor::new(data); + let header = CompressionTransformHeader::unpack(&mut cursor)?; + if header.compression_algorithm != COMPRESSION_ALGORITHM_LZ4 { + return Err(Error::invalid_data(format!( + "unsupported compression algorithm 0x{:04X}, only LZ4 (0x{:04X}) is supported", + header.compression_algorithm, COMPRESSION_ALGORITHM_LZ4 + ))); + } + if header.flags != SMB2_COMPRESSION_FLAG_NONE { + return Err(Error::invalid_data(format!( + "unsupported compression flags 0x{:04X}, only unchained (0x0000) is supported", + header.flags + ))); + } + let offset = header.offset_or_length as usize; + let remaining = &data[CompressionTransformHeader::SIZE..]; + if offset > remaining.len() { + return Err(Error::invalid_data(format!( + "compression offset {} exceeds remaining data length {}", + offset, + remaining.len() + ))); + } + let uncompressed_prefix = &remaining[..offset]; + let compressed_data = &remaining[offset..]; + decompress_message( + uncompressed_prefix, + compressed_data, + header.original_compressed_segment_size, + ) +} + +// Arc-based TransportSend/TransportReceive for TcpTransport sharing. +#[async_trait::async_trait] +impl TransportSend for Arc { + async fn send(&self, data: &[u8]) -> Result<()> { + (**self).send(data).await + } +} + +#[async_trait::async_trait] +impl TransportReceive for Arc { + async fn receive(&self) -> Result> { + (**self).receive().await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::msg::negotiate::{NegotiateContext, HASH_ALGORITHM_SHA512}; + use crate::transport::MockTransport; + use crate::types::flags::HeaderFlags; + + /// Pack a set of SMB2 sub-responses into one compound transport frame + /// by wiring up `NextCommand` offsets and 8-byte-padding each sub + /// except the last. Used by compound execute tests below. + fn build_compound_response_frame(responses: &[Vec]) -> Vec { + let mut padded: Vec> = Vec::new(); + for (i, resp) in responses.iter().enumerate() { + let mut r = resp.clone(); + let is_last = i == responses.len() - 1; + if !is_last { + let remainder = r.len() % 8; + if remainder != 0 { + r.resize(r.len() + (8 - remainder), 0); + } + let next_cmd = r.len() as u32; + r[20..24].copy_from_slice(&next_cmd.to_le_bytes()); + } + padded.push(r); + } + let mut frame = Vec::new(); + for r in &padded { + frame.extend_from_slice(r); + } + frame + } + + /// Build a canned negotiate response with the given dialect. + fn build_negotiate_response(dialect: Dialect) -> Vec { + let resp_header = { + let mut h = Header::new_request(Command::Negotiate); + h.flags.set_response(); + h.credits = 32; + h + }; + let resp_body = NegotiateResponse { + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + dialect_revision: dialect, + server_guid: Guid::ZERO, + capabilities: Capabilities::new(Capabilities::DFS | Capabilities::LEASING), + max_transact_size: 65536, + max_read_size: 65536, + max_write_size: 65536, + system_time: 132_000_000_000_000_000, + server_start_time: 131_000_000_000_000_000, + security_buffer: vec![0x60, 0x00], + negotiate_contexts: if dialect == Dialect::Smb3_1_1 { + vec![NegotiateContext::PreauthIntegrity { + hash_algorithms: vec![HASH_ALGORITHM_SHA512], + salt: vec![0xBB; 32], + }] + } else { + vec![] + }, + }; + pack_message(&resp_header, &resp_body) + } + + #[tokio::test] + async fn negotiate_stores_params_correctly() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + mock.queue_response(build_negotiate_response(Dialect::Smb3_1_1)); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.negotiate().await.unwrap(); + + let params = conn.params().unwrap(); + assert_eq!(params.dialect, Dialect::Smb3_1_1); + assert_eq!(params.max_read_size, 65536); + assert_eq!(params.max_write_size, 65536); + assert_eq!(params.max_transact_size, 65536); + assert!(!params.signing_required); + } + + #[tokio::test] + async fn negotiate_updates_credits() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + mock.queue_response(build_negotiate_response(Dialect::Smb3_0)); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.negotiate().await.unwrap(); + + // Server granted 32 credits, minus 1 consumed for our request. + assert_eq!(conn.credits(), 32); + } + + #[tokio::test] + async fn negotiate_increments_message_id() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + mock.queue_response(build_negotiate_response(Dialect::Smb2_0_2)); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + assert_eq!(conn.next_message_id(), 0); + conn.negotiate().await.unwrap(); + assert_eq!(conn.next_message_id(), 1); + } + + #[tokio::test] + async fn negotiate_updates_preauth_hash() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + mock.queue_response(build_negotiate_response(Dialect::Smb3_1_1)); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + let initial_hash = *conn.preauth_hasher().value(); + conn.negotiate().await.unwrap(); + assert_ne!(conn.preauth_hasher().value(), &initial_hash); + } + + #[tokio::test] + async fn negotiate_rejects_invalid_max_read_size() { + let resp_header = { + let mut h = Header::new_request(Command::Negotiate); + h.flags.set_response(); + h.credits = 1; + h + }; + let resp_body = NegotiateResponse { + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + dialect_revision: Dialect::Smb2_0_2, + server_guid: Guid::ZERO, + capabilities: Capabilities::default(), + max_transact_size: 65536, + max_read_size: 1024, // Too small + max_write_size: 65536, + system_time: 0, + server_start_time: 0, + security_buffer: vec![], + negotiate_contexts: vec![], + }; + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + mock.queue_response(pack_message(&resp_header, &resp_body)); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + let result = conn.negotiate().await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("MaxReadSize")); + } + + #[tokio::test] + async fn message_id_increments_on_send_request() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + // Manually set past negotiate. + conn.set_next_message_id(5); + + use crate::msg::tree_disconnect::TreeDisconnectRequest; + let body = TreeDisconnectRequest; + // With execute(), the msg_id is an internal allocation — we peek it + // via next_message_id() before sending. Use a timeout so the test + // doesn't wait for a response the mock never produces. + assert_eq!(conn.next_message_id(), 5); + let _ = tokio::time::timeout( + std::time::Duration::from_millis(50), + conn.execute(Command::TreeDisconnect, &body, None), + ) + .await; + assert_eq!(conn.next_message_id(), 6); + } + + #[tokio::test] + async fn signing_applied_to_outgoing_messages() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + // Activate signing. + let key = vec![0xAA; 16]; + conn.activate_signing(key, SigningAlgorithm::HmacSha256); + conn.set_session_id(SessionId(0x1234)); + + use crate::msg::tree_disconnect::TreeDisconnectRequest; + let body = TreeDisconnectRequest; + // execute() awaits the response, but we only care about the sent bytes. + // Spawn+abort the future so the send runs but the await doesn't block on + // a response that never comes. + let _ = tokio::time::timeout( + std::time::Duration::from_millis(50), + conn.execute(Command::TreeDisconnect, &body, None), + ) + .await; + + let msg_bytes = mock.sent_message(0).expect("one send recorded"); + // Verify the signed flag is set in the header. + let flags = u32::from_le_bytes(msg_bytes[16..20].try_into().unwrap()); + assert!(flags & HeaderFlags::SIGNED != 0, "message should be signed"); + + // Verify signature is non-zero. + let sig = &msg_bytes[48..64]; + assert_ne!(sig, &[0u8; 16], "signature should not be all zeros"); + } + + #[tokio::test] + async fn negotiate_with_smb2_dialect() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + mock.queue_response(build_negotiate_response(Dialect::Smb2_0_2)); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.negotiate().await.unwrap(); + + let params = conn.params().unwrap(); + assert_eq!(params.dialect, Dialect::Smb2_0_2); + assert!(!params.gmac_negotiated); + assert!(params.cipher.is_none()); + } + + #[tokio::test] + async fn negotiate_sends_all_five_dialects() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + mock.queue_response(build_negotiate_response(Dialect::Smb3_1_1)); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.negotiate().await.unwrap(); + + // Verify the sent request contains all 5 dialects. + let sent = mock.sent_message(0).unwrap(); + let mut cursor = ReadCursor::new(&sent); + let _header = Header::unpack(&mut cursor).unwrap(); + let req = NegotiateRequest::unpack(&mut cursor).unwrap(); + assert_eq!(req.dialects.len(), 5); + assert!(req.dialects.contains(&Dialect::Smb2_0_2)); + assert!(req.dialects.contains(&Dialect::Smb2_1)); + assert!(req.dialects.contains(&Dialect::Smb3_0)); + assert!(req.dialects.contains(&Dialect::Smb3_0_2)); + assert!(req.dialects.contains(&Dialect::Smb3_1_1)); + } + + // (Compound-specific send/receive tests removed — execute_compound tests live below.) + + // ── Compression tests ──────────────────────────────────────────── + + use crate::msg::negotiate::COMPRESSION_LZ4; + use crate::msg::transform::{ + CompressionTransformHeader, COMPRESSION_ALGORITHM_LZ4, COMPRESSION_PROTOCOL_ID, + SMB2_COMPRESSION_FLAG_NONE, + }; + + /// Build a negotiate response that includes a compression context with LZ4. + fn build_negotiate_response_with_compression(dialect: Dialect) -> Vec { + let resp_header = { + let mut h = Header::new_request(Command::Negotiate); + h.flags.set_response(); + h.credits = 32; + h + }; + let resp_body = NegotiateResponse { + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + dialect_revision: dialect, + server_guid: Guid::ZERO, + capabilities: Capabilities::new(Capabilities::DFS | Capabilities::LEASING), + max_transact_size: 65536, + max_read_size: 65536, + max_write_size: 65536, + system_time: 132_000_000_000_000_000, + server_start_time: 131_000_000_000_000_000, + security_buffer: vec![0x60, 0x00], + negotiate_contexts: vec![ + NegotiateContext::PreauthIntegrity { + hash_algorithms: vec![HASH_ALGORITHM_SHA512], + salt: vec![0xBB; 32], + }, + NegotiateContext::Compression { + flags: 0, + algorithms: vec![COMPRESSION_LZ4], + }, + ], + }; + pack_message(&resp_header, &resp_body) + } + + #[tokio::test] + async fn negotiate_detects_compression_support() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + mock.queue_response(build_negotiate_response_with_compression(Dialect::Smb3_1_1)); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.negotiate().await.unwrap(); + + let params = conn.params().unwrap(); + assert!(params.compression_supported); + assert!(conn.compression_enabled()); + } + + #[tokio::test] + async fn negotiate_without_compression_context_disables_compression() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + mock.queue_response(build_negotiate_response(Dialect::Smb3_1_1)); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.negotiate().await.unwrap(); + + let params = conn.params().unwrap(); + assert!(!params.compression_supported); + assert!(!conn.compression_enabled()); + } + + #[tokio::test] + async fn compression_disabled_when_client_config_says_no() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + mock.queue_response(build_negotiate_response_with_compression(Dialect::Smb3_1_1)); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.set_compression_requested(false); + conn.negotiate().await.unwrap(); + + // Server supports it, but client disabled it. + let params = conn.params().unwrap(); + assert!(params.compression_supported); + assert!(!conn.compression_enabled()); + } + + #[tokio::test] + async fn negotiate_offers_compression_context_when_requested() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + mock.queue_response(build_negotiate_response_with_compression(Dialect::Smb3_1_1)); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + // compression_requested defaults to true. + conn.negotiate().await.unwrap(); + + // Parse the sent negotiate request and check for compression context. + let sent = mock.sent_message(0).unwrap(); + let mut cursor = ReadCursor::new(&sent); + let _header = Header::unpack(&mut cursor).unwrap(); + let req = NegotiateRequest::unpack(&mut cursor).unwrap(); + + let has_compression = req.negotiate_contexts.iter().any(|ctx| { + matches!(ctx, NegotiateContext::Compression { algorithms, .. } + if algorithms.contains(&COMPRESSION_LZ4)) + }); + assert!( + has_compression, + "negotiate request should include compression context with LZ4" + ); + } + + #[tokio::test] + async fn negotiate_does_not_offer_compression_when_disabled() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + mock.queue_response(build_negotiate_response(Dialect::Smb3_1_1)); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.set_compression_requested(false); + conn.negotiate().await.unwrap(); + + let sent = mock.sent_message(0).unwrap(); + let mut cursor = ReadCursor::new(&sent); + let _header = Header::unpack(&mut cursor).unwrap(); + let req = NegotiateRequest::unpack(&mut cursor).unwrap(); + + let has_compression = req + .negotiate_contexts + .iter() + .any(|ctx| matches!(ctx, NegotiateContext::Compression { .. })); + assert!( + !has_compression, + "negotiate request should not include compression context" + ); + } + + #[test] + fn build_compressed_frame_roundtrip() { + // Create a message with a compressible payload. + let mut message = vec![0xFE; Header::SIZE]; // header-like prefix + let payload: Vec = b"COMPRESS_ME_".iter().copied().cycle().take(2048).collect(); + message.extend_from_slice(&payload); + + let compressed = compress_message(&message, Header::SIZE).expect("should compress"); + let framed = build_compressed_frame(&compressed); + + // Verify the frame starts with compression protocol ID. + assert_eq!(&framed[0..4], &COMPRESSION_PROTOCOL_ID); + + // Decompress and verify roundtrip. + let decompressed = decompress_response(&framed).expect("should decompress"); + assert_eq!(decompressed, message); + } + + #[test] + fn decompress_response_rejects_unsupported_algorithm() { + // Build a compression transform header with an unsupported algorithm. + let header = CompressionTransformHeader { + original_compressed_segment_size: 100, + compression_algorithm: 0x0001, // LZNT1, not LZ4 + flags: SMB2_COMPRESSION_FLAG_NONE, + offset_or_length: 0, + }; + let mut cursor = WriteCursor::new(); + header.pack(&mut cursor); + let mut frame = cursor.into_inner(); + frame.extend_from_slice(&[0u8; 10]); // bogus data + + let result = decompress_response(&frame); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("unsupported compression algorithm")); + } + + #[test] + fn decompress_response_rejects_chained_compression() { + let header = CompressionTransformHeader { + original_compressed_segment_size: 100, + compression_algorithm: COMPRESSION_ALGORITHM_LZ4, + flags: 0x0001, // chained + offset_or_length: 0, + }; + let mut cursor = WriteCursor::new(); + header.pack(&mut cursor); + let mut frame = cursor.into_inner(); + frame.extend_from_slice(&[0u8; 10]); + + let result = decompress_response(&frame); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("unchained")); + } + + #[test] + fn decompress_response_rejects_too_short_data() { + let result = decompress_response(&[0xFC, b'S', b'M']); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("too short")); + } + + // ── Unsolicited oplock break tests ───────────────────────────── + + // ── Phase 2 (actor + oneshot routing) red tests ───────────────── + // + // These tests pin the invariants the Phase 2 refactor must establish. + // They target the cancellation-by-drop failure mode that Phase 1's + // `HashSet` demux cannot solve: when a caller's future is + // dropped mid-flight (for example, by `tokio::task::JoinHandle::abort()`), + // the in-flight MessageIds stay in `pending`; server responses for those + // ids then get handed to the next caller as if they were legitimate. + // + // Post-Phase-2, each in-flight request carries its own `oneshot::Sender`; + // when the caller's `Receiver` is dropped (future aborted), the receiver + // task discards the response silently on arrival. + // + // These tests fail against current code (Phase 1). They must pass after + // Phase 2 lands. See `docs/specs/connection-actor.md`. + + // ── Phase 3 (silent-discard fix) red test ─────────────────────── + // + // Pins the invariant that an unrecoverable frame-level error + // (decrypt failure, decompress failure, malformed header after + // decryption) MUST NOT silently discard the frame and leave the + // matching waiter hanging forever. The Phase 2 receiver task + // currently `log-at-WARN + continue`s on decrypt failure — the + // msg_id isn't recoverable from an unparseable frame, so there's + // no waiter to notify targeted; the only correct behavior is to + // tear down the connection and fan `Err(Disconnected)` to all + // pending waiters. + // + // This test uses `tokio::time::timeout` to detect the hang: if + // the waiter doesn't resolve within 2 seconds, it's hung (bug + // present, test fails). Post-P3.4 fix, the waiter resolves with + // an error before the timeout. + + #[tokio::test] + async fn phase3_decrypt_failure_errors_waiter_not_hangs() { + use crate::crypto::encryption::Cipher; + + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.set_credits(10); + + // Activate encryption with a key that WON'T match what the + // malformed frame was "encrypted" with — decrypt will fail auth. + let enc_key = vec![0x42; 16]; + let dec_key = vec![0x99; 16]; // deliberately wrong decryption key + conn.activate_encryption(enc_key, dec_key, Cipher::Aes128Gcm); + + // Register a waiter manually so we can inject a bad frame without + // racing with a real send. + let rx = conn.register_waiter(MessageId(4)).unwrap(); + + // Build a frame that starts with TRANSFORM_PROTOCOL_ID so the + // receiver task takes the decrypt path, but whose ciphertext + // is garbage that will fail the GCM auth tag check. We craft a + // "valid-shape" transform header (52 bytes) plus ~64 bytes of + // garbage ciphertext. The receiver task's decrypt_frame call + // returns Err; currently it's log+continue (the bug). + let mut frame = Vec::new(); + frame.extend_from_slice(&TRANSFORM_PROTOCOL_ID); // 0xFD 'S' 'M' 'B' + frame.extend_from_slice(&[0u8; 16]); // signature + frame.extend_from_slice(&[0u8; 16]); // nonce + frame.extend_from_slice(&64u32.to_le_bytes()); // original_message_size + frame.extend_from_slice(&0u16.to_le_bytes()); // reserved + frame.extend_from_slice(&1u16.to_le_bytes()); // flags (Encrypted) + frame.extend_from_slice(&0xDEADu64.to_le_bytes()); // session_id + // Garbage ciphertext — will fail GCM auth on decrypt. + frame.extend_from_slice(&[0xAAu8; 64]); + mock.queue_response(frame); + + // Await the waiter with a short timeout. If Phase 3's fix is in + // place, the receiver task tears down on decrypt failure and the + // waiter resolves with Err(Disconnected) quickly. Without the + // fix, the receiver task `log+continue`s, the waiter hangs, and + // the timeout fires (test fails). + let result = tokio::time::timeout(Duration::from_secs(2), await_frame(rx)).await; + + assert!( + result.is_ok(), + "waiter hung forever on a decrypt-failed frame — Phase 3's silent-discard \ + fix must tear down the connection on unrecoverable frame errors and propagate \ + Err(Disconnected) to pending waiters. Instead the receiver task silently discards \ + the frame and the waiter never resolves. (P3.4 fixes this.)" + ); + let waiter_result = result.unwrap(); + assert!( + waiter_result.is_err(), + "waiter should return an error on decrypt failure, not Ok" + ); + } + + // ── CANCEL tests (pitfall #7) ──────────────────────────────────── + + #[tokio::test] + async fn send_cancel_does_not_consume_credit_or_advance_message_id() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.set_next_message_id(10); + conn.set_credits(5); + + conn.send_cancel(MessageId(7), None).await.unwrap(); + + // MessageId should NOT have advanced. + assert_eq!(conn.next_message_id(), 10); + // Credits should NOT have been consumed. + assert_eq!(conn.credits(), 5); + } + + #[tokio::test] + async fn send_cancel_sync_uses_original_message_id() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.set_session_id(SessionId(0xAAAA)); + + conn.send_cancel(MessageId(42), None).await.unwrap(); + + let sent = mock.sent_message(0).unwrap(); + let mut cursor = ReadCursor::new(&sent); + let header = Header::unpack(&mut cursor).unwrap(); + + assert_eq!(header.command, Command::Cancel); + assert_eq!(header.message_id, MessageId(42)); + assert_eq!(header.credit_charge, CreditCharge(0)); + assert_eq!(header.credits, 0); + assert_eq!(header.session_id, SessionId(0xAAAA)); + assert!(!header.flags.is_async()); + + // Body should be CancelRequest: StructureSize=4, Reserved=0. + assert_eq!(sent.len(), Header::SIZE + 4); + let body_structure_size = u16::from_le_bytes(sent[64..66].try_into().unwrap()); + assert_eq!(body_structure_size, 4); + } + + #[tokio::test] + async fn send_cancel_async_sets_async_flag_and_async_id() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.set_session_id(SessionId(0xBBBB)); + + let async_id = 0x1234_5678_9ABC_DEF0u64; + conn.send_cancel(MessageId(99), Some(async_id)) + .await + .unwrap(); + + let sent = mock.sent_message(0).unwrap(); + let mut cursor = ReadCursor::new(&sent); + let header = Header::unpack(&mut cursor).unwrap(); + + assert_eq!(header.command, Command::Cancel); + assert_eq!(header.message_id, MessageId(99)); + assert!(header.flags.is_async()); + assert_eq!(header.async_id, Some(async_id)); + assert_eq!(header.tree_id, None); + assert_eq!(header.credit_charge, CreditCharge(0)); + assert_eq!(header.credits, 0); + } + + #[tokio::test] + async fn send_cancel_signs_message_when_signing_active() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + let key = vec![0xCC; 16]; + conn.activate_signing(key, SigningAlgorithm::HmacSha256); + conn.set_session_id(SessionId(0xDDDD)); + + conn.send_cancel(MessageId(50), None).await.unwrap(); + + let sent = mock.sent_message(0).unwrap(); + + // Verify the signed flag is set. + let flags = u32::from_le_bytes(sent[16..20].try_into().unwrap()); + assert!(flags & HeaderFlags::SIGNED != 0, "CANCEL should be signed"); + + // Verify the signature is non-zero. + let sig = &sent[48..64]; + assert_ne!(sig, &[0u8; 16], "signature should not be all zeros"); + } + + // ── Encryption tests ───────────────────────────────────────────── + + #[tokio::test] + async fn no_encryption_when_not_activated() { + use crate::msg::echo::EchoRequest; + + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.set_test_params(NegotiatedParams { + dialect: Dialect::Smb3_1_1, + max_read_size: 65536, + max_write_size: 65536, + max_transact_size: 65536, + server_guid: Guid::ZERO, + signing_required: false, + capabilities: Capabilities::default(), + gmac_negotiated: false, + cipher: Some(Cipher::Aes128Gcm), + compression_supported: false, + }); + conn.set_session_id(SessionId(1)); + conn.set_credits(5); + + let _ = tokio::time::timeout( + std::time::Duration::from_millis(50), + conn.execute(Command::Echo, &EchoRequest, None), + ) + .await; + + // Without encryption activated, the sent bytes should start with + // the normal SMB2 protocol ID (0xFE). + let sent = mock.sent_message(0).unwrap(); + assert_eq!( + sent[0], 0xFE, + "without encryption, message must start with 0xFE" + ); + } + + #[tokio::test] + async fn activate_encryption_sets_state() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + assert!(!conn.should_encrypt()); + + conn.activate_encryption(vec![0x42; 16], vec![0x42; 16], Cipher::Aes128Gcm); + + assert!(conn.should_encrypt()); + } + + // ── DFS flag tests ───────────────────────────────────────────────── + + #[tokio::test] + async fn dfs_flag_set_for_registered_tree() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.set_credits(256); + + let tree_id = TreeId(7); + conn.register_dfs_tree(tree_id); + + use crate::msg::echo::EchoRequest; + let body = EchoRequest; + // Fire execute with a short timeout so the test doesn't block on a + // response that never comes — we only care about the sent bytes. + let _ = tokio::time::timeout( + std::time::Duration::from_millis(50), + conn.execute_with_credits(Command::Echo, &body, Some(tree_id), CreditCharge(1)), + ) + .await; + let msg_bytes = mock.sent_message(0).expect("one send recorded"); + + // Header flags are at bytes 16..20 (little-endian u32). + let flags_raw = u32::from_le_bytes(msg_bytes[16..20].try_into().unwrap()); + assert_ne!( + flags_raw & HeaderFlags::DFS_OPERATIONS, + 0, + "DFS_OPERATIONS flag must be set for registered tree" + ); + } + + #[tokio::test] + async fn dfs_flag_not_set_for_unregistered_tree() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.set_credits(256); + + use crate::msg::echo::EchoRequest; + let body = EchoRequest; + let _ = tokio::time::timeout( + std::time::Duration::from_millis(50), + conn.execute_with_credits(Command::Echo, &body, Some(TreeId(7)), CreditCharge(1)), + ) + .await; + let msg_bytes = mock.sent_message(0).expect("one send recorded"); + + let flags_raw = u32::from_le_bytes(msg_bytes[16..20].try_into().unwrap()); + assert_eq!( + flags_raw & HeaderFlags::DFS_OPERATIONS, + 0, + "DFS_OPERATIONS flag must NOT be set for unregistered tree" + ); + } + + #[tokio::test] + async fn dfs_flag_cleared_after_deregister() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.set_credits(256); + + let tree_id = TreeId(7); + conn.register_dfs_tree(tree_id); + conn.deregister_dfs_tree(tree_id); + + use crate::msg::echo::EchoRequest; + let body = EchoRequest; + let _ = tokio::time::timeout( + std::time::Duration::from_millis(50), + conn.execute_with_credits(Command::Echo, &body, Some(tree_id), CreditCharge(1)), + ) + .await; + let msg_bytes = mock.sent_message(0).expect("one send recorded"); + + let flags_raw = u32::from_le_bytes(msg_bytes[16..20].try_into().unwrap()); + assert_eq!( + flags_raw & HeaderFlags::DFS_OPERATIONS, + 0, + "DFS_OPERATIONS flag must NOT be set after deregister" + ); + } + + // ── Phase 3 A.1: Connection: Clone ─────────────────────────────── + + /// Confirms clones share the same connection-wide state via `Arc`. + /// + /// Design note (Option A from `docs/specs/connection-actor.md` review): + /// a cloned `Connection` starts with an EMPTY caller-local `pending_fifo`. + /// `oneshot::Receiver` isn't `Clone`, and in-flight waiters belong to + /// the task that sent the request — a new clone is a fresh sender + /// handle to the same actor, not a snapshot. Credits, session id, + /// negotiated params, and crypto state are shared. + #[tokio::test] + async fn connection_is_cloneable_and_clones_share_state() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let mut original = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + // Mutate shared state on the original. + original.set_credits(42); + original.set_session_id(SessionId(0x1234_5678_9ABC_DEF0)); + original.set_next_message_id(100); + + // Clone and verify the clone sees the same shared state. + let cloned = original.clone(); + assert_eq!(cloned.credits(), 42); + assert_eq!(cloned.session_id(), SessionId(0x1234_5678_9ABC_DEF0)); + assert_eq!(cloned.next_message_id(), 100); + assert_eq!(cloned.server_name(), "test-server"); + + // Mutate via the clone and verify the original observes it too. + cloned.inner.credits.store(7, Ordering::Release); + assert_eq!(original.credits(), 7); + + // Phase 3 A.3 removed the caller-local `pending_fifo`; there is no + // per-clone state anymore. Clones share `Arc` exclusively. + } + + // ── Phase 3 A.2: `execute` / `execute_with_credits` / `execute_compound` ── + // + // These tests exercise the additive concurrent-op API. All callers take + // `&self`, so the orphan filter stays ENABLED (production behavior). Mock + // responses hardcode the MessageIds that `execute` allocates, starting at 0 + // by default (or a specific `set_next_message_id` for multi-op tests). + + /// Build an ECHO response with a specific MessageId. + fn build_echo_response_with_msg_id(msg_id: MessageId) -> Vec { + let mut h = Header::new_request(Command::Echo); + h.flags.set_response(); + h.credits = 10; + h.message_id = msg_id; + pack_message(&h, &crate::msg::echo::EchoResponse) + } + + /// Queue a response AFTER the spawned task has sent its request (and + /// thus registered its waiter). Using `multi_thread` so the receiver + /// task can race the test task — catching any regression where the + /// orphan filter silently drops the response. + #[tokio::test(flavor = "multi_thread")] + async fn execute_returns_correct_frame_for_sent_request() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + + let conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + // Spawn the execute first. `execute` allocates msg_id=0. + let c = conn.clone(); + let handle = tokio::spawn(async move { + c.execute(Command::Echo, &crate::msg::echo::EchoRequest, None) + .await + }); + + // Wait for the send to land, then queue the response. + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while mock.sent_count() < 1 { + if std::time::Instant::now() > deadline { + panic!("execute task did not send its request in 5s"); + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + mock.queue_response(build_echo_response_with_msg_id(MessageId(0))); + + let frame = handle.await.unwrap().unwrap(); + + assert_eq!(frame.header.command, Command::Echo); + assert_eq!(frame.header.message_id, MessageId(0)); + assert!(frame.header.is_response()); + // Body should unpack as EchoResponse. + let mut cursor = ReadCursor::new(&frame.body); + crate::msg::echo::EchoResponse::unpack(&mut cursor).unwrap(); + + mock.assert_fully_consumed(); + } + + /// N concurrent `execute` calls on clones of the same `Connection` all + /// succeed — the receiver task's per-MessageId routing delivers each + /// response to its own waiter. Needs a multi-threaded runtime so the + /// receiver task can make progress while the task-under-test runs. + /// + /// Gotcha/Why: we MUST spawn the tasks first and wait for all N sends + /// to register waiters before queuing responses. The receiver task + /// starts reading `mock` immediately after `from_transport`. If we + /// pre-queue all N responses, the receiver races the spawned tasks — + /// any response whose msg_id hasn't had its waiter registered yet is + /// dropped by the orphan filter (enabled by default in production + /// mode), and the task hangs forever waiting for a response that's + /// already been discarded. This ordering reflects the production + /// reality: responses always arrive AFTER the client sent them. + #[tokio::test(flavor = "multi_thread")] + async fn concurrent_execute_on_one_connection_all_succeed() { + const N: u64 = 20; + + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + + let conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + // Spawn into a JoinSet so a timeout-side panic can introspect + // which tasks haven't returned yet (`set.len()`). Plain + // `Vec` moves into the await loop and we lose that. + let mut set = tokio::task::JoinSet::new(); + for _ in 0..N { + let c = conn.clone(); + set.spawn(async move { + c.execute(Command::Echo, &crate::msg::echo::EchoRequest, None) + .await + }); + } + + // Wait until all N requests have been sent AND all waiters are + // registered. Poll `sent_count` rather than hardcode a sleep. + // `execute` registers the waiter BEFORE calling `sender.send`, + // so `sent_count >= N` implies all N waiters are live. + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while mock.sent_count() < N as usize { + if std::time::Instant::now() > deadline { + panic!( + "tasks did not send all {} requests in 5s (got {})", + N, + mock.sent_count() + ); + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + // Queue N responses with msg_id=0 so auto-rewrite pairs each one + // with the FIFO head of `pending_sent_msg_ids` — i.e. with whatever + // msg_id was actually sent on the wire in that position. Hard-coding + // 0..N here was the original bug: the multi_thread runtime can send + // requests in any order, so the spawned tasks' allocated msg_ids + // don't line up with 0..N in send order. Auto-rewrite would then + // overwrite the i=0 response with FIFO[0] (say msg_id=5) and route + // it to waiter 5, then keep i=5's hard-coded msg_id=5 and re-route + // to the now-removed waiter 5 — the second arrival counted as + // stray, and the original waiter-5's task hung forever waiting for + // a response that was redirected to waiter 0 (which doesn't exist + // in the FIFO ordering). + for _ in 0..N { + mock.queue_response(build_echo_response_with_msg_id(MessageId(0))); + } + + // Drain the set with a hard timeout. The test has hung on multiple + // CI runners (Ubuntu stable, Windows-2025 rust 1.85, macos-1.85 + // historically). Until the root cause is understood, a hang turns + // into a 30 s clean failure instead of a multi-hour CI stall — and + // the panic dumps the smoking-gun state so the next failure has + // diagnostic value. + let mut got_ids: Vec = Vec::with_capacity(N as usize); + let drain = tokio::time::timeout(Duration::from_secs(30), async { + while let Some(joined) = set.join_next().await { + let frame = joined.unwrap().unwrap(); + got_ids.push(frame.header.message_id.0); + } + }) + .await; + + if drain.is_err() { + let pending = set.len(); + let m = conn.metrics(); + let waiters: Vec = { + let g = conn.inner.waiters.lock().unwrap(); + g.keys().map(|mid| mid.0).collect() + }; + let receiver_alive = !conn.inner.disconnected.load(Ordering::Acquire); + set.abort_all(); + panic!( + "concurrent_execute_on_one_connection_all_succeed exceeded 30 s.\n\ + {pending} of {N} execute() futures still pending.\n\ + receiver_alive={receiver_alive}\n\ + mock.sent_count={sent}, mock.pending_responses={pending_resps}\n\ + still-registered waiters (msg_ids): {waiters:?}\n\ + counters: requests_sent={req_sent} \ + responses_routed_ok={ok} responses_routed_err={err} \ + responses_late_after_drop={late} responses_stray={stray} \ + status_pending_loops={pl} unsolicited_notifications_received={uns}", + sent = mock.sent_count(), + pending_resps = mock.pending_responses(), + req_sent = m.requests_sent, + ok = m.responses_routed_ok, + err = m.responses_routed_err, + late = m.responses_late_after_drop, + stray = m.responses_stray, + pl = m.status_pending_loops, + uns = m.unsolicited_notifications_received, + ); + } + + got_ids.sort_unstable(); + assert_eq!(got_ids, (0..N).collect::>()); + + mock.assert_fully_consumed(); + } + + /// Dropping 2 of 5 execute futures before their responses arrive does + /// NOT corrupt the other 3: the receiver task silently discards the + /// frames routed to dropped oneshots, and the 3 surviving tasks see + /// their own responses. + #[tokio::test(flavor = "multi_thread")] + async fn dropped_execute_future_does_not_affect_others() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + + let conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + // Spawn 5 tasks. Each allocates its own MessageId in submission + // order: 0, 1, 2, 3, 4. To make allocation deterministic on the + // multi_thread runtime, wait for each task's send to land before + // spawning the next. `yield_now` alone isn't enough — on a + // multi-worker runtime, the next spawn can race the previous + // task's send and reorder msg_id allocation. + let mut handles = Vec::new(); + for idx in 0..5 { + let c = conn.clone(); + let h = tokio::spawn(async move { + c.execute(Command::Echo, &crate::msg::echo::EchoRequest, None) + .await + }); + handles.push(h); + + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while mock.sent_count() < idx + 1 { + if std::time::Instant::now() > deadline { + panic!( + "task {} did not send its request in 5s (sent_count={})", + idx, + mock.sent_count() + ); + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + + // All 5 tasks have sent; waiters registered; msg_ids = 0..5. + assert_eq!(mock.sent_count(), 5); + + // Abort tasks at indices 1 and 3 (msg_ids 1 and 3). + handles[1].abort(); + handles[3].abort(); + + // Now queue responses for all 5 msg_ids. The 2 aborted-task + // responses route to closed oneshots and get silently discarded; + // the 3 live tasks get their responses. + for i in 0..5u64 { + mock.queue_response(build_echo_response_with_msg_id(MessageId(i))); + } + + // Collect results: tasks 0, 2, 4 should complete OK; tasks 1, 3 + // return JoinError (they were aborted). + for (idx, h) in handles.into_iter().enumerate() { + let res = h.await; + if idx == 1 || idx == 3 { + assert!(res.is_err(), "task {} should have been aborted", idx); + } else { + let frame = res.unwrap().unwrap(); + assert_eq!(frame.header.command, Command::Echo); + assert_eq!(frame.header.message_id, MessageId(idx as u64)); + } + } + + // All 5 responses were consumed by the receiver task (even the 2 + // whose waiters were dropped — the task reads every frame off the + // mock regardless of waiter state). + mock.assert_fully_consumed(); + } + + /// Compound partial failure: op 1 succeeds, op 2 returns an error + /// status, op 3 succeeds. Outer result is `Ok(vec)`; inner is + /// `[Ok, Ok(with-error-status), Ok]` — the per-sub-op error is + /// encoded in `frame.header.status`, not in the inner `Result`, + /// because the server returned a well-formed frame for every op. + #[tokio::test(flavor = "multi_thread")] + async fn execute_compound_partial_failure_routes_correctly() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + + // 3-op compound. `execute_compound` allocates msg_ids 0, 1, 2. + let echo_ok_0 = build_echo_response_with_msg_id(MessageId(0)); + let mut err_hdr = Header::new_request(Command::Echo); + err_hdr.flags.set_response(); + err_hdr.credits = 10; + err_hdr.message_id = MessageId(1); + err_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let err_body = pack_message( + &err_hdr, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + let echo_ok_2 = build_echo_response_with_msg_id(MessageId(2)); + + let compound_response = build_compound_response_frame(&[echo_ok_0, err_body, echo_ok_2]); + + let conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + let c = conn.clone(); + let handle = tokio::spawn(async move { + let ops = [ + CompoundOp::new(Command::Echo, &crate::msg::echo::EchoRequest, None), + CompoundOp::new(Command::Echo, &crate::msg::echo::EchoRequest, None), + CompoundOp::new(Command::Echo, &crate::msg::echo::EchoRequest, None), + ]; + c.execute_compound(&ops).await + }); + + // Wait for the compound request to land on the wire — one send + // for all 3 sub-ops — then queue the response. All 3 waiters + // are registered before the send, so the single compound-reply + // frame routes to all of them. + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while mock.sent_count() < 1 { + if std::time::Instant::now() > deadline { + panic!("execute_compound did not send in 5s"); + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + mock.queue_response(compound_response); + + let results = handle.await.unwrap().unwrap(); + + assert_eq!(results.len(), 3); + let f0 = results[0].as_ref().expect("op 0 should be Ok"); + assert_eq!(f0.header.status, NtStatus::SUCCESS); + assert_eq!(f0.header.message_id, MessageId(0)); + + let f1 = results[1] + .as_ref() + .expect("op 1 still carries a Frame — error status in header"); + assert_eq!(f1.header.status, NtStatus::OBJECT_NAME_NOT_FOUND); + assert_eq!(f1.header.message_id, MessageId(1)); + + let f2 = results[2].as_ref().expect("op 2 should be Ok"); + assert_eq!(f2.header.status, NtStatus::SUCCESS); + assert_eq!(f2.header.message_id, MessageId(2)); + + mock.assert_fully_consumed(); + } + + /// Using a clone after the original is dropped: the `Arc` keeps + /// the receiver task alive. Specifically for `execute` (the A.1 test + /// only exercised direct `sender.send`). + #[tokio::test(flavor = "multi_thread")] + async fn execute_on_clone_works_after_original_dropped() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + + let original = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + let cloned = original.clone(); + drop(original); + + let c = cloned.clone(); + let handle = tokio::spawn(async move { + c.execute(Command::Echo, &crate::msg::echo::EchoRequest, None) + .await + }); + + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while mock.sent_count() < 1 { + if std::time::Instant::now() > deadline { + panic!("execute on clone did not send in 5s"); + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + mock.queue_response(build_echo_response_with_msg_id(MessageId(0))); + + let frame = handle.await.unwrap().unwrap(); + assert_eq!(frame.header.command, Command::Echo); + assert_eq!(frame.header.message_id, MessageId(0)); + + mock.assert_fully_consumed(); + } + + /// A clone'd `Connection` survives the original being dropped: the + /// receiver task and transport sender are behind `Arc`, so + /// dropping the last Arc (not the first) is what aborts the task. + #[tokio::test] + async fn connection_is_cloneable_clone_outlives_original() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let mut original = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + original.set_credits(9); + + let cloned = original.clone(); + drop(original); + + // Shared state still accessible — the receiver task is still live + // because the clone holds an `Arc`. + assert_eq!(cloned.credits(), 9); + assert_eq!(cloned.server_name(), "test-server"); + + // Send should still work: the transport's send half lives on Inner. + // We won't register a waiter (no response queued), just verify the + // send path doesn't panic on a dead-task-map. + // (Easier: send_cancel has no waiter registration.) + cloned + .inner + .sender + .send(b"\x00\x00\x00\x10ignore-me") + .await + .unwrap(); + assert_eq!(mock.sent_count(), 1); + } +} diff --git a/vendor/smb2/src/client/dfs.rs b/vendor/smb2/src/client/dfs.rs new file mode 100644 index 0000000..46da44e --- /dev/null +++ b/vendor/smb2/src/client/dfs.rs @@ -0,0 +1,884 @@ +//! DFS referral IOCTL helper and path resolver with referral cache. +//! +//! Sends `FSCTL_DFS_GET_REFERRALS` via IOCTL to resolve DFS paths. Connects +//! to IPC$ for the IOCTL exchange, similar to how `shares.rs` does for RPC. +//! +//! The [`DfsResolver`] caches referral responses with TTL and resolves UNC +//! paths using longest-prefix matching. All string comparisons are +//! case-insensitive (DFS paths are case-insensitive per MS-DFSC). + +// DFS resolver is used by SmbClient for reactive DFS path resolution. + +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +use log::debug; + +use crate::client::connection::Connection; +use crate::error::Result; +use crate::msg::dfs::{ReqGetDfsReferral, RespGetDfsReferral}; +use crate::msg::ioctl::{ + IoctlRequest, IoctlResponse, FSCTL_DFS_GET_REFERRALS, SMB2_0_IOCTL_IS_FSCTL, +}; +use crate::msg::tree_connect::{TreeConnectRequest, TreeConnectRequestFlags, TreeConnectResponse}; +use crate::msg::tree_disconnect::TreeDisconnectRequest; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::status::NtStatus; +use crate::types::{Command, FileId, TreeId}; +use crate::Error; + +/// Maximum output buffer size for DFS referral responses (8 KiB). +const DFS_MAX_OUTPUT_RESPONSE: u32 = 8192; + +/// Send a DFS referral request and return the parsed response. +/// +/// Connects to IPC$ (or reuses an existing tree), sends +/// `FSCTL_DFS_GET_REFERRALS` via IOCTL with `FileId::SENTINEL`, and +/// parses the response. +/// +/// The `path` should be a UNC-style path with a single leading backslash +/// (for example, `\server\share\dir`). +pub(crate) async fn get_dfs_referral( + conn: &mut Connection, + path: &str, +) -> Result { + // 1. Tree-connect to IPC$ + let tree_id = tree_connect_ipc(conn).await?; + + // Send the IOCTL, then clean up regardless of outcome + let result = send_dfs_ioctl(conn, tree_id, path).await; + + // Tree-disconnect IPC$ (best-effort -- don't mask the real error) + let _ = tree_disconnect(conn, tree_id).await; + + result +} + +/// Connect to the IPC$ share, returning the tree ID. +async fn tree_connect_ipc(conn: &mut Connection) -> Result { + let server = conn.server_name().to_string(); + let unc_path = format!(r"\\{}\IPC$", server); + + let req = TreeConnectRequest { + flags: TreeConnectRequestFlags::default(), + path: unc_path, + }; + + let frame = conn.execute(Command::TreeConnect, &req, None).await?; + + if frame.header.command != Command::TreeConnect { + return Err(Error::invalid_data(format!( + "expected TreeConnect response, got {:?}", + frame.header.command + ))); + } + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::TreeConnect, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let _resp = TreeConnectResponse::unpack(&mut cursor)?; + + let tree_id = frame + .header + .tree_id + .ok_or_else(|| Error::invalid_data("TreeConnect response missing tree ID"))?; + + debug!("dfs: connected to IPC$, tree_id={}", tree_id); + Ok(tree_id) +} + +/// Build and send the FSCTL_DFS_GET_REFERRALS IOCTL, parse the response. +async fn send_dfs_ioctl( + conn: &mut Connection, + tree_id: TreeId, + path: &str, +) -> Result { + // Build the referral request payload + let referral_req = ReqGetDfsReferral { + max_referral_level: 4, + request_file_name: path.to_string(), + }; + let mut req_cursor = WriteCursor::new(); + referral_req.pack(&mut req_cursor); + let input_data = req_cursor.into_inner(); + + debug!( + "dfs: sending FSCTL_DFS_GET_REFERRALS for {:?} ({} bytes input)", + path, + input_data.len() + ); + + // Build the IOCTL request + let ioctl_req = IoctlRequest { + ctl_code: FSCTL_DFS_GET_REFERRALS, + file_id: FileId::SENTINEL, + max_input_response: 0, + max_output_response: DFS_MAX_OUTPUT_RESPONSE, + flags: SMB2_0_IOCTL_IS_FSCTL, + input_data, + }; + + let frame = conn + .execute(Command::Ioctl, &ioctl_req, Some(tree_id)) + .await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Ioctl, + }); + } + + // Parse the IOCTL response envelope + let mut cursor = ReadCursor::new(&frame.body); + let ioctl_resp = IoctlResponse::unpack(&mut cursor)?; + + debug!( + "dfs: received IOCTL response ({} bytes output)", + ioctl_resp.output_data.len() + ); + + // Parse the DFS referral from the output buffer + let mut ref_cursor = ReadCursor::new(&ioctl_resp.output_data); + let referral_resp = RespGetDfsReferral::unpack(&mut ref_cursor)?; + + debug!( + "dfs: parsed {} referral entries (path_consumed={})", + referral_resp.entries.len(), + referral_resp.path_consumed + ); + + Ok(referral_resp) +} + +/// Disconnect from a tree. +async fn tree_disconnect(conn: &mut Connection, tree_id: TreeId) -> Result<()> { + let body = TreeDisconnectRequest; + let frame = conn + .execute(Command::TreeDisconnect, &body, Some(tree_id)) + .await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::TreeDisconnect, + }); + } + + debug!("dfs: disconnected from IPC$"); + Ok(()) +} + +// ── DFS resolver types ─────────────────────────────────────────────── + +/// A resolved DFS path ready for connection. +#[derive(Debug, Clone)] +pub(crate) struct ResolvedPath { + /// Server hostname (or IP) to connect to. + pub server: String, + /// Port to connect on (default 445). + pub port: u16, + /// Share name to tree-connect. + pub share: String, + /// Remaining path within the share (may be empty). + pub remaining_path: String, +} + +/// A single DFS target from a referral response. +#[derive(Debug, Clone)] +struct DfsTarget { + /// Server hostname from the network_address field. + server: String, + /// Share name from the network_address field. + share: String, + /// Any remaining path suffix from the network_address. + remaining_prefix: String, +} + +/// A cached DFS referral entry with TTL. +#[derive(Debug, Clone)] +struct CachedReferral { + /// The DFS path prefix this referral covers (lowercase for matching). + dfs_path_prefix: String, + /// Available targets (first is preferred). + targets: Vec, + /// When this entry expires. + expires_at: Instant, +} + +/// DFS referral cache and path resolver. +/// +/// Maintains a cache of DFS referral responses keyed by path prefix. +/// Resolves UNC paths by longest-prefix matching against the cache, +/// falling back to an IOCTL referral request on cache miss. +pub(crate) struct DfsResolver { + cache: HashMap, + /// Counters surfaced through [`SmbClient::diagnostics`]. + cache_hits: AtomicU64, + referrals_resolved: AtomicU64, +} + +impl DfsResolver { + /// Create a new empty resolver. + pub fn new() -> Self { + Self { + cache: HashMap::new(), + cache_hits: AtomicU64::new(0), + referrals_resolved: AtomicU64::new(0), + } + } + + /// `(cache_hits, referrals_resolved)` for diagnostics. + pub(crate) fn counters(&self) -> (u64, u64) { + ( + self.cache_hits.load(Ordering::Relaxed), + self.referrals_resolved.load(Ordering::Relaxed), + ) + } + + /// Iterate the cache entries (including expired ones — eviction is + /// lazy). Used by [`SmbClient::diagnostics`]. + pub(crate) fn cache_entries(&self) -> Vec { + let now = Instant::now(); + self.cache + .values() + .map(|e| crate::client::diagnostics::DfsCacheEntry { + path_prefix: e.dfs_path_prefix.clone(), + target_count: e.targets.len(), + expires_in: if e.expires_at > now { + Some(e.expires_at - now) + } else { + None + }, + }) + .collect() + } + + /// Resolve a UNC path by checking the cache first, then querying the server. + /// + /// `unc_path` should be like `\\server\share\path\to\file`. + /// `conn` is the connection to the server that returned `STATUS_PATH_NOT_COVERED`. + pub async fn resolve( + &mut self, + conn: &mut Connection, + unc_path: &str, + ) -> Result> { + // 1. Check cache (longest prefix match) + if let Some(resolved) = self.resolve_from_cache(unc_path) { + self.cache_hits.fetch_add(1, Ordering::Relaxed); + debug!("dfs: cache hit for {:?}", unc_path); + return Ok(resolved); + } + + // 2. Send referral request. + // Convert \\server\share\path to \server\share\path (single leading + // backslash for the IOCTL). + let referral_path = if unc_path.starts_with("\\\\") { + &unc_path[1..] // strip one leading backslash + } else { + unc_path + }; + + debug!("dfs: cache miss, sending referral for {:?}", referral_path); + let resp = get_dfs_referral(conn, referral_path).await?; + self.referrals_resolved.fetch_add(1, Ordering::Relaxed); + + // 3. Cache the result + self.cache_referral(&resp); + + // 4. Resolve from the freshly cached entry + self.resolve_from_cache(unc_path).ok_or_else(|| { + Error::invalid_data("DFS referral response did not match the requested path") + }) + } + + /// Try to resolve a path from the cache. Returns `None` on cache miss or + /// expiry. Returns a `Vec` of [`ResolvedPath`]s (multiple targets for + /// failover). + pub(crate) fn resolve_from_cache(&self, unc_path: &str) -> Option> { + let normalized = unc_path.to_lowercase().replace('/', "\\"); + + // Longest prefix match + let mut best_match: Option<&CachedReferral> = None; + for entry in self.cache.values() { + if normalized.starts_with(&entry.dfs_path_prefix) + && entry.expires_at > Instant::now() + && best_match.is_none_or(|b| entry.dfs_path_prefix.len() > b.dfs_path_prefix.len()) + { + best_match = Some(entry); + } + } + + let entry = best_match?; + + // Strip the consumed prefix and build ResolvedPaths + let remaining = &normalized[entry.dfs_path_prefix.len()..]; + let remaining = remaining.trim_start_matches('\\'); + + let resolved: Vec = entry + .targets + .iter() + .map(|target| { + let full_remaining = if target.remaining_prefix.is_empty() { + remaining.to_string() + } else if remaining.is_empty() { + target.remaining_prefix.clone() + } else { + format!("{}\\{}", target.remaining_prefix, remaining) + }; + + ResolvedPath { + server: target.server.clone(), + port: 445, + share: target.share.clone(), + remaining_path: full_remaining, + } + }) + .collect(); + + Some(resolved) + } + + /// Store a referral response in the cache. + fn cache_referral(&mut self, resp: &RespGetDfsReferral) { + if resp.entries.is_empty() { + return; + } + + // Use the dfs_path from the first entry as the cache key. + // Normalize to lowercase backslash form with `\\` prefix (UNC canonical). + let mut dfs_path_prefix = resp.entries[0].dfs_path.to_lowercase().replace('/', "\\"); + if !dfs_path_prefix.starts_with("\\\\") { + if let Some(stripped) = dfs_path_prefix.strip_prefix('\\') { + dfs_path_prefix = format!("\\\\{stripped}"); + } + } + + // Parse targets from entries + let targets: Vec = resp + .entries + .iter() + .filter_map(|e| parse_unc_target(&e.network_address)) + .collect(); + + if targets.is_empty() { + return; + } + + let ttl = resp.entries[0].ttl.max(1); // At least 1 second + + debug!( + "dfs: caching {:?} with {} targets, ttl={}s", + dfs_path_prefix, + targets.len(), + ttl + ); + + self.cache.insert( + dfs_path_prefix.clone(), + CachedReferral { + dfs_path_prefix, + targets, + expires_at: Instant::now() + Duration::from_secs(ttl as u64), + }, + ); + } +} + +/// Parse a UNC network_address into server, share, and remaining path. +/// +/// Input: `\\server\share` or `\\server\share\path`. +/// Returns `None` if the format is invalid. +fn parse_unc_target(network_address: &str) -> Option { + let path = network_address.trim_start_matches('\\'); + let mut parts = path.splitn(3, '\\'); + let server = parts.next()?.to_string(); + let share = parts.next()?.to_string(); + let remaining_prefix = parts.next().unwrap_or("").to_string(); + + if server.is_empty() || share.is_empty() { + return None; + } + + Some(DfsTarget { + server, + share, + remaining_prefix, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::connection::pack_message; + use crate::client::test_helpers::{build_tree_connect_response, setup_connection}; + use crate::msg::header::{ErrorResponse, Header}; + use crate::msg::ioctl::IoctlResponse as IoctlResp; + use crate::msg::tree_connect::ShareType; + use crate::msg::tree_disconnect::TreeDisconnectResponse; + use crate::transport::MockTransport; + use crate::types::TreeId; + use std::sync::Arc; + + /// Build an IOCTL response containing the given output data. + fn build_ioctl_response(output_data: Vec) -> Vec { + let mut h = Header::new_request(Command::Ioctl); + h.flags.set_response(); + h.credits = 32; + + let body = IoctlResp { + ctl_code: FSCTL_DFS_GET_REFERRALS, + file_id: FileId::SENTINEL, + flags: SMB2_0_IOCTL_IS_FSCTL, + output_data, + }; + + pack_message(&h, &body) + } + + /// Build an IOCTL error response with the given status. + fn build_ioctl_error_response(status: NtStatus) -> Vec { + let mut h = Header::new_request(Command::Ioctl); + h.flags.set_response(); + h.credits = 32; + h.status = status; + + let body = ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + + pack_message(&h, &body) + } + + /// Build a TREE_DISCONNECT response. + fn build_tree_disconnect_response() -> Vec { + let mut h = Header::new_request(Command::TreeDisconnect); + h.flags.set_response(); + h.credits = 32; + pack_message(&h, &TreeDisconnectResponse) + } + + /// Pack a known DFS referral response into bytes. + /// + /// Builds a V3 referral with the given entries. + fn pack_dfs_referral_response( + path_consumed: u16, + header_flags: u32, + entries: &[(&str, &str, &str, u32)], // (dfs_path, alt_path, net_addr, ttl) + ) -> Vec { + // We build a V3 referral response manually. + // Entry fixed size: 4 (version+size) + 2+2+4 (server_type+flags+ttl) + // + 2+2+2 (offsets) + 16 (guid) = 34 bytes + let entry_fixed_size: u16 = 34; + let num_entries = entries.len() as u16; + let total_fixed = entry_fixed_size * num_entries; + + // Pre-compute all string bytes + let entry_strings: Vec<(Vec, Vec, Vec)> = entries + .iter() + .map(|(dfs, alt, net, _)| { + ( + encode_null_utf16(dfs), + encode_null_utf16(alt), + encode_null_utf16(net), + ) + }) + .collect(); + + // Compute cumulative string offsets relative to each entry's start. + // All strings come after all fixed entries. The offset for entry i + // is relative to entry i's start position. + let mut buf = Vec::new(); + + // Response header (8 bytes) + buf.extend_from_slice(&path_consumed.to_le_bytes()); + buf.extend_from_slice(&num_entries.to_le_bytes()); + buf.extend_from_slice(&header_flags.to_le_bytes()); + + // Calculate where strings start (after all fixed entries, but + // offsets are measured from the start of the entry data, not from + // the response header -- since RespGetDfsReferral::unpack reads + // the header first and then works with the remaining bytes). + // + // Actually, offsets in V3 entries are relative to the entry start + // within the entry data buffer. + + // Accumulate string buffer contents and compute per-entry offsets. + let mut string_buf = Vec::new(); + let mut per_entry_offsets = Vec::new(); + + for (i, (dfs_bytes, alt_bytes, net_bytes)) in entry_strings.iter().enumerate() { + let entry_start = i as u16 * entry_fixed_size; + let strings_base = total_fixed + string_buf.len() as u16; + + let dfs_offset = strings_base - entry_start; + let alt_offset = dfs_offset + dfs_bytes.len() as u16; + let net_offset = alt_offset + alt_bytes.len() as u16; + + per_entry_offsets.push((dfs_offset, alt_offset, net_offset)); + + string_buf.extend_from_slice(dfs_bytes); + string_buf.extend_from_slice(alt_bytes); + string_buf.extend_from_slice(net_bytes); + } + + // Write fixed entries + for (i, (_, _, _, ttl)) in entries.iter().enumerate() { + let (dfs_off, alt_off, net_off) = per_entry_offsets[i]; + + buf.extend_from_slice(&3u16.to_le_bytes()); // version = 3 + buf.extend_from_slice(&entry_fixed_size.to_le_bytes()); // size + buf.extend_from_slice(&0u16.to_le_bytes()); // server_type + buf.extend_from_slice(&0u16.to_le_bytes()); // referral_entry_flags + buf.extend_from_slice(&ttl.to_le_bytes()); // ttl + buf.extend_from_slice(&dfs_off.to_le_bytes()); + buf.extend_from_slice(&alt_off.to_le_bytes()); + buf.extend_from_slice(&net_off.to_le_bytes()); + buf.extend_from_slice(&[0u8; 16]); // service_site_guid + } + + // Write string buffer + buf.extend_from_slice(&string_buf); + + buf + } + + /// Encode a string as null-terminated UTF-16LE bytes. + fn encode_null_utf16(s: &str) -> Vec { + let mut out = Vec::new(); + for cu in s.encode_utf16() { + out.extend_from_slice(&cu.to_le_bytes()); + } + out.extend_from_slice(&[0x00, 0x00]); + out + } + + #[tokio::test] + async fn dfs_referral_ioctl_flow() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + let tree_id = TreeId(99); + + // Build the DFS referral payload + let referral_bytes = pack_dfs_referral_response( + 48, // path_consumed + 0x02, // header_flags (StorageServers) + &[ + ( + r"\domain\dfs\docs", + r"\domain\dfs\docs", + r"\server1\share", + 600, + ), + ( + r"\domain\dfs\docs", + r"\domain\dfs\docs", + r"\server2\share", + 300, + ), + ], + ); + + // Queue responses: TreeConnect, IOCTL, TreeDisconnect + mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe)); + mock.queue_response(build_ioctl_response(referral_bytes)); + mock.queue_response(build_tree_disconnect_response()); + + let resp = get_dfs_referral(&mut conn, r"\domain\dfs\docs") + .await + .unwrap(); + + assert_eq!(resp.path_consumed, 48); + assert_eq!(resp.header_flags, 0x02); + assert_eq!(resp.entries.len(), 2); + + assert_eq!(resp.entries[0].version, 3); + assert_eq!(resp.entries[0].dfs_path, r"\domain\dfs\docs"); + assert_eq!(resp.entries[0].network_address, r"\server1\share"); + assert_eq!(resp.entries[0].ttl, 600); + + assert_eq!(resp.entries[1].network_address, r"\server2\share"); + assert_eq!(resp.entries[1].ttl, 300); + + // Should have sent 3 messages: TreeConnect, IOCTL, TreeDisconnect + assert_eq!(mock.sent_count(), 3); + } + + #[tokio::test] + async fn dfs_referral_ioctl_error() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + let tree_id = TreeId(99); + + // Queue responses: TreeConnect, IOCTL error, TreeDisconnect + mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe)); + mock.queue_response(build_ioctl_error_response(NtStatus::NOT_FOUND)); + mock.queue_response(build_tree_disconnect_response()); + + let result = get_dfs_referral(&mut conn, r"\nonexistent\path").await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + match &err { + Error::Protocol { status, command } => { + assert_eq!(*status, NtStatus::NOT_FOUND); + assert_eq!(*command, Command::Ioctl); + } + other => panic!("expected Protocol error, got: {other:?}"), + } + + // Should still send TreeDisconnect even after IOCTL error + assert_eq!(mock.sent_count(), 3); + } + + // ── parse_unc_target tests ─────────────────────────────────────── + + #[test] + fn parse_unc_target_basic() { + let t = parse_unc_target(r"\\server\share").unwrap(); + assert_eq!(t.server, "server"); + assert_eq!(t.share, "share"); + assert_eq!(t.remaining_prefix, ""); + } + + #[test] + fn parse_unc_target_with_path() { + let t = parse_unc_target(r"\\server\share\path\to").unwrap(); + assert_eq!(t.server, "server"); + assert_eq!(t.share, "share"); + assert_eq!(t.remaining_prefix, r"path\to"); + } + + #[test] + fn parse_unc_target_invalid() { + assert!(parse_unc_target(r"\\").is_none()); + assert!(parse_unc_target("").is_none()); + assert!(parse_unc_target(r"\\server").is_none()); + // Single backslash + server but no share + assert!(parse_unc_target(r"\server").is_none()); + } + + #[test] + fn parse_unc_target_single_backslash_prefix() { + // Network addresses with single backslash prefix should also work. + let t = parse_unc_target(r"\server\share").unwrap(); + assert_eq!(t.server, "server"); + assert_eq!(t.share, "share"); + assert_eq!(t.remaining_prefix, ""); + } + + #[test] + fn parse_unc_target_triple_backslash() { + // Extra leading backslashes are stripped. + let t = parse_unc_target(r"\\\server\share\path").unwrap(); + assert_eq!(t.server, "server"); + assert_eq!(t.share, "share"); + assert_eq!(t.remaining_prefix, "path"); + } + + #[test] + fn parse_unc_target_ip_address() { + // IP addresses as server names. + let t = parse_unc_target(r"\\192.168.1.100\data").unwrap(); + assert_eq!(t.server, "192.168.1.100"); + assert_eq!(t.share, "data"); + assert_eq!(t.remaining_prefix, ""); + } + + #[test] + fn parse_unc_target_deep_path() { + // The remaining prefix captures everything after server\share. + let t = parse_unc_target(r"\\server\share\a\b\c\d").unwrap(); + assert_eq!(t.server, "server"); + assert_eq!(t.share, "share"); + assert_eq!(t.remaining_prefix, r"a\b\c\d"); + } + + #[test] + fn parse_unc_target_empty_components() { + // Empty server or share should return None. + assert!(parse_unc_target(r"\\\\share").is_none()); // empty server + assert!(parse_unc_target(r"\\\").is_none()); // server is empty after strip + } + + // ── DfsResolver tests ──────────────────────────────────────────── + + /// Helper: build a RespGetDfsReferral for cache tests. + fn make_referral( + dfs_path: &str, + entries: &[(&str, u32)], // (network_address, ttl) + ) -> RespGetDfsReferral { + use crate::msg::dfs::DfsReferralEntry; + + let referral_entries: Vec = entries + .iter() + .map(|(net_addr, ttl)| DfsReferralEntry { + version: 3, + server_type: 0, + referral_entry_flags: 0, + ttl: *ttl, + dfs_path: dfs_path.to_string(), + dfs_alternate_path: dfs_path.to_string(), + network_address: net_addr.to_string(), + }) + .collect(); + + RespGetDfsReferral { + path_consumed: 0, + header_flags: 0, + entries: referral_entries, + } + } + + #[test] + fn resolver_cache_hit() { + let mut resolver = DfsResolver::new(); + + let resp = make_referral(r"\domain\dfs\docs", &[(r"\\server1\share", 600)]); + resolver.cache_referral(&resp); + + let result = resolver.resolve_from_cache(r"\\domain\dfs\docs\file.txt"); + assert!(result.is_some()); + let paths = result.unwrap(); + assert_eq!(paths.len(), 1); + assert_eq!(paths[0].server, "server1"); + assert_eq!(paths[0].share, "share"); + assert_eq!(paths[0].port, 445); + assert_eq!(paths[0].remaining_path, "file.txt"); + } + + #[test] + fn resolver_cache_miss() { + let resolver = DfsResolver::new(); + + let result = resolver.resolve_from_cache(r"\\server\share\file.txt"); + assert!(result.is_none()); + } + + #[test] + fn resolver_cache_expired() { + let mut resolver = DfsResolver::new(); + + // Insert with TTL=0 -- cache_referral clamps to 1s, so we need to + // manually insert an already-expired entry. + let targets = vec![DfsTarget { + server: "srv".to_string(), + share: "data".to_string(), + remaining_prefix: String::new(), + }]; + resolver.cache.insert( + r"\domain\dfs".to_string(), + CachedReferral { + dfs_path_prefix: r"\domain\dfs".to_string(), + targets, + expires_at: Instant::now() - Duration::from_secs(1), + }, + ); + + let result = resolver.resolve_from_cache(r"\\domain\dfs\file.txt"); + assert!(result.is_none(), "expired entry should not match"); + } + + #[test] + fn resolver_cache_longest_prefix() { + let mut resolver = DfsResolver::new(); + + // Insert a short prefix + let short = make_referral(r"\domain\dfs", &[(r"\\server1\root", 600)]); + resolver.cache_referral(&short); + + // Insert a longer prefix + let long = make_referral(r"\domain\dfs\docs", &[(r"\\server2\docs", 600)]); + resolver.cache_referral(&long); + + // Should match the longer prefix + let result = resolver + .resolve_from_cache(r"\\domain\dfs\docs\file.txt") + .unwrap(); + assert_eq!(result[0].server, "server2"); + assert_eq!(result[0].share, "docs"); + assert_eq!(result[0].remaining_path, "file.txt"); + + // A path that only matches the short prefix + let result2 = resolver + .resolve_from_cache(r"\\domain\dfs\other\file.txt") + .unwrap(); + assert_eq!(result2[0].server, "server1"); + assert_eq!(result2[0].share, "root"); + assert_eq!(result2[0].remaining_path, r"other\file.txt"); + } + + #[test] + fn resolver_multiple_targets() { + let mut resolver = DfsResolver::new(); + + let resp = make_referral( + r"\domain\dfs\docs", + &[(r"\\server1\share", 600), (r"\\server2\share", 300)], + ); + resolver.cache_referral(&resp); + + let result = resolver + .resolve_from_cache(r"\\domain\dfs\docs\file.txt") + .unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].server, "server1"); + assert_eq!(result[1].server, "server2"); + // Both should have the same remaining path + assert_eq!(result[0].remaining_path, "file.txt"); + assert_eq!(result[1].remaining_path, "file.txt"); + } + + #[test] + fn resolver_path_normalization() { + let mut resolver = DfsResolver::new(); + + // Cache with backslash-separated DFS path + let resp = make_referral(r"\domain\dfs\docs", &[(r"\\server\share", 600)]); + resolver.cache_referral(&resp); + + // Resolve with double-backslash prefix and mixed case + let result = resolver + .resolve_from_cache(r"\\DOMAIN\DFS\DOCS\Sub\File.txt") + .unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].server, "server"); + assert_eq!(result[0].share, "share"); + // remaining_path is lowercased because we normalize the full input + assert_eq!(result[0].remaining_path, r"sub\file.txt"); + + // Forward slashes should also work + let result2 = resolver + .resolve_from_cache(r"\\domain/dfs/docs/other.txt") + .unwrap(); + assert_eq!(result2[0].remaining_path, "other.txt"); + } + + #[test] + fn resolver_remaining_prefix_from_target() { + let mut resolver = DfsResolver::new(); + + // Target has a remaining prefix (network_address includes a subpath) + let resp = make_referral(r"\domain\dfs\docs", &[(r"\\server\share\subdir", 600)]); + resolver.cache_referral(&resp); + + // With additional path after the DFS prefix + let result = resolver + .resolve_from_cache(r"\\domain\dfs\docs\file.txt") + .unwrap(); + assert_eq!(result[0].remaining_path, r"subdir\file.txt"); + + // Without additional path -- just the target's remaining prefix + let result2 = resolver.resolve_from_cache(r"\\domain\dfs\docs").unwrap(); + assert_eq!(result2[0].remaining_path, "subdir"); + } +} diff --git a/vendor/smb2/src/client/diagnostics.rs b/vendor/smb2/src/client/diagnostics.rs new file mode 100644 index 0000000..6a82173 --- /dev/null +++ b/vendor/smb2/src/client/diagnostics.rs @@ -0,0 +1,1048 @@ +//! Diagnostics: an in-process observability surface for a running [`SmbClient`](crate::SmbClient). +//! +//! Call [`SmbClient::diagnostics`](crate::SmbClient::diagnostics) to capture +//! a point-in-time tree of the client's negotiated parameters, credits, +//! in-flight requests, per-connection counters, and DFS cache state. Call +//! [`Connection::diagnostics`](crate::client::Connection::diagnostics) for +//! the per-connection slice. +//! +//! ## Consistency model +//! +//! Snapshots are **eventually consistent**. Each field is loaded +//! independently: the available-credits gauge, the in-flight count +//! (`waiters.len()`), and each counter are sampled at slightly different +//! moments. Sums of related fields (for example `credits.available + +//! credits.in_flight`) are **not** invariant — read each field for what it +//! says about itself, not as a coupled tuple. A consumer that wants +//! atomicity quiesces operations first. +//! +//! ## Snapshot lock order +//! +//! The snapshot acquires these locks, one at a time, in this order, never +//! across an `.await`: `crypto → waiters → dfs_trees → estimated_rtt`. +//! Each is held only as long as it takes to copy primitives out and +//! release. `params` is an `OnceLock` (wait-free read). `preauth_hasher` +//! and `receiver_task` are not touched by the snapshot. +//! +//! If you add a field that touches a new lock, **extend** this order, don't +//! reshuffle it. +//! +//! ## Counters survive teardown +//! +//! Counters live on `Arc`, which outlives the receiver task. A +//! snapshot taken on a torn-down connection (`disconnected: true`) returns +//! the final counter values at the moment of death. +//! +//! ## Counters reset on reconnect +//! +//! [`SmbClient::reconnect`](crate::SmbClient::reconnect) builds a fresh +//! [`Connection`](crate::client::Connection) with a fresh `Inner`, so +//! per-connection counters return to zero. Client-level counters (the +//! [`ClientMetricsSnapshot`] on [`Diagnostics::client`]) survive — `reconnects` +//! is monotonic across the client's lifetime. +//! +//! See `docs/specs/diagnostics-plan.md` for the design rationale. + +use std::fmt; +use std::time::Duration; + +use crate::crypto::encryption::Cipher; +use crate::crypto::signing::SigningAlgorithm; +use crate::pack::Guid; +use crate::types::flags::Capabilities; +use crate::types::{Dialect, SessionId, TreeId}; + +/// Top-level diagnostics tree, captured by [`SmbClient::diagnostics`](crate::SmbClient::diagnostics). +#[non_exhaustive] +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct Diagnostics { + /// Client-level configuration and counters. + pub client: ClientInfo, + /// The primary connection. Its `session` field carries the primary + /// session (or `None` until session setup runs). + pub primary: ConnectionDiagnostics, + /// DFS cross-server connections, each with its own session. Each + /// extra entry was authenticated separately. + pub extra_connections: Vec, + /// DFS referral cache snapshot (one entry per cached path prefix). + pub dfs_cache: Vec, +} + +/// Client-level configuration + counters. +#[non_exhaustive] +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct ClientInfo { + /// The server address the client was constructed with (`host:port`). + pub primary_server: String, + /// Connection timeout from [`ClientConfig`](crate::ClientConfig). + pub timeout: Duration, + /// Whether the client was configured for auto-reconnect on loss. + pub auto_reconnect: bool, + /// Whether DFS resolution is enabled. + pub dfs_enabled: bool, + /// Client-level counters (survive `reconnect`). + pub metrics: ClientMetricsSnapshot, +} + +/// Per-connection snapshot, captured by +/// [`Connection::diagnostics`](crate::client::Connection::diagnostics) and +/// included in [`Diagnostics::primary`] / [`Diagnostics::extra_connections`]. +#[non_exhaustive] +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct ConnectionDiagnostics { + /// Server hostname or IP this connection talks to. + pub server: String, + /// Negotiated parameters, or `None` until `negotiate()` runs. + pub negotiated: Option, + /// Credit gauge + in-flight count + next-MessageId. + pub credits: CreditInfo, + /// Signing state. + pub signing: SigningInfo, + /// Encryption state. + pub encryption: EncryptionInfo, + /// Compression state. + pub compression: CompressionInfo, + /// RTT measured during `negotiate`, if it ran. + pub rtt_estimate: Option, + /// `true` after the receiver task has torn down (transport error, + /// decrypt failure, etc.). + pub disconnected: bool, + /// Tree IDs that have DFS capability on this connection. + pub dfs_trees: Vec, + /// Session on this connection, or `None` until session setup runs. + pub session: Option, + /// Per-connection counters. + pub metrics: MetricsSnapshot, +} + +/// Snapshot of [`NegotiatedParams`](crate::client::NegotiatedParams) for +/// the diagnostics tree. Same fields, copied (not borrowed). +#[non_exhaustive] +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct NegotiatedSummary { + /// Negotiated dialect. + pub dialect: Dialect, + /// Maximum read size the server supports. + pub max_read_size: u32, + /// Maximum write size the server supports. + pub max_write_size: u32, + /// Maximum transact size the server supports. + pub max_transact_size: u32, + /// Server GUID. + pub server_guid: Guid, + /// Whether the server requires signing. + pub signing_required: bool, + /// Server capabilities. + /// + /// With the `serde` feature on, this serializes as the underlying + /// `u32` bits (not a JSON object of named flags). + pub capabilities: Capabilities, + /// Whether AES-GMAC signing was negotiated (SMB 3.1.1). + pub gmac_negotiated: bool, + /// The negotiated encryption cipher (SMB 3.x). + pub cipher: Option, + /// Whether compression was negotiated with the server. + pub compression_supported: bool, +} + +/// Credit gauge for the connection. +/// +/// All three fields are sampled independently — `available + in_flight` is +/// **not** invariant. See the module-level eventual-consistency note. +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct CreditInfo { + /// Credits currently available to spend on new requests. + pub available: u16, + /// Number of `MessageId`s currently waiting for a response (i.e. + /// `waiters.len()`). + pub in_flight: usize, + /// The `MessageId` that will be assigned to the next request. + pub next_message_id: u64, +} + +/// Signing state. +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct SigningInfo { + /// `true` when outgoing requests are being signed. + pub active: bool, + /// Negotiated signing algorithm, or `None` if signing isn't active. + pub algorithm: Option, +} + +/// Encryption state. +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct EncryptionInfo { + /// `true` when outgoing requests are being encrypted with + /// `TRANSFORM_HEADER`. + pub active: bool, + /// Negotiated encryption cipher, or `None` if encryption isn't active. + pub cipher: Option, +} + +/// Compression state. +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct CompressionInfo { + /// Whether the client requested compression in `ClientConfig`. + pub requested: bool, + /// Whether compression was actually negotiated and is active. + pub negotiated: bool, +} + +/// Per-connection session snapshot. Each +/// [`ConnectionDiagnostics`] has its own — DFS extra connections each +/// authenticate separately, so they each carry a distinct session. +#[non_exhaustive] +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct SessionDiagnostics { + /// SMB session ID assigned by the server. + pub session_id: SessionId, + /// `true` when the session requires signing. + pub should_sign: bool, + /// `true` when the session requires encryption. + pub should_encrypt: bool, + /// Signing algorithm derived for this session. + pub signing_algorithm: SigningAlgorithm, +} + +/// One entry in the DFS referral cache. +#[non_exhaustive] +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct DfsCacheEntry { + /// The DFS path prefix this entry covers. Lowercased UNC form (the + /// internal normalization used for case-insensitive matching). + pub path_prefix: String, + /// Number of failover targets the server returned. + pub target_count: usize, + /// Remaining time-to-live. `None` if the entry has already expired + /// (cache eviction is lazy: expired entries linger until the next + /// `resolve()` for an overlapping prefix). + pub expires_in: Option, +} + +/// Per-connection counter snapshot, taken atomically at the field level +/// but not as a whole (fields may skew — see [module docs](self)). +/// +/// Counters are monotonic across the connection's lifetime. To compute a +/// rate, take two snapshots and subtract. +#[non_exhaustive] +#[derive(Debug, Clone, Copy, Default)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct MetricsSnapshot { + /// Every `MessageId` allocated for a request. Includes negotiate, + /// session-setup, every `execute`, every `execute_with_credits`, every + /// `dispatch` (Watcher's pre-arm CHANGE_NOTIFY), and every sub-op of + /// every `execute_compound`. Does *not* include CANCEL (see + /// [`Self::explicit_cancels_sent`]). + pub requests_sent: u64, + /// Every successful `execute_compound` call — the chain itself, not + /// the per-sub-op count (those tick `requests_sent`). + pub compound_requests_sent: u64, + /// Bytes handed to `Transport::send` — the wire-layer count, after + /// any sign / encrypt / compress on the send side. The byte count a + /// packet capture would observe. + pub wire_bytes_sent: u64, + /// `Connection::send_cancel` invocations. CANCEL is the only SMB op + /// today that we send proactively; cancellation-by-drop is invisible + /// here (the drop never reaches the wire). + pub explicit_cancels_sent: u64, + + /// Sub-frames where the receiver task found the waiter in the map + /// and successfully delivered `Ok(frame)` to it. The normal happy + /// path. + pub responses_routed_ok: u64, + /// Sub-frames where the receiver task found the waiter in the map + /// and successfully delivered `Err(_)` to it. Today this is the + /// union: [`Self::signature_failures`] + [`Self::session_expired_events`]. + /// Don't sum *those* with this counter — they're a partition of it. + pub responses_routed_err: u64, + /// Sub-frames where the receiver task found the waiter in the map + /// but the caller's `oneshot::Receiver` was already dropped. Typical + /// for `tokio::spawn` + `JoinHandle::abort()` patterns where the + /// caller's future was cancelled mid-flight. The frame is discarded + /// silently; credits already applied. + pub responses_late_after_drop: u64, + /// Sub-frames where the receiver task did **not** find the waiter + /// in the map. The genuine orphan: server sent a frame for a + /// `MessageId` we never allocated, or a send-error cleanup raced + /// with arrival. Should be near-zero in normal operation. + pub responses_stray: u64, + /// Bytes received from `Transport::receive` — wire-layer, before + /// any decrypt / decompress. + pub wire_bytes_received: u64, + + /// Interim STATUS_PENDING sub-frames the receiver kept the waiter + /// alive on (CHANGE_NOTIFY long-polls, slow IOCTLs). + pub status_pending_loops: u64, + /// Sub-frames with `MessageId::UNSOLICITED` (today: oplock breaks; + /// the same magic id is reserved for future lease-break and other + /// server-initiated notifications). Counted, logged at DEBUG, + /// skipped — no waiter to route to. + pub unsolicited_notifications_received: u64, + /// Sub-frames whose signature verification failed. The error is + /// routed to the matching waiter (also ticks + /// [`Self::responses_routed_err`]); the connection continues. + pub signature_failures: u64, + /// Frames the receiver task could not decrypt (auth-tag mismatch, + /// missing decryption key, malformed `TransformHeader`). Counted + /// once before the connection tears down — the receiver task + /// fans `Err(Disconnected)` to every pending waiter and exits. + pub decrypt_failures: u64, + /// Frames the receiver task could not decompress. Same teardown + /// behavior as decrypt failures. + pub decompress_failures: u64, + /// Frames the receiver task could not parse (compound split, + /// sub-frame header parse). Same teardown behavior. Covers both + /// the `split_compound` parse-failure branch and the + /// `prepare_sub_frame` header-parse branch. + pub malformed_frames: u64, + /// Sub-frames with `STATUS_NETWORK_SESSION_EXPIRED`. Counted + /// per-sub-frame, not per session-event: a compound of N expired + /// sub-ops ticks N times. For the event-shaped signal "did we + /// reconnect", use [`ClientMetricsSnapshot::reconnects`]. Subset + /// of [`Self::responses_routed_err`]; don't sum. + pub session_expired_events: u64, + + /// `execute` / `execute_with_credits` / `execute_compound` returned + /// an outer `Err` to a caller that polled to completion. Per-call, + /// not per-sub-op: an `execute_compound` whose inner `Vec` contains + /// errors but whose outer `Result` is `Ok` does **not** tick this. + /// + /// Caller-drop (the spawn/abort pattern) is captured by + /// [`Self::responses_late_after_drop`], not here — a dropped future + /// never polls to a return value. + pub requests_returned_err: u64, +} + +/// Client-level counter snapshot. Lives on [`SmbClient`](crate::SmbClient) +/// (above the per-connection layer) and survives +/// [`SmbClient::reconnect`](crate::SmbClient::reconnect). +#[non_exhaustive] +#[derive(Debug, Clone, Copy, Default)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct ClientMetricsSnapshot { + /// `SmbClient::reconnect` invocations. The event-shaped signal "did + /// we reconnect" — pair with + /// [`MetricsSnapshot::session_expired_events`] if you want both. + pub reconnects: u64, + /// DFS path resolutions that resulted in a referral IOCTL to the + /// server (cache miss). + pub dfs_referrals_resolved: u64, + /// DFS path resolutions served from the in-process referral cache + /// (cache hit). + pub dfs_cache_hits: u64, +} + +impl fmt::Display for Diagnostics { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let c = &self.client; + writeln!(f, "SMB client → {}", c.primary_server)?; + writeln!( + f, + " reconnects: {} dfs: {} (hits: {}, referrals resolved: {}, cache entries: {})", + c.metrics.reconnects, + if c.dfs_enabled { "enabled" } else { "disabled" }, + c.metrics.dfs_cache_hits, + c.metrics.dfs_referrals_resolved, + self.dfs_cache.len(), + )?; + writeln!(f)?; + writeln!(f, "Primary connection ({})", self.primary.server)?; + fmt_connection_body(&self.primary, f)?; + + if !self.extra_connections.is_empty() { + writeln!(f)?; + writeln!( + f, + "DFS extra connections: ({})", + self.extra_connections.len() + )?; + for c in &self.extra_connections { + writeln!(f)?; + writeln!(f, " ↳ {}", c.server)?; + fmt_connection_body(c, f)?; + } + } else { + writeln!(f)?; + writeln!(f, "DFS extra connections: (0)")?; + } + Ok(()) + } +} + +fn fmt_connection_body(c: &ConnectionDiagnostics, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let m = &c.metrics; + match &c.negotiated { + Some(n) => { + let rtt = c + .rtt_estimate + .map(|d| format!("{:.1} ms", d.as_secs_f64() * 1000.0)) + .unwrap_or_else(|| "—".to_string()); + writeln!(f, " dialect: {:?} rtt: {}", n.dialect, rtt)?; + writeln!( + f, + " signing: {} encryption: {} compression: {}", + fmt_signing(&c.signing), + fmt_encryption(&c.encryption), + fmt_compression(&c.compression), + )?; + } + None => { + writeln!( + f, + " (pre-negotiate — no dialect / signing / encryption yet)" + )?; + } + } + writeln!( + f, + " credits: {} available · {} in flight · next msg_id {}", + c.credits.available, c.credits.in_flight, c.credits.next_message_id + )?; + writeln!( + f, + " wire bytes: {} sent · {} received", + m.wire_bytes_sent, m.wire_bytes_received + )?; + writeln!( + f, + " responses: {} ok · {} wire-err · {} late · {} stray (sent: {}, caller-err: {})", + m.responses_routed_ok, + m.responses_routed_err, + m.responses_late_after_drop, + m.responses_stray, + m.requests_sent, + m.requests_returned_err, + )?; + writeln!( + f, + " protocol events: {} status-pending · {} unsolicited · {} compound chains · {} cancels", + m.status_pending_loops, + m.unsolicited_notifications_received, + m.compound_requests_sent, + m.explicit_cancels_sent, + )?; + writeln!( + f, + " errors: {} signature · {} decrypt · {} decompress · {} malformed · {} session-expired", + m.signature_failures, + m.decrypt_failures, + m.decompress_failures, + m.malformed_frames, + m.session_expired_events, + )?; + if c.disconnected { + writeln!(f, " status: DISCONNECTED")?; + } + Ok(()) +} + +fn fmt_signing(s: &SigningInfo) -> String { + match (s.active, s.algorithm) { + (true, Some(algo)) => format!("active ({:?})", algo), + (true, None) => "active".to_string(), + (false, _) => "inactive".to_string(), + } +} + +fn fmt_encryption(e: &EncryptionInfo) -> String { + match (e.active, e.cipher) { + (true, Some(c)) => format!("active ({:?})", c), + (true, None) => "active".to_string(), + (false, _) => "inactive".to_string(), + } +} + +fn fmt_compression(c: &CompressionInfo) -> String { + match (c.requested, c.negotiated) { + (true, true) => "active".to_string(), + (true, false) => "requested, not negotiated".to_string(), + (false, true) => "active (not requested)".to_string(), + (false, false) => "off".to_string(), + } +} + +// ── M3: optional serde derives ─────────────────────────────────────────── +// Each diagnostics type carries `#[cfg_attr(feature = "serde", derive(Serialize))]` +// directly on its definition (above). `Capabilities` has a manual `Serialize` +// impl in `types/flags.rs` that emits the underlying u32 bits. + +#[cfg(test)] +mod tests { + //! Per-counter unit tests for M1. + //! + //! Each test exercises one counter against a `MockTransport`, asserting + //! it ticks the expected number of times. The disjoint-partition test + //! at the bottom checks the four routing-outcome counters sum to the + //! total sub-frames the receiver routed. + + use std::sync::Arc; + use std::time::Duration; + + use crate::client::connection::Connection; + use crate::msg::echo::{EchoRequest, EchoResponse}; + use crate::msg::header::Header; + use crate::pack::Pack; + use crate::transport::mock::MockTransport; + use crate::types::status::NtStatus; + use crate::types::{Command, MessageId}; + + /// Build a packed message (header + body) — mirrors `pack_message` in + /// `connection.rs`, kept inline to avoid widening that helper's + /// visibility just for tests. + fn pack(header: &Header, body: &dyn Pack) -> Vec { + let mut cursor = crate::pack::WriteCursor::with_capacity(64 + 16); + header.pack(&mut cursor); + body.pack(&mut cursor); + cursor.into_inner() + } + + fn echo_response(msg_id: MessageId, status: NtStatus) -> Vec { + let mut h = Header::new_request(Command::Echo); + h.flags.set_response(); + h.credits = 10; + h.message_id = msg_id; + h.status = status; + pack(&h, &EchoResponse) + } + + fn echo_ok(msg_id: MessageId) -> Vec { + echo_response(msg_id, NtStatus::SUCCESS) + } + + /// Wait until at least `n` messages have been recorded as sent on + /// the mock. Times out after 5 s. + async fn wait_for_sent(mock: &MockTransport, n: usize) { + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while mock.sent_count() < n { + if std::time::Instant::now() > deadline { + panic!("expected {n} sent messages, got {}", mock.sent_count()); + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + + /// A bare `Connection` over a mock transport, with auto-rewrite ON. + /// Mirrors the existing `execute_returns_correct_frame_for_sent_request` + /// setup but returns the mock so the test can queue / inspect. + fn fresh_conn() -> (Connection, Arc) { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + (conn, mock) + } + + #[tokio::test(flavor = "multi_thread")] + async fn requests_sent_and_wire_bytes_sent_tick_for_one_execute() { + let (conn, mock) = fresh_conn(); + + let c = conn.clone(); + let handle = + tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await }); + + wait_for_sent(&mock, 1).await; + mock.queue_response(echo_ok(MessageId(0))); + handle.await.unwrap().unwrap(); + + let m = conn.metrics(); + assert_eq!(m.requests_sent, 1, "one msg_id allocated → one request"); + assert!(m.wire_bytes_sent > 0, "send wrote some bytes to the wire"); + assert!( + m.wire_bytes_received > 0, + "receive read some bytes from the wire" + ); + assert_eq!(m.responses_routed_ok, 1); + assert_eq!(m.responses_routed_err, 0); + assert_eq!(m.responses_late_after_drop, 0); + assert_eq!(m.responses_stray, 0); + assert_eq!(m.requests_returned_err, 0); + + mock.close(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn requests_sent_ticks_per_sub_op_in_compound_and_compound_chain_counted() { + use crate::client::connection::CompoundOp; + + let (conn, mock) = fresh_conn(); + + let c = conn.clone(); + let handle = tokio::spawn(async move { + let ops = vec![ + CompoundOp::new(Command::Echo, &EchoRequest, None), + CompoundOp::new(Command::Echo, &EchoRequest, None), + CompoundOp::new(Command::Echo, &EchoRequest, None), + ]; + c.execute_compound(&ops).await + }); + + wait_for_sent(&mock, 1).await; + // Three sub-frames → three responses. Auto-rewrite handles the + // msg_id pairing per sub-frame. + mock.queue_response(echo_ok(MessageId(0))); + mock.queue_response(echo_ok(MessageId(0))); + mock.queue_response(echo_ok(MessageId(0))); + handle.await.unwrap().unwrap(); + + let m = conn.metrics(); + assert_eq!(m.requests_sent, 3, "three sub-ops → requests_sent += 3"); + assert_eq!(m.compound_requests_sent, 1, "one compound chain"); + assert_eq!(m.responses_routed_ok, 3); + assert_eq!(m.requests_returned_err, 0); + + mock.close(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn requests_returned_err_ticks_on_outer_err_to_completed_caller() { + let (conn, mock) = fresh_conn(); + + // Close before sending → execute returns Err(Disconnected) once the + // receiver task's transport-error branch fans to the waiter. + let c = conn.clone(); + let handle = + tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await }); + + wait_for_sent(&mock, 1).await; + mock.close(); + let result = handle.await.unwrap(); + assert!(result.is_err(), "execute should error after close"); + + // Receiver-task tear-down may take a beat to propagate; loop briefly. + let deadline = std::time::Instant::now() + Duration::from_secs(2); + while conn.metrics().requests_returned_err == 0 && std::time::Instant::now() < deadline { + tokio::time::sleep(Duration::from_millis(10)).await; + } + + assert_eq!(conn.metrics().requests_returned_err, 1); + } + + #[tokio::test(flavor = "multi_thread")] + async fn responses_late_after_drop_ticks_when_caller_dropped() { + let (conn, mock) = fresh_conn(); + + let c = conn.clone(); + let handle = + tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await }); + + wait_for_sent(&mock, 1).await; + // Drop the caller's future BEFORE the response arrives. The waiter + // is still in the map; the oneshot::Receiver gets dropped. + handle.abort(); + let _ = handle.await; // observe the JoinError, don't unwrap + + // Now queue the response. The receiver task finds the waiter, + // tries to send, sees the dropped Receiver, bumps + // responses_late_after_drop (and NOT responses_stray). + mock.queue_response(echo_ok(MessageId(0))); + + // Wait until the counter ticks (the receiver task drives this). + let deadline = std::time::Instant::now() + Duration::from_secs(2); + while conn.metrics().responses_late_after_drop == 0 && std::time::Instant::now() < deadline + { + tokio::time::sleep(Duration::from_millis(10)).await; + } + + let m = conn.metrics(); + assert_eq!(m.responses_late_after_drop, 1, "caller-drop should tick"); + assert_eq!(m.responses_stray, 0, "stray is for unregistered ids only"); + assert_eq!(m.responses_routed_ok, 0); + + mock.close(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn responses_stray_ticks_for_unregistered_msg_id() { + let (conn, mock) = fresh_conn(); + + // Don't call execute at all. Queue a response for a msg_id no one + // allocated. Auto-rewrite would normally pair with a sent msg_id, + // but there is none — so we use the *non*-auto path: build a + // response with an explicit non-zero msg_id and drop into the + // queue. Auto-rewrite's "keep non-zero, still consume one id" + // logic would block on `send_notify` forever. So disable it + // first by NOT enabling on a fresh second mock. + let _ = (conn, mock); // discarded — we use a non-auto-rewrite mock below + let plain_mock = Arc::new(MockTransport::new()); + let conn = Connection::from_transport( + Box::new(plain_mock.clone()), + Box::new(plain_mock.clone()), + "test-server", + ); + + plain_mock.queue_response(echo_ok(MessageId(999_999))); + + // Poll the counter — `pending_responses() == 0` only proves the + // transport drained, not that the receiver finished processing the + // frame and bumped `responses_stray`. The latter is the actual + // signal we're testing. + let deadline = std::time::Instant::now() + Duration::from_secs(2); + while conn.metrics().responses_stray == 0 && std::time::Instant::now() < deadline { + tokio::time::sleep(Duration::from_millis(10)).await; + } + + let m = conn.metrics(); + assert_eq!(m.responses_stray, 1); + assert_eq!(m.responses_late_after_drop, 0); + assert_eq!(m.responses_routed_ok, 0); + + plain_mock.close(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn unsolicited_notifications_received_ticks_for_unsolicited_msg_id() { + let mock = Arc::new(MockTransport::new()); + let conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + let mut h = Header::new_request(Command::OplockBreak); + h.flags.set_response(); + h.credits = 0; + h.message_id = MessageId::UNSOLICITED; + let frame = pack(&h, &EchoResponse); // body shape doesn't matter; it's skipped + mock.queue_response(frame); + + // Wait for consumption. + let deadline = std::time::Instant::now() + Duration::from_secs(2); + while conn.metrics().unsolicited_notifications_received == 0 + && std::time::Instant::now() < deadline + { + tokio::time::sleep(Duration::from_millis(10)).await; + } + + assert_eq!(conn.metrics().unsolicited_notifications_received, 1); + // UNSOLICITED is skipped — it does NOT tick the routing counters. + assert_eq!(conn.metrics().responses_routed_ok, 0); + assert_eq!(conn.metrics().responses_stray, 0); + + mock.close(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn status_pending_loops_ticks_for_interim_pending_then_final() { + // Don't use auto_rewrite — we need TWO responses paired with ONE sent + // msg_id. The first execute on a fresh connection always allocates + // msg_id=0, so we can hardcode that in both responses. + let mock = Arc::new(MockTransport::new()); + let conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + let c = conn.clone(); + let handle = + tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await }); + + wait_for_sent(&mock, 1).await; + // Interim STATUS_PENDING with msg_id=0, then final SUCCESS with msg_id=0. + mock.queue_response(echo_response(MessageId(0), NtStatus::PENDING)); + mock.queue_response(echo_response(MessageId(0), NtStatus::SUCCESS)); + + handle.await.unwrap().unwrap(); + + let m = conn.metrics(); + assert_eq!(m.status_pending_loops, 1, "one interim PENDING observed"); + assert_eq!(m.responses_routed_ok, 1, "one final response routed"); + + mock.close(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn session_expired_events_ticks_and_also_routes_err() { + let mock = Arc::new(MockTransport::new()); + let conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + let c = conn.clone(); + let handle = + tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await }); + + wait_for_sent(&mock, 1).await; + mock.queue_response(echo_response( + MessageId(0), + NtStatus::NETWORK_SESSION_EXPIRED, + )); + + let result = handle.await.unwrap(); + assert!(result.is_err(), "session-expired should surface as Err"); + + let m = conn.metrics(); + assert_eq!(m.session_expired_events, 1); + assert_eq!( + m.responses_routed_err, 1, + "session_expired_events is a subset of responses_routed_err" + ); + assert_eq!(m.responses_routed_ok, 0); + assert_eq!( + m.requests_returned_err, 1, + "caller polled to completion and got Err" + ); + + mock.close(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn explicit_cancels_sent_ticks_on_send_cancel() { + let mock = Arc::new(MockTransport::new()); + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + conn.send_cancel(MessageId(42), None).await.unwrap(); + + assert_eq!(conn.metrics().explicit_cancels_sent, 1); + // CANCEL does NOT allocate a msg_id — it reuses the original. + assert_eq!(conn.metrics().requests_sent, 0); + + mock.close(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn dispatch_path_is_counted() { + // `dispatch` is the watcher's pre-arm path — funnel-counted via + // allocate_msg_id, same as `execute`. + let (conn, mock) = fresh_conn(); + + let c = conn.clone(); + let handle = + tokio::spawn(async move { c.dispatch(Command::Echo, &EchoRequest, None).await }); + + wait_for_sent(&mock, 1).await; + let rx = handle.await.unwrap().unwrap(); + mock.queue_response(echo_ok(MessageId(0))); + // Drive the response so the receiver processes it (and the awaiter + // sees the result). + let _ = rx.await.unwrap().unwrap(); + + let m = conn.metrics(); + assert_eq!(m.requests_sent, 1, "dispatch funnel-counts via allocate"); + assert!(m.wire_bytes_sent > 0); + assert_eq!(m.responses_routed_ok, 1); + + mock.close(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn counters_survive_teardown() { + let (conn, mock) = fresh_conn(); + + let c = conn.clone(); + let handle = + tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await }); + + wait_for_sent(&mock, 1).await; + mock.queue_response(echo_ok(MessageId(0))); + handle.await.unwrap().unwrap(); + + let before = conn.metrics(); + assert_eq!(before.responses_routed_ok, 1); + + // Tear down. + mock.close(); + // Give the receiver task a tick to observe Err and fan. + tokio::time::sleep(Duration::from_millis(50)).await; + + // Counters still readable. + let after = conn.metrics(); + assert_eq!(after.responses_routed_ok, before.responses_routed_ok); + assert_eq!(after.requests_sent, before.requests_sent); + } + + // ── M2 / M3: full Diagnostics tree + Display + serde ────────────── + + fn fake_client(conn: Connection, session: crate::client::Session) -> crate::SmbClient { + let cfg = crate::ClientConfig { + addr: conn.server_name().to_string(), + timeout: Duration::from_secs(30), + username: String::new(), + password: String::new(), + domain: String::new(), + auto_reconnect: false, + compression: true, + dfs_enabled: true, + dfs_target_overrides: std::collections::HashMap::new(), + }; + crate::SmbClient::from_parts(cfg, conn, session) + } + + fn fake_session() -> crate::client::Session { + crate::client::Session { + session_id: crate::types::SessionId(0x1234_5678_9ABC_DEF0), + signing_key: vec![], + encryption_key: None, + decryption_key: None, + signing_algorithm: crate::crypto::signing::SigningAlgorithm::HmacSha256, + should_sign: false, + should_encrypt: false, + } + } + + #[tokio::test(flavor = "multi_thread")] + async fn display_contains_key_labels() { + let (conn, mock) = fresh_conn(); + let c = conn.clone(); + let handle = + tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await }); + wait_for_sent(&mock, 1).await; + mock.queue_response(echo_ok(MessageId(0))); + handle.await.unwrap().unwrap(); + + let client = fake_client(conn, fake_session()); + let d = client.diagnostics(); + let text = format!("{}", d); + for label in [ + "SMB client", + "test-server", + "credits:", + "wire bytes:", + "responses:", + "protocol events:", + "errors:", + "DFS extra connections", + ] { + assert!( + text.contains(label), + "Display missing {label:?} in:\n{text}" + ); + } + + mock.close(); + } + + #[cfg(feature = "serde")] + #[tokio::test(flavor = "multi_thread")] + async fn serde_round_trip_into_json_value() { + let (conn, mock) = fresh_conn(); + let c = conn.clone(); + let handle = + tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await }); + wait_for_sent(&mock, 1).await; + mock.queue_response(echo_ok(MessageId(0))); + handle.await.unwrap().unwrap(); + + let client = fake_client(conn, fake_session()); + let d = client.diagnostics(); + + let json = serde_json::to_string(&d).expect("serialize"); + let v: serde_json::Value = serde_json::from_str(&json).expect("re-parse"); + + assert_eq!(v["client"]["primary_server"], "test-server", "json: {json}"); + assert_eq!(v["primary"]["server"], "test-server"); + assert_eq!(v["primary"]["metrics"]["requests_sent"], 1); + assert_eq!(v["primary"]["metrics"]["responses_routed_ok"], 1); + assert!(v["primary"]["disconnected"].is_boolean()); + assert!(v["primary"]["credits"]["available"].is_number()); + // SessionId is transparent: bare integer, not `{"0": ...}`. + assert_eq!( + v["primary"]["session"]["session_id"], 0x1234_5678_9ABC_DEF0_u64, + "json: {json}" + ); + + mock.close(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn snapshot_releases_all_locks_before_returning() { + // Regression test: the snapshot promises it holds each lock only + // briefly and releases it before returning. Try_lock'ing after + // the snapshot call must succeed. + let (conn, mock) = fresh_conn(); + let _d = conn.diagnostics(); + // We can't reach `inner` from here without crate access; this test + // lives in-crate so it CAN. The diagnostics module is in + // `client/`, the connection internals are `pub(crate)`-shaped. + // If a future refactor breaks lock ordering, the in-flight test + // above catches it indirectly; this test pins the "no held lock" + // invariant cheaply. + for _ in 0..100 { + let _ = conn.diagnostics(); + } + mock.close(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn routing_partition_is_disjoint_and_complete() { + // 3 sent, 1 normal, 1 caller-drop, 1 stray on top. + let mock = Arc::new(MockTransport::new()); + // Plain mode so we can fully control msg_ids. + let conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + // Op A: send, will succeed. + let c1 = conn.clone(); + let h1 = tokio::spawn(async move { c1.execute(Command::Echo, &EchoRequest, None).await }); + wait_for_sent(&mock, 1).await; // msg_id 0 + + // Op B: send, then abort (caller drop). + let c2 = conn.clone(); + let h2 = tokio::spawn(async move { c2.execute(Command::Echo, &EchoRequest, None).await }); + wait_for_sent(&mock, 2).await; // msg_id 1 + + // Op A response. + mock.queue_response(echo_ok(MessageId(0))); + h1.await.unwrap().unwrap(); + + // Drop op B then queue its response. + h2.abort(); + let _ = h2.await; + mock.queue_response(echo_ok(MessageId(1))); + + // Stray frame for a msg_id no one allocated. + mock.queue_response(echo_ok(MessageId(999_999))); + + // Poll the actual signals — `pending_responses() == 0` only proves + // the transport drained, not that the receiver finished bumping the + // counters. `responses_routed_ok` is already 1 (h1.await guarantees + // it after the receiver_loop fix); `late_after_drop` and + // `responses_stray` each tick once when their frame is processed. + let deadline = std::time::Instant::now() + Duration::from_secs(2); + while (conn.metrics().responses_late_after_drop == 0 || conn.metrics().responses_stray == 0) + && std::time::Instant::now() < deadline + { + tokio::time::sleep(Duration::from_millis(10)).await; + } + + let m = conn.metrics(); + assert_eq!(m.responses_routed_ok, 1); + assert_eq!(m.responses_routed_err, 0); + assert_eq!(m.responses_late_after_drop, 1); + assert_eq!(m.responses_stray, 1); + + // Partition: routed_ok + routed_err + late + stray == + // total sub-frames the receiver dispatched (3 here). + assert_eq!( + m.responses_routed_ok + + m.responses_routed_err + + m.responses_late_after_drop + + m.responses_stray, + 3 + ); + + mock.close(); + } +} diff --git a/vendor/smb2/src/client/mod.rs b/vendor/smb2/src/client/mod.rs new file mode 100644 index 0000000..c3824e3 --- /dev/null +++ b/vendor/smb2/src/client/mod.rs @@ -0,0 +1,1495 @@ +//! High-level SMB2 client API. +//! +//! Provides [`SmbClient`] for easy connect-and-use access, plus lower-level +//! types: [`Connection`] for message exchange, [`Session`] for authenticated +//! sessions, [`Tree`] for share access with file operations, and [`Pipeline`] +//! for batched concurrent operations. + +pub mod connection; +pub(crate) mod dfs; +pub mod diagnostics; +pub mod pipeline; +pub mod session; +pub mod shares; +pub mod stream; +#[cfg(test)] +pub(crate) mod test_helpers; +pub mod tree; +pub mod watcher; + +pub use crate::crypto::encryption::Cipher; +pub use connection::{CompoundOp, Connection, Frame, NegotiatedParams}; +pub use diagnostics::{ + ClientInfo, ClientMetricsSnapshot, CompressionInfo, ConnectionDiagnostics, CreditInfo, + DfsCacheEntry, Diagnostics, EncryptionInfo, MetricsSnapshot, NegotiatedSummary, + SessionDiagnostics, SigningInfo, +}; +pub use pipeline::{Op, OpResult, Pipeline}; +pub use session::Session; +pub use shares::list_shares; +pub use stream::{FileDownload, FileUpload, FileWriter, Progress}; +pub use tree::{DirectoryEntry, FileInfo, FsInfo, Tree}; +pub use watcher::{FileNotifyAction, FileNotifyEvent, Watcher}; + +// Re-export high-level client types. +// (SmbClient, ClientConfig, and connect are defined below in this file.) + +use std::collections::HashMap; +use std::ops::ControlFlow; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Duration; + +use log::{debug, info}; + +use crate::client::dfs::DfsResolver; +use crate::error::{ErrorKind, Result}; +use crate::pack::Unpack; +use crate::rpc::srvsvc::ShareInfo; +use crate::types::FileId; +use crate::Error; + +/// Configuration for an SMB client connection. +#[derive(Debug, Clone)] +pub struct ClientConfig { + /// Server address (host:port). + pub addr: String, + /// Connection timeout. + pub timeout: Duration, + /// Username (empty for guest). + pub username: String, + /// Password (empty for guest). + /// + /// **Security note:** The password is stored in memory so that the client + /// can reconnect without asking the user again. It is not encrypted in + /// memory. Ensure the `SmbClient` is dropped when no longer needed. + pub password: String, + /// Domain (empty for local). + pub domain: String, + /// Whether to automatically reconnect on connection loss. + /// + /// When `true`, the client will attempt to reconnect with exponential + /// backoff when a connection loss is detected. The actual auto-reconnect + /// logic (retry with backoff, re-issue failed operations) will be + /// implemented alongside the concurrent pipeline. For now this flag + /// is stored so the API is ready. + pub auto_reconnect: bool, + /// Enable LZ4 compression for SMB 3.1.1 connections. + /// When enabled, messages are compressed if it reduces their size. + /// Incompressible data (photos, videos) is sent uncompressed automatically. + /// Default: true. + pub compression: bool, + /// Enable DFS (Distributed File System) path resolution. + /// + /// When `true`, operations that receive a DFS referral response + /// (`STATUS_PATH_NOT_COVERED`) automatically resolve the referral, + /// connect to the target server, and retry the operation. + /// Default: true. + pub dfs_enabled: bool, + /// Override addresses for DFS target servers. + /// + /// Maps server hostnames (as they appear in DFS referrals) to + /// `host:port` socket addresses. Useful when DFS targets use + /// internal hostnames that the client can't resolve, or when + /// port mapping is needed (for example, Docker test environments). + /// + /// Default: empty (use the server hostname from the referral + /// with port 445). + pub dfs_target_overrides: std::collections::HashMap, +} + +/// A connection to a specific server with its authenticated session. +/// +/// Used for DFS cross-server referrals where the client needs connections +/// to multiple servers simultaneously. +#[allow(dead_code)] +pub(crate) struct ConnectionEntry { + /// The connection to the server. + pub conn: Connection, + /// The authenticated session on this connection. + pub session: Session, +} + +/// High-level SMB2 client with reconnection support. +/// +/// Wraps a [`Connection`] + [`Session`] and provides methods for connecting +/// to shares, listing shares, and reconnecting after network failures. +/// +/// **Security note:** This struct stores the password in memory so it can +/// reconnect without asking the user again. The password is not encrypted. +/// Drop the `SmbClient` when no longer needed. +pub struct SmbClient { + config: ClientConfig, + conn: Connection, + session: Session, + /// Server name of the primary connection (from `conn.server_name()`). + primary_server: String, + /// Extra connections for DFS cross-server targets, keyed by server name. + extra_connections: HashMap, + /// DFS referral resolver with TTL-based cache. + dfs_resolver: DfsResolver, + /// Client-level counter: how many times `reconnect()` ran. Survives + /// each reconnect (per-connection counters do not). + reconnects: AtomicU64, +} + +impl SmbClient { + /// Connect to an SMB server and authenticate. + /// + /// Performs TCP connect, negotiate, and session setup in one call. + pub async fn connect(config: ClientConfig) -> Result { + info!("smb_client: connecting to {}", config.addr); + + let mut conn = Connection::connect(&config.addr, config.timeout).await?; + conn.set_compression_requested(config.compression); + conn.negotiate().await?; + + let session = Session::setup( + &mut conn, + &config.username, + &config.password, + &config.domain, + ) + .await?; + + info!( + "smb_client: connected and authenticated, session_id={}, compression={}", + session.session_id, + conn.compression_enabled() + ); + + let primary_server = config.addr.clone(); + + Ok(SmbClient { + config, + conn, + session, + primary_server, + extra_connections: HashMap::new(), + dfs_resolver: DfsResolver::new(), + reconnects: AtomicU64::new(0), + }) + } + + /// Connect using an existing connection and session (for testing). + #[cfg(test)] + pub(crate) fn from_parts(config: ClientConfig, conn: Connection, session: Session) -> Self { + let primary_server = config.addr.clone(); + SmbClient { + config, + conn, + session, + primary_server, + extra_connections: HashMap::new(), + dfs_resolver: DfsResolver::new(), + reconnects: AtomicU64::new(0), + } + } + + /// List available shares on the server. + /// + /// Connects to the IPC$ share, performs an RPC exchange via the srvsvc + /// named pipe, and returns only disk shares (excluding admin shares + /// ending with `$`). + pub async fn list_shares(&mut self) -> Result> { + shares::list_shares(&mut self.conn).await + } + + /// Connect to a share on the server. + /// + /// If the share requires encryption (`SMB2_SHAREFLAG_ENCRYPT_DATA`) + /// and encryption is not already active, encryption is activated + /// using the session's keys. + pub async fn connect_share(&mut self, share_name: &str) -> Result { + let mut tree = Tree::connect(&mut self.conn, share_name).await?; + tree.server = self.primary_server.clone(); + + // Activate encryption if the share requires it and it's not already active. + // Fall back to AES-128-CCM if the server didn't send an encryption + // negotiate context (same fallback as session-level encryption). + if tree.encrypt_data && !self.conn.should_encrypt() { + if let (Some(ref enc_key), Some(ref dec_key)) = + (&self.session.encryption_key, &self.session.decryption_key) + { + let cipher = self + .conn + .params() + .and_then(|p| p.cipher) + .unwrap_or(crate::crypto::encryption::Cipher::Aes128Ccm); + self.conn + .activate_encryption(enc_key.clone(), dec_key.clone(), cipher); + } + } + + Ok(tree) + } + + /// Manually reconnect after a connection loss. + /// + /// Re-does TCP connect, negotiate, and session setup using the stored + /// credentials. All previous tree connections and file handles are + /// invalidated. The caller must re-do [`SmbClient::connect_share`] for + /// any shares they need. + pub async fn reconnect(&mut self) -> Result<()> { + info!("smb_client: reconnecting to {}", self.config.addr); + + let conn = Connection::connect(&self.config.addr, self.config.timeout).await?; + self.reconnect_with(conn).await + } + + /// Reconnect using an already-established connection. + /// + /// Negotiates and authenticates on the given connection using stored + /// credentials. This is the core reconnection logic, separated from + /// TCP connect so it can be tested with mock transports. + async fn reconnect_with(&mut self, mut conn: Connection) -> Result<()> { + self.reconnects.fetch_add(1, Ordering::Relaxed); + conn.set_compression_requested(self.config.compression); + conn.negotiate().await?; + + let session = Session::setup( + &mut conn, + &self.config.username, + &self.config.password, + &self.config.domain, + ) + .await?; + + self.primary_server = self.config.addr.clone(); + self.conn = conn; + self.session = session; + self.extra_connections.clear(); + + info!( + "smb_client: reconnected, new session_id={}", + self.session.session_id + ); + Ok(()) + } + + /// Get the negotiated parameters. + pub fn params(&self) -> Option<&NegotiatedParams> { + self.conn.params() + } + + /// Get the session info. + pub fn session(&self) -> &Session { + &self.session + } + + /// Get the client config. + pub fn config(&self) -> &ClientConfig { + &self.config + } + + /// Current number of available credits. + pub fn credits(&self) -> u16 { + self.conn.credits() + } + + /// Estimated round-trip time from the negotiate exchange. + pub fn estimated_rtt(&self) -> Option { + self.conn.estimated_rtt() + } + + /// Capture a tree of diagnostics: client config, primary + DFS-extra + /// connections, the session on each connection, per-connection + /// counters, the DFS referral cache, and client-level counters. + /// + /// See [`crate::client::diagnostics`] for the consistency model. In + /// short: eventually consistent, snapshot survives connection + /// teardown, per-connection counters reset on + /// [`Self::reconnect`], client-level counters survive. + pub fn diagnostics(&self) -> crate::client::diagnostics::Diagnostics { + use crate::client::diagnostics::{ + ClientInfo, ClientMetricsSnapshot, Diagnostics, SessionDiagnostics, + }; + + let (cache_hits, referrals_resolved) = self.dfs_resolver.counters(); + let client = ClientInfo { + primary_server: self.primary_server.clone(), + timeout: self.config.timeout, + auto_reconnect: self.config.auto_reconnect, + dfs_enabled: self.config.dfs_enabled, + metrics: ClientMetricsSnapshot { + reconnects: self.reconnects.load(Ordering::Relaxed), + dfs_referrals_resolved: referrals_resolved, + dfs_cache_hits: cache_hits, + }, + }; + + let session_for = |s: &Session| SessionDiagnostics { + session_id: s.session_id, + should_sign: s.should_sign, + should_encrypt: s.should_encrypt, + signing_algorithm: s.signing_algorithm, + }; + + let mut primary = self.conn.diagnostics(); + primary.session = Some(session_for(&self.session)); + + let extra_connections = self + .extra_connections + .values() + .map(|entry| { + let mut d = entry.conn.diagnostics(); + d.session = Some(session_for(&entry.session)); + d + }) + .collect(); + + Diagnostics { + client, + primary, + extra_connections, + dfs_cache: self.dfs_resolver.cache_entries(), + } + } + + /// Get a mutable reference to the underlying connection. + /// + /// Needed when using [`Tree`] methods directly, since they require + /// `&mut Connection`. For most use cases, prefer the convenience methods + /// on `SmbClient` (like [`list_directory`](Self::list_directory)) instead. + pub fn connection_mut(&mut self) -> &mut Connection { + &mut self.conn + } + + /// Get a mutable reference to the connection that owns the given tree. + /// + /// Routes through the primary connection when the tree's server matches, + /// or through an extra connection established for a DFS cross-server + /// referral. + pub(crate) fn connection_for_tree(&mut self, tree: &Tree) -> &mut Connection { + if tree.server == self.primary_server { + &mut self.conn + } else { + &mut self + .extra_connections + .get_mut(&tree.server) + .expect("no connection for tree server") + .conn + } + } + + // ── DFS helpers ─────────────────────────────────────────────────── + + /// Handle a DFS redirect by resolving the referral, connecting to + /// the target server (creating a new connection if needed), and + /// updating the tree in-place. + /// + /// Returns the resolved remaining path to use for the retry. + async fn handle_dfs_redirect( + &mut self, + tree: &mut Tree, + original_path: &str, + ) -> Result { + // Extract hostname (strip port) for UNC path construction. + let hostname = tree + .server + .split(':') + .next() + .unwrap_or(&tree.server) + .to_string(); + let share = tree.share_name.clone(); + let normalized = original_path.replace('/', "\\"); + let unc_path = format!("\\\\{}\\{}\\{}", hostname, share, normalized); + + debug!("dfs: resolving {}", unc_path); + + // Resolve the referral (uses cache or IOCTL). + // We inline the connection lookup to avoid borrowing both + // `self.dfs_resolver` and `self` (via connection_for_tree) + // at the same time. + let conn = if tree.server == self.primary_server { + &mut self.conn + } else { + &mut self + .extra_connections + .get_mut(&tree.server) + .expect("no connection for tree server") + .conn + }; + let resolved_list = self.dfs_resolver.resolve(conn, &unc_path).await?; + + // Try each target (multi-target failover). + let mut last_error = None; + for resolved in &resolved_list { + let target_addr = self + .config + .dfs_target_overrides + .get(&resolved.server) + .cloned() + .unwrap_or_else(|| format!("{}:{}", resolved.server, resolved.port)); + + // Get or create connection to target server. + match self.ensure_connection(&target_addr).await { + Ok(()) => {} + Err(e) => { + debug!("dfs: failed to connect to {}: {}", target_addr, e); + last_error = Some(e); + continue; + } + } + + // Get or create tree on the target share. + match self.ensure_tree(&target_addr, &resolved.share).await { + Ok(new_tree) => { + // Update the caller's tree in-place. + *tree = new_tree; + return Ok(resolved.remaining_path.clone()); + } + Err(e) => { + debug!( + "dfs: failed to connect to share {} on {}: {}", + resolved.share, target_addr, e + ); + last_error = Some(e); + continue; + } + } + } + + Err(last_error.unwrap_or_else(|| Error::invalid_data("DFS: no targets in referral"))) + } + + /// Ensure a connection exists in the pool for the given server address. + async fn ensure_connection(&mut self, target_addr: &str) -> Result<()> { + if target_addr == self.primary_server { + return Ok(()); // Already have primary connection. + } + if self.extra_connections.contains_key(target_addr) { + return Ok(()); // Already in pool. + } + + // Create new connection to target. + let mut conn = Connection::connect(target_addr, self.config.timeout).await?; + conn.set_compression_requested(self.config.compression); + conn.negotiate().await?; + + // Authenticate with same credentials. + let session = Session::setup( + &mut conn, + &self.config.username, + &self.config.password, + &self.config.domain, + ) + .await?; + + self.extra_connections + .insert(target_addr.to_string(), ConnectionEntry { conn, session }); + Ok(()) + } + + /// Ensure a tree-connect exists for the given server and share. + async fn ensure_tree(&mut self, target_addr: &str, share: &str) -> Result { + let conn = if target_addr == self.primary_server { + &mut self.conn + } else { + &mut self + .extra_connections + .get_mut(target_addr) + .ok_or_else(|| Error::invalid_data("DFS: no connection for target"))? + .conn + }; + + let mut tree = Tree::connect(conn, share).await?; + // Override server to the full addr:port so connection_for_tree + // can distinguish targets that share the same hostname but + // use different ports (for example, Docker port-mapped containers). + tree.server = target_addr.to_string(); + Ok(tree) + } + + /// Check whether a DFS retry should be attempted for the given error. + fn should_retry_dfs(&self, err: &Error) -> bool { + self.config.dfs_enabled && err.kind() == ErrorKind::DfsReferral + } + + // ── Convenience methods that delegate to Tree ────────────────────── + + /// List files in a directory on the given share. + /// + /// This is a convenience wrapper around [`Tree::list_directory`] that + /// saves you from threading `connection_mut()` through every call. + /// If the server returns a DFS referral, the tree is updated in-place + /// and the operation is retried on the target server. + pub async fn list_directory( + &mut self, + tree: &mut Tree, + path: &str, + ) -> Result> { + let result = { + let conn = self.connection_for_tree(tree); + tree.list_directory(conn, path).await + }; + match result { + Err(e) if self.should_retry_dfs(&e) => { + let new_path = self.handle_dfs_redirect(tree, path).await?; + let conn = self.connection_for_tree(tree); + tree.list_directory(conn, &new_path).await + } + other => other, + } + } + + /// Read a file from the given share. + pub async fn read_file(&mut self, tree: &mut Tree, path: &str) -> Result> { + let result = { + let conn = self.connection_for_tree(tree); + tree.read_file(conn, path).await + }; + match result { + Err(e) if self.should_retry_dfs(&e) => { + let new_path = self.handle_dfs_redirect(tree, path).await?; + let conn = self.connection_for_tree(tree); + tree.read_file(conn, &new_path).await + } + other => other, + } + } + + /// Read a small file using a compound CREATE+READ+CLOSE request. + /// + /// Sends all three operations in a single transport frame, reducing + /// round-trips from 3 to 1. Best for files that fit in a single + /// READ (up to MaxReadSize, typically 8 MB). + pub async fn read_file_compound(&mut self, tree: &mut Tree, path: &str) -> Result> { + let result = { + let conn = self.connection_for_tree(tree); + tree.read_file_compound(conn, path).await + }; + match result { + Err(e) if self.should_retry_dfs(&e) => { + let new_path = self.handle_dfs_redirect(tree, path).await?; + let conn = self.connection_for_tree(tree); + tree.read_file_compound(conn, &new_path).await + } + other => other, + } + } + + /// Read a file using pipelined I/O (faster for large files). + pub async fn read_file_pipelined(&mut self, tree: &mut Tree, path: &str) -> Result> { + let result = { + let conn = self.connection_for_tree(tree); + tree.read_file_pipelined(conn, path).await + }; + match result { + Err(e) if self.should_retry_dfs(&e) => { + let new_path = self.handle_dfs_redirect(tree, path).await?; + let conn = self.connection_for_tree(tree); + tree.read_file_pipelined(conn, &new_path).await + } + other => other, + } + } + + /// Write data to a file on the given share (create or overwrite). + pub async fn write_file(&mut self, tree: &mut Tree, path: &str, data: &[u8]) -> Result { + let result = { + let conn = self.connection_for_tree(tree); + tree.write_file(conn, path, data).await + }; + match result { + Err(e) if self.should_retry_dfs(&e) => { + let new_path = self.handle_dfs_redirect(tree, path).await?; + let conn = self.connection_for_tree(tree); + tree.write_file(conn, &new_path, data).await + } + other => other, + } + } + + /// Write a small file using a compound CREATE+WRITE+FLUSH+CLOSE request. + /// + /// Sends all four operations in a single transport frame, reducing + /// round-trips from 4 to 1. Best for files that fit in MaxWriteSize + /// (typically 64 KB to 8 MB). For larger files, use + /// [`write_file_pipelined`](Self::write_file_pipelined). + pub async fn write_file_compound( + &mut self, + tree: &mut Tree, + path: &str, + data: &[u8], + ) -> Result { + let result = { + let conn = self.connection_for_tree(tree); + tree.write_file_compound(conn, path, data).await + }; + match result { + Err(e) if self.should_retry_dfs(&e) => { + let new_path = self.handle_dfs_redirect(tree, path).await?; + let conn = self.connection_for_tree(tree); + tree.write_file_compound(conn, &new_path, data).await + } + other => other, + } + } + + /// Write data to a file using pipelined I/O (faster for large files). + pub async fn write_file_pipelined( + &mut self, + tree: &mut Tree, + path: &str, + data: &[u8], + ) -> Result { + let result = { + let conn = self.connection_for_tree(tree); + tree.write_file_pipelined(conn, path, data).await + }; + match result { + Err(e) if self.should_retry_dfs(&e) => { + let new_path = self.handle_dfs_redirect(tree, path).await?; + let conn = self.connection_for_tree(tree); + tree.write_file_pipelined(conn, &new_path, data).await + } + other => other, + } + } + + /// Query file system space information for the given share. + /// + /// Returns total capacity, free space, and allocation unit sizes. + /// Uses a compound CREATE+QUERY_INFO+CLOSE for efficiency (one round-trip). + pub async fn fs_info(&mut self, tree: &mut Tree) -> Result { + let result = { + let conn = self.connection_for_tree(tree); + tree.fs_info(conn).await + }; + match result { + Err(e) if self.should_retry_dfs(&e) => { + // fs_info has no path argument -- the DFS redirect uses + // the root of the share as the path. + let _new_path = self.handle_dfs_redirect(tree, "").await?; + let conn = self.connection_for_tree(tree); + tree.fs_info(conn).await + } + other => other, + } + } + + /// Delete a file on the given share. + pub async fn delete_file(&mut self, tree: &mut Tree, path: &str) -> Result<()> { + let result = { + let conn = self.connection_for_tree(tree); + tree.delete_file(conn, path).await + }; + match result { + Err(e) if self.should_retry_dfs(&e) => { + let new_path = self.handle_dfs_redirect(tree, path).await?; + let conn = self.connection_for_tree(tree); + tree.delete_file(conn, &new_path).await + } + other => other, + } + } + + /// Delete multiple files on the given share in a single batch. + /// + /// Sends all requests before waiting for responses, minimizing + /// round-trips. Returns results in the same order as the input paths. + /// + /// Note: DFS retry is not applied to batch operations. If the share + /// is a DFS target, perform a single-file operation first to trigger + /// the redirect, then use the batch method on the resolved tree. + pub async fn delete_files(&mut self, tree: &mut Tree, paths: &[&str]) -> Vec> { + let conn = self.connection_for_tree(tree); + tree.delete_files(conn, paths).await + } + + /// Get file metadata (size, timestamps, whether it's a directory). + pub async fn stat(&mut self, tree: &mut Tree, path: &str) -> Result { + let result = { + let conn = self.connection_for_tree(tree); + tree.stat(conn, path).await + }; + match result { + Err(e) if self.should_retry_dfs(&e) => { + let new_path = self.handle_dfs_redirect(tree, path).await?; + let conn = self.connection_for_tree(tree); + tree.stat(conn, &new_path).await + } + other => other, + } + } + + /// Stat multiple files on the given share in a single batch. + /// + /// Sends all requests before waiting for responses, minimizing + /// round-trips. Returns results in the same order as the input paths. + /// + /// Note: DFS retry is not applied to batch operations. If the share + /// is a DFS target, perform a single-file operation first to trigger + /// the redirect, then use the batch method on the resolved tree. + pub async fn stat_files(&mut self, tree: &mut Tree, paths: &[&str]) -> Vec> { + let conn = self.connection_for_tree(tree); + tree.stat_files(conn, paths).await + } + + /// Rename a file or directory on the given share. + pub async fn rename(&mut self, tree: &mut Tree, from: &str, to: &str) -> Result<()> { + let result = { + let conn = self.connection_for_tree(tree); + tree.rename(conn, from, to).await + }; + match result { + Err(e) if self.should_retry_dfs(&e) => { + let new_path = self.handle_dfs_redirect(tree, from).await?; + let conn = self.connection_for_tree(tree); + tree.rename(conn, &new_path, to).await + } + other => other, + } + } + + /// Rename multiple files on the given share in a single batch. + /// + /// Sends all requests before waiting for responses, minimizing + /// round-trips. Returns results in the same order as the input pairs. + /// + /// Note: DFS retry is not applied to batch operations. If the share + /// is a DFS target, perform a single-file operation first to trigger + /// the redirect, then use the batch method on the resolved tree. + pub async fn rename_files( + &mut self, + tree: &mut Tree, + renames: &[(&str, &str)], + ) -> Vec> { + let conn = self.connection_for_tree(tree); + tree.rename_files(conn, renames).await + } + + /// Create a directory on the given share. + pub async fn create_directory(&mut self, tree: &mut Tree, path: &str) -> Result<()> { + let result = { + let conn = self.connection_for_tree(tree); + tree.create_directory(conn, path).await + }; + match result { + Err(e) if self.should_retry_dfs(&e) => { + let new_path = self.handle_dfs_redirect(tree, path).await?; + let conn = self.connection_for_tree(tree); + tree.create_directory(conn, &new_path).await + } + other => other, + } + } + + /// Delete an empty directory on the given share. + pub async fn delete_directory(&mut self, tree: &mut Tree, path: &str) -> Result<()> { + let result = { + let conn = self.connection_for_tree(tree); + tree.delete_directory(conn, path).await + }; + match result { + Err(e) if self.should_retry_dfs(&e) => { + let new_path = self.handle_dfs_redirect(tree, path).await?; + let conn = self.connection_for_tree(tree); + tree.delete_directory(conn, &new_path).await + } + other => other, + } + } + + /// Start a streaming file download (memory-efficient for large files). + /// + /// Returns a [`FileDownload`] that yields chunks one at a time without + /// buffering the entire file in memory. Each call to + /// [`next_chunk`](FileDownload::next_chunk) sends one READ request. + /// + /// The connection is borrowed mutably for the lifetime of the download, + /// so no other operations can run concurrently. This prevents accidental + /// interleaving of SMB messages. + /// + /// # Example + /// + /// ```ignore + /// # async fn example(client: &mut smb2::SmbClient, share: &smb2::Tree) -> Result<(), smb2::Error> { + /// use tokio::io::AsyncWriteExt; + /// + /// let mut download = client.download(&share, "big_video.mp4").await?; + /// println!("Downloading {} bytes...", download.size()); + /// + /// let mut file = tokio::fs::File::create("big_video.mp4").await?; + /// while let Some(chunk) = download.next_chunk().await { + /// let bytes = chunk?; + /// file.write_all(&bytes).await?; + /// println!("{:.1}%", download.progress().percent()); + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn download<'a>( + &'a mut self, + tree: &'a Tree, + path: &str, + ) -> Result> { + tree.download(&mut self.conn, path).await + } + + /// Start a streaming file upload with progress tracking. + /// + /// Returns a [`FileUpload`] that writes data in chunks. Each call to + /// [`write_next_chunk`](FileUpload::write_next_chunk) sends one WRITE + /// request and reports progress. + /// + /// For small files (data fits in one MaxWriteSize), the data is written + /// immediately via a compound CREATE+WRITE+FLUSH+CLOSE request in the + /// constructor. The returned `FileUpload` is already complete, and + /// `write_next_chunk` returns `false` immediately. This gives the caller + /// a uniform API regardless of file size. + /// + /// The connection is borrowed mutably for the lifetime of the upload, + /// so no other operations can run concurrently. This prevents accidental + /// interleaving of SMB messages. + /// + /// # Example + /// + /// ```ignore + /// # async fn example(client: &mut smb2::SmbClient, share: &smb2::Tree) -> Result<(), smb2::Error> { + /// let data = std::fs::read("large_video.mp4")?; + /// let mut upload = client.upload(&share, "remote_video.mp4", &data).await?; + /// println!("Uploading {} bytes...", upload.total_bytes()); + /// + /// while upload.write_next_chunk().await? { + /// println!("{:.1}%", upload.progress().percent()); + /// } + /// // File is flushed and closed automatically after the last chunk. + /// # Ok(()) + /// # } + /// ``` + pub async fn upload<'a>( + &'a mut self, + tree: &'a Tree, + path: &str, + data: &'a [u8], + ) -> Result> { + let normalized = path.replace('/', "\\"); + let normalized = normalized.trim_start_matches('\\'); + + let max_write = self + .conn + .params() + .map(|p| p.max_write_size as usize) + .unwrap_or(65536); + + if data.len() <= max_write { + // Small file: write everything via compound in one round-trip. + tree.write_file_compound(&mut self.conn, normalized, data) + .await?; + Ok(stream::FileUpload::new_done( + tree, + &mut self.conn, + data.len() as u64, + )) + } else { + // Large file: open the file, let the caller drive chunks. + let file_id = tree.open_file_for_write(&mut self.conn, normalized).await?; + let chunk_size = max_write as u32; + Ok(stream::FileUpload::new( + tree, + &mut self.conn, + file_id, + data, + chunk_size, + )) + } + } + + /// Create a push-based pipelined streaming file writer. + /// + /// Opens (or creates) the file for writing and returns a [`FileWriter`] + /// that the caller drives by pushing data chunks. The returned writer + /// owns a cheap `Arc::clone` of `Connection` and an `Arc` — it + /// is `'static` and does not borrow from the client. Multiple writers + /// built this way pipeline their WRITEs over a single SMB session + /// without external locking. + /// + /// No DFS retry; the writer pins to the connection it was built from. + /// + /// # Example + /// + /// ```no_run + /// # async fn example(client: &smb2::SmbClient, share: &smb2::Tree) -> Result<(), smb2::Error> { + /// let mut writer = client.create_file_writer(share, "output.bin").await?; + /// writer.write_chunk(b"hello").await?; + /// writer.write_chunk(b" world").await?; + /// let total = writer.finish().await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn create_file_writer(&self, tree: &Tree, path: &str) -> Result { + // Convenience wrapper: clone the primary connection (cheap + // `Arc::clone`) and the `Tree` into an `Arc`, then build a writer + // that owns both. The client's connection is not borrowed for the + // upload's duration, so concurrent writers proceed in parallel. + stream::open_file_writer(std::sync::Arc::new(tree.clone()), self.conn.clone(), path).await + } + + /// Exclusive-create sibling of [`create_file_writer`](Self::create_file_writer). + /// + /// Same shape, but the CREATE uses `FileCreate` disposition: if the file + /// already exists the open fails with + /// [`crate::ErrorKind::AlreadyExists`]. Use + /// this for race-free "create only if absent" writes — for example, a + /// file manager's "New File" action where silently clobbering an + /// existing file is unsafe. + pub async fn create_file_writer_exclusive( + &self, + tree: &Tree, + path: &str, + ) -> Result { + stream::open_file_writer_exclusive( + std::sync::Arc::new(tree.clone()), + self.conn.clone(), + path, + ) + .await + } + + /// Read a file with progress reporting and cancellation. + /// + /// Uses pipelined I/O for performance, calling `on_progress` after each + /// chunk is received. Return `ControlFlow::Break(())` to cancel the read. + pub async fn read_file_with_progress( + &mut self, + tree: &mut Tree, + path: &str, + on_progress: F, + ) -> Result> + where + F: FnMut(Progress) -> ControlFlow<()>, + { + // DFS retry is not straightforward with progress callbacks (the + // callback is consumed by the first attempt). For now, attempt + // the operation directly. If DFS redirect is needed, the caller + // should resolve the tree first using a simpler method. + let conn = self.connection_for_tree(tree); + tree.read_file_pipelined_with_progress(conn, path, on_progress) + .await + } + + /// Write a file with progress reporting and cancellation. + /// + /// Writes data in chunks, calling `on_progress` after each chunk. + /// Return `ControlFlow::Break(())` to cancel the write. + /// + /// The file is flushed before closing to ensure data is persisted + /// on the server. + pub async fn write_file_with_progress( + &mut self, + tree: &mut Tree, + path: &str, + data: &[u8], + mut on_progress: F, + ) -> Result + where + F: FnMut(Progress) -> ControlFlow<()>, + { + let normalized = path.replace('/', "\\"); + let normalized = normalized.trim_start_matches('\\'); + + // Open the file for writing. + let req = crate::msg::create::CreateRequest { + requested_oplock_level: crate::types::OplockLevel::None, + impersonation_level: crate::msg::create::ImpersonationLevel::Impersonation, + desired_access: crate::types::flags::FileAccessMask::new( + crate::types::flags::FileAccessMask::FILE_WRITE_DATA + | crate::types::flags::FileAccessMask::FILE_WRITE_ATTRIBUTES + | crate::types::flags::FileAccessMask::SYNCHRONIZE, + ), + file_attributes: 0x80, // FILE_ATTRIBUTE_NORMAL + share_access: crate::msg::create::ShareAccess(0), + create_disposition: crate::msg::create::CreateDisposition::FileOverwriteIf, + create_options: 0x0000_0040, // FILE_NON_DIRECTORY_FILE + name: normalized.to_string(), + create_contexts: vec![], + }; + + let frame = self + .conn + .execute(crate::types::Command::Create, &req, Some(tree.tree_id)) + .await?; + + if frame.header.status != crate::types::status::NtStatus::SUCCESS { + return Err(crate::Error::Protocol { + status: frame.header.status, + command: crate::types::Command::Create, + }); + } + + let mut cursor = crate::pack::ReadCursor::new(&frame.body); + let create_resp = crate::msg::create::CreateResponse::unpack(&mut cursor)?; + let file_id = create_resp.file_id; + + let max_write = self + .conn + .params() + .map(|p| p.max_write_size) + .unwrap_or(65536); + + let mut total_written = 0u64; + let mut offset = 0usize; + let mut cancelled = false; + + while offset < data.len() { + let remaining = data.len() - offset; + let chunk_size = remaining.min(max_write as usize); + let chunk = &data[offset..offset + chunk_size]; + + let write_req = crate::msg::write::WriteRequest { + data_offset: 0x70, + offset: offset as u64, + file_id, + channel: 0, + remaining_bytes: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + flags: 0, + data: chunk.to_vec(), + }; + + let credit_charge = (chunk_size as u64).div_ceil(65536).max(1) as u16; + let frame = self + .conn + .execute_with_credits( + crate::types::Command::Write, + &write_req, + Some(tree.tree_id), + crate::types::CreditCharge(credit_charge), + ) + .await?; + + if frame.header.status != crate::types::status::NtStatus::SUCCESS { + // Close handle before returning error. + let _ = tree.close_handle(&mut self.conn, file_id).await; + return Err(crate::Error::Protocol { + status: frame.header.status, + command: crate::types::Command::Write, + }); + } + + let mut cursor = crate::pack::ReadCursor::new(&frame.body); + let resp = crate::msg::write::WriteResponse::unpack(&mut cursor)?; + + total_written += resp.count as u64; + offset += chunk_size; + + let progress = Progress { + bytes_transferred: total_written, + total_bytes: Some(data.len() as u64), + }; + + if let ControlFlow::Break(()) = on_progress(progress) { + cancelled = true; + break; + } + } + + if cancelled { + // Best-effort close without flush. + let _ = tree.close_handle(&mut self.conn, file_id).await; + return Err(crate::Error::Cancelled); + } + + // Flush to ensure data is persisted. + tree.flush_handle(&mut self.conn, file_id).await?; + + // Close the handle. + tree.close_handle(&mut self.conn, file_id).await?; + + Ok(total_written) + } + + /// Write a file from a streaming source using pipelined I/O. + /// + /// Pulls data on demand from a callback, so you never need the full + /// file in memory. See [`Tree::write_file_streamed`] for the full + /// callback contract, performance characteristics, and usage guide. + /// + /// DFS retry is not supported for streamed writes (the callback is + /// consumed by the first attempt). If the share uses DFS, resolve + /// the tree first using a simpler method. + pub async fn write_file_streamed( + &mut self, + tree: &mut Tree, + path: &str, + next_chunk: &mut F, + ) -> Result + where + F: FnMut() -> Option, std::io::Error>>, + { + let conn = self.connection_for_tree(tree); + tree.write_file_streamed(conn, path, next_chunk).await + } + + /// Flush a file to ensure data is persisted on the server. + /// + /// This sends an SMB2 FLUSH request for the given file handle. + /// Write methods (`write_file`, `write_file_pipelined`, + /// `write_file_with_progress`) flush automatically before closing. + /// Use this if you need to flush a handle obtained through the + /// low-level API. + pub async fn flush_file(&mut self, tree: &mut Tree, file_id: FileId) -> Result<()> { + let conn = self.connection_for_tree(tree); + tree.flush_handle(conn, file_id).await + } + + /// Watch a directory for changes. + /// + /// Opens the directory and returns a [`Watcher`] that yields change + /// events. The server holds each request until changes occur (long poll). + /// + /// Set `recursive` to `true` to watch the entire subtree. + /// + /// The returned `Watcher` owns a cloned connection (cheap `Arc::clone`, + /// all clones multiplex over the same SMB session), so this client + /// remains usable for other operations while watching. + pub async fn watch(&mut self, tree: &Tree, path: &str, recursive: bool) -> Result { + tree.watch(&mut self.conn, path, recursive).await + } + + /// Disconnect from a share. + pub async fn disconnect_share(&mut self, tree: &Tree) -> Result<()> { + let conn = self.connection_for_tree(tree); + tree.disconnect(conn).await + } +} + +/// Connect to an SMB server with the simplest possible API. +/// +/// This is a shorthand for creating a [`ClientConfig`] and calling +/// [`SmbClient::connect`]. Uses a five-second timeout and no auto-reconnect. +pub async fn connect(addr: &str, username: &str, password: &str) -> Result { + SmbClient::connect(ClientConfig { + addr: addr.to_string(), + timeout: Duration::from_secs(5), + username: username.to_string(), + password: password.to_string(), + domain: String::new(), + auto_reconnect: false, + compression: true, + dfs_enabled: true, + dfs_target_overrides: std::collections::HashMap::new(), + }) + .await +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::connection::pack_message; + use crate::msg::header::Header; + use crate::msg::negotiate::{NegotiateContext, NegotiateResponse, HASH_ALGORITHM_SHA512}; + use crate::msg::session_setup::{SessionFlags, SessionSetupResponse}; + use crate::msg::tree_connect::ShareType; + use crate::pack::Guid; + use crate::transport::MockTransport; + use crate::types::flags::{Capabilities, SecurityMode}; + use crate::types::status::NtStatus; + use crate::types::{Command, Dialect, SessionId, TreeId}; + use std::sync::Arc; + + /// Build a negotiate response. + fn build_negotiate_response() -> Vec { + let mut h = Header::new_request(Command::Negotiate); + h.flags.set_response(); + h.credits = 32; + let body = NegotiateResponse { + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + dialect_revision: Dialect::Smb3_1_1, + server_guid: Guid::ZERO, + capabilities: Capabilities::new(Capabilities::DFS | Capabilities::LEASING), + max_transact_size: 65536, + max_read_size: 65536, + max_write_size: 65536, + system_time: 132_000_000_000_000_000, + server_start_time: 131_000_000_000_000_000, + security_buffer: vec![0x60, 0x00], + negotiate_contexts: vec![NegotiateContext::PreauthIntegrity { + hash_algorithms: vec![HASH_ALGORITHM_SHA512], + salt: vec![0xBB; 32], + }], + }; + pack_message(&h, &body) + } + + /// Build a session setup response. + fn build_session_setup_response( + status: NtStatus, + session_id: SessionId, + security_buffer: Vec, + session_flags: SessionFlags, + ) -> Vec { + let mut h = Header::new_request(Command::SessionSetup); + h.flags.set_response(); + h.credits = 32; + h.status = status; + h.session_id = session_id; + + let body = SessionSetupResponse { + session_flags, + security_buffer, + }; + + pack_message(&h, &body) + } + + /// Build a minimal NTLM challenge message (Type 2). + fn build_ntlm_challenge() -> Vec { + let mut buf = Vec::new(); + + // Signature + buf.extend_from_slice(b"NTLMSSP\0"); + // MessageType = 2 + buf.extend_from_slice(&2u32.to_le_bytes()); + // TargetNameFields: Len=0, MaxLen=0, Offset=56 + buf.extend_from_slice(&0u16.to_le_bytes()); + buf.extend_from_slice(&0u16.to_le_bytes()); + buf.extend_from_slice(&56u32.to_le_bytes()); + // NegotiateFlags + let flags: u32 = 0x0000_0001 // UNICODE + | 0x0000_0200 // NTLM + | 0x0008_0000 // EXTENDED_SESSIONSECURITY + | 0x0080_0000 // TARGET_INFO + | 0x2000_0000 // 128 + | 0x4000_0000 // KEY_EXCH + | 0x8000_0000 // 56 + | 0x0000_0010 // SIGN + | 0x0000_0020; // SEAL + buf.extend_from_slice(&flags.to_le_bytes()); + // ServerChallenge + buf.extend_from_slice(&[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF]); + // Reserved + buf.extend_from_slice(&[0u8; 8]); + // TargetInfoFields + let target_info = { + let mut ti = Vec::new(); + ti.extend_from_slice(&0u16.to_le_bytes()); // MsvAvEOL AvId=0 + ti.extend_from_slice(&0u16.to_le_bytes()); // AvLen=0 + ti + }; + let ti_offset = 56u32; + buf.extend_from_slice(&(target_info.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(target_info.len() as u16).to_le_bytes()); + buf.extend_from_slice(&ti_offset.to_le_bytes()); + while buf.len() < 56 { + buf.push(0); + } + buf.extend_from_slice(&target_info); + buf + } + + /// Queue negotiate + session setup responses on a mock transport. + fn queue_negotiate_and_session(mock: &MockTransport, session_id: SessionId) { + mock.queue_response(build_negotiate_response()); + + let challenge = build_ntlm_challenge(); + mock.queue_response(build_session_setup_response( + NtStatus::MORE_PROCESSING_REQUIRED, + session_id, + challenge, + SessionFlags(0), + )); + + mock.queue_response(build_session_setup_response( + NtStatus::SUCCESS, + session_id, + vec![], + SessionFlags(0), + )); + } + + /// Create a mock-backed SmbClient without going through TCP. + async fn make_mock_client(mock: &Arc, session_id: SessionId) -> SmbClient { + mock.enable_auto_rewrite_msg_id(); + queue_negotiate_and_session(mock, session_id); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + conn.negotiate().await.unwrap(); + + let session = Session::setup(&mut conn, "user", "pass", "").await.unwrap(); + + let config = ClientConfig { + addr: "test-server:445".to_string(), + timeout: Duration::from_secs(5), + username: "user".to_string(), + password: "pass".to_string(), + domain: String::new(), + auto_reconnect: false, + compression: true, + dfs_enabled: true, + dfs_target_overrides: std::collections::HashMap::new(), + }; + + SmbClient::from_parts(config, conn, session) + } + + #[tokio::test] + async fn smb_client_connect_via_mock_negotiates_and_authenticates() { + let mock = Arc::new(MockTransport::new()); + let session_id = SessionId(0xABCD); + + let client = make_mock_client(&mock, session_id).await; + + assert_eq!(client.session().session_id, session_id); + assert!(client.params().is_some()); + assert_eq!(client.params().unwrap().dialect, Dialect::Smb3_1_1); + } + + #[tokio::test] + async fn smb_client_stores_config() { + let mock = Arc::new(MockTransport::new()); + let client = make_mock_client(&mock, SessionId(1)).await; + + assert_eq!(client.config().addr, "test-server:445"); + assert_eq!(client.config().username, "user"); + assert_eq!(client.config().password, "pass"); + assert!(!client.config().auto_reconnect); + } + + #[tokio::test] + async fn smb_client_connect_share_returns_tree() { + let mock = Arc::new(MockTransport::new()); + let mut client = make_mock_client(&mock, SessionId(1)).await; + + // Queue tree connect response. + mock.queue_response(crate::client::test_helpers::build_tree_connect_response( + TreeId(42), + ShareType::Disk, + )); + + let tree = client.connect_share("TestShare").await.unwrap(); + assert_eq!(tree.tree_id, TreeId(42)); + assert_eq!(tree.share_name, "TestShare"); + } + + #[tokio::test] + async fn smb_client_reconnect_creates_new_session() { + let mock = Arc::new(MockTransport::new()); + let original_session_id = SessionId(0x1111); + let mut client = make_mock_client(&mock, original_session_id).await; + + // Verify original session. + assert_eq!(client.session().session_id, original_session_id); + + // Create a new mock for the "reconnected" transport. + let mock2 = Arc::new(MockTransport::new()); + mock2.enable_auto_rewrite_msg_id(); + let new_session_id = SessionId(0x2222); + queue_negotiate_and_session(mock2.as_ref(), new_session_id); + + let new_conn = Connection::from_transport( + Box::new(mock2.clone()), + Box::new(mock2.clone()), + "test-server", + ); + + client.reconnect_with(new_conn).await.unwrap(); + + // Session should be new. + assert_eq!(client.session().session_id, new_session_id); + } + + #[tokio::test] + async fn smb_client_reconnect_invalidates_old_params() { + let mock = Arc::new(MockTransport::new()); + let mut client = make_mock_client(&mock, SessionId(0x1111)).await; + + // Get old params for comparison. + let old_server_guid = client.params().unwrap().server_guid; + + // Create a new mock for the "reconnected" transport. + let mock2 = Arc::new(MockTransport::new()); + mock2.enable_auto_rewrite_msg_id(); + queue_negotiate_and_session(mock2.as_ref(), SessionId(0x2222)); + + let new_conn = Connection::from_transport( + Box::new(mock2.clone()), + Box::new(mock2.clone()), + "test-server", + ); + + client.reconnect_with(new_conn).await.unwrap(); + + // Params should be freshly negotiated (same values in this mock, + // but the connection is new). + assert!(client.params().is_some()); + assert_eq!(client.params().unwrap().server_guid, old_server_guid); + } + + #[tokio::test] + async fn smb_client_auto_reconnect_flag_stored() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + queue_negotiate_and_session(mock.as_ref(), SessionId(1)); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.negotiate().await.unwrap(); + let session = Session::setup(&mut conn, "user", "pass", "").await.unwrap(); + + let config = ClientConfig { + addr: "test-server:445".to_string(), + timeout: Duration::from_secs(5), + username: "user".to_string(), + password: "pass".to_string(), + domain: String::new(), + auto_reconnect: true, + compression: true, + dfs_enabled: true, + dfs_target_overrides: std::collections::HashMap::new(), + }; + + let client = SmbClient::from_parts(config, conn, session); + assert!(client.config().auto_reconnect); + } + + #[tokio::test] + async fn smb_client_connection_mut_returns_connection() { + let mock = Arc::new(MockTransport::new()); + let mut client = make_mock_client(&mock, SessionId(1)).await; + + // Verify we can access the connection. + assert!(client.connection_mut().params().is_some()); + } + + #[tokio::test] + async fn smb_client_list_shares_delegates_to_shares_module() { + let mock = Arc::new(MockTransport::new()); + let mut client = make_mock_client(&mock, SessionId(0x5555)).await; + + // Queue the full share listing flow (same as shares module tests). + // This verifies SmbClient.list_shares() delegates correctly. + use crate::client::shares::tests::queue_share_listing_responses; + queue_share_listing_responses( + &mock, + &[ + ( + "Documents", + crate::rpc::srvsvc::STYPE_DISKTREE, + "Shared docs", + ), + ( + "IPC$", + crate::rpc::srvsvc::STYPE_IPC | crate::rpc::srvsvc::STYPE_SPECIAL, + "Remote IPC", + ), + ], + ); + + let shares = client.list_shares().await.unwrap(); + + // Only disk shares returned. + assert_eq!(shares.len(), 1); + assert_eq!(shares[0].name, "Documents"); + } +} diff --git a/vendor/smb2/src/client/pipeline.rs b/vendor/smb2/src/client/pipeline.rs new file mode 100644 index 0000000..03e56f4 --- /dev/null +++ b/vendor/smb2/src/client/pipeline.rs @@ -0,0 +1,670 @@ +//! Unified operation pipeline for concurrent SMB2 operations. +//! +//! The [`Pipeline`] sends multiple SMB2 requests without waiting for each +//! response, filling the credit window. Results are collected and returned +//! once all operations complete. +//! +//! This is a first-iteration pipeline that executes a batch of operations. +//! Future iterations will add a channel-based streaming interface, compound +//! request construction, and chunk-level interleaving for large files. + +use log::debug; + +use crate::client::connection::Connection; +use crate::client::tree::Tree; + +/// An operation to execute through the pipeline. +#[derive(Debug, Clone)] +pub enum Op { + /// Read a file, returning its contents. + ReadFile(String), + /// Write data to a file (create or overwrite). + WriteFile(String, Vec), + /// Delete a file. + Delete(String), + /// List a directory. + ListDirectory(String), + /// Get file metadata. + Stat(String), +} + +/// Result of a pipeline operation. +#[derive(Debug)] +pub enum OpResult { + /// File data read successfully. + FileData { + /// The path that was read. + path: String, + /// The file contents. + data: Vec, + }, + /// File written successfully. + Written { + /// The path that was written. + path: String, + /// Number of bytes written. + bytes_written: u64, + }, + /// File deleted successfully. + Deleted { + /// The path that was deleted. + path: String, + }, + /// Directory listing. + DirEntries { + /// The path that was listed. + path: String, + /// The directory entries. + entries: Vec, + }, + /// File metadata. + Stat { + /// The path that was queried. + path: String, + /// The file information. + info: crate::client::tree::FileInfo, + }, + /// Operation failed. + Error { + /// The path that failed. + path: String, + /// The error that occurred. + error: crate::Error, + }, +} + +/// A pipeline for executing multiple SMB operations as a batch. +/// +/// The pipeline executes operations sequentially in this first iteration. +/// Each multi-step operation (for example, read = CREATE + READ + CLOSE) runs +/// to completion before the next operation starts. Future iterations will +/// interleave steps from different operations to fill the credit window. +pub struct Pipeline<'a> { + conn: &'a mut Connection, + tree: &'a Tree, +} + +impl<'a> Pipeline<'a> { + /// Create a new pipeline bound to a connection and tree. + pub fn new(conn: &'a mut Connection, tree: &'a Tree) -> Self { + Self { conn, tree } + } + + /// Execute a batch of operations and return the results. + /// + /// Results are returned in the same order as the input operations. + /// Each operation that fails produces an [`OpResult::Error`] rather + /// than aborting the entire batch. + pub async fn execute(&mut self, ops: Vec) -> Vec { + let mut results = Vec::with_capacity(ops.len()); + + for op in ops { + let result = self.execute_one(op).await; + results.push(result); + } + + results + } + + /// Execute a single operation. + async fn execute_one(&mut self, op: Op) -> OpResult { + match op { + Op::ReadFile(path) => { + debug!("pipeline: read_file path={}", path); + match self.tree.read_file(self.conn, &path).await { + Ok(data) => OpResult::FileData { path, data }, + Err(e) => OpResult::Error { path, error: e }, + } + } + Op::WriteFile(path, data) => { + debug!("pipeline: write_file path={}", path); + match self.tree.write_file(self.conn, &path, &data).await { + Ok(bytes_written) => OpResult::Written { + path, + bytes_written, + }, + Err(e) => OpResult::Error { path, error: e }, + } + } + Op::Delete(path) => { + debug!("pipeline: delete path={}", path); + match self.tree.delete_file(self.conn, &path).await { + Ok(()) => OpResult::Deleted { path }, + Err(e) => OpResult::Error { path, error: e }, + } + } + Op::ListDirectory(path) => { + debug!("pipeline: list_directory path={}", path); + match self.tree.list_directory(self.conn, &path).await { + Ok(entries) => OpResult::DirEntries { path, entries }, + Err(e) => OpResult::Error { path, error: e }, + } + } + Op::Stat(path) => { + debug!("pipeline: stat path={}", path); + match self.tree.stat(self.conn, &path).await { + Ok(info) => OpResult::Stat { path, info }, + Err(e) => OpResult::Error { path, error: e }, + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::connection::pack_message; + use crate::client::test_helpers::{ + build_close_response, build_create_response, setup_connection, + }; + use crate::client::tree::Tree; + use crate::msg::create::{CreateAction, CreateResponse}; + use crate::msg::header::{ErrorResponse, Header}; + use crate::msg::query_directory::QueryDirectoryResponse; + use crate::msg::query_info::QueryInfoResponse; + use crate::msg::read::ReadResponse; + use crate::msg::write::WriteResponse; + use crate::pack::FileTime; + use crate::transport::MockTransport; + use crate::types::status::NtStatus; + use crate::types::{Command, FileId, OplockLevel, TreeId}; + use std::sync::Arc; + + fn test_tree() -> Tree { + Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + } + } + + fn build_create_response_directory(file_id: FileId) -> Vec { + let mut h = Header::new_request(Command::Create); + h.flags.set_response(); + h.credits = 32; + + let body = CreateResponse { + oplock_level: OplockLevel::None, + flags: 0, + create_action: CreateAction::FileOpened, + creation_time: FileTime(132_000_000_000_000_000), + last_access_time: FileTime(132_000_000_000_000_000), + last_write_time: FileTime(133_000_000_000_000_000), + change_time: FileTime(133_000_000_000_000_000), + allocation_size: 0, + end_of_file: 0, + file_attributes: 0x10, // DIRECTORY + file_id, + create_contexts: vec![], + }; + + pack_message(&h, &body) + } + + fn build_flush_response() -> Vec { + let mut h = Header::new_request(Command::Flush); + h.flags.set_response(); + h.credits = 32; + + let body = crate::msg::flush::FlushResponse; + pack_message(&h, &body) + } + + fn build_read_response(data: Vec) -> Vec { + let mut h = Header::new_request(Command::Read); + h.flags.set_response(); + h.credits = 32; + + let body = ReadResponse { + data_offset: 0x50, + data_remaining: 0, + flags: 0, + data, + }; + + pack_message(&h, &body) + } + + fn build_write_response(count: u32) -> Vec { + let mut h = Header::new_request(Command::Write); + h.flags.set_response(); + h.credits = 32; + + let body = WriteResponse { + count, + remaining: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + }; + + pack_message(&h, &body) + } + + fn build_query_info_response(output_buffer: Vec) -> Vec { + let mut h = Header::new_request(Command::QueryInfo); + h.flags.set_response(); + h.credits = 32; + + let body = QueryInfoResponse { output_buffer }; + + pack_message(&h, &body) + } + + fn build_query_directory_response(status: NtStatus, entries_data: Vec) -> Vec { + let mut h = Header::new_request(Command::QueryDirectory); + h.flags.set_response(); + h.credits = 32; + h.status = status; + + if status == NtStatus::NO_MORE_FILES { + let body = ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + return pack_message(&h, &body); + } + + let body = QueryDirectoryResponse { + output_buffer: entries_data, + }; + + pack_message(&h, &body) + } + + /// Build a FileBasicInformation buffer (40 bytes). + fn build_file_basic_info( + creation_time: u64, + last_access_time: u64, + last_write_time: u64, + change_time: u64, + file_attributes: u32, + ) -> Vec { + let mut buf = Vec::new(); + buf.extend_from_slice(&creation_time.to_le_bytes()); + buf.extend_from_slice(&last_access_time.to_le_bytes()); + buf.extend_from_slice(&last_write_time.to_le_bytes()); + buf.extend_from_slice(&change_time.to_le_bytes()); + buf.extend_from_slice(&file_attributes.to_le_bytes()); + // Padding to 40 bytes (Reserved) + buf.extend_from_slice(&0u32.to_le_bytes()); + buf + } + + /// Build a FileStandardInformation buffer (24 bytes). + fn build_file_standard_info( + allocation_size: u64, + end_of_file: u64, + number_of_links: u32, + delete_pending: bool, + directory: bool, + ) -> Vec { + let mut buf = Vec::new(); + buf.extend_from_slice(&allocation_size.to_le_bytes()); + buf.extend_from_slice(&end_of_file.to_le_bytes()); + buf.extend_from_slice(&number_of_links.to_le_bytes()); + buf.push(if delete_pending { 1 } else { 0 }); + buf.push(if directory { 1 } else { 0 }); + buf.extend_from_slice(&0u16.to_le_bytes()); // Reserved + buf + } + + /// Build a single FileBothDirectoryInformation entry. + fn build_file_both_dir_info( + name: &str, + size: u64, + is_directory: bool, + next_offset: u32, + ) -> Vec { + let name_u16: Vec = name.encode_utf16().collect(); + let name_bytes_len = name_u16.len() * 2; + + let mut buf = Vec::new(); + buf.extend_from_slice(&next_offset.to_le_bytes()); + buf.extend_from_slice(&0u32.to_le_bytes()); // FileIndex + buf.extend_from_slice(&132_000_000_000_000_000u64.to_le_bytes()); // CreationTime + buf.extend_from_slice(&132_000_000_000_000_000u64.to_le_bytes()); // LastAccessTime + buf.extend_from_slice(&133_000_000_000_000_000u64.to_le_bytes()); // LastWriteTime + buf.extend_from_slice(&133_000_000_000_000_000u64.to_le_bytes()); // ChangeTime + buf.extend_from_slice(&size.to_le_bytes()); + buf.extend_from_slice(&((size + 4095) & !4095).to_le_bytes()); // AllocationSize + let attrs: u32 = if is_directory { 0x10 } else { 0x20 }; + buf.extend_from_slice(&attrs.to_le_bytes()); + buf.extend_from_slice(&(name_bytes_len as u32).to_le_bytes()); + buf.extend_from_slice(&0u32.to_le_bytes()); // EaSize + buf.push(0); // ShortNameLength + buf.push(0); // Reserved + buf.extend_from_slice(&[0u8; 24]); // ShortName + for &u in &name_u16 { + buf.extend_from_slice(&u.to_le_bytes()); + } + buf + } + + /// Build a compound response frame with proper NextCommand offsets and padding. + fn build_compound_response_frame(responses: &[Vec]) -> Vec { + let mut padded: Vec> = Vec::new(); + for (i, resp) in responses.iter().enumerate() { + let mut r = resp.clone(); + let is_last = i == responses.len() - 1; + if !is_last { + // Pad to 8-byte alignment. + let remainder = r.len() % 8; + if remainder != 0 { + r.resize(r.len() + (8 - remainder), 0); + } + // Set NextCommand. + let next_cmd = r.len() as u32; + r[20..24].copy_from_slice(&next_cmd.to_le_bytes()); + } + padded.push(r); + } + let mut frame = Vec::new(); + for r in &padded { + frame.extend_from_slice(r); + } + frame + } + + /// Build a compound read response frame (CREATE + READ + CLOSE) for pipeline tests. + fn build_compound_read_response(file_id: FileId, data: Vec) -> Vec { + let create_resp = build_create_response(file_id, data.len() as u64); + let read_resp = build_read_response(data); + let close_resp = build_close_response(); + build_compound_response_frame(&[create_resp, read_resp, close_resp]) + } + + #[tokio::test] + async fn pipeline_batch_of_three_reads() { + let mock = Arc::new(MockTransport::new()); + + let file_id = FileId { + persistent: 1, + volatile: 2, + }; + + // Three read operations, each needs a compound CREATE + READ + CLOSE frame. + for i in 0..3 { + let data = format!("content_{}", i); + mock.queue_response(build_compound_read_response(file_id, data.into_bytes())); + } + + let mut conn = setup_connection(&mock); + let tree = test_tree(); + let mut pipeline = Pipeline::new(&mut conn, &tree); + + let results = pipeline + .execute(vec![ + Op::ReadFile("file1.txt".to_string()), + Op::ReadFile("file2.txt".to_string()), + Op::ReadFile("file3.txt".to_string()), + ]) + .await; + + assert_eq!(results.len(), 3); + for (i, result) in results.into_iter().enumerate() { + match result { + OpResult::FileData { path, data } => { + assert_eq!(path, format!("file{}.txt", i + 1)); + assert_eq!(data, format!("content_{}", i).into_bytes()); + } + other => panic!("expected FileData, got {:?}", other), + } + } + } + + #[tokio::test] + async fn pipeline_mixed_ops() { + let mock = Arc::new(MockTransport::new()); + + let file_id = FileId { + persistent: 1, + volatile: 2, + }; + + // Op 1: ReadFile -- compound CREATE + READ + CLOSE + mock.queue_response(build_compound_read_response(file_id, b"hello".to_vec())); + + // Op 2: Delete -- compound CREATE(DELETE_ON_CLOSE) + CLOSE + let del_create = build_create_response(file_id, 0); + let del_close = build_close_response(); + mock.queue_response(build_compound_response_frame(&[del_create, del_close])); + + // Op 3: ListDirectory -- CREATE + QUERY_DIR + QUERY_DIR(NO_MORE) + CLOSE + mock.queue_response(build_create_response_directory(file_id)); + let entry = build_file_both_dir_info("test.txt", 100, false, 0); + mock.queue_response(build_query_directory_response(NtStatus::SUCCESS, entry)); + mock.queue_response(build_query_directory_response( + NtStatus::NO_MORE_FILES, + vec![], + )); + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = test_tree(); + let mut pipeline = Pipeline::new(&mut conn, &tree); + + let results = pipeline + .execute(vec![ + Op::ReadFile("data.bin".to_string()), + Op::Delete("old.txt".to_string()), + Op::ListDirectory("docs".to_string()), + ]) + .await; + + assert_eq!(results.len(), 3); + + match &results[0] { + OpResult::FileData { data, .. } => assert_eq!(data, b"hello"), + other => panic!("expected FileData, got {:?}", other), + } + + match &results[1] { + OpResult::Deleted { path } => assert_eq!(path, "old.txt"), + other => panic!("expected Deleted, got {:?}", other), + } + + match &results[2] { + OpResult::DirEntries { entries, .. } => { + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].name, "test.txt"); + } + other => panic!("expected DirEntries, got {:?}", other), + } + } + + #[tokio::test] + async fn pipeline_delete_file() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 1, + volatile: 2, + }; + + // DELETE = compound CREATE(DELETE_ON_CLOSE) + CLOSE + let create_resp = build_create_response(file_id, 0); + let close_resp = build_close_response(); + let frame = build_compound_response_frame(&[create_resp, close_resp]); + mock.queue_response(frame); + + let mut conn = setup_connection(&mock); + let tree = test_tree(); + let mut pipeline = Pipeline::new(&mut conn, &tree); + + let results = pipeline + .execute(vec![Op::Delete("remove_me.txt".to_string())]) + .await; + + assert_eq!(results.len(), 1); + match &results[0] { + OpResult::Deleted { path } => assert_eq!(path, "remove_me.txt"), + other => panic!("expected Deleted, got {:?}", other), + } + + // One compound frame sent. + let sent = mock.sent_messages(); + assert_eq!(sent.len(), 1); + } + + #[tokio::test] + async fn pipeline_write_file() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 1, + volatile: 2, + }; + + // WRITE uses compound: CREATE+WRITE+FLUSH+CLOSE in one frame. + let create_resp = build_create_response(file_id, 0); + let write_resp = build_write_response(11); + let flush_resp = build_flush_response(); + let close_resp = build_close_response(); + let frame = + build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]); + mock.queue_response(frame); + + let mut conn = setup_connection(&mock); + let tree = test_tree(); + let mut pipeline = Pipeline::new(&mut conn, &tree); + + let results = pipeline + .execute(vec![Op::WriteFile( + "output.txt".to_string(), + b"hello world".to_vec(), + )]) + .await; + + assert_eq!(results.len(), 1); + match &results[0] { + OpResult::Written { + path, + bytes_written, + } => { + assert_eq!(path, "output.txt"); + assert_eq!(*bytes_written, 11); + } + other => panic!("expected Written, got {:?}", other), + } + } + + #[tokio::test] + async fn pipeline_stat() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 1, + volatile: 2, + }; + + // STAT = compound CREATE + QUERY_INFO(basic) + QUERY_INFO(standard) + CLOSE + let create_resp = build_create_response(file_id, 0); + + let basic_info = build_file_basic_info( + 132_000_000_000_000_000, + 132_100_000_000_000_000, + 133_000_000_000_000_000, + 133_000_000_000_000_000, + 0x20, // ARCHIVE (not a directory) + ); + let basic_resp = build_query_info_response(basic_info); + + let std_info = build_file_standard_info( + 4096, // allocation_size + 2048, // end_of_file (actual size) + 1, // number_of_links + false, // delete_pending + false, // directory + ); + let std_resp = build_query_info_response(std_info); + + let close_resp = build_close_response(); + + let frame = build_compound_response_frame(&[create_resp, basic_resp, std_resp, close_resp]); + mock.queue_response(frame); + + let mut conn = setup_connection(&mock); + let tree = test_tree(); + let mut pipeline = Pipeline::new(&mut conn, &tree); + + let results = pipeline + .execute(vec![Op::Stat("info.txt".to_string())]) + .await; + + assert_eq!(results.len(), 1); + match &results[0] { + OpResult::Stat { path, info } => { + assert_eq!(path, "info.txt"); + assert_eq!(info.size, 2048); + assert!(!info.is_directory); + assert_eq!(info.created, FileTime(132_000_000_000_000_000)); + assert_eq!(info.modified, FileTime(133_000_000_000_000_000)); + } + other => panic!("expected Stat, got {:?}", other), + } + } + + #[tokio::test] + async fn pipeline_error_does_not_abort_batch() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 1, + volatile: 2, + }; + + // Op 1: ReadFile that fails at CREATE -- compound frame with cascaded errors. + let error_body = ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + + let mut h1 = Header::new_request(Command::Create); + h1.flags.set_response(); + h1.credits = 32; + h1.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let create_err = pack_message(&h1, &error_body); + + let mut h2 = Header::new_request(Command::Read); + h2.flags.set_response(); + h2.credits = 32; + h2.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let read_err = pack_message(&h2, &error_body); + + let mut h3 = Header::new_request(Command::Close); + h3.flags.set_response(); + h3.credits = 32; + h3.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let close_err = pack_message(&h3, &error_body); + + mock.queue_response(build_compound_response_frame(&[ + create_err, read_err, close_err, + ])); + + // Op 2: ReadFile that succeeds -- compound frame. + mock.queue_response(build_compound_read_response(file_id, b"abc".to_vec())); + + let mut conn = setup_connection(&mock); + let tree = test_tree(); + let mut pipeline = Pipeline::new(&mut conn, &tree); + + let results = pipeline + .execute(vec![ + Op::ReadFile("missing.txt".to_string()), + Op::ReadFile("exists.txt".to_string()), + ]) + .await; + + assert_eq!(results.len(), 2); + match &results[0] { + OpResult::Error { path, .. } => assert_eq!(path, "missing.txt"), + other => panic!("expected Error, got {:?}", other), + } + match &results[1] { + OpResult::FileData { path, data } => { + assert_eq!(path, "exists.txt"); + assert_eq!(data, b"abc"); + } + other => panic!("expected FileData, got {:?}", other), + } + } +} diff --git a/vendor/smb2/src/client/session.rs b/vendor/smb2/src/client/session.rs new file mode 100644 index 0000000..d7b268a --- /dev/null +++ b/vendor/smb2/src/client/session.rs @@ -0,0 +1,769 @@ +//! Authenticated SMB2 session. +//! +//! The [`Session`] type manages the multi-round-trip SESSION_SETUP exchange +//! (NTLM authentication), key derivation, and signing activation. + +use log::{debug, info, trace, warn}; + +use crate::auth::ntlm::{NtlmAuthenticator, NtlmCredentials}; +use crate::client::connection::Connection; +use crate::crypto::kdf::derive_session_keys; +use crate::crypto::signing::{algorithm_for_dialect, SigningAlgorithm}; +use crate::error::Result; +use crate::msg::session_setup::{SessionSetupRequest, SessionSetupResponse}; +use crate::pack::{ReadCursor, Unpack}; +use crate::types::flags::{Capabilities, SecurityMode}; +use crate::types::status::NtStatus; +use crate::types::{Command, Dialect, SessionId}; +use crate::Error; + +use crate::msg::session_setup::SessionSetupRequestFlags; + +/// An authenticated SMB2 session with derived keys. +#[derive(Debug)] +pub struct Session { + /// The session ID assigned by the server. + pub session_id: SessionId, + /// Key used to sign outgoing messages. + pub signing_key: Vec, + /// Key used to encrypt outgoing messages (SMB 3.x). + pub encryption_key: Option>, + /// Key used to decrypt incoming messages (SMB 3.x). + pub decryption_key: Option>, + /// The signing algorithm to use. + pub signing_algorithm: SigningAlgorithm, + /// Whether outgoing messages should be signed. + pub should_sign: bool, + /// Whether outgoing messages should be encrypted. + pub should_encrypt: bool, +} + +impl Session { + /// Perform the multi-round-trip SESSION_SETUP exchange. + /// + /// Steps: + /// 1. Send NTLM NEGOTIATE_MESSAGE in SESSION_SETUP. + /// 2. Receive STATUS_MORE_PROCESSING_REQUIRED with CHALLENGE_MESSAGE. + /// 3. Update preauth hash with request+response. + /// 4. Send NTLM AUTHENTICATE_MESSAGE in SESSION_SETUP. + /// 5. Receive STATUS_SUCCESS with session flags. + /// 6. Update preauth hash with request+response. + /// 7. Derive signing/encryption keys. + /// 8. Activate signing on the connection. + pub async fn setup( + conn: &mut Connection, + username: &str, + password: &str, + domain: &str, + ) -> Result { + let params = conn + .params() + .ok_or_else(|| Error::invalid_data("negotiate must complete before session setup"))? + .clone(); + + let mut auth = NtlmAuthenticator::new(NtlmCredentials { + username: username.to_string(), + password: password.to_string(), + domain: domain.to_string(), + }); + + // Clone the preauth hasher for this session (spec: per-session hash). + let mut session_hasher = conn.preauth_hasher().clone(); + + // ── Round 1: NEGOTIATE_MESSAGE ── + debug!("session: round 1, sending NTLM negotiate"); + + let type1_bytes = auth.negotiate(); + + let req1 = SessionSetupRequest { + flags: SessionSetupRequestFlags(0), + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + capabilities: Capabilities::default(), + channel: 0, + previous_session_id: 0, + security_buffer: type1_bytes, + }; + + let (frame1, req1_raw) = conn + .execute_capturing_request(Command::SessionSetup, &req1, None) + .await?; + + // Update session preauth hash with request. + session_hasher.update(&req1_raw); + + let resp1_header = frame1.header; + let resp1_body = frame1.body; + + // Update session preauth hash with response. + session_hasher.update(&frame1.raw); + + if resp1_header.command != Command::SessionSetup { + return Err(Error::invalid_data(format!( + "expected SessionSetup response, got {:?}", + resp1_header.command + ))); + } + + if !resp1_header.status.is_more_processing_required() { + if resp1_header.status.is_error() { + return Err(Error::Protocol { + status: resp1_header.status, + command: Command::SessionSetup, + }); + } + return Err(Error::invalid_data( + "expected STATUS_MORE_PROCESSING_REQUIRED, got success on first round", + )); + } + + // The server assigned a session ID -- use it for subsequent requests. + debug!( + "session: round 1 complete, status={:?}, session_id={}", + resp1_header.status, resp1_header.session_id + ); + conn.set_session_id(resp1_header.session_id); + + // Parse the challenge response. + let mut cursor1 = ReadCursor::new(&resp1_body); + let setup_resp1 = SessionSetupResponse::unpack(&mut cursor1)?; + + // ── Round 2: AUTHENTICATE_MESSAGE ── + debug!("session: round 2, sending NTLM authenticate"); + + let type3_bytes = auth.authenticate(&setup_resp1.security_buffer)?; + + let req2 = SessionSetupRequest { + flags: SessionSetupRequestFlags(0), + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + capabilities: Capabilities::default(), + channel: 0, + previous_session_id: 0, + security_buffer: type3_bytes, + }; + + let (frame2, req2_raw) = conn + .execute_capturing_request(Command::SessionSetup, &req2, None) + .await?; + + // Update session preauth hash with the request ONLY. + // The final SESSION_SETUP response (STATUS_SUCCESS) is NOT + // included in the preauth hash (spec section 3.2.5.3.1). + // Only STATUS_MORE_PROCESSING_REQUIRED responses are hashed. + session_hasher.update(&req2_raw); + + let resp2_header = frame2.header; + let resp2_body = frame2.body; + + // Do NOT hash the success response -- the preauth hash used for + // key derivation contains only messages up to (and including) + // the final authenticate request, not the success response. + + if resp2_header.command != Command::SessionSetup { + return Err(Error::invalid_data(format!( + "expected SessionSetup response, got {:?}", + resp2_header.command + ))); + } + + if resp2_header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: resp2_header.status, + command: Command::SessionSetup, + }); + } + + // Parse the final response. + let mut cursor2 = ReadCursor::new(&resp2_body); + let setup_resp2 = SessionSetupResponse::unpack(&mut cursor2)?; + + let session_id = resp2_header.session_id; + conn.set_session_id(session_id); + + // Get the session key from NTLM. + let session_key = auth + .session_key() + .ok_or_else(|| Error::Auth { + message: "NTLM did not produce a session key".to_string(), + })? + .to_vec(); + + // Determine signing algorithm. + let gmac_negotiated = params.gmac_negotiated; + let signing_algorithm = algorithm_for_dialect(params.dialect, gmac_negotiated); + debug!( + "session: signing_algo={:?}, dialect={}", + signing_algorithm, params.dialect + ); + + // Derive keys for SMB 3.x, or use session key directly for SMB 2.x. + trace!( + "session: deriving keys, session_key_len={}", + session_key.len() + ); + let (signing_key, encryption_key, decryption_key) = match params.dialect { + Dialect::Smb3_0 | Dialect::Smb3_0_2 => { + let keys = derive_session_keys(&session_key, params.dialect, None, 128); + ( + keys.signing_key, + Some(keys.encryption_key), + Some(keys.decryption_key), + ) + } + Dialect::Smb3_1_1 => { + // Key length: 256 bits only for AES-256 ciphers. GMAC signing + // uses AES-128-GCM internally, so it needs 128-bit (16-byte) keys. + let key_len_bits = match params.cipher { + Some(crate::crypto::encryption::Cipher::Aes256Ccm) + | Some(crate::crypto::encryption::Cipher::Aes256Gcm) => 256, + _ => 128, + }; + let keys = derive_session_keys( + &session_key, + Dialect::Smb3_1_1, + Some(session_hasher.value()), + key_len_bits, + ); + ( + keys.signing_key, + Some(keys.encryption_key), + Some(keys.decryption_key), + ) + } + _ => { + // SMB 2.x: use session key directly for signing. + (session_key.clone(), None, None) + } + }; + + // Determine if we should sign. + let should_sign = params.signing_required + || !setup_resp2.session_flags.is_guest() && !setup_resp2.session_flags.is_null(); + + let should_encrypt = setup_resp2.session_flags.encrypt_data(); + + // Activate signing on the connection. + if should_sign { + conn.activate_signing(signing_key.clone(), signing_algorithm); + } + + // Activate encryption on the connection if the session requires it. + // The cipher comes from negotiate contexts (SMB 3.1.1). If the server + // didn't send one (for example, Samba with `smb encrypt = required` sometimes + // omits the encryption context), fall back to AES-128-CCM which is + // universally supported by all SMB 3.x servers. + if should_encrypt { + let cipher = params + .cipher + .unwrap_or(crate::crypto::encryption::Cipher::Aes128Ccm); + if let (Some(ref enc_key), Some(ref dec_key)) = (&encryption_key, &decryption_key) { + conn.activate_encryption(enc_key.clone(), dec_key.clone(), cipher); + } else { + warn!( + "session: encryption requested but missing keys, \ + enc_key={}, dec_key={}", + encryption_key.is_some(), + decryption_key.is_some(), + ); + } + } + + info!( + "session: established, session_id={}, sign={}, encrypt={}", + session_id, should_sign, should_encrypt + ); + + Ok(Session { + session_id, + signing_key, + encryption_key, + decryption_key, + signing_algorithm, + should_sign, + should_encrypt, + }) + } + + /// Perform Kerberos-based SESSION_SETUP. + /// + /// Authenticates against the KDC first (AS + TGS), then sends the + /// SPNEGO-wrapped AP-REQ in SESSION_SETUP. Handles both single-round + /// (STATUS_SUCCESS) and mutual-auth (STATUS_MORE_PROCESSING_REQUIRED) + /// flows. + /// + /// The session key comes from the Kerberos TGS exchange, not from the + /// SMB server response. + /// Perform Kerberos-based SESSION_SETUP using a credential cache. + /// + /// Reads cached tickets from the ccache. If a service ticket for + /// `cifs/` is cached, uses it directly (no KDC needed). + /// If only a TGT is cached, does a TGS exchange for the service ticket. + pub async fn setup_kerberos_from_ccache( + conn: &mut Connection, + credentials: &crate::auth::kerberos::KerberosCredentials, + server_hostname: &str, + ccache: &crate::auth::kerberos::ccache::CCache, + ) -> Result { + let mut auth = crate::auth::kerberos::KerberosAuthenticator::new(credentials.clone()); + auth.authenticate_from_ccache(ccache, server_hostname) + .await?; + Self::setup_kerberos_with_auth(conn, &mut auth).await + } + + /// Perform Kerberos-based SESSION_SETUP. + /// + /// Authenticates against the KDC first (AS + TGS), then sends the + /// SPNEGO-wrapped AP-REQ in SESSION_SETUP. Handles both single-round + /// (STATUS_SUCCESS) and mutual-auth (STATUS_MORE_PROCESSING_REQUIRED) + /// flows. + /// + /// The session key comes from the Kerberos TGS exchange, not from the + /// SMB server response. + pub async fn setup_kerberos( + conn: &mut Connection, + credentials: &crate::auth::kerberos::KerberosCredentials, + server_hostname: &str, + ) -> Result { + let mut auth = crate::auth::kerberos::KerberosAuthenticator::new(credentials.clone()); + auth.authenticate(server_hostname).await?; + Self::setup_kerberos_with_auth(conn, &mut auth).await + } + + /// Shared Kerberos SESSION_SETUP logic used by both password-based + /// and ccache-based authentication paths. + async fn setup_kerberos_with_auth( + conn: &mut Connection, + auth: &mut crate::auth::kerberos::KerberosAuthenticator, + ) -> Result { + let params = conn + .params() + .ok_or_else(|| Error::invalid_data("negotiate must complete before session setup"))? + .clone(); + + let token = auth + .token() + .ok_or_else(|| Error::Auth { + message: "Kerberos authentication produced no token".to_string(), + })? + .to_vec(); + + debug!("session: Kerberos auth complete, token_len={}", token.len()); + + // Clone the preauth hasher for this session. + let mut session_hasher = conn.preauth_hasher().clone(); + + // Step 2: Send SPNEGO-wrapped AP-REQ in SESSION_SETUP. + let req = SessionSetupRequest { + flags: SessionSetupRequestFlags(0), + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + capabilities: Capabilities::default(), + channel: 0, + previous_session_id: 0, + security_buffer: token, + }; + + let (frame, req_raw) = conn + .execute_capturing_request(Command::SessionSetup, &req, None) + .await?; + + // Hash the request (same as NTLM round 1). + session_hasher.update(&req_raw); + + let resp_header = frame.header; + let resp_body = frame.body; + let resp_raw = frame.raw; + + if resp_header.command != Command::SessionSetup { + return Err(Error::invalid_data(format!( + "expected SessionSetup response, got {:?}", + resp_header.command + ))); + } + + if resp_header.status != NtStatus::SUCCESS + && !resp_header.status.is_more_processing_required() + { + return Err(Error::Protocol { + status: resp_header.status, + command: Command::SessionSetup, + }); + } + + // The server assigned a session ID. + let session_id = resp_header.session_id; + conn.set_session_id(session_id); + + let mut cursor = ReadCursor::new(&resp_body); + let setup_resp = SessionSetupResponse::unpack(&mut cursor)?; + + if resp_header.status.is_more_processing_required() { + debug!( + "session: Kerberos got MORE_PROCESSING_REQUIRED, session_id={}", + session_id + ); + + // Hash the response per MS-SMB2 3.2.5.3.1. + session_hasher.update(&resp_raw); + } + + // Process the SPNEGO response token (AP-REP or KRB-ERROR). + // This applies to both STATUS_SUCCESS and MORE_PROCESSING_REQUIRED — + // the server may include an AP-REP with a sub-session key in either. + if !setup_resp.security_buffer.is_empty() { + let spnego_resp = + crate::auth::spnego::parse_neg_token_resp(&setup_resp.security_buffer)?; + debug!( + "session: SPNEGO state={:?}, has_token={}, supported_mech={:02x?}", + spnego_resp.neg_state, + spnego_resp.response_token.is_some(), + spnego_resp.supported_mech.as_deref().unwrap_or(&[]), + ); + + if let Some(ref token_bytes) = spnego_resp.response_token { + auth.process_mutual_auth_token(token_bytes)?; + } + } + + // Get the session key AFTER processing the AP-REP (the server's + // subkey may have overridden ours). + // + // Per MS-SMB2 3.2.5.3: "Session.SessionKey MUST be set to the first + // 16 bytes of the cryptographic key queried from the GSS protocol." + let full_key = auth.session_key().ok_or_else(|| Error::Auth { + message: "Kerberos authentication produced no session key".to_string(), + })?; + let session_key = if full_key.len() > 16 { + full_key[..16].to_vec() + } else { + full_key.to_vec() + }; + + debug!( + "session: Kerberos session_key_len={} (truncated from {})", + session_key.len(), + full_key.len() + ); + + // Determine signing algorithm. + let signing_algorithm = algorithm_for_dialect(params.dialect, params.gmac_negotiated); + debug!( + "session: Kerberos signing_algo={:?}, dialect={}", + signing_algorithm, params.dialect + ); + + // Derive keys for SMB 3.x using the Kerberos session key. + let (signing_key, encryption_key, decryption_key) = match params.dialect { + Dialect::Smb3_0 | Dialect::Smb3_0_2 => { + let keys = derive_session_keys(&session_key, params.dialect, None, 128); + ( + keys.signing_key, + Some(keys.encryption_key), + Some(keys.decryption_key), + ) + } + Dialect::Smb3_1_1 => { + let key_len_bits = match params.cipher { + Some(crate::crypto::encryption::Cipher::Aes256Ccm) + | Some(crate::crypto::encryption::Cipher::Aes256Gcm) => 256, + _ => 128, + }; + let keys = derive_session_keys( + &session_key, + Dialect::Smb3_1_1, + Some(session_hasher.value()), + key_len_bits, + ); + ( + keys.signing_key, + Some(keys.encryption_key), + Some(keys.decryption_key), + ) + } + _ => (session_key.clone(), None, None), + }; + + let should_sign = params.signing_required + || !setup_resp.session_flags.is_guest() && !setup_resp.session_flags.is_null(); + + let should_encrypt = setup_resp.session_flags.encrypt_data(); + + if should_sign { + conn.activate_signing(signing_key.clone(), signing_algorithm); + } + + if should_encrypt { + let cipher = params + .cipher + .unwrap_or(crate::crypto::encryption::Cipher::Aes128Ccm); + if let (Some(ref enc_key), Some(ref dec_key)) = (&encryption_key, &decryption_key) { + conn.activate_encryption(enc_key.clone(), dec_key.clone(), cipher); + } + } + + info!( + "session: Kerberos established, session_id={}, sign={}, encrypt={}", + session_id, should_sign, should_encrypt + ); + + Ok(Session { + session_id, + signing_key, + encryption_key, + decryption_key, + signing_algorithm, + should_sign, + should_encrypt, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::connection::{pack_message, Connection, NegotiatedParams}; + use crate::msg::header::Header; + use crate::msg::session_setup::{SessionFlags, SessionSetupResponse}; + use crate::pack::Guid; + use crate::transport::MockTransport; + use crate::types::flags::Capabilities; + use crate::types::status::NtStatus; + use crate::types::{Command, Dialect, SessionId}; + use std::sync::Arc; + + /// Build a session setup response with the given status and session ID. + fn build_session_setup_response( + status: NtStatus, + session_id: SessionId, + security_buffer: Vec, + session_flags: SessionFlags, + ) -> Vec { + let mut h = Header::new_request(Command::SessionSetup); + h.flags.set_response(); + h.credits = 32; + h.status = status; + h.session_id = session_id; + + let body = SessionSetupResponse { + session_flags, + security_buffer, + }; + + pack_message(&h, &body) + } + + /// Build a minimal NTLM challenge message (Type 2). + /// + /// This is a stripped-down challenge that the NtlmAuthenticator can parse. + fn build_ntlm_challenge() -> Vec { + let mut buf = Vec::new(); + + // Signature (8 bytes) + buf.extend_from_slice(b"NTLMSSP\0"); + // MessageType = 2 (4 bytes) + buf.extend_from_slice(&2u32.to_le_bytes()); + // TargetNameFields: Len=0, MaxLen=0, Offset=56 + buf.extend_from_slice(&0u16.to_le_bytes()); // Len + buf.extend_from_slice(&0u16.to_le_bytes()); // MaxLen + buf.extend_from_slice(&56u32.to_le_bytes()); // Offset + // NegotiateFlags + let flags: u32 = 0x0000_0001 // UNICODE + | 0x0000_0200 // NTLM + | 0x0008_0000 // EXTENDED_SESSIONSECURITY + | 0x0080_0000 // TARGET_INFO + | 0x2000_0000 // 128 + | 0x4000_0000 // KEY_EXCH + | 0x8000_0000 // 56 + | 0x0000_0010 // SIGN + | 0x0000_0020; // SEAL + buf.extend_from_slice(&flags.to_le_bytes()); + // ServerChallenge (8 bytes) + buf.extend_from_slice(&[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF]); + // Reserved (8 bytes) + buf.extend_from_slice(&[0u8; 8]); + + // TargetInfoFields: Len, MaxLen, Offset (will be at offset 56 + target_name_len) + // Build target info: just MsvAvEOL + let target_info = build_av_eol(); + let ti_offset = 56u32; // right after the fixed header + buf.extend_from_slice(&(target_info.len() as u16).to_le_bytes()); // Len + buf.extend_from_slice(&(target_info.len() as u16).to_le_bytes()); // MaxLen + buf.extend_from_slice(&ti_offset.to_le_bytes()); // Offset + + // Ensure we're at offset 56 (pad if needed). + while buf.len() < 56 { + buf.push(0); + } + + // Target info data + buf.extend_from_slice(&target_info); + + buf + } + + /// Build an AV_PAIR list with just MsvAvEOL. + fn build_av_eol() -> Vec { + let mut buf = Vec::new(); + // MsvAvEOL: AvId=0, AvLen=0 + buf.extend_from_slice(&0u16.to_le_bytes()); + buf.extend_from_slice(&0u16.to_le_bytes()); + buf + } + + #[tokio::test] + async fn session_setup_stores_session_id() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let session_id = SessionId(0xDEAD_BEEF); + + // Queue the two session setup responses. + let challenge = build_ntlm_challenge(); + mock.queue_response(build_session_setup_response( + NtStatus::MORE_PROCESSING_REQUIRED, + session_id, + challenge, + SessionFlags(0), + )); + mock.queue_response(build_session_setup_response( + NtStatus::SUCCESS, + session_id, + vec![], + SessionFlags(0), + )); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + + // Set up negotiate params (pretend we already negotiated). + // We need to call negotiate or set params manually. + // Let's also queue a negotiate response first. + // Actually, let's set params directly. + set_test_params(&mut conn, Dialect::Smb2_0_2); + + let session = Session::setup(&mut conn, "user", "pass", "").await.unwrap(); + assert_eq!(session.session_id, session_id); + } + + #[tokio::test] + async fn session_setup_derives_signing_key() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let session_id = SessionId(0x1234); + + let challenge = build_ntlm_challenge(); + mock.queue_response(build_session_setup_response( + NtStatus::MORE_PROCESSING_REQUIRED, + session_id, + challenge, + SessionFlags(0), + )); + mock.queue_response(build_session_setup_response( + NtStatus::SUCCESS, + session_id, + vec![], + SessionFlags(0), + )); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + set_test_params(&mut conn, Dialect::Smb2_0_2); + + let session = Session::setup(&mut conn, "user", "pass", "").await.unwrap(); + assert!(!session.signing_key.is_empty()); + } + + #[tokio::test] + async fn session_setup_activates_signing() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let session_id = SessionId(0x5678); + + let challenge = build_ntlm_challenge(); + mock.queue_response(build_session_setup_response( + NtStatus::MORE_PROCESSING_REQUIRED, + session_id, + challenge, + SessionFlags(0), + )); + mock.queue_response(build_session_setup_response( + NtStatus::SUCCESS, + session_id, + vec![], + SessionFlags(0), + )); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + set_test_params(&mut conn, Dialect::Smb2_0_2); + + let session = Session::setup(&mut conn, "user", "pass", "").await.unwrap(); + assert!(session.should_sign); + assert_eq!(session.signing_algorithm, SigningAlgorithm::HmacSha256); + } + + #[tokio::test] + async fn session_setup_error_on_auth_failure() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let session_id = SessionId(0x9999); + + let challenge = build_ntlm_challenge(); + mock.queue_response(build_session_setup_response( + NtStatus::MORE_PROCESSING_REQUIRED, + session_id, + challenge, + SessionFlags(0), + )); + // Auth fails on second round. + mock.queue_response(build_session_setup_response( + NtStatus::LOGON_FAILURE, + session_id, + vec![], + SessionFlags(0), + )); + + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + set_test_params(&mut conn, Dialect::Smb2_0_2); + + let result = Session::setup(&mut conn, "user", "badpass", "").await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + matches!( + err, + Error::Protocol { + status: NtStatus::LOGON_FAILURE, + .. + } + ), + "expected LOGON_FAILURE, got: {err}" + ); + } + + /// Helper: set fake negotiated params on a connection. + fn set_test_params(conn: &mut Connection, dialect: Dialect) { + conn.set_test_params(NegotiatedParams { + dialect, + max_read_size: 65536, + max_write_size: 65536, + max_transact_size: 65536, + server_guid: Guid::ZERO, + signing_required: false, + capabilities: Capabilities::default(), + gmac_negotiated: false, + cipher: None, + compression_supported: false, + }); + } +} diff --git a/vendor/smb2/src/client/shares.rs b/vendor/smb2/src/client/shares.rs new file mode 100644 index 0000000..59d8051 --- /dev/null +++ b/vendor/smb2/src/client/shares.rs @@ -0,0 +1,764 @@ +//! Share enumeration via IPC$ + srvsvc RPC. +//! +//! Lists available shares on an SMB server by connecting to the IPC$ share, +//! opening the srvsvc named pipe, and performing the NetShareEnumAll RPC +//! exchange. + +use log::{debug, info}; + +use crate::client::connection::Connection; +use crate::error::Result; +use crate::msg::close::CloseRequest; +use crate::msg::create::{ + CreateDisposition, CreateRequest, CreateResponse, ImpersonationLevel, ShareAccess, +}; +use crate::msg::read::{ReadRequest, ReadResponse, SMB2_CHANNEL_NONE}; +use crate::msg::tree_connect::{TreeConnectRequest, TreeConnectRequestFlags, TreeConnectResponse}; +use crate::msg::tree_disconnect::TreeDisconnectRequest; +use crate::msg::write::{WriteRequest, WriteResponse}; +use crate::pack::{ReadCursor, Unpack}; +use crate::rpc; +use crate::rpc::srvsvc::{self, ShareInfo}; +use crate::types::flags::FileAccessMask; +use crate::types::status::NtStatus; +use crate::types::{Command, FileId, OplockLevel, TreeId}; +use crate::Error; + +/// Read buffer size for pipe reads (64 KiB is plenty for share listings). +const PIPE_READ_BUFFER_SIZE: u32 = 65536; + +/// List available shares on the server. +/// +/// Connects to the IPC$ share, opens the srvsvc named pipe, performs +/// the RPC exchange, and returns filtered disk shares. +/// +/// This is a self-contained operation -- it opens and closes its own +/// tree connection to IPC$. +pub async fn list_shares(conn: &mut Connection) -> Result> { + // 1. Tree connect to IPC$ + let tree_id = tree_connect_ipc(conn).await?; + + // Run the pipe operations, then clean up regardless of outcome + let result = pipe_rpc_exchange(conn, tree_id).await; + + // 8. Tree disconnect (best-effort -- don't mask the real error) + let _ = tree_disconnect(conn, tree_id).await; + + let all_shares = result?; + + // 9. Filter to disk shares + let filtered = srvsvc::filter_disk_shares(all_shares); + info!("shares: found {} disk shares", filtered.len()); + Ok(filtered) +} + +/// Connect to the IPC$ share, returning the tree ID. +async fn tree_connect_ipc(conn: &mut Connection) -> Result { + let server = conn.server_name().to_string(); + let unc_path = format!(r"\\{}\IPC$", server); + + let req = TreeConnectRequest { + flags: TreeConnectRequestFlags::default(), + path: unc_path, + }; + + let frame = conn.execute(Command::TreeConnect, &req, None).await?; + + if frame.header.command != Command::TreeConnect { + return Err(Error::invalid_data(format!( + "expected TreeConnect response, got {:?}", + frame.header.command + ))); + } + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::TreeConnect, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let _resp = TreeConnectResponse::unpack(&mut cursor)?; + + let tree_id = frame + .header + .tree_id + .ok_or_else(|| Error::invalid_data("TreeConnect response missing tree ID"))?; + + info!("shares: connected to IPC$, tree_id={}", tree_id); + Ok(tree_id) +} + +/// Open the srvsvc pipe, perform the RPC bind and request, then close. +async fn pipe_rpc_exchange(conn: &mut Connection, tree_id: TreeId) -> Result> { + // 2. Create \pipe\srvsvc + let file_id = open_srvsvc_pipe(conn, tree_id).await?; + + // Run RPC exchange, then close regardless of outcome + let result = rpc_bind_and_request(conn, tree_id, file_id).await; + + // 7. Close the pipe handle (best-effort) + let _ = close_handle(conn, tree_id, file_id).await; + + result +} + +/// Perform the RPC bind + NetShareEnumAll request over the pipe. +async fn rpc_bind_and_request( + conn: &mut Connection, + tree_id: TreeId, + file_id: FileId, +) -> Result> { + // 3. Write RPC BIND + let bind_data = rpc::build_srvsvc_bind(1); + write_pipe(conn, tree_id, file_id, &bind_data).await?; + debug!("shares: sent RPC BIND ({} bytes)", bind_data.len()); + + // 4. Read RPC BIND_ACK + let bind_ack_data = read_pipe_message(conn, tree_id, file_id).await?; + rpc::parse_bind_ack(&bind_ack_data)?; + debug!("shares: received BIND_ACK, context accepted"); + + // 5. Write RPC REQUEST (NetShareEnumAll) + let server_name = format!(r"\\{}", conn.server_name()); + let request_data = srvsvc::build_net_share_enum_all(2, &server_name); + write_pipe(conn, tree_id, file_id, &request_data).await?; + debug!( + "shares: sent NetShareEnumAll request ({} bytes)", + request_data.len() + ); + + // 6. Read RPC RESPONSE, reassembling DCE/RPC fragments (MS-RPCE 2.2.2.6). + // A large NetShareEnum reply may arrive as several fragment PDUs, each its + // own pipe message, with PFC_LAST_FRAG set only on the last. + let mut stub = Vec::new(); + let mut fragments = 0; + loop { + let pdu = read_pipe_message(conn, tree_id, file_id).await?; + let (frag_stub, is_last) = rpc::parse_response_fragment(&pdu)?; + stub.extend_from_slice(frag_stub); + fragments += 1; + if is_last { + break; + } + } + let shares = srvsvc::parse_net_share_enum_all_stub(&stub)?; + debug!( + "shares: received {} shares in response ({} RPC fragment(s))", + shares.len(), + fragments + ); + + Ok(shares) +} + +/// Open the `\pipe\srvsvc` named pipe via CREATE. +async fn open_srvsvc_pipe(conn: &mut Connection, tree_id: TreeId) -> Result { + let req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::FILE_READ_DATA | FileAccessMask::FILE_WRITE_DATA, + ), + file_attributes: 0, + share_access: ShareAccess(ShareAccess::FILE_SHARE_READ | ShareAccess::FILE_SHARE_WRITE), + create_disposition: CreateDisposition::FileOpen, + create_options: 0, + name: r"srvsvc".to_string(), + create_contexts: vec![], + }; + + let frame = conn.execute(Command::Create, &req, Some(tree_id)).await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Create, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = CreateResponse::unpack(&mut cursor)?; + debug!("shares: opened srvsvc pipe, file_id={:?}", resp.file_id); + Ok(resp.file_id) +} + +/// Write data to the pipe. +async fn write_pipe( + conn: &mut Connection, + tree_id: TreeId, + file_id: FileId, + data: &[u8], +) -> Result<()> { + // DataOffset: header (64) + fixed write body (48) = 112 = 0x70 + let req = WriteRequest { + data_offset: 0x70, + offset: 0, + file_id, + channel: 0, + remaining_bytes: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + flags: 0, + data: data.to_vec(), + }; + + let frame = conn.execute(Command::Write, &req, Some(tree_id)).await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Write, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = WriteResponse::unpack(&mut cursor)?; + debug!("shares: wrote {} bytes to pipe", resp.count); + Ok(()) +} + +/// Read one complete pipe message, following `STATUS_BUFFER_OVERFLOW`. +/// +/// A pipe message larger than our read buffer comes back as one or more +/// `STATUS_BUFFER_OVERFLOW` reads carrying partial data, terminated by a +/// `STATUS_SUCCESS` read with the remainder (MS-SMB2 3.3.5.10). We append each +/// chunk until a `SUCCESS` read completes the message. +async fn read_pipe_message( + conn: &mut Connection, + tree_id: TreeId, + file_id: FileId, +) -> Result> { + let mut message = Vec::new(); + + loop { + let req = ReadRequest { + padding: 0x50, + flags: 0, + length: PIPE_READ_BUFFER_SIZE, + offset: 0, + file_id, + minimum_count: 0, + channel: SMB2_CHANNEL_NONE, + remaining_bytes: 0, + read_channel_info: vec![], + }; + + let frame = conn.execute(Command::Read, &req, Some(tree_id)).await?; + + let status = frame.header.status; + // BUFFER_OVERFLOW is a warning meaning "partial data, read again", not a + // failure -- accept it alongside SUCCESS. + if !status.is_success_or_partial() { + return Err(Error::Protocol { + status, + command: Command::Read, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = ReadResponse::unpack(&mut cursor)?; + let chunk_len = resp.data.len(); + message.extend_from_slice(&resp.data); + + // SUCCESS completes the message; BUFFER_OVERFLOW means read more. + if status != NtStatus::BUFFER_OVERFLOW { + break; + } + // Guard against a server that signals overflow but sends no data, which + // would otherwise spin forever. + if chunk_len == 0 { + return Err(Error::invalid_data( + "pipe read returned BUFFER_OVERFLOW with no data", + )); + } + } + + debug!("shares: read {} bytes from pipe", message.len()); + Ok(message) +} + +/// Close a file handle. +async fn close_handle(conn: &mut Connection, tree_id: TreeId, file_id: FileId) -> Result<()> { + let req = CloseRequest { flags: 0, file_id }; + + let frame = conn.execute(Command::Close, &req, Some(tree_id)).await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Close, + }); + } + + Ok(()) +} + +/// Disconnect from a tree. +async fn tree_disconnect(conn: &mut Connection, tree_id: TreeId) -> Result<()> { + let body = TreeDisconnectRequest; + let frame = conn + .execute(Command::TreeDisconnect, &body, Some(tree_id)) + .await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::TreeDisconnect, + }); + } + + info!("shares: disconnected from IPC$"); + Ok(()) +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use crate::client::connection::{pack_message, NegotiatedParams}; + use crate::client::test_helpers::{ + build_close_response, build_create_response, build_tree_connect_response, setup_connection, + }; + use crate::msg::header::Header; + use crate::msg::read::ReadResponse as ReadResp; + use crate::msg::tree_connect::ShareType; + use crate::msg::tree_disconnect::TreeDisconnectResponse; + use crate::msg::write::WriteResponse as WriteResp; + use crate::pack::Guid; + use crate::rpc::srvsvc::{STYPE_DISKTREE, STYPE_IPC, STYPE_SPECIAL}; + use crate::transport::MockTransport; + use crate::types::flags::Capabilities; + use crate::types::{Dialect, SessionId, TreeId}; + use std::sync::Arc; + + fn build_write_response(count: u32) -> Vec { + let mut h = Header::new_request(Command::Write); + h.flags.set_response(); + h.credits = 32; + + let body = WriteResp { + count, + remaining: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + }; + + pack_message(&h, &body) + } + + fn build_read_response(data: Vec) -> Vec { + build_read_response_with_status(data, NtStatus::SUCCESS) + } + + /// Build a READ response with an explicit NTSTATUS. + /// + /// Pipe reads use `STATUS_BUFFER_OVERFLOW` to mean "this read returned a + /// partial message; read again for the rest." + fn build_read_response_with_status(data: Vec, status: NtStatus) -> Vec { + let mut h = Header::new_request(Command::Read); + h.flags.set_response(); + h.credits = 32; + h.status = status; + + let body = ReadResp { + data_offset: 0x50, + data_remaining: 0, + flags: 0, + data, + }; + + pack_message(&h, &body) + } + + fn build_tree_disconnect_response() -> Vec { + let mut h = Header::new_request(Command::TreeDisconnect); + h.flags.set_response(); + h.credits = 32; + pack_message(&h, &TreeDisconnectResponse) + } + + /// Build a canned RPC BIND_ACK response. + fn build_bind_ack() -> Vec { + use crate::pack::WriteCursor; + + let mut w = WriteCursor::with_capacity(64); + // Common header + w.write_u8(5); // version + w.write_u8(0); // version minor + w.write_u8(12); // BIND_ACK type + w.write_u8(0x03); // flags (first + last) + w.write_bytes(&[0x10, 0x00, 0x00, 0x00]); // data rep + let frag_len_pos = w.position(); + w.write_u16_le(0); // frag length placeholder + w.write_u16_le(0); // auth length + w.write_u32_le(1); // call id + + // BIND_ACK specific + w.write_u16_le(4280); // max xmit frag + w.write_u16_le(4280); // max recv frag + w.write_u32_le(0x12345); // assoc group + + // Secondary address (empty) + w.write_u16_le(0); + w.write_bytes(&[0, 0]); // padding + + // Result list + w.write_u8(1); // num results + w.write_bytes(&[0, 0, 0]); // reserved + w.write_u16_le(0); // result = accepted + w.write_u16_le(0); // reason + + // Transfer syntax UUID + version (20 bytes) + use crate::pack::Pack; + let ndr_uuid = Guid { + data1: 0x8A885D04, + data2: 0x1CEB, + data3: 0x11C9, + data4: [0x9F, 0xE8, 0x08, 0x00, 0x2B, 0x10, 0x48, 0x60], + }; + ndr_uuid.pack(&mut w); + w.write_u32_le(2); + + let total_len = w.position(); + w.set_u16_le_at(frag_len_pos, total_len as u16); + w.into_inner() + } + + /// Build the NDR stub for a NetShareEnumAll RESPONSE (no RPC envelope). + fn build_share_enum_stub(shares: &[(&str, u32, &str)]) -> Vec { + use crate::pack::WriteCursor; + + // Build NDR stub + let mut w = WriteCursor::with_capacity(512); + let count = shares.len() as u32; + + // Level = 1 + w.write_u32_le(1); + // Union discriminant = 1 + w.write_u32_le(1); + + if count == 0 { + w.write_u32_le(0); // null container + w.write_u32_le(0); // total entries + w.write_u32_le(0); // resume handle + w.write_u32_le(0); // return value + } else { + // Container pointer + w.write_u32_le(0x0002_0000); + // EntriesRead + w.write_u32_le(count); + // Array pointer + w.write_u32_le(0x0002_0004); + // MaxCount + w.write_u32_le(count); + + // Fixed entries + for (i, &(_, share_type, _)) in shares.iter().enumerate() { + w.write_u32_le(0x0002_0008 + (i as u32) * 2); // name ref + w.write_u32_le(share_type); + w.write_u32_le(0x0002_0108 + (i as u32) * 2); // comment ref + } + + // Deferred strings + for &(name, _, comment) in shares { + write_ndr_string(&mut w, name); + write_ndr_string(&mut w, comment); + } + + w.write_u32_le(count); // total entries + w.write_u32_le(0); // resume handle + w.write_u32_le(0); // return value + } + + w.into_inner() + } + + /// Wrap NDR stub bytes in an RPC RESPONSE PDU with the given PFC flags. + /// + /// `pfc_flags` lets a caller emit a fragment (for example, `PFC_FIRST_FRAG` + /// alone for a non-final fragment) instead of the usual `FIRST | LAST`. + fn wrap_rpc_response_pdu(stub_chunk: &[u8], pfc_flags: u8) -> Vec { + use crate::pack::WriteCursor; + + let mut w = WriteCursor::with_capacity(24 + stub_chunk.len()); + w.write_u8(5); + w.write_u8(0); + w.write_u8(2); // RESPONSE + w.write_u8(pfc_flags); + w.write_bytes(&[0x10, 0x00, 0x00, 0x00]); + let frag_len_pos = w.position(); + w.write_u16_le(0); + w.write_u16_le(0); + w.write_u32_le(2); // call id + + w.write_u32_le(stub_chunk.len() as u32); // alloc hint + w.write_u16_le(0); // context id + w.write_u8(0); // cancel count + w.write_u8(0); // reserved + + w.write_bytes(stub_chunk); + + let total_len = w.position(); + w.set_u16_le_at(frag_len_pos, total_len as u16); + w.into_inner() + } + + /// Build a canned single-fragment RPC RESPONSE with NetShareEnumAll data. + fn build_share_enum_response(shares: &[(&str, u32, &str)]) -> Vec { + // 0x03 = PFC_FIRST_FRAG | PFC_LAST_FRAG (a complete, single-fragment PDU). + wrap_rpc_response_pdu(&build_share_enum_stub(shares), 0x03) + } + + fn write_ndr_string(w: &mut crate::pack::WriteCursor, s: &str) { + let utf16: Vec = s.encode_utf16().chain(std::iter::once(0)).collect(); + let char_count = utf16.len() as u32; + w.write_u32_le(char_count); + w.write_u32_le(0); + w.write_u32_le(char_count); + for &code_unit in &utf16 { + w.write_u16_le(code_unit); + } + w.align_to(4); + } + + /// Queue all the responses needed for a full list_shares flow. + pub(crate) fn queue_share_listing_responses( + mock: &MockTransport, + shares: &[(&str, u32, &str)], + ) { + let tree_id = TreeId(42); + let file_id = FileId { + persistent: 0xAAAA, + volatile: 0xBBBB, + }; + + // 1. TREE_CONNECT response + mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe)); + // 2. CREATE response (open srvsvc pipe) + mock.queue_response(build_create_response(file_id, 0)); + // 3. WRITE response (RPC BIND) + mock.queue_response(build_write_response(72)); + // 4. READ response (BIND_ACK) + mock.queue_response(build_read_response(build_bind_ack())); + // 5. WRITE response (NetShareEnumAll request) + mock.queue_response(build_write_response(100)); + // 6. READ response (NetShareEnumAll response) + mock.queue_response(build_read_response(build_share_enum_response(shares))); + // 7. CLOSE response + mock.queue_response(build_close_response()); + // 8. TREE_DISCONNECT response + mock.queue_response(build_tree_disconnect_response()); + } + + /// Like `queue_share_listing_responses`, but the server splits a single + /// RPC RESPONSE PDU across two pipe reads: the first read returns + /// `STATUS_BUFFER_OVERFLOW` with the leading bytes, the second returns + /// `SUCCESS` with the rest. The client must stitch them before parsing. + fn queue_overflow_share_listing_responses(mock: &MockTransport, shares: &[(&str, u32, &str)]) { + let tree_id = TreeId(42); + let file_id = FileId { + persistent: 0xAAAA, + volatile: 0xBBBB, + }; + + let pdu = build_share_enum_response(shares); + let split = pdu.len() / 2; + let (first, rest) = pdu.split_at(split); + + mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe)); + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_write_response(72)); + mock.queue_response(build_read_response(build_bind_ack())); + mock.queue_response(build_write_response(100)); + // The response PDU arrives in two chunks: overflow then success. + mock.queue_response(build_read_response_with_status( + first.to_vec(), + NtStatus::BUFFER_OVERFLOW, + )); + mock.queue_response(build_read_response_with_status( + rest.to_vec(), + NtStatus::SUCCESS, + )); + mock.queue_response(build_close_response()); + mock.queue_response(build_tree_disconnect_response()); + } + + /// Like `queue_share_listing_responses`, but the RPC RESPONSE is split into + /// two DCE/RPC fragments (each its own pipe message): the first carries + /// `PFC_FIRST_FRAG`, the second `PFC_LAST_FRAG`. The client must reassemble + /// the stub across fragments before parsing. + fn queue_fragmented_share_listing_responses( + mock: &MockTransport, + shares: &[(&str, u32, &str)], + ) { + let tree_id = TreeId(42); + let file_id = FileId { + persistent: 0xAAAA, + volatile: 0xBBBB, + }; + + let stub = build_share_enum_stub(shares); + let split = stub.len() / 2; + let (first, rest) = stub.split_at(split); + let frag1 = wrap_rpc_response_pdu(first, 0x01); // PFC_FIRST_FRAG only + let frag2 = wrap_rpc_response_pdu(rest, 0x02); // PFC_LAST_FRAG only + + mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe)); + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_write_response(72)); + mock.queue_response(build_read_response(build_bind_ack())); + mock.queue_response(build_write_response(100)); + mock.queue_response(build_read_response(frag1)); + mock.queue_response(build_read_response(frag2)); + mock.queue_response(build_close_response()); + mock.queue_response(build_tree_disconnect_response()); + } + + #[tokio::test] + async fn list_shares_reassembles_buffer_overflow_reads() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + queue_overflow_share_listing_responses( + &mock, + &[ + ("Documents", STYPE_DISKTREE, "Shared docs"), + ("Photos", STYPE_DISKTREE, "Family photos"), + ], + ); + + let shares = list_shares(&mut conn).await.unwrap(); + + assert_eq!(shares.len(), 2); + assert_eq!(shares[0].name, "Documents"); + assert_eq!(shares[1].name, "Photos"); + } + + #[tokio::test] + async fn list_shares_reassembles_rpc_fragments() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + queue_fragmented_share_listing_responses( + &mock, + &[ + ("Documents", STYPE_DISKTREE, "Shared docs"), + ("Photos", STYPE_DISKTREE, "Family photos"), + ], + ); + + let shares = list_shares(&mut conn).await.unwrap(); + + assert_eq!(shares.len(), 2); + assert_eq!(shares[0].name, "Documents"); + assert_eq!(shares[1].name, "Photos"); + } + + #[tokio::test] + async fn list_shares_returns_disk_shares() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + queue_share_listing_responses( + &mock, + &[ + ("Documents", STYPE_DISKTREE, "Shared docs"), + ("IPC$", STYPE_IPC | STYPE_SPECIAL, "Remote IPC"), + ("C$", STYPE_DISKTREE | STYPE_SPECIAL, "Default share"), + ("Photos", STYPE_DISKTREE, "Family photos"), + ], + ); + + let shares = list_shares(&mut conn).await.unwrap(); + + // Only disk shares without $ suffix and without STYPE_SPECIAL + assert_eq!(shares.len(), 2); + assert_eq!(shares[0].name, "Documents"); + assert_eq!(shares[0].comment, "Shared docs"); + assert_eq!(shares[1].name, "Photos"); + assert_eq!(shares[1].comment, "Family photos"); + } + + #[tokio::test] + async fn list_shares_sends_correct_number_of_messages() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + queue_share_listing_responses(&mock, &[("TestShare", STYPE_DISKTREE, "A test share")]); + + let _shares = list_shares(&mut conn).await.unwrap(); + + // Should have sent 8 messages: + // TREE_CONNECT, CREATE, WRITE(bind), READ(bind_ack), + // WRITE(request), READ(response), CLOSE, TREE_DISCONNECT + assert_eq!(mock.sent_count(), 8); + } + + #[tokio::test] + async fn list_shares_empty_server() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + queue_share_listing_responses(&mock, &[]); + + let shares = list_shares(&mut conn).await.unwrap(); + assert!(shares.is_empty()); + } + + #[tokio::test] + async fn list_shares_filters_non_disk_shares() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + // All non-disk or special shares + queue_share_listing_responses( + &mock, + &[ + ("IPC$", STYPE_IPC | STYPE_SPECIAL, "Remote IPC"), + ("ADMIN$", STYPE_DISKTREE | STYPE_SPECIAL, "Remote Admin"), + ], + ); + + let shares = list_shares(&mut conn).await.unwrap(); + assert!(shares.is_empty()); + } + + #[tokio::test] + async fn list_shares_uses_correct_server_name() { + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + let mut conn = + Connection::from_transport(Box::new(mock.clone()), Box::new(mock.clone()), "my-nas"); + conn.set_test_params(NegotiatedParams { + dialect: Dialect::Smb2_0_2, + max_read_size: 65536, + max_write_size: 65536, + max_transact_size: 65536, + server_guid: Guid::ZERO, + signing_required: false, + capabilities: Capabilities::default(), + gmac_negotiated: false, + cipher: None, + compression_supported: false, + }); + conn.set_session_id(SessionId(0x1234)); + + queue_share_listing_responses(&mock, &[("share1", STYPE_DISKTREE, "")]); + + let shares = list_shares(&mut conn).await.unwrap(); + assert_eq!(shares.len(), 1); + + // Verify the TREE_CONNECT request contains \\my-nas\IPC$ + let sent = mock.sent_messages(); + let tree_connect_bytes = &sent[0]; + // The UNC path is UTF-16LE in the request body + let unc_utf8 = String::from_utf8_lossy(tree_connect_bytes); + // Verify the server name appears somewhere in the raw bytes + assert!( + tree_connect_bytes.windows(2).any(|w| w == b"m\0"), // 'm' in UTF-16LE from "my-nas" + "TREE_CONNECT should reference the server name" + ); + drop(unc_utf8); + } +} diff --git a/vendor/smb2/src/client/stream.rs b/vendor/smb2/src/client/stream.rs new file mode 100644 index 0000000..b1c6ce2 --- /dev/null +++ b/vendor/smb2/src/client/stream.rs @@ -0,0 +1,1499 @@ +//! Streaming file I/O with progress reporting. +//! +//! Provides [`FileDownload`] for memory-efficient large file downloads, +//! [`FileUpload`] for streaming uploads with progress, +//! [`FileWriter`] for push-based pipelined writes (use +//! [`FileWriter::finish`] for normal completion, [`FileWriter::abort`] for +//! fast cancellation), and [`Progress`] for tracking transfer progress. + +use std::ops::ControlFlow; +use std::sync::Arc; + +use log::debug; + +use crate::client::connection::Connection; +use crate::client::tree::Tree; +use crate::error::Result; +use crate::msg::read::{ReadRequest, ReadResponse, SMB2_CHANNEL_NONE}; +use crate::msg::write::{WriteRequest, WriteResponse}; +use crate::pack::{ReadCursor, Unpack}; +use crate::types::status::NtStatus; +use crate::types::{Command, FileId}; +use crate::Error; + +/// Maximum number of pipelined write requests in flight. +/// Matches `MAX_PIPELINE_WINDOW` in `tree.rs`. +const MAX_PIPELINE_WINDOW: usize = 32; + +/// Progress information for a file transfer. +#[derive(Debug, Clone, Copy)] +pub struct Progress { + /// Bytes transferred so far. + pub bytes_transferred: u64, + /// Total file size (if known). + pub total_bytes: Option, +} + +impl Progress { + /// Progress as a percentage (0.0 to 100.0). + #[must_use] + pub fn percent(&self) -> f64 { + self.fraction() * 100.0 + } + + /// Progress as a fraction (0.0 to 1.0). + #[must_use] + pub fn fraction(&self) -> f64 { + match self.total_bytes { + Some(total) if total > 0 => self.bytes_transferred as f64 / total as f64, + Some(_) => 1.0, // Empty file is "complete" + None => 0.0, + } + } +} + +/// An in-progress file download that yields chunks without buffering +/// the entire file in memory. +/// +/// Each call to [`next_chunk`](FileDownload::next_chunk) sends one SMB2 READ +/// request and returns the response data. This is sequential (not pipelined) +/// but memory-efficient: only one chunk is in memory at a time. +/// +/// The file handle is closed when the download completes or is dropped. +/// +/// # Example +/// +/// ```ignore +/// # async fn example(client: &mut smb2::SmbClient, share: &smb2::Tree) -> Result<(), smb2::Error> { +/// use tokio::io::AsyncWriteExt; +/// +/// let mut download = client.download(&share, "big_video.mp4").await?; +/// println!("Downloading {} bytes...", download.size()); +/// +/// let mut file = tokio::fs::File::create("big_video.mp4").await?; +/// while let Some(chunk) = download.next_chunk().await { +/// let bytes = chunk?; +/// file.write_all(&bytes).await?; +/// println!("{:.1}%", download.progress().percent()); +/// } +/// # Ok(()) +/// # } +/// ``` +pub struct FileDownload<'a> { + tree: &'a Tree, + conn: &'a mut Connection, + file_id: FileId, + file_size: u64, + bytes_received: u64, + chunk_size: u32, + done: bool, +} + +impl<'a> FileDownload<'a> { + /// Create a new streaming download from an already-opened file handle. + /// + /// Most callers want [`SmbClient::download`](crate::SmbClient::download) or + /// [`Tree::download`](crate::Tree::download), which issue the CREATE + /// themselves and wrap the resulting handle. Use this constructor when + /// you've already opened the file via [`Tree::open_file`] (for example, + /// to reuse a handle across multiple readers, or to build a custom + /// chunk loop with non-default `chunk_size`). + /// + /// The caller is responsible for making sure `file_id` belongs to `tree` + /// and was opened with read access. The `FileDownload` will CLOSE the + /// handle when the last chunk is consumed or when it is dropped. + pub fn new( + tree: &'a Tree, + conn: &'a mut Connection, + file_id: FileId, + file_size: u64, + chunk_size: u32, + ) -> Self { + Self { + tree, + conn, + file_id, + file_size, + bytes_received: 0, + chunk_size, + done: false, + } + } + + /// Total file size in bytes. + #[must_use] + pub fn size(&self) -> u64 { + self.file_size + } + + /// Bytes received so far. + #[must_use] + pub fn bytes_received(&self) -> u64 { + self.bytes_received + } + + /// Current transfer progress. + #[must_use] + pub fn progress(&self) -> Progress { + Progress { + bytes_transferred: self.bytes_received, + total_bytes: Some(self.file_size), + } + } + + /// Get the next chunk of data from the server. + /// + /// Returns `None` when the download is complete. Each call sends + /// one SMB2 READ request and returns the response data. The file + /// handle is automatically closed when the last chunk is consumed. + pub async fn next_chunk(&mut self) -> Option>> { + if self.done { + return None; + } + + let remaining = self.file_size.saturating_sub(self.bytes_received); + if remaining == 0 { + // Close the handle when we've read everything. + let close_result = self.close().await; + if let Err(e) = close_result { + return Some(Err(e)); + } + return None; + } + + let this_chunk = remaining.min(self.chunk_size as u64) as u32; + + let req = ReadRequest { + padding: 0x50, + flags: 0, + length: this_chunk, + offset: self.bytes_received, + file_id: self.file_id, + minimum_count: 0, + channel: SMB2_CHANNEL_NONE, + remaining_bytes: 0, + read_channel_info: vec![], + }; + + let credit_charge = (this_chunk as u64).div_ceil(65536).max(1) as u16; + let exec_result = self + .conn + .execute_with_credits( + Command::Read, + &req, + Some(self.tree.tree_id), + crate::types::CreditCharge(credit_charge), + ) + .await; + + match exec_result { + Err(e) => { + self.done = true; + Some(Err(e)) + } + Ok(frame) => { + if frame.header.status == NtStatus::END_OF_FILE { + let _ = self.close().await; + return None; + } + + if frame.header.status != NtStatus::SUCCESS { + self.done = true; + return Some(Err(Error::Protocol { + status: frame.header.status, + command: Command::Read, + })); + } + + let mut cursor = ReadCursor::new(&frame.body); + match ReadResponse::unpack(&mut cursor) { + Err(e) => { + self.done = true; + Some(Err(e)) + } + Ok(resp) => { + if resp.data.is_empty() { + let _ = self.close().await; + return None; + } + + self.bytes_received += resp.data.len() as u64; + + // If this was the last chunk, close the handle. + if self.bytes_received >= self.file_size { + if let Err(e) = self.close().await { + return Some(Err(e)); + } + } + + Some(Ok(resp.data)) + } + } + } + } + } + + /// Consume the download and collect all data with a progress callback. + /// + /// Return `ControlFlow::Break(())` from the callback to cancel the download. + /// Cancellation returns `Error::Cancelled`. + pub async fn collect_with_progress(mut self, mut on_progress: F) -> Result> + where + F: FnMut(Progress) -> ControlFlow<()>, + { + let mut data = Vec::with_capacity(self.file_size as usize); + + while let Some(result) = self.next_chunk().await { + let chunk = result?; + data.extend_from_slice(&chunk); + + if let ControlFlow::Break(()) = on_progress(self.progress()) { + // Best-effort close before returning. + let _ = self.close().await; + return Err(Error::Cancelled); + } + } + + Ok(data) + } + + /// Consume the download and collect all data into a `Vec`. + pub async fn collect(mut self) -> Result> { + let mut data = Vec::with_capacity(self.file_size as usize); + + while let Some(result) = self.next_chunk().await { + let chunk = result?; + data.extend_from_slice(&chunk); + } + + Ok(data) + } + + /// Close the file handle. Only sends CLOSE once. + async fn close(&mut self) -> Result<()> { + if self.done { + return Ok(()); + } + self.done = true; + self.tree.close_handle(self.conn, self.file_id).await + } +} + +impl Drop for FileDownload<'_> { + fn drop(&mut self) { + if !self.done { + debug!( + "stream: FileDownload dropped before completion, file handle may leak \ + (bytes_received={}/{})", + self.bytes_received, self.file_size + ); + // We can't close the handle in Drop because it's async. + // The caller should consume the download fully or call close(). + } + } +} + +/// An in-progress file upload that writes data in chunks with progress. +/// +/// Each call to [`write_next_chunk`](FileUpload::write_next_chunk) sends one +/// SMB2 WRITE request and returns `true` while there is more data to send. +/// When the last chunk is written, the file handle is automatically flushed +/// and closed, and `write_next_chunk` returns `false`. +/// +/// The connection is borrowed mutably for the lifetime of the upload, +/// preventing accidental interleaving of SMB messages. +/// +/// # Cancellation +/// +/// To cancel an upload, stop calling `write_next_chunk`. The file handle +/// will be closed (without flush) when the `FileUpload` is dropped, though +/// this cannot be guaranteed in async contexts since `Drop` is synchronous. +/// For clean cancellation, call `write_next_chunk` in a loop that checks +/// your own cancellation condition. +/// +/// # Example +/// +/// ```no_run +/// # async fn example(client: &mut smb2::SmbClient, share: &smb2::Tree) -> Result<(), smb2::Error> { +/// let data = std::fs::read("large_video.mp4")?; +/// let mut upload = client.upload(&share, "remote_video.mp4", &data).await?; +/// println!("Uploading {} bytes...", upload.total_bytes()); +/// +/// while upload.write_next_chunk().await? { +/// println!("{:.1}%", upload.progress().percent()); +/// } +/// // File is flushed and closed automatically after the last chunk. +/// # Ok(()) +/// # } +/// ``` +pub struct FileUpload<'a> { + tree: &'a Tree, + conn: &'a mut Connection, + file_id: FileId, + data: &'a [u8], + total_bytes: u64, + bytes_written: u64, + chunk_size: u32, + done: bool, +} + +impl<'a> FileUpload<'a> { + /// Create a streaming upload for a large file (data larger than one chunk). + /// + /// Opens the file for writing. The caller then drives the upload with + /// [`write_next_chunk`](FileUpload::write_next_chunk). + pub(crate) fn new( + tree: &'a Tree, + conn: &'a mut Connection, + file_id: FileId, + data: &'a [u8], + chunk_size: u32, + ) -> Self { + Self { + tree, + conn, + file_id, + data, + total_bytes: data.len() as u64, + bytes_written: 0, + chunk_size, + done: false, + } + } + + /// Create a "done" upload for small files that were already written + /// via compound in the constructor. + pub(crate) fn new_done(tree: &'a Tree, conn: &'a mut Connection, total_bytes: u64) -> Self { + Self { + tree, + conn, + file_id: FileId::SENTINEL, + data: &[], + total_bytes, + bytes_written: total_bytes, + chunk_size: 0, + done: true, + } + } + + /// Total data size in bytes. + #[must_use] + pub fn total_bytes(&self) -> u64 { + self.total_bytes + } + + /// Bytes written so far. + #[must_use] + pub fn bytes_written(&self) -> u64 { + self.bytes_written + } + + /// Current transfer progress. + #[must_use] + pub fn progress(&self) -> Progress { + Progress { + bytes_transferred: self.bytes_written, + total_bytes: Some(self.total_bytes), + } + } + + /// Write the next chunk of data to the server. + /// + /// Returns `Ok(true)` while there is more data to write, and `Ok(false)` + /// when the upload is complete. After the last chunk, automatically flushes + /// and closes the file handle. + /// + /// For small files that were written via compound in the constructor, + /// this immediately returns `Ok(false)`. + pub async fn write_next_chunk(&mut self) -> Result { + if self.done { + return Ok(false); + } + + let offset = self.bytes_written as usize; + if offset >= self.data.len() { + // All data written -- flush and close. + self.flush_and_close().await?; + return Ok(false); + } + + let remaining = self.data.len() - offset; + let this_chunk = remaining.min(self.chunk_size as usize); + let chunk = &self.data[offset..offset + this_chunk]; + + let write_req = WriteRequest { + data_offset: 0x70, + offset: offset as u64, + file_id: self.file_id, + channel: 0, + remaining_bytes: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + flags: 0, + data: chunk.to_vec(), + }; + + let credit_charge = (this_chunk as u64).div_ceil(65536).max(1) as u16; + let exec_result = self + .conn + .execute_with_credits( + Command::Write, + &write_req, + Some(self.tree.tree_id), + crate::types::CreditCharge(credit_charge), + ) + .await; + + match exec_result { + Err(e) => { + self.done = true; + Err(e) + } + Ok(frame) => { + if frame.header.status != NtStatus::SUCCESS { + self.done = true; + // Best-effort close without flush. + let _ = self.tree.close_handle(self.conn, self.file_id).await; + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Write, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = WriteResponse::unpack(&mut cursor)?; + self.bytes_written += resp.count as u64; + + // If all data is written, flush and close. + if self.bytes_written >= self.total_bytes { + self.flush_and_close().await?; + return Ok(false); + } + + Ok(true) + } + } + } + + /// Flush and close the file handle. Only runs once. + async fn flush_and_close(&mut self) -> Result<()> { + if self.done { + return Ok(()); + } + self.done = true; + + // Flush to ensure data is persisted. + self.tree.flush_handle(self.conn, self.file_id).await?; + // Close the handle. + self.tree.close_handle(self.conn, self.file_id).await + } +} + +impl Drop for FileUpload<'_> { + fn drop(&mut self) { + if !self.done { + debug!( + "stream: FileUpload dropped before completion, file handle may leak \ + (bytes_written={}/{})", + self.bytes_written, self.total_bytes + ); + // We can't close the handle in Drop because it's async. + // The caller should drive the upload to completion. + } + } +} + +/// A push-based pipelined streaming file writer. +/// +/// The consumer pushes data chunks at their own pace. Writes are pipelined +/// using a sliding window (up to 32 in-flight requests) +/// for high throughput. Chunks larger than `max_write_size` are split +/// internally into wire-level WRITE requests. +/// +/// Call [`finish`](FileWriter::finish) when done to flush, close the handle, +/// and get the total confirmed byte count. +/// +/// # Example +/// +/// ```no_run +/// # async fn example(client: &smb2::SmbClient, share: &smb2::Tree) -> Result<(), smb2::Error> { +/// let mut writer = client.create_file_writer(share, "output.bin").await?; +/// writer.write_chunk(b"first part").await?; +/// writer.write_chunk(b"second part").await?; +/// let total = writer.finish().await?; +/// println!("Wrote {total} bytes"); +/// # Ok(()) +/// # } +/// ``` +/// Pinned-boxed `execute_with_credits` future, kept owned by `FileWriter` +/// in a `FuturesUnordered` so multiple WRITEs can be in flight on one +/// connection concurrently. +type BoxedWriteFut = std::pin::Pin< + Box> + Send>, +>; + +/// Push-based streaming writer. Owns its `Connection` and `Arc`, +/// so the writer is `'static` and N concurrent writers pipeline over one +/// SMB session without any external locking. +/// +/// Both fields are cheap `Arc::clone`s. The receiver task multiplexes +/// responses by `MessageId` so N independent `FileWriter`s can write to +/// different files on the same connection concurrently. +pub struct FileWriter { + tree: Arc, + conn: Connection, + file_id: FileId, + max_write_size: u32, + /// Next write offset in the file. + offset: u64, + /// In-flight WRITE futures. `FuturesUnordered::len()` gives the same + /// "how many responses are pending" count the old `in_flight: usize` + /// field tracked pre-Phase-3. + in_flight: futures_util::stream::FuturesUnordered, + /// Confirmed bytes (from WRITE responses). + total_written: u64, + /// Buffer for leftover data when a push chunk is larger than `max_write_size`. + pending_data: Vec, + /// Read position within `pending_data`. + pending_offset: usize, + /// Chunk that was pulled but couldn't be sent due to credit exhaustion. + stashed_chunk: Option>, + /// Whether the writer has been finalized (handle closed). + done: bool, +} + +/// Open (or create) a file for writing and return a streaming [`FileWriter`] +/// that owns its `Connection` and `Arc`. +/// +/// Use this when you hold a cloned `Connection` and want to drive a +/// streaming write without holding any external lock for the upload's +/// duration. The returned writer is `'static` — drop it, move it across +/// tasks, hand it to `tokio::spawn`, it doesn't borrow from anything. +/// +/// Multiple `FileWriter`s built from clones of the same `Connection` +/// pipeline their WRITEs over a single SMB session. +/// +/// `SmbClient::create_file_writer` and `Tree::create_file_writer` are +/// thin convenience wrappers around this; reach for them when you already +/// hold an `&SmbClient` or `&Arc` and don't need the explicit +/// connection clone. +pub async fn open_file_writer( + tree: Arc, + mut conn: Connection, + path: &str, +) -> Result { + let normalized = tree.format_path(path); + debug!("stream: open_file_writer path={}", normalized); + + let file_id = tree.open_file_for_write(&mut conn, &normalized).await?; + let max_write = conn.params().map(|p| p.max_write_size).unwrap_or(65536); + + Ok(FileWriter::new(tree, conn, file_id, max_write)) +} + +/// Exclusive-create sibling of [`open_file_writer`]. Opens the CREATE with +/// `FileCreate` disposition: if the file already exists the open fails with +/// [`crate::ErrorKind::AlreadyExists`] instead of +/// truncating it. +/// +/// `Tree::create_file_writer_exclusive` is the convenience wrapper most +/// callers want. +pub async fn open_file_writer_exclusive( + tree: Arc, + mut conn: Connection, + path: &str, +) -> Result { + let normalized = tree.format_path(path); + debug!("stream: open_file_writer_exclusive path={}", normalized); + + let file_id = tree + .open_file_for_exclusive_create(&mut conn, &normalized) + .await?; + let max_write = conn.params().map(|p| p.max_write_size).unwrap_or(65536); + + Ok(FileWriter::new(tree, conn, file_id, max_write)) +} + +impl FileWriter { + /// Create a new push-based streaming writer. + /// + /// Most callers want [`open_file_writer`], [`Tree::create_file_writer`], + /// or [`SmbClient::create_file_writer`](crate::SmbClient::create_file_writer) + /// which issue the CREATE for you. Use this constructor when you've + /// already opened the file via [`Tree::open_file_for_write`] (for + /// example, to reuse a handle across multiple writers). + pub(crate) fn new( + tree: Arc, + conn: Connection, + file_id: FileId, + max_write_size: u32, + ) -> Self { + Self { + tree, + conn, + file_id, + max_write_size, + offset: 0, + in_flight: futures_util::stream::FuturesUnordered::new(), + total_written: 0, + pending_data: Vec::new(), + pending_offset: 0, + stashed_chunk: None, + done: false, + } + } + + /// Push a data chunk to the writer. + /// + /// The data is split into wire-level WRITE requests (each up to + /// `max_write_size` bytes) and sent pipelined. When the sliding window + /// is full, this method drains one in-flight response before sending, + /// providing backpressure. + /// + /// Empty chunks are no-ops. + pub async fn write_chunk(&mut self, data: &[u8]) -> Result<()> { + if data.is_empty() { + return Ok(()); + } + + // Append to pending buffer. If there's already pending data, extend it; + // otherwise set the new chunk as pending. + if self.pending_offset < self.pending_data.len() { + let leftover = self.pending_data[self.pending_offset..].to_vec(); + self.pending_data = leftover; + self.pending_offset = 0; + self.pending_data.extend_from_slice(data); + } else { + self.pending_data = data.to_vec(); + self.pending_offset = 0; + } + + // Flush any stashed chunk from a previous call before processing new data. + self.flush_stash().await?; + + // Send as many wire chunks as the window allows. + while let Some(wire_chunk) = self.next_pending_chunk() { + if !self.send_or_stash(wire_chunk).await? { + return Ok(()); // Stashed — will be sent on next call or finish() + } + } + + Ok(()) + } + + /// Finish the writer: drain all in-flight responses, flush, and close. + /// + /// Returns the total number of confirmed bytes written. Consumes `self` + /// to prevent write-after-close at compile time. + pub async fn finish(mut self) -> Result { + // Flush stash and drain all remaining pending data. Unlike write_chunk, + // finish() must send everything — it loops send_or_stash until the stash + // is empty, draining responses to free credits as needed. + self.flush_stash().await?; + + while let Some(wire_chunk) = self.next_pending_chunk() { + // send_or_stash may stash if credits are exhausted. Keep flushing + // until everything is sent. This terminates because drain_one frees + // a credit, and we have finite data. + if !self.send_or_stash(wire_chunk).await? { + self.flush_stash().await?; + } + } + + // Drain all in-flight responses. + self.drain_all().await?; + + // Flush to ensure data is persisted. + self.tree.flush_handle(&mut self.conn, self.file_id).await?; + + // Close the handle. + self.tree.close_handle(&mut self.conn, self.file_id).await?; + + self.done = true; + Ok(self.total_written) + } + + /// Abort the writer: discard unsent data, drain in-flight responses, and + /// close the handle without flushing. + /// + /// Use this when you want to cancel a write partway through — for example + /// on user-triggered cancellation or an error path where the partial upload + /// will be deleted anyway. `abort()` skips the server-side fsync that + /// [`finish`](FileWriter::finish) does, so it returns as soon as the + /// in-flight window is drained. + /// + /// What it does: + /// - Discards any buffered (unsent) data. Wire WRITEs already in flight + /// still have responses on the way; those are drained to keep credits + /// and message-IDs in sync with the server. Errors on those responses + /// are swallowed — at this point we don't care. + /// - Skips the FLUSH that [`finish`](FileWriter::finish) sends before + /// CLOSE, so the server does not fsync. This is the main reason to + /// prefer `abort()` over `finish()` on cancellation. + /// - Best-effort CLOSE of the file handle. If the CLOSE fails, the error + /// is logged at debug and swallowed. + /// + /// Contrast with [`finish`](FileWriter::finish): `finish()` sends every + /// pending byte, flushes, and propagates errors from the flush/close + /// paths. `abort()` sends nothing more, never flushes, and returns `Ok` + /// regardless of what the server said on the way out. + /// + /// Returns the number of confirmed bytes written at the moment of abort + /// (from WRITE responses seen so far). Consumes `self` to prevent + /// write-after-abort at compile time. The `Result` wrapper mirrors + /// [`finish`](FileWriter::finish)'s signature and leaves room for future + /// failure modes; today `abort()` never returns `Err`. + /// + /// The caller is responsible for deleting the partial remote file if they + /// don't want it to linger — the server now has a zero-to-N byte file + /// depending on how many WRITEs completed before the abort. + /// + /// # Future extension + /// + /// A `close_and_delete()` variant that sends `SET_INFO + /// FileDispositionInformation(DeletePending=true)` before CLOSE would + /// combine the two round-trips the caller does today. Out of scope here. + /// + /// # Example + /// + /// ```no_run + /// # use std::ops::ControlFlow; + /// # async fn example( + /// # client: &smb2::SmbClient, + /// # share: &smb2::Tree, + /// # cancel: impl Fn() -> bool, + /// # ) -> Result<(), smb2::Error> { + /// let mut writer = client.create_file_writer(share, "output.bin").await?; + /// for chunk in [b"first".as_slice(), b"second", b"third"] { + /// if cancel() { + /// let written = writer.abort().await?; + /// println!("Aborted after {written} bytes confirmed"); + /// // Caller: delete the partial remote file here if desired. + /// return Ok(()); + /// } + /// writer.write_chunk(chunk).await?; + /// } + /// writer.finish().await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn abort(mut self) -> Result { + use futures_util::stream::StreamExt; + + // 1. Discard anything we have not yet put on the wire. Unsent data + // means nothing to the server and carries no credits. + self.pending_data.clear(); + self.pending_offset = 0; + self.stashed_chunk = None; + + // 2. Drain in-flight WRITE responses — they're already in the + // kernel/network buffer, and dropping them unread would desync + // credits and message IDs. Errors are swallowed: on abort we + // don't care if a WRITE failed or succeeded. + while let Some(result) = self.in_flight.next().await { + match result { + Ok(frame) => { + if frame.header.status == NtStatus::SUCCESS { + // Keep total_written accurate for callers that log it. + let mut cursor = ReadCursor::new(&frame.body); + if let Ok(resp) = WriteResponse::unpack(&mut cursor) { + self.total_written += resp.count as u64; + } + } else { + debug!( + "stream: FileWriter::abort() ignoring WRITE error status {:?}", + frame.header.status + ); + } + } + Err(e) => { + // Transport-level failure while draining. There's nothing + // sensible to do — the connection may already be gone. + // Mark everything drained and move on. + debug!( + "stream: FileWriter::abort() giving up on remaining in-flight \ + response(s) after transport error: {}", + e + ); + break; + } + } + } + + // 3. Skip flush_handle() — that's the whole point of abort(). + + // 4. Best-effort CLOSE. If it fails, log and move on. + if let Err(e) = self.tree.close_handle(&mut self.conn, self.file_id).await { + debug!( + "stream: FileWriter::abort() best-effort CLOSE failed, handle may leak \ + server-side until session teardown: {}", + e + ); + } + + // 5. Silence the Drop warning — we finalized cleanly. + self.done = true; + Ok(self.total_written) + } + + /// Confirmed bytes written (from server WRITE responses). + #[must_use] + pub fn bytes_written(&self) -> u64 { + self.total_written + } + + /// Current transfer progress. + /// + /// `total_bytes` is always `None` because push-based writers don't + /// know the total size upfront. + #[must_use] + pub fn progress(&self) -> Progress { + Progress { + bytes_transferred: self.total_written, + total_bytes: None, + } + } + + /// Get the next wire-level chunk from the pending buffer. + fn next_pending_chunk(&mut self) -> Option> { + if self.pending_offset >= self.pending_data.len() { + return None; + } + + let end = (self.pending_offset + self.max_write_size as usize).min(self.pending_data.len()); + let slice = self.pending_data[self.pending_offset..end].to_vec(); + self.pending_offset = end; + + if self.pending_offset >= self.pending_data.len() { + self.pending_data.clear(); + self.pending_offset = 0; + } + + Some(slice) + } + + /// Launch one wire-level WRITE request into the `in_flight` queue. + fn launch_wire_chunk(&mut self, data: Vec) { + let data_len = data.len() as u64; + let credit_charge = data_len.div_ceil(65536).max(1) as u16; + + let req = WriteRequest { + data_offset: 0x70, + offset: self.offset, + file_id: self.file_id, + channel: 0, + remaining_bytes: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + flags: 0, + data, + }; + + let c = self.conn.clone(); + let tree_id = self.tree.tree_id; + self.in_flight.push(Box::pin(async move { + c.execute_with_credits( + Command::Write, + &req, + Some(tree_id), + crate::types::CreditCharge(credit_charge), + ) + .await + })); + + self.offset += data_len; + } + + /// Receive one in-flight WRITE response. + async fn drain_one(&mut self) -> Result<()> { + use futures_util::stream::StreamExt; + + let Some(result) = self.in_flight.next().await else { + return Ok(()); + }; + let frame = result?; + + if frame.header.status != NtStatus::SUCCESS { + // Drain remaining in-flight (best-effort), then close handle. + while self.in_flight.next().await.is_some() {} + // Best-effort close. + let _ = self.tree.close_handle(&mut self.conn, self.file_id).await; + self.done = true; + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Write, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = WriteResponse::unpack(&mut cursor)?; + self.total_written += resp.count as u64; + + Ok(()) + } + + /// Drain all in-flight WRITE responses. + async fn drain_all(&mut self) -> Result<()> { + while !self.in_flight.is_empty() { + self.drain_one().await?; + } + Ok(()) + } + + /// Check whether we have enough credits to send a chunk of this size. + fn can_send(&self, data: &[u8]) -> bool { + let credit_charge = (data.len() as u64).div_ceil(65536).max(1) as u16; + let credits_available = self.conn.credits() as usize / credit_charge.max(1) as usize; + credits_available > 0 && self.in_flight.len() < MAX_PIPELINE_WINDOW + } + + /// Try to send a wire chunk. If the window is full or credits are exhausted, + /// drain one response and retry. If still unable, stash the chunk and return + /// `Ok(false)` (caller decides whether to wait or return). + async fn send_or_stash(&mut self, data: Vec) -> Result { + // Make room if the window is full. + if self.in_flight.len() >= MAX_PIPELINE_WINDOW { + self.drain_one().await?; + } + + if self.can_send(&data) { + self.launch_wire_chunk(data); + return Ok(true); + } + + // No credits — drain one response to reclaim credits and retry. + if !self.in_flight.is_empty() { + self.drain_one().await?; + if self.can_send(&data) { + self.launch_wire_chunk(data); + return Ok(true); + } + } + + // Still can't send. Stash for later. + self.stashed_chunk = Some(data); + Ok(false) + } + + /// Send any stashed chunk, draining responses as needed to free credits. + async fn flush_stash(&mut self) -> Result<()> { + if let Some(stashed) = self.stashed_chunk.take() { + // Make room if needed. + if !self.in_flight.is_empty() && !self.can_send(&stashed) { + self.drain_one().await?; + } + if self.can_send(&stashed) { + self.launch_wire_chunk(stashed); + } else { + // Re-stash — caller must drain more or give up. + self.stashed_chunk = Some(stashed); + } + } + Ok(()) + } +} + +impl Drop for FileWriter { + fn drop(&mut self) { + if !self.done { + debug!( + "stream: FileWriter dropped without finish(), file handle may leak \ + (bytes_written={})", + self.total_written + ); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::test_helpers::{ + build_close_error_response, build_close_response, build_create_response, + build_flush_response, build_write_error_response, build_write_response, setup_connection, + }; + use crate::transport::MockTransport; + use crate::types::status::NtStatus; + use crate::types::{FileId, TreeId}; + use std::sync::Arc; + + fn test_tree() -> Arc { + Arc::new(Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }) + } + + fn test_file_id() -> FileId { + FileId { + persistent: 0xAA, + volatile: 0xBB, + } + } + + // ── FileWriter tests ─────────────────────────────────────────────── + + #[tokio::test] + async fn file_writer_single_chunk() { + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + // Queue: CREATE + WRITE(100) + FLUSH + CLOSE + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_write_response(100)); + mock.queue_response(build_flush_response()); + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap(); + writer.write_chunk(&[0u8; 100]).await.unwrap(); + assert_eq!(writer.bytes_written(), 0); // Not yet drained + let total = writer.finish().await.unwrap(); + assert_eq!(total, 100); + } + + #[tokio::test] + async fn file_writer_multiple_chunks() { + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_write_response(100)); + mock.queue_response(build_write_response(100)); + mock.queue_response(build_write_response(100)); + mock.queue_response(build_flush_response()); + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap(); + writer.write_chunk(&[1u8; 100]).await.unwrap(); + writer.write_chunk(&[2u8; 100]).await.unwrap(); + writer.write_chunk(&[3u8; 100]).await.unwrap(); + let total = writer.finish().await.unwrap(); + assert_eq!(total, 300); + } + + #[tokio::test] + async fn file_writer_empty_finish() { + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + // Queue: CREATE + FLUSH + CLOSE (no WRITE) + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_flush_response()); + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let writer = tree.create_file_writer(conn, "empty.bin").await.unwrap(); + let total = writer.finish().await.unwrap(); + assert_eq!(total, 0); + + // Verify: CREATE + FLUSH + CLOSE = 3 sent messages. + assert_eq!(mock.sent_count(), 3); + } + + #[tokio::test] + async fn file_writer_empty_chunk_noop() { + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + // Queue: CREATE + WRITE(50) + FLUSH + CLOSE + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_write_response(50)); + mock.queue_response(build_flush_response()); + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap(); + writer.write_chunk(&[]).await.unwrap(); // No-op + writer.write_chunk(&[0u8; 50]).await.unwrap(); + let total = writer.finish().await.unwrap(); + assert_eq!(total, 50); + + // CREATE + WRITE + FLUSH + CLOSE = 4 (no extra WRITE for empty chunk). + assert_eq!(mock.sent_count(), 4); + } + + #[tokio::test] + async fn file_writer_chunk_splitting() { + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + // max_write_size = 65536, send 200KB = 3 x 65536 + 1 x 8192. + // 200 * 1024 = 204800. 204800 / 65536 = 3.125 -> 4 wire writes. + let chunk_size = 200 * 1024; + let wire_1 = 65536u32; + let wire_2 = 65536u32; + let wire_3 = 65536u32; + let wire_4 = (chunk_size - 3 * 65536) as u32; // 8192 + + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_write_response(wire_1)); + mock.queue_response(build_write_response(wire_2)); + mock.queue_response(build_write_response(wire_3)); + mock.queue_response(build_write_response(wire_4)); + mock.queue_response(build_flush_response()); + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let mut writer = tree.create_file_writer(conn, "big.bin").await.unwrap(); + writer.write_chunk(&vec![0u8; chunk_size]).await.unwrap(); + let total = writer.finish().await.unwrap(); + assert_eq!(total, (wire_1 + wire_2 + wire_3 + wire_4) as u64); + + // CREATE + 4 WRITEs + FLUSH + CLOSE = 7 + assert_eq!(mock.sent_count(), 7); + } + + #[tokio::test] + async fn file_writer_progress_none_total() { + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_flush_response()); + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let writer = tree.create_file_writer(conn, "out.bin").await.unwrap(); + let progress = writer.progress(); + assert!(progress.total_bytes.is_none()); + assert_eq!(progress.bytes_transferred, 0); + writer.finish().await.unwrap(); + } + + #[tokio::test] + async fn file_writer_bytes_written_tracks_confirmed() { + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_write_response(100)); + mock.queue_response(build_write_response(200)); + mock.queue_response(build_flush_response()); + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap(); + + // After pushing but before finish, bytes_written reflects only drained responses. + writer.write_chunk(&[0u8; 100]).await.unwrap(); + assert_eq!(writer.bytes_written(), 0); // Not yet drained + + writer.write_chunk(&[0u8; 200]).await.unwrap(); + assert_eq!(writer.bytes_written(), 0); // Still not drained + + // finish() drains all. + let total = writer.finish().await.unwrap(); + assert_eq!(total, 300); + } + + #[tokio::test] + async fn file_writer_backpressure() { + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + mock.queue_response(build_create_response(file_id, 0)); + + // Queue MAX_PIPELINE_WINDOW + 1 write responses. + for _ in 0..MAX_PIPELINE_WINDOW + 1 { + mock.queue_response(build_write_response(64)); + } + mock.queue_response(build_flush_response()); + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap(); + + // Fill the window. + for _ in 0..MAX_PIPELINE_WINDOW { + writer.write_chunk(&[0u8; 64]).await.unwrap(); + } + + // This write must drain one response before sending (backpressure). + writer.write_chunk(&[0u8; 64]).await.unwrap(); + + // At least one response was drained by backpressure. + assert!(writer.bytes_written() >= 64); + + let total = writer.finish().await.unwrap(); + assert_eq!(total, (MAX_PIPELINE_WINDOW as u64 + 1) * 64); + } + + #[tokio::test] + async fn file_writer_server_error() { + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + mock.queue_response(build_create_response(file_id, 0)); + // Return error for the WRITE. + mock.queue_response(build_write_error_response(NtStatus::DISK_FULL)); + // CLOSE after error cleanup. + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap(); + writer.write_chunk(&[0u8; 100]).await.unwrap(); + let result = writer.finish().await; + assert!(result.is_err()); + + let err = result.unwrap_err(); + assert!( + format!("{err:?}").contains("DISK_FULL"), + "expected DISK_FULL, got: {err:?}" + ); + } + + #[tokio::test] + async fn file_writer_finish_drains_all() { + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_write_response(50)); + mock.queue_response(build_write_response(75)); + mock.queue_response(build_write_response(25)); + mock.queue_response(build_flush_response()); + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap(); + writer.write_chunk(&[0u8; 50]).await.unwrap(); + writer.write_chunk(&[0u8; 75]).await.unwrap(); + writer.write_chunk(&[0u8; 25]).await.unwrap(); + + // None drained yet. + assert_eq!(writer.bytes_written(), 0); + + // finish() must drain all 3. + let total = writer.finish().await.unwrap(); + assert_eq!(total, 150); + } + + // ── FileWriter::abort tests ──────────────────────────────────────── + + #[tokio::test] + async fn file_writer_abort_no_in_flight() { + // abort() with nothing in flight: just CLOSE, no FLUSH, no extra reads. + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + // Queue: CREATE + CLOSE (note: no FLUSH — abort skips fsync). + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let writer = tree.create_file_writer(conn, "out.bin").await.unwrap(); + let total = writer.abort().await.unwrap(); + assert_eq!(total, 0); + + // Exactly 2 messages on the wire: CREATE, CLOSE. + assert_eq!(mock.sent_count(), 2); + } + + #[tokio::test] + async fn file_writer_abort_drains_in_flight() { + // abort() must consume in-flight WRITE responses to keep the + // connection in sync, but skips FLUSH. + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + mock.queue_response(build_create_response(file_id, 0)); + // Three WRITEs on the wire, three responses queued. + mock.queue_response(build_write_response(50)); + mock.queue_response(build_write_response(75)); + mock.queue_response(build_write_response(25)); + // No FLUSH response — abort must not send FLUSH. + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap(); + writer.write_chunk(&[0u8; 50]).await.unwrap(); + writer.write_chunk(&[0u8; 75]).await.unwrap(); + writer.write_chunk(&[0u8; 25]).await.unwrap(); + + // Nothing drained yet — write_chunk doesn't drain unless the window fills. + assert_eq!(writer.bytes_written(), 0); + + // abort() drains all three and returns the confirmed total. + let total = writer.abort().await.unwrap(); + assert_eq!(total, 150); + + // Wire traffic: CREATE + 3 WRITEs + CLOSE = 5. No FLUSH. + assert_eq!(mock.sent_count(), 5); + } + + #[tokio::test] + async fn file_writer_abort_swallows_write_errors() { + // Mid-stream WRITE failure during abort's drain: swallowed, abort + // still closes and returns Ok. + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_write_response(100)); + // Second WRITE errors — abort must not bubble this up. + mock.queue_response(build_write_error_response(NtStatus::DISK_FULL)); + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap(); + writer.write_chunk(&[0u8; 100]).await.unwrap(); + writer.write_chunk(&[0u8; 100]).await.unwrap(); + + // abort() should return Ok despite the DISK_FULL on the second WRITE. + // total_written reflects only the successful WRITE (100). + let total = writer.abort().await.unwrap(); + assert_eq!(total, 100); + + // Wire: CREATE + 2 WRITEs + CLOSE = 4. + assert_eq!(mock.sent_count(), 4); + } + + #[tokio::test] + async fn file_writer_abort_discards_stashed_chunk() { + // If a chunk was stashed (credit/window exhaustion scenario in + // real traffic), abort() must not send it. + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap(); + + // Inject a stashed chunk and pending buffer directly — in real traffic + // these would accumulate when credits run out. Neither should get sent. + writer.stashed_chunk = Some(vec![0u8; 500]); + writer.pending_data = vec![0u8; 1000]; + writer.pending_offset = 0; + + let total = writer.abort().await.unwrap(); + assert_eq!(total, 0); + + // Only CREATE + CLOSE on the wire. No WRITE from the stash or buffer. + assert_eq!(mock.sent_count(), 2); + } + + #[tokio::test] + async fn file_writer_abort_close_error_is_swallowed() { + // CLOSE failing at the end is logged but not surfaced — abort + // is a best-effort fast exit. + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_write_response(100)); + // CLOSE returns an error. abort() must still return Ok. + mock.queue_response(build_close_error_response(NtStatus::FILE_CLOSED)); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap(); + writer.write_chunk(&[0u8; 100]).await.unwrap(); + + let result = writer.abort().await; + assert!( + result.is_ok(), + "abort() should swallow CLOSE errors, got: {result:?}" + ); + assert_eq!(result.unwrap(), 100); + + // CREATE + WRITE + CLOSE = 3. + assert_eq!(mock.sent_count(), 3); + } + + #[tokio::test] + async fn file_writer_abort_sets_done_so_drop_is_silent() { + // After abort() returns, the `done` flag is set, so the Drop impl + // does not log a "dropped without finish()" warning. We can't + // inspect `done` once the writer has been consumed, but we can + // confirm abort returns Ok (which only happens on the done=true + // path) and that the test ends cleanly under log capture. + let mock = Arc::new(MockTransport::new()); + let file_id = test_file_id(); + + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_close_response()); + + let conn = setup_connection(&mock); + let tree = test_tree(); + + let writer = tree.create_file_writer(conn, "out.bin").await.unwrap(); + let result = writer.abort().await; + assert!(result.is_ok()); + // The writer has been consumed. `Drop` ran inside abort's frame + // with done=true, so no warning fired. (Behavior-only check; + // exposing `done` for inspection was not needed.) + } + + // ── Progress tests ───────────────────────────────────────────────── + + #[test] + fn progress_calculations() { + let cases = [ + (50, Some(100), 50.0, 0.5), + (100, Some(100), 100.0, 1.0), + (25, Some(100), 25.0, 0.25), + (0, Some(0), 100.0, 1.0), // Empty file + (50, None, 0.0, 0.0), // Unknown total + ]; + for (transferred, total, expected_pct, expected_frac) in cases { + let p = Progress { + bytes_transferred: transferred, + total_bytes: total, + }; + assert_eq!( + p.percent(), + expected_pct, + "percent failed for {transferred}/{total:?}" + ); + assert_eq!( + p.fraction(), + expected_frac, + "fraction failed for {transferred}/{total:?}" + ); + } + + // Large numbers. + let large = Progress { + bytes_transferred: u64::MAX / 2, + total_bytes: Some(u64::MAX), + }; + let frac = large.fraction(); + assert!(frac > 0.49 && frac < 0.51); + } +} diff --git a/vendor/smb2/src/client/test_helpers.rs b/vendor/smb2/src/client/test_helpers.rs new file mode 100644 index 0000000..903678a --- /dev/null +++ b/vendor/smb2/src/client/test_helpers.rs @@ -0,0 +1,182 @@ +//! Shared test helper functions for `client` module tests. +//! +//! These build mock SMB2 responses used across pipeline, shares, and tree tests. + +use std::sync::Arc; + +use crate::client::connection::{pack_message, Connection, NegotiatedParams}; +use crate::msg::close::CloseResponse; +use crate::msg::create::{CreateAction, CreateResponse}; +use crate::msg::header::Header; +use crate::msg::tree_connect::{ShareType, TreeConnectResponse}; +use crate::pack::{FileTime, Guid}; +use crate::transport::MockTransport; +use crate::types::flags::{Capabilities, ShareCapabilities, ShareFlags}; +use crate::types::{Command, Dialect, FileId, OplockLevel, SessionId, TreeId}; + +/// Create a mock-backed connection with standard negotiated params. +/// +/// Enables the mock's auto-msg_id-rewrite so canned `build_*_response` +/// helpers (which hardcode `MessageId(0)` and don't know the caller's +/// allocated msg_ids) still route through the Phase 3 receiver task: on +/// each `receive()` the mock patches sub-frame msg_ids to match the next +/// pending sent msg_id in FIFO order. Replaces the pre-Phase-3 +/// `set_orphan_filter_enabled(false)` path. +pub(crate) fn setup_connection(mock: &Arc) -> Connection { + mock.enable_auto_rewrite_msg_id(); + let mut conn = Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn.set_test_params(NegotiatedParams { + dialect: Dialect::Smb2_0_2, + max_read_size: 65536, + max_write_size: 65536, + max_transact_size: 65536, + server_guid: Guid::ZERO, + signing_required: false, + capabilities: Capabilities::default(), + gmac_negotiated: false, + cipher: None, + compression_supported: false, + }); + conn.set_session_id(SessionId(0x1234)); + conn +} + +/// Build a CREATE response with the given file ID and end-of-file size. +pub(crate) fn build_create_response(file_id: FileId, end_of_file: u64) -> Vec { + let mut h = Header::new_request(Command::Create); + h.flags.set_response(); + h.credits = 32; + + let body = CreateResponse { + oplock_level: OplockLevel::None, + flags: 0, + create_action: CreateAction::FileOpened, + creation_time: FileTime::ZERO, + last_access_time: FileTime::ZERO, + last_write_time: FileTime::ZERO, + change_time: FileTime::ZERO, + allocation_size: 0, + end_of_file, + file_attributes: 0, + file_id, + create_contexts: vec![], + }; + + pack_message(&h, &body) +} + +/// Build a CREATE response with a non-success status (for error tests). +pub(crate) fn build_create_error_response(status: crate::types::status::NtStatus) -> Vec { + use crate::msg::header::ErrorResponse; + let mut h = Header::new_request(Command::Create); + h.flags.set_response(); + h.credits = 32; + h.status = status; + + let body = ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + + pack_message(&h, &body) +} + +/// Build a CLOSE response with zeroed fields. +pub(crate) fn build_close_response() -> Vec { + let mut h = Header::new_request(Command::Close); + h.flags.set_response(); + h.credits = 32; + + let body = CloseResponse { + flags: 0, + creation_time: FileTime::ZERO, + last_access_time: FileTime::ZERO, + last_write_time: FileTime::ZERO, + change_time: FileTime::ZERO, + allocation_size: 0, + end_of_file: 0, + file_attributes: 0, + }; + + pack_message(&h, &body) +} + +/// Build a WRITE response with the given byte count. +pub(crate) fn build_write_response(count: u32) -> Vec { + use crate::msg::write::WriteResponse; + let mut h = Header::new_request(Command::Write); + h.flags.set_response(); + h.credits = 32; + + let body = WriteResponse { + count, + remaining: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + }; + + pack_message(&h, &body) +} + +/// Build a WRITE response with a non-success status (for error tests). +pub(crate) fn build_write_error_response(status: crate::types::status::NtStatus) -> Vec { + use crate::msg::header::ErrorResponse; + let mut h = Header::new_request(Command::Write); + h.flags.set_response(); + h.credits = 32; + h.status = status; + + let body = ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + + pack_message(&h, &body) +} + +/// Build a CLOSE response with a non-success status (for error tests). +pub(crate) fn build_close_error_response(status: crate::types::status::NtStatus) -> Vec { + use crate::msg::header::ErrorResponse; + let mut h = Header::new_request(Command::Close); + h.flags.set_response(); + h.credits = 32; + h.status = status; + + let body = ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + + pack_message(&h, &body) +} + +/// Build a FLUSH response. +pub(crate) fn build_flush_response() -> Vec { + let mut h = Header::new_request(Command::Flush); + h.flags.set_response(); + h.credits = 32; + + let body = crate::msg::flush::FlushResponse; + pack_message(&h, &body) +} + +/// Build a TREE_CONNECT response with the given tree ID and share type. +pub(crate) fn build_tree_connect_response(tree_id: TreeId, share_type: ShareType) -> Vec { + let mut h = Header::new_request(Command::TreeConnect); + h.flags.set_response(); + h.credits = 32; + h.tree_id = Some(tree_id); + + let body = TreeConnectResponse { + share_type, + share_flags: ShareFlags::default(), + capabilities: ShareCapabilities::default(), + maximal_access: 0x001F_01FF, + }; + + pack_message(&h, &body) +} diff --git a/vendor/smb2/src/client/tree.rs b/vendor/smb2/src/client/tree.rs new file mode 100644 index 0000000..c4cb113 --- /dev/null +++ b/vendor/smb2/src/client/tree.rs @@ -0,0 +1,6691 @@ +//! Tree (share) connection and file operations. +//! +//! The [`Tree`] type represents a connection to a specific share on the server. +//! It provides methods for directory listing, file reading/writing, deletion, +//! renaming, stat, and directory creation. + +use std::ops::ControlFlow; +use std::sync::Arc; + +use log::{debug, info, trace, warn}; + +use crate::client::connection::{CompoundOp, Connection}; +use crate::client::stream::{FileDownload, Progress}; +use crate::error::Result; +use crate::msg::close::CloseRequest; +use crate::msg::create::{ + CreateDisposition, CreateRequest, CreateResponse, ImpersonationLevel, ShareAccess, +}; +use crate::msg::flush::FlushRequest; +use crate::msg::query_directory::{ + FileInformationClass, QueryDirectoryFlags, QueryDirectoryRequest, QueryDirectoryResponse, +}; +use crate::msg::query_info::{InfoType, QueryInfoRequest, QueryInfoResponse}; +use crate::msg::read::{ReadRequest, ReadResponse, SMB2_CHANNEL_NONE}; +use crate::msg::set_info::SetInfoRequest; +use crate::msg::tree_connect::{TreeConnectRequest, TreeConnectRequestFlags, TreeConnectResponse}; +use crate::msg::tree_disconnect::TreeDisconnectRequest; +use crate::msg::write::{WriteRequest, WriteResponse}; +use crate::pack::{FileTime, ReadCursor, Unpack}; +use crate::types::flags::FileAccessMask; +use crate::types::status::NtStatus; +#[cfg(test)] +use crate::types::MessageId; +use crate::types::{Command, CreditCharge, FileId, OplockLevel, TreeId}; +use crate::Error; + +/// Maximum number of requests to keep in flight during pipelining. +/// +/// More than 32 in-flight requests creates diminishing returns and +/// increases memory usage (buffering responses). 32 x 64 KB = 2 MB +/// in flight is plenty for Gigabit LAN. +const MAX_PIPELINE_WINDOW: usize = 32; + +/// Unwrap an `execute_compound` result, propagating the first inner +/// waiter-level error (session expired, signature verify failure, +/// connection disconnected mid-await) as the outer `Err`. Returns a +/// `Vec` so callers can index per sub-op. +/// +/// Matches the pre-Phase-3 `receive_compound_expected`'s short-circuit +/// semantics: any routing-level failure aborts the whole operation +/// rather than silently handing back a partial response list the +/// caller would have to inspect one-by-one. Sub-op status codes +/// (`STATUS_OBJECT_NAME_NOT_FOUND` and friends) are NOT errors here; +/// they ride in `Frame::header.status` and the caller checks them. +fn all_or_first_err( + frames: Vec>, +) -> Result> { + let mut out = Vec::with_capacity(frames.len()); + for r in frames { + out.push(r?); + } + Ok(out) +} + +/// File attribute constant: the entry is a directory. +const FILE_ATTRIBUTE_DIRECTORY: u32 = 0x0000_0010; + +/// Create option: the target must be a directory. +const FILE_DIRECTORY_FILE: u32 = 0x0000_0001; + +/// Create option: the target must not be a directory. +const FILE_NON_DIRECTORY_FILE: u32 = 0x0000_0040; + +/// Create option: delete file when all handles are closed. +const FILE_DELETE_ON_CLOSE: u32 = 0x0000_1000; + +/// FileBasicInformation class for QUERY_INFO (MS-FSCC 2.4.7). +const FILE_BASIC_INFORMATION: u8 = 4; + +/// FileStandardInformation class for QUERY_INFO (MS-FSCC 2.4.41). +const FILE_STANDARD_INFORMATION: u8 = 5; + +/// FileRenameInformation class for SET_INFO (MS-FSCC 2.4.34.2). +const FILE_RENAME_INFORMATION: u8 = 10; + +/// FileFsFullSizeInformation class for QUERY_INFO (MS-FSCC 2.5.4). +const FILE_FS_FULL_SIZE_INFORMATION: u8 = 7; + +/// A directory entry returned by [`Tree::list_directory`]. +#[derive(Debug, Clone)] +pub struct DirectoryEntry { + /// The file or directory name. + pub name: String, + /// The file size in bytes (0 for directories). + pub size: u64, + /// Whether this entry is a directory. + pub is_directory: bool, + /// The creation time. + pub created: FileTime, + /// The last modification time. + pub modified: FileTime, +} + +/// File metadata returned by [`Tree::stat`]. +#[derive(Debug, Clone)] +pub struct FileInfo { + /// The file size in bytes. + pub size: u64, + /// Whether this is a directory. + pub is_directory: bool, + /// The creation time. + pub created: FileTime, + /// The last modification time. + pub modified: FileTime, + /// The last access time. + pub accessed: FileTime, +} + +/// File system space information for a share. +#[derive(Debug, Clone)] +pub struct FsInfo { + /// Total capacity in bytes. + pub total_bytes: u64, + /// Free space available to the caller in bytes. + pub free_bytes: u64, + /// Total free space on the volume in bytes (may differ from + /// `free_bytes` if quotas are in effect). + pub total_free_bytes: u64, + /// Bytes per sector. + pub bytes_per_sector: u32, + /// Sectors per allocation unit (cluster). + pub sectors_per_unit: u32, +} + +/// A connection to a specific share (tree connect). +#[derive(Clone)] +pub struct Tree { + /// The tree ID assigned by the server. + pub tree_id: TreeId, + /// The share name. + pub share_name: String, + /// The server name (hostname or IP) this tree is connected to. + /// + /// Used by `SmbClient` to route operations through the correct + /// connection when DFS referrals point to different servers. + pub server: String, + /// Whether the share is a DFS share. + pub is_dfs: bool, + /// Whether the share requires encryption. + pub encrypt_data: bool, +} + +impl Tree { + /// Connect to a share on the server. + /// + /// Sends a TREE_CONNECT request with the UNC path `\\server\share` + /// encoded in UTF-16LE. + pub async fn connect(conn: &mut Connection, share_name: &str) -> Result { + let server = conn.server_name().to_string(); + let unc_path = format!(r"\\{}\{}", server, share_name); + + let req = TreeConnectRequest { + flags: TreeConnectRequestFlags::default(), + path: unc_path, + }; + + let frame = conn.execute(Command::TreeConnect, &req, None).await?; + + if frame.header.command != Command::TreeConnect { + return Err(Error::invalid_data(format!( + "expected TreeConnect response, got {:?}", + frame.header.command + ))); + } + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::TreeConnect, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = TreeConnectResponse::unpack(&mut cursor)?; + + let tree_id = frame + .header + .tree_id + .ok_or_else(|| Error::invalid_data("TreeConnect response missing tree ID"))?; + + let is_dfs = resp + .capabilities + .contains(crate::types::flags::ShareCapabilities::DFS); + let encrypt_data = resp + .share_flags + .contains(crate::types::flags::ShareFlags::ENCRYPT_DATA); + + info!("tree: connected share={}, tree_id={}", share_name, tree_id); + debug!("tree: is_dfs={}, encrypt_data={}", is_dfs, encrypt_data); + + if is_dfs { + conn.register_dfs_tree(tree_id); + } + + Ok(Tree { + tree_id, + share_name: share_name.to_string(), + server: server.clone(), + is_dfs, + encrypt_data, + }) + } + + /// Normalize and format a path for this tree. + /// + /// When `is_dfs` is true, the server expects the path to include the + /// `server\share\` prefix (MS-SMB2 3.2.4.3: "the client MUST pass a + /// DFS path containing the server, share, and path to the open"). + /// The server strips the first two path components to get the local path, + /// and if the resulting path starts with a DFS link name, it returns + /// `STATUS_PATH_NOT_COVERED` so the client can resolve the referral. + pub(crate) fn format_path(&self, path: &str) -> String { + let normalized = normalize_path(path); + if self.is_dfs { + // Extract hostname (strip port if present) for the DFS path prefix. + let hostname = self.server.split(':').next().unwrap_or(&self.server); + if normalized.is_empty() { + format!("{}\\{}", hostname, self.share_name) + } else { + format!("{}\\{}\\{}", hostname, self.share_name, normalized) + } + } else { + normalized + } + } + + /// List files in a directory. + /// + /// Opens the directory with CREATE, queries entries with QUERY_DIRECTORY + /// (looping until STATUS_NO_MORE_FILES), then closes the handle. + pub async fn list_directory( + &self, + conn: &mut Connection, + path: &str, + ) -> Result> { + let normalized = self.format_path(path); + debug!("tree: list_directory path={}", normalized); + + // Open the directory. + let file_id = self.open_directory(conn, &normalized).await?; + + // Query directory entries. + let result = self.query_directory_loop(conn, file_id).await; + + // Close the handle regardless of query result. + let close_result = self.close_handle(conn, file_id).await; + + // Return the query result, or if it succeeded, check the close result. + let entries = result?; + close_result?; + debug!("tree: list_directory done, entries={}", entries.len()); + Ok(entries) + } + + /// Read a small file using a compound CREATE+READ+CLOSE request. + /// + /// Sends all three operations in a single transport frame, reducing + /// round-trips from 3 to 1. Best for files that fit in a single + /// READ (up to MaxReadSize). + /// + /// For files larger than MaxReadSize, use `read_file_pipelined` instead. + pub async fn read_file_compound(&self, conn: &mut Connection, path: &str) -> Result> { + let normalized = self.format_path(path); + let max_read = conn.params().map(|p| p.max_read_size).unwrap_or(65536); + debug!( + "tree: read_file_compound path={}, max_read={}", + normalized, max_read + ); + + // Build CREATE request (same params as open_file). + let create_req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::FILE_READ_DATA + | FileAccessMask::FILE_READ_ATTRIBUTES + | FileAccessMask::SYNCHRONIZE, + ), + file_attributes: 0, + share_access: ShareAccess( + ShareAccess::FILE_SHARE_READ + | ShareAccess::FILE_SHARE_WRITE + | ShareAccess::FILE_SHARE_DELETE, + ), + create_disposition: CreateDisposition::FileOpen, + create_options: 0, + name: normalized.clone(), + create_contexts: vec![], + }; + + // Build READ request with sentinel FileId. + // CreditCharge for READ = ceil(max_read / 65536). + let read_credit_charge = (max_read as u64).div_ceil(65536) as u16; + let read_req = ReadRequest { + padding: 0x50, + flags: 0, + length: max_read, + offset: 0, + file_id: FileId::SENTINEL, + minimum_count: 0, + channel: SMB2_CHANNEL_NONE, + remaining_bytes: 0, + read_channel_info: vec![], + }; + + // Build CLOSE request with sentinel FileId. + let close_req = CloseRequest { + flags: 0, + file_id: FileId::SENTINEL, + }; + + // Send as compound. + let ops = [ + CompoundOp { + command: Command::Create, + body: &create_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::Read, + body: &read_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(read_credit_charge), + }, + CompoundOp { + command: Command::Close, + body: &close_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + ]; + + let responses = all_or_first_err(conn.execute_compound(&ops).await?)?; + + let create_header = &responses[0].header; + let create_body = &responses[0].body; + let read_header = &responses[1].header; + let read_body = &responses[1].body; + let close_header = &responses[2].header; + + // Check CREATE response. + if create_header.status != NtStatus::SUCCESS { + // CREATE failed -- all three fail (cascaded). No handle to clean up. + return Err(Error::Protocol { + status: create_header.status, + command: Command::Create, + }); + } + + let mut cursor = ReadCursor::new(create_body); + let create_resp = CreateResponse::unpack(&mut cursor)?; + let file_id = create_resp.file_id; + + // Check READ response. + if read_header.status != NtStatus::SUCCESS && read_header.status != NtStatus::END_OF_FILE { + // READ failed. CLOSE also failed in the compound (cascaded). + // Issue a standalone CLOSE to clean up the handle. + debug!( + "tree: compound READ failed ({:?}), issuing standalone CLOSE", + read_header.status + ); + let _ = self.close_handle(conn, file_id).await; + return Err(Error::Protocol { + status: read_header.status, + command: Command::Read, + }); + } + + // Parse READ data. + let data = if read_header.status == NtStatus::END_OF_FILE { + // Empty file. + Vec::new() + } else { + let mut cursor = ReadCursor::new(read_body); + let read_resp = ReadResponse::unpack(&mut cursor)?; + read_resp.data + }; + + // Check CLOSE response. If it failed but CREATE and READ succeeded, + // the handle might still be open, but there's nothing we can do + // since we already have the data. + if close_header.status != NtStatus::SUCCESS { + debug!( + "tree: compound CLOSE returned {:?} (non-fatal, data already read)", + close_header.status, + ); + } + + debug!("tree: read_file_compound done, read {} bytes", data.len()); + Ok(data) + } + + /// Read a file's contents using a compound request (1 round-trip). + /// + /// Sends CREATE+READ+CLOSE as a single compound message. For files + /// that fit in MaxReadSize (typically 8 MB), this is the fastest + /// path -- 1 round-trip instead of 3+. + /// + /// For files larger than MaxReadSize, the compound returns only the + /// first chunk. In that case, use [`read_file_pipelined`](Self::read_file_pipelined) + /// for concurrent chunked reads. + pub async fn read_file(&self, conn: &mut Connection, path: &str) -> Result> { + self.read_file_compound(conn, path).await + } + + /// Disconnect from the share. + pub async fn disconnect(&self, conn: &mut Connection) -> Result<()> { + debug!( + "tree: disconnecting share={}, tree_id={}", + self.share_name, self.tree_id + ); + let body = TreeDisconnectRequest; + let frame = conn + .execute(Command::TreeDisconnect, &body, Some(self.tree_id)) + .await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::TreeDisconnect, + }); + } + + conn.deregister_dfs_tree(self.tree_id); + + info!( + "tree: disconnected share={}, tree_id={}", + self.share_name, self.tree_id + ); + Ok(()) + } + + /// Start watching a directory for changes. + /// + /// Opens the directory and returns a [`Watcher`](crate::client::watcher::Watcher) that yields change + /// events via [`next_events()`](crate::client::watcher::Watcher::next_events). + /// The server holds each request until changes occur, making this a + /// long-poll operation. + /// + /// Set `recursive` to `true` to watch the entire subtree. + /// + /// The returned `Watcher` owns a cloned `Connection` (cheap + /// `Arc::clone`, all clones multiplex over the same SMB session), so + /// the caller is free to perform other operations on `conn` while + /// watching. No second `SmbClient` is required. + pub async fn watch( + &self, + conn: &mut Connection, + path: &str, + recursive: bool, + ) -> Result { + let normalized = self.format_path(path); + debug!( + "tree: watch path={}, recursive={}, tree_id={}", + normalized, recursive, self.tree_id + ); + + // Open the directory with FILE_LIST_DIRECTORY access (same as + // FILE_READ_DATA = 0x0001). We need the handle to stay open for + // the lifetime of the watcher. + let file_id = self.open_directory(conn, &normalized).await?; + + // Hand the watcher an owned Tree clone and an owned Connection + // clone so it can pipeline CHANGE_NOTIFY requests independently + // of the caller's connection use. + Ok(crate::client::watcher::Watcher::new( + self.clone(), + conn.clone(), + file_id, + recursive, + )) + } + + /// Delete a file using a compound request (1 round-trip). + /// + /// Sends CREATE (with `DELETE_ON_CLOSE`) + CLOSE as a single compound + /// message. The server deletes the file when the CLOSE completes. + pub async fn delete_file(&self, conn: &mut Connection, path: &str) -> Result<()> { + self.delete_compound(conn, path, FILE_NON_DIRECTORY_FILE, "file") + .await + } + + /// Delete multiple files using batch compound requests. + /// + /// Sends all compound (CREATE+CLOSE) requests before waiting for any + /// responses, minimizing total round-trips. Returns results in the same + /// order as the input paths. Each file's result is independent -- one + /// failure does not affect the others. + pub async fn delete_files(&self, conn: &mut Connection, paths: &[&str]) -> Vec> { + if paths.is_empty() { + return vec![]; + } + + debug!("tree: delete_files batch, count={}", paths.len()); + + // Issue one `execute_compound` per path sequentially. Each compound + // is still CREATE+CLOSE in a single wire frame, so per-file round + // trips stay at 1. Pre-Phase-3 this loop did "phase 1: send all, + // phase 2: receive all" to overlap server work — the new API + // doesn't expose raw send/receive separately; if that throughput + // matters, `execute_compound` can run on cloned connections via + // `tokio::spawn` to interleave responses through the receiver + // task's per-`MessageId` routing. + let mut results: Vec> = Vec::with_capacity(paths.len()); + let mut cleanup_handles: Vec = Vec::new(); + for path in paths { + let normalized = self.format_path(path); + let create_req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::DELETE | FileAccessMask::FILE_READ_ATTRIBUTES, + ), + file_attributes: 0, + share_access: ShareAccess( + ShareAccess::FILE_SHARE_READ + | ShareAccess::FILE_SHARE_WRITE + | ShareAccess::FILE_SHARE_DELETE, + ), + create_disposition: CreateDisposition::FileOpen, + create_options: FILE_DELETE_ON_CLOSE | FILE_NON_DIRECTORY_FILE, + name: normalized, + create_contexts: vec![], + }; + let close_req = CloseRequest { + flags: 0, + file_id: FileId::SENTINEL, + }; + let ops = [ + CompoundOp { + command: Command::Create, + body: &create_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::Close, + body: &close_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + ]; + let frames = match conn.execute_compound(&ops).await { + Ok(v) => v, + Err(e) => { + results.push(Err(e)); + continue; + } + }; + let responses = match all_or_first_err(frames) { + Ok(v) => v, + Err(e) => { + results.push(Err(e)); + continue; + } + }; + let create_header = &responses[0].header; + let create_body = &responses[0].body; + let close_header = &responses[1].header; + if create_header.status != NtStatus::SUCCESS { + results.push(Err(Error::Protocol { + status: create_header.status, + command: Command::Create, + })); + } else if close_header.status != NtStatus::SUCCESS { + if let Ok(create_resp) = CreateResponse::unpack(&mut ReadCursor::new(create_body)) { + cleanup_handles.push(create_resp.file_id); + } + results.push(Err(Error::Protocol { + status: close_header.status, + command: Command::Close, + })); + } else { + info!("tree: batch deleted file={}", path); + results.push(Ok(())); + } + } + + // Phase 3: Cleanup -- issue standalone CLOSEs for leaked handles. + for file_id in &cleanup_handles { + warn!( + "tree: batch delete cleanup, issuing standalone CLOSE for {:?}", + file_id + ); + let _ = self.close_handle(conn, *file_id).await; + } + + debug!( + "tree: delete_files batch done, {}/{} succeeded", + results.iter().filter(|r| r.is_ok()).count(), + paths.len() + ); + results + } + + /// Get file metadata (size, timestamps, is_directory) using a compound request (1 round-trip). + /// + /// Sends CREATE + QUERY_INFO (FileBasicInformation) + + /// QUERY_INFO (FileStandardInformation) + CLOSE as a single compound message. + pub async fn stat(&self, conn: &mut Connection, path: &str) -> Result { + let normalized = self.format_path(path); + debug!("tree: stat (compound) path={}", normalized); + + // BUILD CREATE request for reading attributes. + let create_req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::FILE_READ_ATTRIBUTES | FileAccessMask::SYNCHRONIZE, + ), + file_attributes: 0, + share_access: ShareAccess( + ShareAccess::FILE_SHARE_READ + | ShareAccess::FILE_SHARE_WRITE + | ShareAccess::FILE_SHARE_DELETE, + ), + create_disposition: CreateDisposition::FileOpen, + create_options: 0, + name: normalized.clone(), + create_contexts: vec![], + }; + + // QUERY_INFO for FileBasicInformation (timestamps + attributes). + let basic_req = QueryInfoRequest { + info_type: InfoType::File, + file_info_class: FILE_BASIC_INFORMATION, + output_buffer_length: 40, + additional_information: 0, + flags: 0, + file_id: FileId::SENTINEL, + input_buffer: vec![], + }; + + // QUERY_INFO for FileStandardInformation (size + is_directory). + let std_req = QueryInfoRequest { + info_type: InfoType::File, + file_info_class: FILE_STANDARD_INFORMATION, + output_buffer_length: 24, + additional_information: 0, + flags: 0, + file_id: FileId::SENTINEL, + input_buffer: vec![], + }; + + // CLOSE with sentinel FileId. + let close_req = CloseRequest { + flags: 0, + file_id: FileId::SENTINEL, + }; + + let ops = [ + CompoundOp { + command: Command::Create, + body: &create_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::QueryInfo, + body: &basic_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::QueryInfo, + body: &std_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::Close, + body: &close_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + ]; + + let responses = all_or_first_err(conn.execute_compound(&ops).await?)?; + + let create_header = &responses[0].header; + let create_body = &responses[0].body; + let basic_header = &responses[1].header; + let basic_body = &responses[1].body; + let std_header = &responses[2].header; + let std_body = &responses[2].body; + let close_header = &responses[3].header; + + // If CREATE failed, all ops cascade. No handle to clean up. + if create_header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: create_header.status, + command: Command::Create, + }); + } + + // Check first QUERY_INFO (basic). If it failed, issue standalone CLOSE. + if !basic_header.status.is_success_or_partial() { + let mut cursor = ReadCursor::new(create_body); + let create_resp = CreateResponse::unpack(&mut cursor)?; + warn!( + "tree: compound QUERY_INFO (basic) failed ({:?}), issuing standalone CLOSE", + basic_header.status + ); + let _ = self.close_handle(conn, create_resp.file_id).await; + return Err(Error::Protocol { + status: basic_header.status, + command: Command::QueryInfo, + }); + } + if basic_header.status == NtStatus::BUFFER_OVERFLOW { + warn!("recv: STATUS_BUFFER_OVERFLOW on FileBasicInformation, response data may be truncated"); + } + + // Parse FileBasicInformation. + let mut cursor = ReadCursor::new(basic_body); + let basic_resp = QueryInfoResponse::unpack(&mut cursor)?; + let basic_buf = &basic_resp.output_buffer; + + if basic_buf.len() < 36 { + return Err(Error::invalid_data(format!( + "FileBasicInformation too short: {} bytes", + basic_buf.len() + ))); + } + + let created = FileTime(u64::from_le_bytes(basic_buf[0..8].try_into().unwrap())); + let accessed = FileTime(u64::from_le_bytes(basic_buf[8..16].try_into().unwrap())); + let modified = FileTime(u64::from_le_bytes(basic_buf[16..24].try_into().unwrap())); + let _change_time = u64::from_le_bytes(basic_buf[24..32].try_into().unwrap()); + let file_attributes = u32::from_le_bytes(basic_buf[32..36].try_into().unwrap()); + + // Check second QUERY_INFO (standard). If it failed, issue standalone CLOSE. + if !std_header.status.is_success_or_partial() { + let mut cursor = ReadCursor::new(create_body); + let create_resp = CreateResponse::unpack(&mut cursor)?; + warn!( + "tree: compound QUERY_INFO (standard) failed ({:?}), issuing standalone CLOSE", + std_header.status + ); + let _ = self.close_handle(conn, create_resp.file_id).await; + return Err(Error::Protocol { + status: std_header.status, + command: Command::QueryInfo, + }); + } + if std_header.status == NtStatus::BUFFER_OVERFLOW { + warn!("recv: STATUS_BUFFER_OVERFLOW on FileStandardInformation, response data may be truncated"); + } + + // Parse FileStandardInformation. + let mut cursor = ReadCursor::new(std_body); + let std_resp = QueryInfoResponse::unpack(&mut cursor)?; + let std_buf = &std_resp.output_buffer; + + if std_buf.len() < 22 { + return Err(Error::invalid_data(format!( + "FileStandardInformation too short: {} bytes", + std_buf.len() + ))); + } + + let _allocation_size = u64::from_le_bytes(std_buf[0..8].try_into().unwrap()); + let end_of_file = u64::from_le_bytes(std_buf[8..16].try_into().unwrap()); + let _number_of_links = u32::from_le_bytes(std_buf[16..20].try_into().unwrap()); + let _delete_pending = std_buf[20]; + let is_directory_byte = std_buf[21]; + + let is_directory = + is_directory_byte != 0 || (file_attributes & FILE_ATTRIBUTE_DIRECTORY) != 0; + + // Check CLOSE response (non-fatal, we already have the data). + if close_header.status != NtStatus::SUCCESS { + debug!( + "tree: compound CLOSE returned {:?} (non-fatal, stat data already read)", + close_header.status, + ); + } + + debug!( + "tree: stat done, size={}, is_dir={}", + end_of_file, is_directory + ); + Ok(FileInfo { + size: end_of_file, + is_directory, + created, + modified, + accessed, + }) + } + + /// Stat multiple files using batch compound requests. + /// + /// Sends all compound (CREATE+QUERY_INFO+QUERY_INFO+CLOSE) requests before + /// waiting for any responses. Returns results in the same order as the + /// input paths. + pub async fn stat_files(&self, conn: &mut Connection, paths: &[&str]) -> Vec> { + if paths.is_empty() { + return vec![]; + } + + debug!("tree: stat_files batch, count={}", paths.len()); + + // Issue one `execute_compound` per path sequentially. See + // `delete_files` for the same shape and the note on wire-level + // pipelining tradeoffs. + let mut results: Vec> = Vec::with_capacity(paths.len()); + let mut cleanup_handles: Vec = Vec::new(); + + for path in paths { + let normalized = self.format_path(path); + let create_req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::FILE_READ_ATTRIBUTES | FileAccessMask::SYNCHRONIZE, + ), + file_attributes: 0, + share_access: ShareAccess( + ShareAccess::FILE_SHARE_READ + | ShareAccess::FILE_SHARE_WRITE + | ShareAccess::FILE_SHARE_DELETE, + ), + create_disposition: CreateDisposition::FileOpen, + create_options: 0, + name: normalized, + create_contexts: vec![], + }; + let basic_req = QueryInfoRequest { + info_type: InfoType::File, + file_info_class: FILE_BASIC_INFORMATION, + output_buffer_length: 40, + additional_information: 0, + flags: 0, + file_id: FileId::SENTINEL, + input_buffer: vec![], + }; + let std_req = QueryInfoRequest { + info_type: InfoType::File, + file_info_class: FILE_STANDARD_INFORMATION, + output_buffer_length: 24, + additional_information: 0, + flags: 0, + file_id: FileId::SENTINEL, + input_buffer: vec![], + }; + let close_req = CloseRequest { + flags: 0, + file_id: FileId::SENTINEL, + }; + let ops = [ + CompoundOp { + command: Command::Create, + body: &create_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::QueryInfo, + body: &basic_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::QueryInfo, + body: &std_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::Close, + body: &close_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + ]; + let frames = match conn.execute_compound(&ops).await { + Ok(v) => v, + Err(e) => { + results.push(Err(e)); + continue; + } + }; + let responses = match all_or_first_err(frames) { + Ok(v) => v, + Err(e) => { + results.push(Err(e)); + continue; + } + }; + let parsed = self.parse_stat_batch_response(&responses, &mut cleanup_handles); + if parsed.is_ok() { + debug!("tree: batch stat done for file={}", path); + } + results.push(parsed); + } + + // Phase 3: Cleanup -- standalone CLOSEs for leaked handles. + for file_id in &cleanup_handles { + warn!( + "tree: batch stat cleanup, issuing standalone CLOSE for {:?}", + file_id + ); + let _ = self.close_handle(conn, *file_id).await; + } + + debug!( + "tree: stat_files batch done, {}/{} succeeded", + results.iter().filter(|r| r.is_ok()).count(), + paths.len() + ); + results + } + + /// Parse a single stat compound response for the batch stat method. + /// + /// Expects exactly four sub-frames (CREATE, QUERY_INFO basic, + /// QUERY_INFO standard, CLOSE). + fn parse_stat_batch_response( + &self, + responses: &[crate::client::connection::Frame], + cleanup_handles: &mut Vec, + ) -> Result { + debug_assert_eq!( + responses.len(), + 4, + "stat compound must have 4 sub-responses" + ); + + let create_header = &responses[0].header; + let create_body = &responses[0].body; + let basic_header = &responses[1].header; + let basic_body = &responses[1].body; + let std_header = &responses[2].header; + let std_body = &responses[2].body; + + if create_header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: create_header.status, + command: Command::Create, + }); + } + + // CREATE succeeded -- if a later op fails, we need cleanup. + let file_id = CreateResponse::unpack(&mut ReadCursor::new(create_body)) + .map(|r| r.file_id) + .ok(); + + if !basic_header.status.is_success_or_partial() { + if let Some(fid) = file_id { + cleanup_handles.push(fid); + } + return Err(Error::Protocol { + status: basic_header.status, + command: Command::QueryInfo, + }); + } + + let mut cursor = ReadCursor::new(basic_body); + let basic_resp = QueryInfoResponse::unpack(&mut cursor)?; + let basic_buf = &basic_resp.output_buffer; + + if basic_buf.len() < 36 { + if let Some(fid) = file_id { + cleanup_handles.push(fid); + } + return Err(Error::invalid_data(format!( + "FileBasicInformation too short: {} bytes", + basic_buf.len() + ))); + } + + let created = FileTime(u64::from_le_bytes(basic_buf[0..8].try_into().unwrap())); + let accessed = FileTime(u64::from_le_bytes(basic_buf[8..16].try_into().unwrap())); + let modified = FileTime(u64::from_le_bytes(basic_buf[16..24].try_into().unwrap())); + let file_attributes = u32::from_le_bytes(basic_buf[32..36].try_into().unwrap()); + + if !std_header.status.is_success_or_partial() { + if let Some(fid) = file_id { + cleanup_handles.push(fid); + } + return Err(Error::Protocol { + status: std_header.status, + command: Command::QueryInfo, + }); + } + + let mut cursor = ReadCursor::new(std_body); + let std_resp = QueryInfoResponse::unpack(&mut cursor)?; + let std_buf = &std_resp.output_buffer; + + if std_buf.len() < 22 { + if let Some(fid) = file_id { + cleanup_handles.push(fid); + } + return Err(Error::invalid_data(format!( + "FileStandardInformation too short: {} bytes", + std_buf.len() + ))); + } + + let end_of_file = u64::from_le_bytes(std_buf[8..16].try_into().unwrap()); + let is_directory_byte = std_buf[21]; + + let is_directory = + is_directory_byte != 0 || (file_attributes & FILE_ATTRIBUTE_DIRECTORY) != 0; + + Ok(FileInfo { + size: end_of_file, + is_directory, + created, + modified, + accessed, + }) + } + + /// Query file system space information for this share. + /// + /// Returns total capacity, free space, and allocation unit sizes. + /// Uses a compound CREATE+QUERY_INFO+CLOSE for efficiency (one round-trip). + pub async fn fs_info(&self, conn: &mut Connection) -> Result { + debug!("tree: fs_info on share={}", self.share_name); + + // Build CREATE request to open the root directory of the share. + let create_req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::FILE_READ_ATTRIBUTES | FileAccessMask::SYNCHRONIZE, + ), + file_attributes: 0, + share_access: ShareAccess( + ShareAccess::FILE_SHARE_READ + | ShareAccess::FILE_SHARE_WRITE + | ShareAccess::FILE_SHARE_DELETE, + ), + create_disposition: CreateDisposition::FileOpen, + create_options: FILE_DIRECTORY_FILE, + name: String::new(), // root of share + create_contexts: vec![], + }; + + // Build QUERY_INFO request for FileFsFullSizeInformation. + // Use sentinel FileId; the compound will fill it in. + let query_req = QueryInfoRequest { + info_type: InfoType::Filesystem, + file_info_class: FILE_FS_FULL_SIZE_INFORMATION, + output_buffer_length: 32, // 3 x i64 + 2 x u32 + additional_information: 0, + flags: 0, + file_id: FileId::SENTINEL, + input_buffer: vec![], + }; + + // Build CLOSE request with sentinel FileId. + let close_req = CloseRequest { + flags: 0, + file_id: FileId::SENTINEL, + }; + + // Send as compound. + let ops = [ + CompoundOp { + command: Command::Create, + body: &create_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::QueryInfo, + body: &query_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::Close, + body: &close_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + ]; + + let responses = all_or_first_err(conn.execute_compound(&ops).await?)?; + + let create_header = &responses[0].header; + let query_header = &responses[1].header; + let query_body = &responses[1].body; + let close_header = &responses[2].header; + + // Check CREATE response. + if create_header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: create_header.status, + command: Command::Create, + }); + } + + // Check QUERY_INFO response. + if !query_header.status.is_success_or_partial() { + // QUERY_INFO failed. Issue standalone CLOSE to clean up. + let mut cursor = ReadCursor::new(&responses[0].body); + let create_resp = CreateResponse::unpack(&mut cursor)?; + debug!( + "tree: compound QUERY_INFO failed ({:?}), issuing standalone CLOSE", + query_header.status + ); + let _ = self.close_handle(conn, create_resp.file_id).await; + return Err(Error::Protocol { + status: query_header.status, + command: Command::QueryInfo, + }); + } + if query_header.status == NtStatus::BUFFER_OVERFLOW { + warn!("recv: STATUS_BUFFER_OVERFLOW on FileFsFullSizeInformation, response data may be truncated"); + } + + // Parse the FileFsFullSizeInformation response. + let mut cursor = ReadCursor::new(query_body); + let query_resp = QueryInfoResponse::unpack(&mut cursor)?; + let buf = &query_resp.output_buffer; + + if buf.len() < 32 { + return Err(Error::invalid_data(format!( + "FileFsFullSizeInformation too short: {} bytes", + buf.len() + ))); + } + + let total_allocation_units = i64::from_le_bytes(buf[0..8].try_into().unwrap()) as u64; + let caller_available_units = i64::from_le_bytes(buf[8..16].try_into().unwrap()) as u64; + let actual_available_units = i64::from_le_bytes(buf[16..24].try_into().unwrap()) as u64; + let sectors_per_unit = u32::from_le_bytes(buf[24..28].try_into().unwrap()); + let bytes_per_sector = u32::from_le_bytes(buf[28..32].try_into().unwrap()); + + let bytes_per_unit = sectors_per_unit as u64 * bytes_per_sector as u64; + let total_bytes = total_allocation_units * bytes_per_unit; + let free_bytes = caller_available_units * bytes_per_unit; + let total_free_bytes = actual_available_units * bytes_per_unit; + + // Check CLOSE response (non-fatal if it failed). + if close_header.status != NtStatus::SUCCESS { + debug!( + "tree: compound CLOSE returned {:?} (non-fatal, fs_info already read)", + close_header.status, + ); + } + + debug!( + "tree: fs_info done, total={}, free={}, total_free={}", + total_bytes, free_bytes, total_free_bytes + ); + Ok(FsInfo { + total_bytes, + free_bytes, + total_free_bytes, + bytes_per_sector, + sectors_per_unit, + }) + } + + /// Rename or move a file within the same share using a compound request (1 round-trip). + /// + /// Sends CREATE + SET_INFO (FileRenameInformation) + CLOSE as a single + /// compound message. + pub async fn rename(&self, conn: &mut Connection, from: &str, to: &str) -> Result<()> { + let from_normalized = self.format_path(from); + let to_normalized = normalize_path(to); + debug!( + "tree: rename (compound) from={} to={}", + from_normalized, to_normalized + ); + + // Build CREATE request with DELETE access (required for rename). + let create_req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::DELETE | FileAccessMask::FILE_READ_ATTRIBUTES, + ), + file_attributes: 0, + share_access: ShareAccess( + ShareAccess::FILE_SHARE_READ + | ShareAccess::FILE_SHARE_WRITE + | ShareAccess::FILE_SHARE_DELETE, + ), + create_disposition: CreateDisposition::FileOpen, + create_options: 0, + name: from_normalized.clone(), + create_contexts: vec![], + }; + + // Build SET_INFO request with FileRenameInformation and sentinel FileId. + let setinfo_req = SetInfoRequest { + info_type: InfoType::File, + file_info_class: FILE_RENAME_INFORMATION, + additional_information: 0, + file_id: FileId::SENTINEL, + buffer: build_rename_info_buffer(&to_normalized), + }; + + // Build CLOSE request with sentinel FileId. + let close_req = CloseRequest { + flags: 0, + file_id: FileId::SENTINEL, + }; + + let ops = [ + CompoundOp { + command: Command::Create, + body: &create_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::SetInfo, + body: &setinfo_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::Close, + body: &close_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + ]; + + let responses = all_or_first_err(conn.execute_compound(&ops).await?)?; + + let create_header = &responses[0].header; + let create_body = &responses[0].body; + let setinfo_header = &responses[1].header; + let close_header = &responses[2].header; + + // If CREATE failed, all ops cascade. No handle to clean up. + if create_header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: create_header.status, + command: Command::Create, + }); + } + + // CREATE succeeded. If SET_INFO failed, CLOSE also cascaded. + // Issue standalone CLOSE to avoid leaking the handle. + if setinfo_header.status != NtStatus::SUCCESS { + let mut cursor = ReadCursor::new(create_body); + let create_resp = CreateResponse::unpack(&mut cursor)?; + warn!( + "tree: compound SET_INFO failed ({:?}), issuing standalone CLOSE", + setinfo_header.status + ); + let _ = self.close_handle(conn, create_resp.file_id).await; + return Err(Error::Protocol { + status: setinfo_header.status, + command: Command::SetInfo, + }); + } + + // Check CLOSE response (non-fatal if it failed, rename already done). + if close_header.status != NtStatus::SUCCESS { + debug!( + "tree: compound CLOSE returned {:?} (non-fatal, rename already done)", + close_header.status, + ); + } + + info!( + "tree: renamed from={} to={}", + from_normalized, to_normalized + ); + Ok(()) + } + + /// Rename multiple files using batch compound requests. + /// + /// Sends all compound (CREATE+SET_INFO+CLOSE) requests before waiting for + /// any responses. Returns results in the same order as the input pairs. + pub async fn rename_files( + &self, + conn: &mut Connection, + renames: &[(&str, &str)], + ) -> Vec> { + if renames.is_empty() { + return vec![]; + } + + debug!("tree: rename_files batch, count={}", renames.len()); + + // Sequential `execute_compound` per rename. See `delete_files` for + // the pipelining note. + let mut results: Vec> = Vec::with_capacity(renames.len()); + let mut cleanup_handles: Vec = Vec::new(); + + for (from, to) in renames { + let from_normalized = self.format_path(from); + let to_normalized = normalize_path(to); + let create_req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::DELETE | FileAccessMask::FILE_READ_ATTRIBUTES, + ), + file_attributes: 0, + share_access: ShareAccess( + ShareAccess::FILE_SHARE_READ + | ShareAccess::FILE_SHARE_WRITE + | ShareAccess::FILE_SHARE_DELETE, + ), + create_disposition: CreateDisposition::FileOpen, + create_options: 0, + name: from_normalized, + create_contexts: vec![], + }; + let setinfo_req = SetInfoRequest { + info_type: InfoType::File, + file_info_class: FILE_RENAME_INFORMATION, + additional_information: 0, + file_id: FileId::SENTINEL, + buffer: build_rename_info_buffer(&to_normalized), + }; + let close_req = CloseRequest { + flags: 0, + file_id: FileId::SENTINEL, + }; + let ops = [ + CompoundOp { + command: Command::Create, + body: &create_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::SetInfo, + body: &setinfo_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::Close, + body: &close_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + ]; + let frames = match conn.execute_compound(&ops).await { + Ok(v) => v, + Err(e) => { + results.push(Err(e)); + continue; + } + }; + let responses = match all_or_first_err(frames) { + Ok(v) => v, + Err(e) => { + results.push(Err(e)); + continue; + } + }; + let create_header = &responses[0].header; + let create_body = &responses[0].body; + let setinfo_header = &responses[1].header; + let close_header = &responses[2].header; + + if create_header.status != NtStatus::SUCCESS { + results.push(Err(Error::Protocol { + status: create_header.status, + command: Command::Create, + })); + } else if setinfo_header.status != NtStatus::SUCCESS { + if let Ok(create_resp) = CreateResponse::unpack(&mut ReadCursor::new(create_body)) { + cleanup_handles.push(create_resp.file_id); + } + results.push(Err(Error::Protocol { + status: setinfo_header.status, + command: Command::SetInfo, + })); + } else { + if close_header.status != NtStatus::SUCCESS { + debug!( + "tree: batch rename CLOSE returned {:?} (non-fatal)", + close_header.status, + ); + } + info!("tree: batch renamed from={} to={}", from, to); + results.push(Ok(())); + } + } + + // Phase 3: Cleanup -- standalone CLOSEs for leaked handles. + for file_id in &cleanup_handles { + warn!( + "tree: batch rename cleanup, issuing standalone CLOSE for {:?}", + file_id + ); + let _ = self.close_handle(conn, *file_id).await; + } + + debug!( + "tree: rename_files batch done, {}/{} succeeded", + results.iter().filter(|r| r.is_ok()).count(), + renames.len() + ); + results + } + + /// Write a file using a compound CREATE+WRITE+FLUSH+CLOSE request. + /// + /// Sends all four operations in a single transport frame (1 round-trip). + /// Best for files that fit in MaxWriteSize. For larger files, use + /// [`write_file_pipelined`](Self::write_file_pipelined). + pub async fn write_file_compound( + &self, + conn: &mut Connection, + path: &str, + data: &[u8], + ) -> Result { + let normalized = self.format_path(path); + debug!( + "tree: write_file_compound path={}, len={}", + normalized, + data.len() + ); + + // Build CREATE request (write access, overwrite-if disposition). + let create_req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::FILE_WRITE_DATA + | FileAccessMask::FILE_WRITE_ATTRIBUTES + | FileAccessMask::SYNCHRONIZE, + ), + file_attributes: 0x80, // FILE_ATTRIBUTE_NORMAL + share_access: ShareAccess(0), + create_disposition: CreateDisposition::FileOverwriteIf, + create_options: FILE_NON_DIRECTORY_FILE, + name: normalized.clone(), + create_contexts: vec![], + }; + + // Build WRITE request with sentinel FileId. + // DataOffset = Header::SIZE (64) + WriteRequest fixed body (48) = 0x70. + let write_credit_charge = (data.len() as u64).div_ceil(65536).max(1) as u16; + let write_req = WriteRequest { + data_offset: 0x70, + offset: 0, + file_id: FileId::SENTINEL, + channel: 0, + remaining_bytes: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + flags: 0, + data: data.to_vec(), + }; + + // Build FLUSH request with sentinel FileId. + let flush_req = FlushRequest { + file_id: FileId::SENTINEL, + }; + + // Build CLOSE request with sentinel FileId. + let close_req = CloseRequest { + flags: 0, + file_id: FileId::SENTINEL, + }; + + // Send as 4-way compound. + let ops = [ + CompoundOp { + command: Command::Create, + body: &create_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::Write, + body: &write_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(write_credit_charge), + }, + CompoundOp { + command: Command::Flush, + body: &flush_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::Close, + body: &close_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + ]; + + let responses = all_or_first_err(conn.execute_compound(&ops).await?)?; + + let create_header = &responses[0].header; + let create_body = &responses[0].body; + let write_header = &responses[1].header; + let write_body = &responses[1].body; + let flush_header = &responses[2].header; + let close_header = &responses[3].header; + + // Check CREATE response. + if create_header.status != NtStatus::SUCCESS { + // CREATE failed -- all four fail (cascaded). No handle to clean up. + return Err(Error::Protocol { + status: create_header.status, + command: Command::Create, + }); + } + + let mut cursor = ReadCursor::new(create_body); + let create_resp = CreateResponse::unpack(&mut cursor)?; + let file_id = create_resp.file_id; + + // Check WRITE response. + if write_header.status != NtStatus::SUCCESS { + // WRITE failed. FLUSH and CLOSE also failed in the compound (cascaded). + // Issue a standalone CLOSE to clean up the handle. + debug!( + "tree: compound WRITE failed ({:?}), issuing standalone CLOSE", + write_header.status + ); + let _ = self.close_handle(conn, file_id).await; + return Err(Error::Protocol { + status: write_header.status, + command: Command::Write, + }); + } + + let mut cursor = ReadCursor::new(write_body); + let write_resp = WriteResponse::unpack(&mut cursor)?; + let bytes_written = write_resp.count as u64; + + // Check FLUSH response. If it failed but WRITE succeeded, + // the data might not be persisted yet but the write did happen. + if flush_header.status != NtStatus::SUCCESS { + debug!( + "tree: compound FLUSH returned {:?} (data written but may not be persisted)", + flush_header.status, + ); + } + + // Check CLOSE response. If it failed but CREATE and WRITE succeeded, + // the handle might still be open, but there's nothing we can do + // since we already have the data written. + if close_header.status != NtStatus::SUCCESS { + debug!( + "tree: compound CLOSE returned {:?} (non-fatal, data already written)", + close_header.status, + ); + } + + debug!( + "tree: write_file_compound done, wrote {} bytes", + bytes_written + ); + Ok(bytes_written) + } + + /// Write data to a file (create or overwrite). + /// + /// For data that fits in MaxWriteSize (typically 64 KB to 8 MB), uses a + /// compound CREATE+WRITE+FLUSH+CLOSE in a single round-trip. For larger + /// data, falls back to the pipelined write path. + /// + /// Returns the total number of bytes written. + pub async fn write_file(&self, conn: &mut Connection, path: &str, data: &[u8]) -> Result { + let max_write = conn + .params() + .map(|p| p.max_write_size as usize) + .unwrap_or(65536); + if data.len() <= max_write { + self.write_file_compound(conn, path, data).await + } else { + self.write_file_pipelined(conn, path, data).await + } + } + + /// Read a file using pipelined I/O with a sliding window. + /// + /// Opens the file, determines its size, then uses a sliding window to + /// keep the pipe full: as each response arrives, the next request is sent + /// immediately. Much faster than sequential [`read_file`](Self::read_file) + /// for large files. + /// + /// Uses 64 KB chunks with CreditCharge=1 to maximize concurrency. + /// The window is capped at 32 in-flight requests (2 MB). + pub async fn read_file_pipelined(&self, conn: &mut Connection, path: &str) -> Result> { + let normalized = self.format_path(path); + + // Open the file. + let (file_id, file_size) = self.open_file(conn, &normalized).await?; + + if file_size == 0 { + debug!( + "tree: read_file_pipelined path={}, size=0 (empty file)", + normalized + ); + self.close_handle(conn, file_id).await?; + return Ok(Vec::new()); + } + + // Balance chunk size for pipelining: small enough to keep many + // in flight (sliding window benefit), large enough to minimize + // per-chunk overhead (headers, signing). + // + // For files that fit in one read: use file size (no chunking). + // For larger files: use 512 KB -- gives ~20 chunks per 10 MB + // (enough for pipelining) with 8 credits per chunk (manageable). + let max_read = conn.params().map(|p| p.max_read_size).unwrap_or(65536); + let pipeline_chunk = 512 * 1024_u32; // 512 KB + let chunk_size = if file_size <= max_read as u64 { + // File fits in one read -- no pipelining needed. + (file_size as u32).min(max_read) + } else { + // Use pipeline chunk size, capped to MaxReadSize. + pipeline_chunk.min(max_read) + }; + let credit_charge = chunk_size.div_ceil(65536) as u16; + let total_chunks = file_size.div_ceil(chunk_size as u64) as usize; + debug!( + "tree: read_file_pipelined path={}, size={}, chunk_size={}, credit_charge={}, total_chunks={}, credits={}", + normalized, file_size, chunk_size, credit_charge, total_chunks, conn.credits() + ); + + let start = std::time::Instant::now(); + let result = self + .read_pipelined_loop( + conn, + file_id, + file_size, + chunk_size, + credit_charge, + total_chunks, + ) + .await; + + // Close the handle regardless of read result. + let close_result = self.close_handle(conn, file_id).await; + + let data = result?; + close_result?; + + let elapsed = start.elapsed(); + let mb = data.len() as f64 / (1024.0 * 1024.0); + let mbps = if elapsed.as_secs_f64() > 0.0 { + mb / elapsed.as_secs_f64() + } else { + 0.0 + }; + debug!( + "tree: read_file_pipelined done, read {} bytes in {:.2?} ({:.1} MB/s)", + data.len(), + elapsed, + mbps + ); + + Ok(data) + } + + /// Read a file using pipelined I/O with progress reporting and cancellation. + /// + /// Same as [`read_file_pipelined`](Self::read_file_pipelined) but calls + /// `on_progress` after each chunk is received. Return + /// `ControlFlow::Break(())` from the callback to cancel the read. + pub async fn read_file_pipelined_with_progress( + &self, + conn: &mut Connection, + path: &str, + mut on_progress: F, + ) -> Result> + where + F: FnMut(Progress) -> ControlFlow<()>, + { + let normalized = self.format_path(path); + + let (file_id, file_size) = self.open_file(conn, &normalized).await?; + + if file_size == 0 { + debug!( + "tree: read_file_pipelined_with_progress path={}, size=0 (empty file)", + normalized + ); + self.close_handle(conn, file_id).await?; + let _ = on_progress(Progress { + bytes_transferred: 0, + total_bytes: Some(0), + }); + return Ok(Vec::new()); + } + + let max_read = conn.params().map(|p| p.max_read_size).unwrap_or(65536); + let pipeline_chunk = 512 * 1024_u32; + let chunk_size = if file_size <= max_read as u64 { + (file_size as u32).min(max_read) + } else { + pipeline_chunk.min(max_read) + }; + let credit_charge = chunk_size.div_ceil(65536) as u16; + let total_chunks = file_size.div_ceil(chunk_size as u64) as usize; + debug!( + "tree: read_file_pipelined_with_progress path={}, size={}, chunk_size={}, total_chunks={}", + normalized, file_size, chunk_size, total_chunks + ); + + let result = self + .read_pipelined_loop_with_progress( + conn, + file_id, + file_size, + chunk_size, + credit_charge, + total_chunks, + &mut on_progress, + ) + .await; + + // Close the handle regardless of read result. + let close_result = self.close_handle(conn, file_id).await; + + let data = result?; + close_result?; + + debug!( + "tree: read_file_pipelined_with_progress done, read {} bytes", + data.len() + ); + Ok(data) + } + + /// Start a streaming file download on this tree. + /// + /// Issues CREATE and returns a [`FileDownload`] that pulls the body in + /// chunks via [`next_chunk`](FileDownload::next_chunk). Mirrors + /// [`SmbClient::download`](crate::SmbClient::download) but accepts a + /// borrowed [`Connection`] directly, so callers who hold a cloned + /// `Connection` (see [`Connection::clone`]) can drive concurrent + /// downloads on one SMB session. + /// + /// For files that fit in one READ (≤ `max_read_size`), prefer + /// [`read_file_compound`](Self::read_file_compound) — 1 RTT vs. 3 RTTs. + /// + /// # Example + /// + /// ```ignore + /// # async fn example(conn: &mut smb2::Connection, tree: &smb2::Tree) -> Result<(), smb2::Error> { + /// let mut download = tree.download(conn, "big.bin").await?; + /// while let Some(chunk) = download.next_chunk().await { + /// let bytes = chunk?; + /// // process bytes + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn download<'a>( + &'a self, + conn: &'a mut Connection, + path: &str, + ) -> Result> { + let normalized = path.replace('/', "\\"); + let normalized = normalized.trim_start_matches('\\'); + let (file_id, file_size) = self.open_file(conn, normalized).await?; + let chunk_size = conn.params().map(|p| p.max_read_size).unwrap_or(65536); + Ok(FileDownload::new( + self, conn, file_id, file_size, chunk_size, + )) + } + + /// Write a file using pipelined I/O with a sliding window. + /// + /// Opens/creates the file, then uses a sliding window to keep the pipe + /// full: as each response arrives, the next request is sent immediately. + /// Flushes to ensure data is persisted on the server. Much faster than + /// sequential [`write_file`](Self::write_file) for large data. + /// + /// Uses MaxWriteSize chunks to minimize overhead for large payloads. + pub async fn write_file_pipelined( + &self, + conn: &mut Connection, + path: &str, + data: &[u8], + ) -> Result { + let normalized = self.format_path(path); + + if data.is_empty() { + debug!( + "tree: write_file_pipelined path={}, len=0 (empty write)", + normalized + ); + // Still create the file (to match write_file behavior). + return self.write_file_compound(conn, path, data).await; + } + + // Open (or create) the file for writing. + let req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::FILE_WRITE_DATA + | FileAccessMask::FILE_WRITE_ATTRIBUTES + | FileAccessMask::SYNCHRONIZE, + ), + file_attributes: 0x80, // FILE_ATTRIBUTE_NORMAL + share_access: ShareAccess(0), + create_disposition: CreateDisposition::FileOverwriteIf, + create_options: FILE_NON_DIRECTORY_FILE, + name: normalized.clone(), + create_contexts: vec![], + }; + + let frame = conn + .execute(Command::Create, &req, Some(self.tree_id)) + .await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Create, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let create_resp = CreateResponse::unpack(&mut cursor)?; + let file_id = create_resp.file_id; + + // Use MaxWriteSize for pipelined writes: minimizes overhead for + // large payloads being sent (we're sending data, not just a small request). + let max_write = conn.params().map(|p| p.max_write_size).unwrap_or(65536); + let chunk_size = max_write; + let credit_charge = chunk_size.div_ceil(65536) as u16; + let total_chunks = data.len().div_ceil(chunk_size as usize); + debug!( + "tree: write_file_pipelined path={}, len={}, chunk_size={}, credit_charge={}, total_chunks={}, credits={}", + normalized, data.len(), chunk_size, credit_charge, total_chunks, conn.credits() + ); + + let start = std::time::Instant::now(); + let result = self + .write_pipelined_loop(conn, file_id, data, chunk_size, credit_charge, total_chunks) + .await; + + // Flush to ensure data is persisted on the server. + if result.is_ok() { + self.flush_handle(conn, file_id).await?; + } + + // Close the handle. + let close_result = self.close_handle(conn, file_id).await; + + let bytes_written = result?; + close_result?; + + let elapsed = start.elapsed(); + let mb = bytes_written as f64 / (1024.0 * 1024.0); + let mbps = if elapsed.as_secs_f64() > 0.0 { + mb / elapsed.as_secs_f64() + } else { + 0.0 + }; + debug!( + "tree: write_file_pipelined done, wrote {} bytes in {:.2?} ({:.1} MB/s)", + bytes_written, elapsed, mbps + ); + + Ok(bytes_written) + } + + /// Write a file from a streaming source using pipelined I/O. + /// + /// Pulls data on demand from a callback, so you never need the full + /// file in memory. Ideal for writing from a network stream, a + /// channel, or any producer that generates data incrementally. + /// + /// # Callback contract + /// + /// Each call to `next_chunk` must return one of: + /// - `Some(Ok(data))` — the next chunk to write (any size; chunks + /// larger than `MaxWriteSize` are split automatically) + /// - `Some(Err(e))` — an I/O error from the source; aborts the + /// write, drains in-flight responses, and propagates the error + /// - `None` — end of stream; all remaining in-flight writes are + /// completed before returning + /// + /// An empty `Vec` in `Some(Ok(vec![]))` is treated the same as + /// `None` (end of stream). + /// + /// # Behavior + /// + /// - Returns the total number of bytes the server acknowledged. + /// - The file handle is always closed, even on error. + /// - If `next_chunk` returns `None` on the first call, an empty file + /// is created. + /// - On early termination (callback error or server error), a partial + /// file may remain on the server. The caller is responsible for + /// cleanup (for example, calling [`delete_file`](Self::delete_file)). + /// + /// # Performance + /// + /// Uses a sliding window of up to 32 in-flight WRITE requests (same + /// approach as [`write_file_pipelined`](Self::write_file_pipelined)), + /// so throughput stays high even on high-latency links. Memory usage + /// is bounded to the sliding window, not the full file size. + /// + /// # When to use which write method + /// + /// | Method | Best for | + /// |--------|----------| + /// | [`write_file`](Self::write_file) | Small files that fit in a single compound (one round-trip) | + /// | [`write_file_pipelined`](Self::write_file_pipelined) | Large files already in a `&[u8]` buffer | + /// | `write_file_streamed` | Data produced incrementally (streams, channels, generators) | + /// + /// # Example + /// + /// ```no_run + /// # async fn example(tree: &smb2::client::Tree, conn: &mut smb2::client::Connection) -> smb2::Result<()> { + /// let chunks = vec![b"hello ".to_vec(), b"world".to_vec()]; + /// let mut iter = chunks.into_iter(); + /// let mut next = || iter.next().map(Ok); + /// + /// let bytes_written = tree.write_file_streamed(conn, "greeting.txt", &mut next).await?; + /// assert_eq!(bytes_written, 11); + /// # Ok(()) + /// # } + /// ``` + pub async fn write_file_streamed( + &self, + conn: &mut Connection, + path: &str, + next_chunk: &mut F, + ) -> Result + where + F: FnMut() -> Option, std::io::Error>>, + { + let normalized = self.format_path(path); + debug!("tree: write_file_streamed path={}", normalized); + + // Open (or create) the file for writing. + let file_id = self.open_file_for_write(conn, &normalized).await?; + + let max_write = conn.params().map(|p| p.max_write_size).unwrap_or(65536); + + let start = std::time::Instant::now(); + let result = self + .write_streamed_loop(conn, file_id, next_chunk, max_write) + .await; + + // Flush to ensure data is persisted on the server. + if result.is_ok() { + self.flush_handle(conn, file_id).await?; + } + + // Close the handle (always, even on error). + let close_result = self.close_handle(conn, file_id).await; + + let bytes_written = result?; + close_result?; + + let elapsed = start.elapsed(); + let mb = bytes_written as f64 / (1024.0 * 1024.0); + let mbps = if elapsed.as_secs_f64() > 0.0 { + mb / elapsed.as_secs_f64() + } else { + 0.0 + }; + debug!( + "tree: write_file_streamed done, wrote {} bytes in {:.2?} ({:.1} MB/s)", + bytes_written, elapsed, mbps + ); + + Ok(bytes_written) + } + + /// Create a push-based pipelined streaming writer that owns its + /// `Connection` and `Arc`. + /// + /// Opens (or creates) the file for writing and returns a + /// [`FileWriter`](super::stream::FileWriter) that accepts pushed + /// chunks. The caller drives writes at their own pace and calls + /// [`FileWriter::finish`](super::stream::FileWriter::finish) to + /// flush and close. + /// + /// The returned writer is `'static` — multiple writers built from + /// clones of the same `Connection` pipeline their WRITEs over a + /// single SMB session without external locking. + pub async fn create_file_writer( + self: &Arc, + conn: Connection, + path: &str, + ) -> Result { + super::stream::open_file_writer(Arc::clone(self), conn, path).await + } + + /// Open a push-based pipelined file writer with **exclusive-create** + /// semantics. Same shape as [`Tree::create_file_writer`], but the CREATE + /// uses `FileCreate` disposition: if the file already exists the open + /// fails with [`crate::ErrorKind::AlreadyExists`] + /// instead of truncating it. + /// + /// Use this when the consumer needs a race-free "create only if absent" + /// write — for example, a file manager's "New File" action where + /// silently clobbering an existing file is unsafe. + /// + /// The returned writer is `'static` and behaves identically to + /// `create_file_writer` from there on; chunks are pipelined over the + /// shared SMB session. + pub async fn create_file_writer_exclusive( + self: &Arc, + conn: Connection, + path: &str, + ) -> Result { + super::stream::open_file_writer_exclusive(Arc::clone(self), conn, path).await + } + + /// Create a directory. + /// + /// Opens the path with `FileCreate` disposition and `FILE_DIRECTORY_FILE` + /// option, then immediately closes the handle. + pub async fn create_directory(&self, conn: &mut Connection, path: &str) -> Result<()> { + let normalized = self.format_path(path); + debug!("tree: create_directory path={}", normalized); + + let req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::FILE_READ_ATTRIBUTES | FileAccessMask::SYNCHRONIZE, + ), + file_attributes: FILE_ATTRIBUTE_DIRECTORY, + share_access: ShareAccess( + ShareAccess::FILE_SHARE_READ + | ShareAccess::FILE_SHARE_WRITE + | ShareAccess::FILE_SHARE_DELETE, + ), + create_disposition: CreateDisposition::FileCreate, + create_options: FILE_DIRECTORY_FILE, + name: normalized.clone(), + create_contexts: vec![], + }; + + let frame = conn + .execute(Command::Create, &req, Some(self.tree_id)) + .await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Create, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let create_resp = CreateResponse::unpack(&mut cursor)?; + let file_id = create_resp.file_id; + + // Close the handle immediately. + self.close_handle(conn, file_id).await?; + info!("tree: created directory={}", normalized); + Ok(()) + } + + /// Delete a directory using a compound request (1 round-trip). + /// + /// Sends CREATE (with `DELETE_ON_CLOSE`) + CLOSE as a single compound + /// message. The directory must be empty. + pub async fn delete_directory(&self, conn: &mut Connection, path: &str) -> Result<()> { + self.delete_compound(conn, path, FILE_DIRECTORY_FILE, "directory") + .await + } + + // ── Private helpers ────────────────────────────────────────────── + + /// Compound CREATE (DELETE_ON_CLOSE) + CLOSE in a single round-trip. + /// + /// `type_option` selects file vs directory (`FILE_NON_DIRECTORY_FILE` + /// or `FILE_DIRECTORY_FILE`). `kind` is used only for log messages. + async fn delete_compound( + &self, + conn: &mut Connection, + path: &str, + type_option: u32, + kind: &str, + ) -> Result<()> { + let normalized = self.format_path(path); + debug!("tree: delete_{} (compound) path={}", kind, normalized); + + let create_req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::DELETE | FileAccessMask::FILE_READ_ATTRIBUTES, + ), + file_attributes: 0, + share_access: ShareAccess( + ShareAccess::FILE_SHARE_READ + | ShareAccess::FILE_SHARE_WRITE + | ShareAccess::FILE_SHARE_DELETE, + ), + create_disposition: CreateDisposition::FileOpen, + create_options: FILE_DELETE_ON_CLOSE | type_option, + name: normalized.clone(), + create_contexts: vec![], + }; + + let close_req = CloseRequest { + flags: 0, + file_id: FileId::SENTINEL, + }; + + let ops = [ + CompoundOp { + command: Command::Create, + body: &create_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + CompoundOp { + command: Command::Close, + body: &close_req, + tree_id: Some(self.tree_id), + credit_charge: CreditCharge(1), + }, + ]; + + let responses = all_or_first_err(conn.execute_compound(&ops).await?)?; + + let create_header = &responses[0].header; + let create_body = &responses[0].body; + let close_header = &responses[1].header; + + // If CREATE failed, all ops in the compound fail (cascaded). No handle to clean up. + if create_header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: create_header.status, + command: Command::Create, + }); + } + + // CREATE succeeded. If CLOSE failed, issue a standalone CLOSE + // to avoid leaking the handle (and to ensure deletion happens). + if close_header.status != NtStatus::SUCCESS { + let mut cursor = ReadCursor::new(create_body); + let create_resp = CreateResponse::unpack(&mut cursor)?; + warn!( + "tree: compound CLOSE failed ({:?}), issuing standalone CLOSE", + close_header.status + ); + let _ = self.close_handle(conn, create_resp.file_id).await; + return Err(Error::Protocol { + status: close_header.status, + command: Command::Close, + }); + } + + info!("tree: deleted {}={}", kind, normalized); + Ok(()) + } + + /// Open a directory handle. + async fn open_directory(&self, conn: &mut Connection, path: &str) -> Result { + let req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::FILE_READ_DATA + | FileAccessMask::FILE_READ_ATTRIBUTES + | FileAccessMask::SYNCHRONIZE, + ), + file_attributes: 0, + share_access: ShareAccess( + ShareAccess::FILE_SHARE_READ + | ShareAccess::FILE_SHARE_WRITE + | ShareAccess::FILE_SHARE_DELETE, + ), + create_disposition: CreateDisposition::FileOpen, + create_options: FILE_DIRECTORY_FILE, + name: path.to_string(), + create_contexts: vec![], + }; + + let frame = conn + .execute(Command::Create, &req, Some(self.tree_id)) + .await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Create, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = CreateResponse::unpack(&mut cursor)?; + Ok(resp.file_id) + } + + /// Open a file handle for reading and return the file ID and size. + /// + /// Sends a single CREATE with read access, `FileOpen` disposition (fail + /// if absent), and the standard share mask. Returns the server's + /// [`FileId`] plus the file's size in bytes (end-of-file offset) so + /// callers can size their read loop. + /// + /// Most callers want [`read_file_compound`](Self::read_file_compound), + /// [`read_file_pipelined`](Self::read_file_pipelined), or + /// [`download`](Self::download), which all open the file, read it, and + /// close it in one call. Use `open_file` directly when you want to build + /// a custom read loop — for example, constructing a [`FileDownload`] + /// with a non-default `chunk_size` via + /// [`FileDownload::new`](crate::client::stream::FileDownload::new). + /// + /// The caller is responsible for closing the handle when done (either + /// by handing it to a [`FileDownload`], which closes on completion or + /// drop, or by calling the internal close path). Leaking the handle + /// wastes server resources. + pub async fn open_file(&self, conn: &mut Connection, path: &str) -> Result<(FileId, u64)> { + let req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::FILE_READ_DATA + | FileAccessMask::FILE_READ_ATTRIBUTES + | FileAccessMask::SYNCHRONIZE, + ), + file_attributes: 0, + share_access: ShareAccess( + ShareAccess::FILE_SHARE_READ + | ShareAccess::FILE_SHARE_WRITE + | ShareAccess::FILE_SHARE_DELETE, + ), + create_disposition: CreateDisposition::FileOpen, + create_options: 0, + name: path.to_string(), + create_contexts: vec![], + }; + + let frame = conn + .execute(Command::Create, &req, Some(self.tree_id)) + .await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Create, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = CreateResponse::unpack(&mut cursor)?; + Ok((resp.file_id, resp.end_of_file)) + } + + /// Open (or create) a file for writing, returning the file handle. + /// + /// Uses `FileOverwriteIf` disposition (create if absent, overwrite if present) + /// and requests write access. Used by [`FileUpload`](crate::client::stream::FileUpload). + pub(crate) async fn open_file_for_write( + &self, + conn: &mut Connection, + path: &str, + ) -> Result { + self.open_file_for_write_with_disposition(conn, path, CreateDisposition::FileOverwriteIf) + .await + } + + /// Open a file for writing using a specific `CreateDisposition`. + /// + /// Shared body of [`open_file_for_write`](Self::open_file_for_write) + /// (`FileOverwriteIf`) and + /// [`open_file_for_exclusive_create`](Self::open_file_for_exclusive_create) + /// (`FileCreate`). Held private so the disposition stays a strict + /// allow-list inside the crate. + async fn open_file_for_write_with_disposition( + &self, + conn: &mut Connection, + path: &str, + create_disposition: CreateDisposition, + ) -> Result { + let req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::FILE_WRITE_DATA + | FileAccessMask::FILE_WRITE_ATTRIBUTES + | FileAccessMask::SYNCHRONIZE, + ), + file_attributes: 0x80, // FILE_ATTRIBUTE_NORMAL + share_access: ShareAccess(0), + create_disposition, + create_options: FILE_NON_DIRECTORY_FILE, + name: path.to_string(), + create_contexts: vec![], + }; + + let frame = conn + .execute(Command::Create, &req, Some(self.tree_id)) + .await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Create, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = CreateResponse::unpack(&mut cursor)?; + Ok(resp.file_id) + } + + /// Open a file for writing with `FileCreate` disposition (exclusive create). + /// + /// Returns the file handle on success. When the file already exists the + /// server returns `STATUS_OBJECT_NAME_COLLISION`, which surfaces as + /// [`crate::ErrorKind::AlreadyExists`]. Used by + /// [`Tree::create_file_writer_exclusive`](Self::create_file_writer_exclusive) + /// so consumers can implement a race-free "create only if absent" file + /// write. + /// + /// Pairs with [`open_file_for_write`](Self::open_file_for_write), which + /// uses `FileOverwriteIf` (truncating). + pub(crate) async fn open_file_for_exclusive_create( + &self, + conn: &mut Connection, + path: &str, + ) -> Result { + self.open_file_for_write_with_disposition(conn, path, CreateDisposition::FileCreate) + .await + } + + /// Loop QUERY_DIRECTORY until STATUS_NO_MORE_FILES. + async fn query_directory_loop( + &self, + conn: &mut Connection, + file_id: FileId, + ) -> Result> { + // Cap output buffer to 65536 so that CreditCharge=1 is valid. + // The spec requires CreditCharge = 1 + (OutputBufferLength - 1) / 65536 + // for multi-credit dialects. Using 65536 keeps CreditCharge=1 which + // matches what send_request sets, while still being plenty for dir entries. + let max_output = conn + .params() + .map(|p| p.max_transact_size.min(65536)) + .unwrap_or(65536); + + let mut all_entries = Vec::new(); + let mut first = true; + + loop { + let req = QueryDirectoryRequest { + file_information_class: FileInformationClass::FileBothDirectoryInformation, + flags: QueryDirectoryFlags(if first { + QueryDirectoryFlags::RESTART_SCANS + } else { + 0 + }), + file_index: 0, + file_id, + output_buffer_length: max_output, + file_name: "*".to_string(), + }; + first = false; + + let frame = conn + .execute(Command::QueryDirectory, &req, Some(self.tree_id)) + .await?; + + if frame.header.status == NtStatus::NO_MORE_FILES { + break; + } + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::QueryDirectory, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = QueryDirectoryResponse::unpack(&mut cursor)?; + + // Parse FileBothDirectoryInformation entries from the output buffer. + let entries = parse_file_both_directory_info(&resp.output_buffer)?; + for e in &entries { + trace!( + "tree: dir_entry name={}, size={}, is_dir={}", + e.name, + e.size, + e.is_directory + ); + } + all_entries.extend(entries); + } + + Ok(all_entries) + } + + /// Read file data in chunks. + #[allow(dead_code)] // Will be used by read_file_pipelined for large-file chunked reads. + async fn read_loop( + &self, + conn: &mut Connection, + file_id: FileId, + file_size: u64, + ) -> Result> { + let max_read = conn.params().map(|p| p.max_read_size).unwrap_or(65536); + + let mut data = Vec::with_capacity(file_size as usize); + let mut offset = 0u64; + + loop { + let remaining = file_size.saturating_sub(offset); + if remaining == 0 { + break; + } + + let chunk_size = remaining.min(max_read as u64) as u32; + + let req = ReadRequest { + padding: 0x50, + flags: 0, + length: chunk_size, + offset, + file_id, + minimum_count: 0, + channel: SMB2_CHANNEL_NONE, + remaining_bytes: 0, + read_channel_info: vec![], + }; + + let frame = conn + .execute(Command::Read, &req, Some(self.tree_id)) + .await?; + + // STATUS_END_OF_FILE means we read past the end. + if frame.header.status == NtStatus::END_OF_FILE { + break; + } + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Read, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = ReadResponse::unpack(&mut cursor)?; + + if resp.data.is_empty() { + break; + } + + offset += resp.data.len() as u64; + data.extend_from_slice(&resp.data); + } + + Ok(data) + } + + /// Pipelined read using a sliding window. + /// + /// Instead of batch send/receive phases, each received response + /// immediately triggers the next send. The pipe stays full at all times, + /// eliminating idle gaps between batches. + async fn read_pipelined_loop( + &self, + conn: &mut Connection, + file_id: FileId, + file_size: u64, + chunk_size: u32, + credit_charge: u16, + total_chunks: usize, + ) -> Result> { + use futures_util::stream::{FuturesUnordered, StreamExt}; + + let mut data = vec![0u8; file_size as usize]; + let mut chunks_sent = 0usize; + let mut chunks_received = 0usize; + + let max_from_credits = conn.credits() as usize / credit_charge.max(1) as usize; + let initial_window = total_chunks.min(max_from_credits).min(MAX_PIPELINE_WINDOW); + + if initial_window == 0 { + return Err(Error::invalid_data( + "no credits available for pipelined read", + )); + } + + debug!( + "tree: pipeline read sliding window: initial_window={}, total_chunks={}, credits={}", + initial_window, + total_chunks, + conn.credits() + ); + + // Spawn each chunk read as an independent `execute_with_credits` + // future. `FuturesUnordered` polls them concurrently — the actor- + // based receiver task routes responses by `MessageId`, so all + // chunks compete fairly even when they arrive out of order. + let mut in_flight = FuturesUnordered::new(); + let build_req = |chunk_index: usize| -> ReadRequest { + let offset = chunk_index as u64 * chunk_size as u64; + let this_chunk = if chunk_index == total_chunks - 1 { + (file_size - offset) as u32 + } else { + chunk_size + }; + ReadRequest { + padding: 0x50, + flags: 0, + length: this_chunk, + offset, + file_id, + minimum_count: 0, + channel: SMB2_CHANNEL_NONE, + remaining_bytes: 0, + read_channel_info: vec![], + } + }; + let launch_chunk = |conn: &Connection, chunk_index: usize, tree_id: TreeId| -> _ { + let c = conn.clone(); + let req = build_req(chunk_index); + async move { + let frame = c + .execute_with_credits( + Command::Read, + &req, + Some(tree_id), + CreditCharge(credit_charge), + ) + .await; + (chunk_index, frame) + } + }; + + for _ in 0..initial_window { + in_flight.push(launch_chunk(conn, chunks_sent, self.tree_id)); + chunks_sent += 1; + } + + while chunks_received < total_chunks { + let Some((chunk_index, frame_result)) = in_flight.next().await else { + break; + }; + chunks_received += 1; + let frame = frame_result?; + + if frame.header.status == NtStatus::END_OF_FILE { + // File is shorter than expected. Keep draining but don't + // launch more. + continue; + } + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Read, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = ReadResponse::unpack(&mut cursor)?; + + if !resp.data.is_empty() { + let dest_offset = chunk_index as u64 * chunk_size as u64; + let dest_end = (dest_offset as usize + resp.data.len()).min(data.len()); + let src_len = dest_end - dest_offset as usize; + data[dest_offset as usize..dest_end].copy_from_slice(&resp.data[..src_len]); + } + + if chunks_sent < total_chunks { + let credits_available = conn.credits() as usize / credit_charge.max(1) as usize; + if credits_available > 0 { + in_flight.push(launch_chunk(conn, chunks_sent, self.tree_id)); + chunks_sent += 1; + } + } + } + + Ok(data) + } + + /// Pipelined read with progress callback and cancellation. + /// + /// Same sliding window as `read_pipelined_loop`, but calls `on_progress` + /// after each chunk. Returns `Error::Cancelled` if the callback breaks. + async fn read_pipelined_loop_with_progress( + &self, + conn: &mut Connection, + file_id: FileId, + file_size: u64, + chunk_size: u32, + credit_charge: u16, + total_chunks: usize, + on_progress: &mut F, + ) -> Result> + where + F: FnMut(Progress) -> ControlFlow<()>, + { + use futures_util::stream::{FuturesUnordered, StreamExt}; + + let mut data = vec![0u8; file_size as usize]; + let mut chunks_sent = 0usize; + let mut chunks_received = 0usize; + let mut bytes_received = 0u64; + + let max_from_credits = conn.credits() as usize / credit_charge.max(1) as usize; + let initial_window = total_chunks.min(max_from_credits).min(MAX_PIPELINE_WINDOW); + + if initial_window == 0 { + return Err(Error::invalid_data( + "no credits available for pipelined read", + )); + } + + let mut in_flight = FuturesUnordered::new(); + let build_req = |chunk_index: usize| -> ReadRequest { + let offset = chunk_index as u64 * chunk_size as u64; + let this_chunk = if chunk_index == total_chunks - 1 { + (file_size - offset) as u32 + } else { + chunk_size + }; + ReadRequest { + padding: 0x50, + flags: 0, + length: this_chunk, + offset, + file_id, + minimum_count: 0, + channel: SMB2_CHANNEL_NONE, + remaining_bytes: 0, + read_channel_info: vec![], + } + }; + let launch_chunk = |conn: &Connection, chunk_index: usize, tree_id: TreeId| { + let c = conn.clone(); + let req = build_req(chunk_index); + async move { + let frame = c + .execute_with_credits( + Command::Read, + &req, + Some(tree_id), + CreditCharge(credit_charge), + ) + .await; + (chunk_index, frame) + } + }; + + for _ in 0..initial_window { + in_flight.push(launch_chunk(conn, chunks_sent, self.tree_id)); + chunks_sent += 1; + } + + while chunks_received < total_chunks { + let Some((chunk_index, frame_result)) = in_flight.next().await else { + break; + }; + chunks_received += 1; + let frame = frame_result?; + + if frame.header.status == NtStatus::END_OF_FILE { + continue; + } + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Read, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = ReadResponse::unpack(&mut cursor)?; + + if !resp.data.is_empty() { + let dest_offset = chunk_index as u64 * chunk_size as u64; + let dest_end = (dest_offset as usize + resp.data.len()).min(data.len()); + let src_len = dest_end - dest_offset as usize; + data[dest_offset as usize..dest_end].copy_from_slice(&resp.data[..src_len]); + bytes_received += src_len as u64; + } + + let progress = Progress { + bytes_transferred: bytes_received, + total_bytes: Some(file_size), + }; + if let ControlFlow::Break(()) = on_progress(progress) { + return Err(Error::Cancelled); + } + + if chunks_sent < total_chunks { + let credits_available = conn.credits() as usize / credit_charge.max(1) as usize; + if credits_available > 0 { + in_flight.push(launch_chunk(conn, chunks_sent, self.tree_id)); + chunks_sent += 1; + } + } + } + + Ok(data) + } + + /// Pipelined write using a sliding window. + /// + /// Instead of batch send/receive phases, each received response + /// immediately triggers the next send. The pipe stays full at all times. + async fn write_pipelined_loop( + &self, + conn: &mut Connection, + file_id: FileId, + data: &[u8], + chunk_size: u32, + credit_charge: u16, + total_chunks: usize, + ) -> Result { + use futures_util::stream::{FuturesUnordered, StreamExt}; + + let mut chunks_sent = 0usize; + let mut chunks_received = 0usize; + let mut total_written = 0u64; + + let max_from_credits = conn.credits() as usize / credit_charge.max(1) as usize; + let initial_window = total_chunks.min(max_from_credits).min(MAX_PIPELINE_WINDOW); + + if initial_window == 0 { + return Err(Error::invalid_data( + "no credits available for pipelined write", + )); + } + + debug!( + "tree: pipeline write sliding window: initial_window={}, total_chunks={}, credits={}", + initial_window, + total_chunks, + conn.credits() + ); + + let mut in_flight = FuturesUnordered::new(); + let build_req = |chunk_index: usize| -> WriteRequest { + let offset = chunk_index * chunk_size as usize; + let end = (offset + chunk_size as usize).min(data.len()); + let chunk = &data[offset..end]; + WriteRequest { + data_offset: 0x70, + offset: offset as u64, + file_id, + channel: 0, + remaining_bytes: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + flags: 0, + data: chunk.to_vec(), + } + }; + let launch_chunk = |conn: &Connection, chunk_index: usize, tree_id: TreeId| { + let c = conn.clone(); + let req = build_req(chunk_index); + async move { + let frame = c + .execute_with_credits( + Command::Write, + &req, + Some(tree_id), + CreditCharge(credit_charge), + ) + .await; + (chunk_index, frame) + } + }; + + for _ in 0..initial_window { + in_flight.push(launch_chunk(conn, chunks_sent, self.tree_id)); + chunks_sent += 1; + } + + while chunks_received < total_chunks { + let Some((_chunk_index, frame_result)) = in_flight.next().await else { + break; + }; + chunks_received += 1; + let frame = frame_result?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Write, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = WriteResponse::unpack(&mut cursor)?; + total_written += resp.count as u64; + + if chunks_sent < total_chunks { + let credits_available = conn.credits() as usize / credit_charge.max(1) as usize; + if credits_available > 0 { + in_flight.push(launch_chunk(conn, chunks_sent, self.tree_id)); + chunks_sent += 1; + } + } + } + + Ok(total_written) + } + + /// Inner loop for streamed writes with a sliding window. + /// + /// Pulls chunks from the callback, splits them if larger than + /// `max_write`, and sends WRITE requests. Uses a sliding window + /// of in-flight requests for throughput. + async fn write_streamed_loop( + &self, + conn: &mut Connection, + file_id: FileId, + next_chunk: &mut F, + max_write: u32, + ) -> Result + where + F: FnMut() -> Option, std::io::Error>>, + { + use futures_util::stream::{FuturesUnordered, StreamExt}; + + type BoxedExecute = std::pin::Pin< + Box> + Send>, + >; + + let mut offset = 0u64; + let mut in_flight = 0usize; + let mut total_written = 0u64; + let mut done = false; // callback exhausted or errored + let mut callback_err: Option = None; + let mut in_flight_futs: FuturesUnordered = FuturesUnordered::new(); + + // Buffer for leftover data when a callback chunk is larger than max_write. + let mut pending_data: Vec = Vec::new(); + let mut pending_offset = 0usize; + + // Chunk that was pulled but couldn't be sent due to credit exhaustion. + // Re-checked before pulling the next chunk from the callback. + let mut stashed_chunk: Option> = None; + + // Helper: try to get the next wire-level chunk (up to max_write bytes). + // Returns Some(data) or None if no more data available. + let next_wire_chunk = |pending_data: &mut Vec, + pending_offset: &mut usize, + done: &mut bool, + callback_err: &mut Option, + next_chunk: &mut F| + -> Option> { + // First, drain any pending leftover from a previous large chunk. + if *pending_offset < pending_data.len() { + let end = (*pending_offset + max_write as usize).min(pending_data.len()); + let slice = pending_data[*pending_offset..end].to_vec(); + *pending_offset = end; + if *pending_offset >= pending_data.len() { + pending_data.clear(); + *pending_offset = 0; + } + return Some(slice); + } + + if *done { + return None; + } + + // Pull from the callback. + match next_chunk() { + None => { + *done = true; + None + } + Some(Err(e)) => { + *done = true; + *callback_err = Some(e); + None + } + Some(Ok(data)) => { + if data.is_empty() { + // Treat empty chunk as end of stream. + *done = true; + return None; + } + if data.len() <= max_write as usize { + Some(data) + } else { + // Split: return first max_write bytes, buffer the rest. + let first = data[..max_write as usize].to_vec(); + *pending_data = data; + *pending_offset = max_write as usize; + Some(first) + } + } + } + }; + + // Initial fill: send up to window_size writes. + loop { + let credit_charge_per = max_write.div_ceil(65536).max(1) as u16; + let max_from_credits = conn.credits() as usize / credit_charge_per.max(1) as usize; + let can_send = max_from_credits.min(MAX_PIPELINE_WINDOW.saturating_sub(in_flight)); + + if can_send == 0 { + break; + } + + let chunk = next_wire_chunk( + &mut pending_data, + &mut pending_offset, + &mut done, + &mut callback_err, + next_chunk, + ); + + match chunk { + None => break, + Some(chunk_data) => { + let data_len = chunk_data.len() as u64; + let cc = data_len.div_ceil(65536).max(1) as u16; + let c = conn.clone(); + let tree_id = self.tree_id; + let req = WriteRequest { + data_offset: 0x70, + offset, + file_id, + channel: 0, + remaining_bytes: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + flags: 0, + data: chunk_data, + }; + in_flight_futs.push(Box::pin(async move { + c.execute_with_credits( + Command::Write, + &req, + Some(tree_id), + CreditCharge(cc), + ) + .await + })); + offset += data_len; + in_flight += 1; + } + } + } + + // Sliding loop: receive one response, send next chunk (if any). + while in_flight > 0 { + let frame_result = match in_flight_futs.next().await { + Some(r) => r, + None => break, + }; + in_flight -= 1; + let frame = frame_result?; + + if frame.header.status != NtStatus::SUCCESS { + // Drain remaining in-flight responses (best-effort). + while in_flight_futs.next().await.is_some() {} + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Write, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = WriteResponse::unpack(&mut cursor)?; + total_written += resp.count as u64; + + if callback_err.is_none() && stashed_chunk.is_none() { + let chunk = next_wire_chunk( + &mut pending_data, + &mut pending_offset, + &mut done, + &mut callback_err, + next_chunk, + ); + + if let Some(chunk_data) = chunk { + let data_len = chunk_data.len() as u64; + let cc = data_len.div_ceil(65536).max(1) as u16; + let credits_available = conn.credits() as usize / cc.max(1) as usize; + + if credits_available > 0 { + let c = conn.clone(); + let tree_id = self.tree_id; + let req = WriteRequest { + data_offset: 0x70, + offset, + file_id, + channel: 0, + remaining_bytes: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + flags: 0, + data: chunk_data, + }; + in_flight_futs.push(Box::pin(async move { + c.execute_with_credits( + Command::Write, + &req, + Some(tree_id), + CreditCharge(cc), + ) + .await + })); + offset += data_len; + in_flight += 1; + } else { + stashed_chunk = Some(chunk_data); + } + } + } else if let Some(chunk_data) = stashed_chunk.take() { + let data_len = chunk_data.len() as u64; + let cc = data_len.div_ceil(65536).max(1) as u16; + let credits_available = conn.credits() as usize / cc.max(1) as usize; + + if credits_available > 0 { + let c = conn.clone(); + let tree_id = self.tree_id; + let req = WriteRequest { + data_offset: 0x70, + offset, + file_id, + channel: 0, + remaining_bytes: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + flags: 0, + data: chunk_data, + }; + in_flight_futs.push(Box::pin(async move { + c.execute_with_credits( + Command::Write, + &req, + Some(tree_id), + CreditCharge(cc), + ) + .await + })); + offset += data_len; + in_flight += 1; + } else { + stashed_chunk = Some(chunk_data); + } + } + } + + // If the callback returned an error, propagate it now + // (after all in-flight responses have been drained). + if let Some(io_err) = callback_err { + return Err(Error::Io(io_err)); + } + + Ok(total_written) + } + + /// Flush a file handle to ensure data is persisted on the server. + /// + /// Sends an SMB2 FLUSH request and waits for the server to confirm + /// that all cached data has been written to persistent storage. + pub(crate) async fn flush_handle(&self, conn: &mut Connection, file_id: FileId) -> Result<()> { + debug!("tree: flushing file handle"); + let req = FlushRequest { file_id }; + + let frame = conn + .execute(Command::Flush, &req, Some(self.tree_id)) + .await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Flush, + }); + } + + Ok(()) + } + + /// Close a file handle. + pub(crate) async fn close_handle(&self, conn: &mut Connection, file_id: FileId) -> Result<()> { + let req = CloseRequest { flags: 0, file_id }; + + let frame = conn + .execute(Command::Close, &req, Some(self.tree_id)) + .await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Close, + }); + } + + Ok(()) + } + + /// Write data to a file in chunks. + /// + /// Kept for potential future use by callers that need per-chunk control + /// without pipelining or compounding. + #[allow(dead_code)] + async fn write_loop(&self, conn: &mut Connection, file_id: FileId, data: &[u8]) -> Result { + let max_write = conn.params().map(|p| p.max_write_size).unwrap_or(65536); + + let mut total_written = 0u64; + let mut offset = 0usize; + + while offset < data.len() { + let remaining = data.len() - offset; + let chunk_size = remaining.min(max_write as usize); + let chunk = &data[offset..offset + chunk_size]; + + // DataOffset: header (64) + fixed write body (48) = 112 = 0x70 + let req = WriteRequest { + data_offset: 0x70, + offset: offset as u64, + file_id, + channel: 0, + remaining_bytes: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + flags: 0, + data: chunk.to_vec(), + }; + + let frame = conn + .execute(Command::Write, &req, Some(self.tree_id)) + .await?; + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::Write, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = WriteResponse::unpack(&mut cursor)?; + + total_written += resp.count as u64; + offset += chunk_size; + } + + Ok(total_written) + } +} + +/// Build a FileRenameInformation buffer (MS-FSCC 2.4.34.2). +fn build_rename_info_buffer(new_name: &str) -> Vec { + let name_u16: Vec = new_name.encode_utf16().collect(); + let name_byte_len = name_u16.len() * 2; + + let mut buf = Vec::with_capacity(20 + name_byte_len); + buf.push(0); // ReplaceIfExists = false + buf.extend_from_slice(&[0u8; 7]); // Reserved + buf.extend_from_slice(&0u64.to_le_bytes()); // RootDirectory + buf.extend_from_slice(&(name_byte_len as u32).to_le_bytes()); // FileNameLength + for &u in &name_u16 { + buf.extend_from_slice(&u.to_le_bytes()); + } + buf +} + +/// Normalize a file path: convert `/` to `\` and strip leading `\`. +fn normalize_path(path: &str) -> String { + let p = path.replace('/', "\\"); + p.trim_start_matches('\\').to_string() +} + +/// Parse `FileBothDirectoryInformation` entries from raw bytes. +/// +/// Each entry has: +/// - NextEntryOffset (4 bytes) +/// - FileIndex (4 bytes) +/// - CreationTime (8 bytes) +/// - LastAccessTime (8 bytes) +/// - LastWriteTime (8 bytes) +/// - ChangeTime (8 bytes) +/// - EndOfFile (8 bytes) +/// - AllocationSize (8 bytes) +/// - FileAttributes (4 bytes) +/// - FileNameLength (4 bytes) +/// - EaSize (4 bytes) +/// - ShortNameLength (1 byte) +/// - Reserved (1 byte) +/// - ShortName (24 bytes) +/// - FileName (variable, FileNameLength bytes) +fn parse_file_both_directory_info(data: &[u8]) -> Result> { + let mut entries = Vec::new(); + let mut offset = 0usize; + + loop { + if offset + 94 > data.len() { + // Not enough data for the fixed part. + break; + } + + let entry_data = &data[offset..]; + let mut cursor = ReadCursor::new(entry_data); + + let next_entry_offset = cursor.read_u32_le()? as usize; + let _file_index = cursor.read_u32_le()?; + let creation_time = FileTime::unpack(&mut cursor)?; + let _last_access_time = FileTime::unpack(&mut cursor)?; + let last_write_time = FileTime::unpack(&mut cursor)?; + let _change_time = FileTime::unpack(&mut cursor)?; + let end_of_file = cursor.read_u64_le()?; + let _allocation_size = cursor.read_u64_le()?; + let file_attributes = cursor.read_u32_le()?; + let file_name_length = cursor.read_u32_le()? as usize; + let _ea_size = cursor.read_u32_le()?; + let _short_name_length = cursor.read_u8()?; + let _reserved = cursor.read_u8()?; + // ShortName: 24 bytes (fixed, null-padded). + cursor.skip(24)?; + // FileName: FileNameLength bytes in UTF-16LE. + let name = if file_name_length > 0 { + cursor.read_utf16_le(file_name_length)? + } else { + String::new() + }; + + let is_directory = (file_attributes & FILE_ATTRIBUTE_DIRECTORY) != 0; + + entries.push(DirectoryEntry { + name, + size: end_of_file, + is_directory, + created: creation_time, + modified: last_write_time, + }); + + if next_entry_offset == 0 { + break; + } + offset += next_entry_offset; + } + + Ok(entries) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::connection::pack_message; + use crate::client::test_helpers::{ + build_close_response, build_create_error_response, build_create_response, + build_tree_connect_response, setup_connection, + }; + use crate::msg::create::{CreateAction, CreateResponse}; + use crate::msg::header::Header; + use crate::msg::query_directory::QueryDirectoryResponse; + use crate::msg::tree_connect::ShareType; + use crate::transport::MockTransport; + use crate::types::status::NtStatus; + use crate::types::{Command, TreeId}; + use std::sync::Arc; + + fn build_flush_response() -> Vec { + let mut h = Header::new_request(Command::Flush); + h.flags.set_response(); + h.credits = 32; + + let body = crate::msg::flush::FlushResponse; + pack_message(&h, &body) + } + + fn build_query_directory_response(status: NtStatus, entries_data: Vec) -> Vec { + let mut h = Header::new_request(Command::QueryDirectory); + h.flags.set_response(); + h.credits = 32; + h.status = status; + + if status == NtStatus::NO_MORE_FILES { + // Error response body for NO_MORE_FILES. + use crate::msg::header::ErrorResponse; + let body = ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + return pack_message(&h, &body); + } + + let body = QueryDirectoryResponse { + output_buffer: entries_data, + }; + + pack_message(&h, &body) + } + + fn build_read_response(status: NtStatus, data: Vec) -> Vec { + let mut h = Header::new_request(Command::Read); + h.flags.set_response(); + h.credits = 32; + h.status = status; + + if status == NtStatus::END_OF_FILE { + use crate::msg::header::ErrorResponse; + let body = ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + return pack_message(&h, &body); + } + + let body = ReadResponse { + data_offset: 0x50, + data_remaining: 0, + flags: 0, + data, + }; + + pack_message(&h, &body) + } + + /// Build a single FileBothDirectoryInformation entry. + fn build_file_both_dir_info( + name: &str, + size: u64, + is_directory: bool, + next_offset: u32, + ) -> Vec { + let name_u16: Vec = name.encode_utf16().collect(); + let name_bytes_len = name_u16.len() * 2; + + let mut buf = Vec::new(); + // NextEntryOffset (4) + buf.extend_from_slice(&next_offset.to_le_bytes()); + // FileIndex (4) + buf.extend_from_slice(&0u32.to_le_bytes()); + // CreationTime (8) + buf.extend_from_slice(&132_000_000_000_000_000u64.to_le_bytes()); + // LastAccessTime (8) + buf.extend_from_slice(&132_000_000_000_000_000u64.to_le_bytes()); + // LastWriteTime (8) + buf.extend_from_slice(&133_000_000_000_000_000u64.to_le_bytes()); + // ChangeTime (8) + buf.extend_from_slice(&133_000_000_000_000_000u64.to_le_bytes()); + // EndOfFile (8) + buf.extend_from_slice(&size.to_le_bytes()); + // AllocationSize (8) + buf.extend_from_slice(&((size + 4095) & !4095).to_le_bytes()); + // FileAttributes (4) + let attrs = if is_directory { + FILE_ATTRIBUTE_DIRECTORY + } else { + 0x00000020 // ARCHIVE + }; + buf.extend_from_slice(&attrs.to_le_bytes()); + // FileNameLength (4) + buf.extend_from_slice(&(name_bytes_len as u32).to_le_bytes()); + // EaSize (4) + buf.extend_from_slice(&0u32.to_le_bytes()); + // ShortNameLength (1) + buf.push(0); + // Reserved (1) + buf.push(0); + // ShortName (24 bytes, zero-padded) + buf.extend_from_slice(&[0u8; 24]); + // FileName (variable) + for &u in &name_u16 { + buf.extend_from_slice(&u.to_le_bytes()); + } + + buf + } + + #[tokio::test] + async fn tree_connect_stores_tree_id() { + let mock = Arc::new(MockTransport::new()); + let tree_id = TreeId(42); + mock.queue_response(build_tree_connect_response(tree_id, ShareType::Disk)); + + let mut conn = setup_connection(&mock); + let tree = Tree::connect(&mut conn, "naspi").await.unwrap(); + assert_eq!(tree.tree_id, tree_id); + assert_eq!(tree.share_name, "naspi"); + } + + #[tokio::test] + async fn tree_connect_sends_unc_path() { + let mock = Arc::new(MockTransport::new()); + mock.queue_response(build_tree_connect_response(TreeId(1), ShareType::Disk)); + + let mut conn = setup_connection(&mock); + let _tree = Tree::connect(&mut conn, "myshare").await.unwrap(); + + // Verify the sent request contains the UNC path. + let sent = mock.sent_message(0).unwrap(); + let mut cursor = ReadCursor::new(&sent); + let _header = Header::unpack(&mut cursor).unwrap(); + let req = TreeConnectRequest::unpack(&mut cursor).unwrap(); + assert_eq!(req.path, r"\\test-server\myshare"); + } + + #[tokio::test] + async fn list_directory_returns_entries() { + let mock = Arc::new(MockTransport::new()); + let tree_id = TreeId(10); + let file_id = FileId { + persistent: 0x1111, + volatile: 0x2222, + }; + + // Build two directory entries. + let entry1 = build_file_both_dir_info("file1.txt", 1024, false, 0); + let total_entry_len = entry1.len(); + let entry1_with_next = + build_file_both_dir_info("file1.txt", 1024, false, total_entry_len as u32); + let entry2 = build_file_both_dir_info("subdir", 0, true, 0); + + let mut entries_data = entry1_with_next; + entries_data.extend_from_slice(&entry2); + + // Queue: CREATE response, QUERY_DIRECTORY response (with data), QUERY_DIRECTORY response (no more), CLOSE response. + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_query_directory_response( + NtStatus::SUCCESS, + entries_data, + )); + mock.queue_response(build_query_directory_response( + NtStatus::NO_MORE_FILES, + vec![], + )); + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id, + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let entries = tree.list_directory(&mut conn, "somedir").await.unwrap(); + assert_eq!(entries.len(), 2); + assert_eq!(entries[0].name, "file1.txt"); + assert_eq!(entries[0].size, 1024); + assert!(!entries[0].is_directory); + assert_eq!(entries[1].name, "subdir"); + assert!(entries[1].is_directory); + } + + #[tokio::test] + async fn read_file_returns_data() { + let mock = Arc::new(MockTransport::new()); + let tree_id = TreeId(20); + let file_id = FileId { + persistent: 0x3333, + volatile: 0x4444, + }; + let file_data = b"Hello, SMB world!"; + + // Queue a single compound response frame: CREATE + READ + CLOSE. + let create_resp = build_create_response(file_id, file_data.len() as u64); + let read_resp = build_read_response(NtStatus::SUCCESS, file_data.to_vec()); + let close_resp = build_close_response(); + let frame = build_compound_response_frame(&[create_resp, read_resp, close_resp]); + mock.queue_response(frame); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id, + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let data = tree.read_file(&mut conn, "test.txt").await.unwrap(); + assert_eq!(data, file_data); + } + + #[tokio::test] + async fn normalize_path_converts_slashes() { + assert_eq!(normalize_path("foo/bar/baz"), "foo\\bar\\baz"); + assert_eq!(normalize_path("/leading/slash"), "leading\\slash"); + assert_eq!(normalize_path("\\leading\\backslash"), "leading\\backslash"); + assert_eq!(normalize_path("no_change"), "no_change"); + } + + #[tokio::test] + async fn format_path_prepends_dfs_prefix() { + let tree = Tree { + tree_id: TreeId(1), + share_name: "dfs".to_string(), + server: "server1".to_string(), + is_dfs: true, + encrypt_data: false, + }; + assert_eq!( + tree.format_path("data/hello.txt"), + "server1\\dfs\\data\\hello.txt" + ); + assert_eq!(tree.format_path(""), "server1\\dfs"); + assert_eq!( + tree.format_path("nested/path"), + "server1\\dfs\\nested\\path" + ); + } + + #[tokio::test] + async fn format_path_strips_port_from_dfs_prefix() { + let tree = Tree { + tree_id: TreeId(1), + share_name: "dfs".to_string(), + server: "server1:10456".to_string(), + is_dfs: true, + encrypt_data: false, + }; + assert_eq!( + tree.format_path("data/hello.txt"), + "server1\\dfs\\data\\hello.txt" + ); + } + + #[tokio::test] + async fn format_path_no_prefix_when_not_dfs() { + let tree = Tree { + tree_id: TreeId(1), + share_name: "public".to_string(), + server: "server1".to_string(), + is_dfs: false, + encrypt_data: false, + }; + assert_eq!(tree.format_path("data/hello.txt"), "data\\hello.txt"); + assert_eq!(tree.format_path(""), ""); + } + + #[tokio::test] + async fn parse_file_both_dir_info_single_entry() { + let data = build_file_both_dir_info("test.txt", 42, false, 0); + let entries = parse_file_both_directory_info(&data).unwrap(); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].name, "test.txt"); + assert_eq!(entries[0].size, 42); + assert!(!entries[0].is_directory); + } + + #[tokio::test] + async fn tree_disconnect_sends_request() { + let mock = Arc::new(MockTransport::new()); + + // Queue a tree disconnect response. + let mut h = Header::new_request(Command::TreeDisconnect); + h.flags.set_response(); + h.credits = 32; + use crate::msg::tree_disconnect::TreeDisconnectResponse; + mock.queue_response(pack_message(&h, &TreeDisconnectResponse)); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(99), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + tree.disconnect(&mut conn).await.unwrap(); + assert_eq!(mock.sent_count(), 1); + } + + // ── Delete file tests ──────────────────────────────────────────── + + fn build_write_response(count: u32) -> Vec { + use crate::msg::write::WriteResponse; + let mut h = Header::new_request(Command::Write); + h.flags.set_response(); + h.credits = 32; + + let body = WriteResponse { + count, + remaining: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + }; + + pack_message(&h, &body) + } + + fn build_query_info_response(output_buffer: Vec) -> Vec { + build_query_info_response_with_status(NtStatus::SUCCESS, output_buffer) + } + + fn build_query_info_response_with_status(status: NtStatus, output_buffer: Vec) -> Vec { + use crate::msg::query_info::QueryInfoResponse; + let mut h = Header::new_request(Command::QueryInfo); + h.flags.set_response(); + h.credits = 32; + h.status = status; + + let body = QueryInfoResponse { output_buffer }; + pack_message(&h, &body) + } + + fn build_set_info_response() -> Vec { + use crate::msg::set_info::SetInfoResponse; + let mut h = Header::new_request(Command::SetInfo); + h.flags.set_response(); + h.credits = 32; + + let body = SetInfoResponse; + pack_message(&h, &body) + } + + /// Build a FileBasicInformation buffer (40 bytes). + fn build_file_basic_info( + creation_time: u64, + last_access_time: u64, + last_write_time: u64, + change_time: u64, + file_attributes: u32, + ) -> Vec { + let mut buf = Vec::new(); + buf.extend_from_slice(&creation_time.to_le_bytes()); + buf.extend_from_slice(&last_access_time.to_le_bytes()); + buf.extend_from_slice(&last_write_time.to_le_bytes()); + buf.extend_from_slice(&change_time.to_le_bytes()); + buf.extend_from_slice(&file_attributes.to_le_bytes()); + buf.extend_from_slice(&0u32.to_le_bytes()); // Reserved/padding + buf + } + + /// Build a FileStandardInformation buffer (24 bytes). + fn build_file_standard_info( + allocation_size: u64, + end_of_file: u64, + number_of_links: u32, + delete_pending: bool, + directory: bool, + ) -> Vec { + let mut buf = Vec::new(); + buf.extend_from_slice(&allocation_size.to_le_bytes()); + buf.extend_from_slice(&end_of_file.to_le_bytes()); + buf.extend_from_slice(&number_of_links.to_le_bytes()); + buf.push(if delete_pending { 1 } else { 0 }); + buf.push(if directory { 1 } else { 0 }); + buf.extend_from_slice(&0u16.to_le_bytes()); // Reserved + buf + } + + #[tokio::test] + async fn delete_file_sends_compound_create_and_close() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0xAA, + volatile: 0xBB, + }; + + // DELETE = compound CREATE(DELETE_ON_CLOSE) + CLOSE + let create_resp = build_create_response(file_id, 0); + let close_resp = build_close_response(); + let frame = build_compound_response_frame(&[create_resp, close_resp]); + mock.queue_response(frame); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + tree.delete_file(&mut conn, "remove.txt").await.unwrap(); + + // One compound frame sent. + assert_eq!(mock.sent_count(), 1); + + // Verify the CREATE request has DELETE access and DELETE_ON_CLOSE + let sent = mock.sent_message(0).unwrap(); + let mut cursor = ReadCursor::new(&sent); + let _header = Header::unpack(&mut cursor).unwrap(); + let req = CreateRequest::unpack(&mut cursor).unwrap(); + assert!(req.desired_access.contains(FileAccessMask::DELETE)); + assert_ne!(req.create_options & FILE_DELETE_ON_CLOSE, 0); + assert_ne!(req.create_options & FILE_NON_DIRECTORY_FILE, 0); + } + + #[tokio::test] + async fn delete_file_create_failure_returns_error() { + let mock = Arc::new(MockTransport::new()); + + // Build compound response where CREATE fails. + let mut create_hdr = Header::new_request(Command::Create); + create_hdr.flags.set_response(); + create_hdr.credits = 32; + create_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let create_resp = pack_message( + &create_hdr, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let mut close_hdr = Header::new_request(Command::Close); + close_hdr.flags.set_response(); + close_hdr.credits = 32; + close_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let close_resp = pack_message( + &close_hdr, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let frame = build_compound_response_frame(&[create_resp, close_resp]); + mock.queue_response(frame); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let result = tree.delete_file(&mut conn, "nonexistent.txt").await; + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().status(), + Some(NtStatus::OBJECT_NAME_NOT_FOUND) + ); + // Only the one compound frame, no standalone CLOSE needed. + assert_eq!(mock.sent_count(), 1); + } + + #[tokio::test] + async fn delete_file_close_failure_issues_standalone_close() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0xAA, + volatile: 0xBB, + }; + + // Compound: CREATE succeeds, CLOSE fails. + let create_resp = build_create_response(file_id, 0); + + let mut close_hdr = Header::new_request(Command::Close); + close_hdr.flags.set_response(); + close_hdr.credits = 32; + close_hdr.status = NtStatus::UNSUCCESSFUL; + let close_resp = pack_message( + &close_hdr, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let frame = build_compound_response_frame(&[create_resp, close_resp]); + mock.queue_response(frame); + + // Queue response for the standalone CLOSE retry. + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let result = tree.delete_file(&mut conn, "tricky.txt").await; + assert!(result.is_err()); + // Compound frame + standalone CLOSE = 2 messages sent. + assert_eq!(mock.sent_count(), 2); + } + + // ── Write file tests ───────────────────────────────────────────── + + #[tokio::test] + async fn write_file_sends_create_write_close() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0xCC, + volatile: 0xDD, + }; + + // write_file for small data now uses compound: CREATE+WRITE+FLUSH+CLOSE in one frame. + let create_resp = build_create_response(file_id, 0); + let write_resp = build_write_response(5); + let flush_resp = build_flush_response(); + let close_resp = build_close_response(); + + let frame = + build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]); + mock.queue_response(frame); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let written = tree + .write_file(&mut conn, "out.txt", b"hello") + .await + .unwrap(); + assert_eq!(written, 5); + // One compound frame sent. + assert_eq!(mock.sent_count(), 1); + } + + // ── Stat tests ─────────────────────────────────────────────────── + + #[tokio::test] + async fn stat_sends_compound_and_returns_file_info() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0xEE, + volatile: 0xFF, + }; + + // STAT = compound CREATE + QUERY_INFO(basic) + QUERY_INFO(standard) + CLOSE + let create_resp = build_create_response(file_id, 0); + let basic = build_file_basic_info( + 132_000_000_000_000_000, + 132_100_000_000_000_000, + 133_000_000_000_000_000, + 133_000_000_000_000_000, + 0x20, // ARCHIVE + ); + let basic_resp = build_query_info_response(basic); + let std_info = build_file_standard_info(4096, 2048, 1, false, false); + let std_resp = build_query_info_response(std_info); + let close_resp = build_close_response(); + + let frame = build_compound_response_frame(&[create_resp, basic_resp, std_resp, close_resp]); + mock.queue_response(frame); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let info = tree.stat(&mut conn, "doc.txt").await.unwrap(); + assert_eq!(info.size, 2048); + assert!(!info.is_directory); + assert_eq!(info.created, FileTime(132_000_000_000_000_000)); + assert_eq!(info.modified, FileTime(133_000_000_000_000_000)); + assert_eq!(info.accessed, FileTime(132_100_000_000_000_000)); + // One compound frame sent. + assert_eq!(mock.sent_count(), 1); + } + + #[tokio::test] + async fn stat_create_failure_returns_error() { + let mock = Arc::new(MockTransport::new()); + + // Build compound response where CREATE fails (all ops cascade). + let mut create_hdr = Header::new_request(Command::Create); + create_hdr.flags.set_response(); + create_hdr.credits = 32; + create_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let err_body = crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + let create_resp = pack_message(&create_hdr, &err_body); + + let mut q1_hdr = Header::new_request(Command::QueryInfo); + q1_hdr.flags.set_response(); + q1_hdr.credits = 32; + q1_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let q1_resp = pack_message(&q1_hdr, &err_body); + + let mut q2_hdr = Header::new_request(Command::QueryInfo); + q2_hdr.flags.set_response(); + q2_hdr.credits = 32; + q2_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let q2_resp = pack_message(&q2_hdr, &err_body); + + let mut close_hdr = Header::new_request(Command::Close); + close_hdr.flags.set_response(); + close_hdr.credits = 32; + close_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let close_resp = pack_message(&close_hdr, &err_body); + + let frame = build_compound_response_frame(&[create_resp, q1_resp, q2_resp, close_resp]); + mock.queue_response(frame); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let result = tree.stat(&mut conn, "nonexistent.txt").await; + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().status(), + Some(NtStatus::OBJECT_NAME_NOT_FOUND) + ); + assert_eq!(mock.sent_count(), 1); + } + + #[tokio::test] + async fn stat_query_failure_issues_standalone_close() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0xEE, + volatile: 0xFF, + }; + + // Compound: CREATE succeeds, first QUERY_INFO fails, rest cascade. + let create_resp = build_create_response(file_id, 0); + + let err_body = crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + + let mut q1_hdr = Header::new_request(Command::QueryInfo); + q1_hdr.flags.set_response(); + q1_hdr.credits = 32; + q1_hdr.status = NtStatus::UNSUCCESSFUL; + let q1_resp = pack_message(&q1_hdr, &err_body); + + let mut q2_hdr = Header::new_request(Command::QueryInfo); + q2_hdr.flags.set_response(); + q2_hdr.credits = 32; + q2_hdr.status = NtStatus::UNSUCCESSFUL; + let q2_resp = pack_message(&q2_hdr, &err_body); + + let mut close_hdr = Header::new_request(Command::Close); + close_hdr.flags.set_response(); + close_hdr.credits = 32; + close_hdr.status = NtStatus::UNSUCCESSFUL; + let close_resp = pack_message(&close_hdr, &err_body); + + let frame = build_compound_response_frame(&[create_resp, q1_resp, q2_resp, close_resp]); + mock.queue_response(frame); + + // Queue response for the standalone CLOSE retry. + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let result = tree.stat(&mut conn, "tricky.txt").await; + assert!(result.is_err()); + // Compound frame + standalone CLOSE = 2 messages sent. + assert_eq!(mock.sent_count(), 2); + } + + // ── Batch stat tests ────────────────────────────────────────────── + + #[tokio::test] + async fn stat_files_batch_happy_path() { + let mock = Arc::new(MockTransport::new()); + + // Queue 3 compound responses (CREATE+QUERY+QUERY+CLOSE each). + for i in 0..3u64 { + let file_id = FileId { + persistent: i + 1, + volatile: i + 100, + }; + let create_resp = build_create_response(file_id, 0); + let basic = build_file_basic_info( + 132_000_000_000_000_000 + i, + 132_100_000_000_000_000 + i, + 133_000_000_000_000_000 + i, + 133_000_000_000_000_000 + i, + 0x20, + ); + let basic_resp = build_query_info_response(basic); + let std_info = build_file_standard_info(4096, 1024 * (i + 1), 1, false, false); + let std_resp = build_query_info_response(std_info); + let close_resp = build_close_response(); + mock.queue_response(build_compound_response_frame(&[ + create_resp, + basic_resp, + std_resp, + close_resp, + ])); + } + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let results = tree + .stat_files(&mut conn, &["a.txt", "b.txt", "c.txt"]) + .await; + + assert_eq!(results.len(), 3); + assert_eq!(results[0].as_ref().unwrap().size, 1024); + assert_eq!(results[1].as_ref().unwrap().size, 2048); + assert_eq!(results[2].as_ref().unwrap().size, 3072); + assert_eq!(mock.sent_count(), 3); + } + + #[tokio::test] + async fn stat_files_batch_partial_failure() { + let mock = Arc::new(MockTransport::new()); + + let err_body = crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + + // File 1: success + let file_id = FileId { + persistent: 1, + volatile: 100, + }; + let create_resp = build_create_response(file_id, 0); + let basic = build_file_basic_info( + 132_000_000_000_000_000, + 132_100_000_000_000_000, + 133_000_000_000_000_000, + 133_000_000_000_000_000, + 0x20, + ); + let basic_resp = build_query_info_response(basic); + let std_info = build_file_standard_info(4096, 512, 1, false, false); + let std_resp = build_query_info_response(std_info); + let close_resp = build_close_response(); + mock.queue_response(build_compound_response_frame(&[ + create_resp, + basic_resp, + std_resp, + close_resp, + ])); + + // File 2: CREATE fails -- cascaded failure + let mut create_hdr = Header::new_request(Command::Create); + create_hdr.flags.set_response(); + create_hdr.credits = 32; + create_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let create_err = pack_message(&create_hdr, &err_body); + + let mut q1_hdr = Header::new_request(Command::QueryInfo); + q1_hdr.flags.set_response(); + q1_hdr.credits = 32; + q1_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let q1_err = pack_message(&q1_hdr, &err_body); + + let mut q2_hdr = Header::new_request(Command::QueryInfo); + q2_hdr.flags.set_response(); + q2_hdr.credits = 32; + q2_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let q2_err = pack_message(&q2_hdr, &err_body); + + let mut close_hdr = Header::new_request(Command::Close); + close_hdr.flags.set_response(); + close_hdr.credits = 32; + close_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let close_err = pack_message(&close_hdr, &err_body); + mock.queue_response(build_compound_response_frame(&[ + create_err, q1_err, q2_err, close_err, + ])); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let results = tree + .stat_files(&mut conn, &["exists.txt", "missing.txt"]) + .await; + + assert_eq!(results.len(), 2); + assert_eq!(results[0].as_ref().unwrap().size, 512); + assert!(results[1].is_err()); + assert_eq!( + results[1].as_ref().unwrap_err().status(), + Some(NtStatus::OBJECT_NAME_NOT_FOUND) + ); + } + + #[tokio::test] + async fn stat_files_empty_returns_empty() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let results: Vec> = tree.stat_files(&mut conn, &[]).await; + assert!(results.is_empty()); + } + + // ── Rename tests ───────────────────────────────────────────────── + + #[tokio::test] + async fn rename_sends_compound_create_setinfo_close() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0x11, + volatile: 0x22, + }; + + // RENAME = compound CREATE + SET_INFO + CLOSE + let create_resp = build_create_response(file_id, 0); + let setinfo_resp = build_set_info_response(); + let close_resp = build_close_response(); + let frame = build_compound_response_frame(&[create_resp, setinfo_resp, close_resp]); + mock.queue_response(frame); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + tree.rename(&mut conn, "old.txt", "new.txt").await.unwrap(); + + // One compound frame sent. + assert_eq!(mock.sent_count(), 1); + + // Verify the CREATE has DELETE access (required for rename) + let sent = mock.sent_message(0).unwrap(); + let mut cursor = ReadCursor::new(&sent); + let _header = Header::unpack(&mut cursor).unwrap(); + let req = CreateRequest::unpack(&mut cursor).unwrap(); + assert!(req.desired_access.contains(FileAccessMask::DELETE)); + } + + #[tokio::test] + async fn rename_create_failure_returns_error() { + let mock = Arc::new(MockTransport::new()); + + // Build compound response where CREATE fails. + let mut create_hdr = Header::new_request(Command::Create); + create_hdr.flags.set_response(); + create_hdr.credits = 32; + create_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let create_resp = pack_message( + &create_hdr, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let mut setinfo_hdr = Header::new_request(Command::SetInfo); + setinfo_hdr.flags.set_response(); + setinfo_hdr.credits = 32; + setinfo_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let setinfo_resp = pack_message( + &setinfo_hdr, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let mut close_hdr = Header::new_request(Command::Close); + close_hdr.flags.set_response(); + close_hdr.credits = 32; + close_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let close_resp = pack_message( + &close_hdr, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let frame = build_compound_response_frame(&[create_resp, setinfo_resp, close_resp]); + mock.queue_response(frame); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let result = tree.rename(&mut conn, "old.txt", "new.txt").await; + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().status(), + Some(NtStatus::OBJECT_NAME_NOT_FOUND) + ); + // Only the one compound frame, no standalone CLOSE needed. + assert_eq!(mock.sent_count(), 1); + } + + #[tokio::test] + async fn rename_setinfo_failure_issues_standalone_close() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0x11, + volatile: 0x22, + }; + + // Compound: CREATE succeeds, SET_INFO fails, CLOSE cascades failure. + let create_resp = build_create_response(file_id, 0); + + let mut setinfo_hdr = Header::new_request(Command::SetInfo); + setinfo_hdr.flags.set_response(); + setinfo_hdr.credits = 32; + setinfo_hdr.status = NtStatus::UNSUCCESSFUL; + let setinfo_resp = pack_message( + &setinfo_hdr, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let mut close_hdr = Header::new_request(Command::Close); + close_hdr.flags.set_response(); + close_hdr.credits = 32; + close_hdr.status = NtStatus::UNSUCCESSFUL; + let close_resp = pack_message( + &close_hdr, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let frame = build_compound_response_frame(&[create_resp, setinfo_resp, close_resp]); + mock.queue_response(frame); + + // Queue response for the standalone CLOSE retry. + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let result = tree.rename(&mut conn, "old.txt", "new.txt").await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().status(), Some(NtStatus::UNSUCCESSFUL)); + // Compound frame + standalone CLOSE = 2 messages sent. + assert_eq!(mock.sent_count(), 2); + } + + // ── Batch rename tests ──────────────────────────────────────────── + + #[tokio::test] + async fn rename_files_batch_happy_path() { + let mock = Arc::new(MockTransport::new()); + + // Queue 3 compound responses (CREATE+SET_INFO+CLOSE each). + for i in 0..3u64 { + let file_id = FileId { + persistent: i + 1, + volatile: i + 100, + }; + let create_resp = build_create_response(file_id, 0); + let setinfo_resp = build_set_info_response(); + let close_resp = build_close_response(); + mock.queue_response(build_compound_response_frame(&[ + create_resp, + setinfo_resp, + close_resp, + ])); + } + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let results = tree + .rename_files( + &mut conn, + &[ + ("a.txt", "a2.txt"), + ("b.txt", "b2.txt"), + ("c.txt", "c2.txt"), + ], + ) + .await; + + assert_eq!(results.len(), 3); + assert!(results[0].is_ok()); + assert!(results[1].is_ok()); + assert!(results[2].is_ok()); + assert_eq!(mock.sent_count(), 3); + } + + #[tokio::test] + async fn rename_files_batch_partial_failure() { + let mock = Arc::new(MockTransport::new()); + + let err_body = crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + + // File 1: success + let file_id = FileId { + persistent: 1, + volatile: 100, + }; + let create_resp = build_create_response(file_id, 0); + let setinfo_resp = build_set_info_response(); + let close_resp = build_close_response(); + mock.queue_response(build_compound_response_frame(&[ + create_resp, + setinfo_resp, + close_resp, + ])); + + // File 2: CREATE fails (not found) + let mut create_hdr = Header::new_request(Command::Create); + create_hdr.flags.set_response(); + create_hdr.credits = 32; + create_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let create_err = pack_message(&create_hdr, &err_body); + + let mut si_hdr = Header::new_request(Command::SetInfo); + si_hdr.flags.set_response(); + si_hdr.credits = 32; + si_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let si_err = pack_message(&si_hdr, &err_body); + + let mut close_hdr = Header::new_request(Command::Close); + close_hdr.flags.set_response(); + close_hdr.credits = 32; + close_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let close_err = pack_message(&close_hdr, &err_body); + mock.queue_response(build_compound_response_frame(&[ + create_err, si_err, close_err, + ])); + + // File 3: success + let file_id = FileId { + persistent: 3, + volatile: 102, + }; + let create_resp = build_create_response(file_id, 0); + let setinfo_resp = build_set_info_response(); + let close_resp = build_close_response(); + mock.queue_response(build_compound_response_frame(&[ + create_resp, + setinfo_resp, + close_resp, + ])); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let results = tree + .rename_files( + &mut conn, + &[ + ("a.txt", "a2.txt"), + ("missing.txt", "m2.txt"), + ("c.txt", "c2.txt"), + ], + ) + .await; + + assert_eq!(results.len(), 3); + assert!(results[0].is_ok()); + assert!(results[1].is_err()); + assert_eq!( + results[1].as_ref().unwrap_err().status(), + Some(NtStatus::OBJECT_NAME_NOT_FOUND) + ); + assert!(results[2].is_ok()); + } + + #[tokio::test] + async fn rename_files_empty_returns_empty() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let results: Vec> = tree.rename_files(&mut conn, &[]).await; + assert!(results.is_empty()); + assert_eq!(mock.sent_count(), 0); + } + + // ── Create directory tests ─────────────────────────────────────── + + #[tokio::test] + async fn create_directory_sends_create_and_close() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0x33, + volatile: 0x44, + }; + + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + tree.create_directory(&mut conn, "new_dir").await.unwrap(); + assert_eq!(mock.sent_count(), 2); + + // Verify the CREATE has FILE_DIRECTORY_FILE option and FileCreate disposition + let sent = mock.sent_message(0).unwrap(); + let mut cursor = ReadCursor::new(&sent); + let _header = Header::unpack(&mut cursor).unwrap(); + let req = CreateRequest::unpack(&mut cursor).unwrap(); + assert_eq!(req.create_disposition, CreateDisposition::FileCreate); + assert_ne!(req.create_options & FILE_DIRECTORY_FILE, 0); + } + + // ── Exclusive-create writer open tests ──────────────────────────── + + #[tokio::test] + async fn open_file_for_exclusive_create_sends_file_create_disposition() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0xAA, + volatile: 0xBB, + }; + mock.queue_response(build_create_response(file_id, 0)); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + tree.open_file_for_exclusive_create(&mut conn, "new.bin") + .await + .unwrap(); + + let sent = mock.sent_message(0).unwrap(); + let mut cursor = ReadCursor::new(&sent); + let _header = Header::unpack(&mut cursor).unwrap(); + let req = CreateRequest::unpack(&mut cursor).unwrap(); + assert_eq!( + req.create_disposition, + CreateDisposition::FileCreate, + "exclusive-create writer must use FileCreate, not FileOverwriteIf" + ); + // File, not directory. + assert_ne!(req.create_options & FILE_NON_DIRECTORY_FILE, 0); + } + + #[tokio::test] + async fn open_file_for_exclusive_create_maps_collision_to_already_exists() { + let mock = Arc::new(MockTransport::new()); + // STATUS_OBJECT_NAME_COLLISION = 0xC0000035; the server response any + // time `FileCreate` hits an existing file. + mock.queue_response(build_create_error_response(NtStatus::OBJECT_NAME_COLLISION)); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let err = tree + .open_file_for_exclusive_create(&mut conn, "existing.bin") + .await + .expect_err("exclusive-create on an existing file must error"); + assert_eq!( + err.kind(), + crate::ErrorKind::AlreadyExists, + "STATUS_OBJECT_NAME_COLLISION must map to ErrorKind::AlreadyExists, got: {err}" + ); + } + + // ── Delete directory tests ─────────────────────────────────────── + + #[tokio::test] + async fn delete_directory_sends_compound_create_and_close() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0x55, + volatile: 0x66, + }; + + // DELETE = compound CREATE(DELETE_ON_CLOSE) + CLOSE + let create_resp = build_create_response(file_id, 0); + let close_resp = build_close_response(); + let frame = build_compound_response_frame(&[create_resp, close_resp]); + mock.queue_response(frame); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + tree.delete_directory(&mut conn, "old_dir").await.unwrap(); + + // One compound frame sent. + assert_eq!(mock.sent_count(), 1); + + // Verify the CREATE has DELETE_ON_CLOSE and FILE_DIRECTORY_FILE + let sent = mock.sent_message(0).unwrap(); + let mut cursor = ReadCursor::new(&sent); + let _header = Header::unpack(&mut cursor).unwrap(); + let req = CreateRequest::unpack(&mut cursor).unwrap(); + assert_ne!(req.create_options & FILE_DELETE_ON_CLOSE, 0); + assert_ne!(req.create_options & FILE_DIRECTORY_FILE, 0); + } + + // ── Batch delete tests ─────────────────────────────────────────── + + #[tokio::test] + async fn delete_files_batch_happy_path() { + let mock = Arc::new(MockTransport::new()); + + // Queue 3 compound responses (CREATE+CLOSE each). + for i in 0..3u64 { + let file_id = FileId { + persistent: i + 1, + volatile: i + 100, + }; + let create_resp = build_create_response(file_id, 0); + let close_resp = build_close_response(); + mock.queue_response(build_compound_response_frame(&[create_resp, close_resp])); + } + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let results = tree + .delete_files(&mut conn, &["a.txt", "b.txt", "c.txt"]) + .await; + + assert_eq!(results.len(), 3); + assert!(results[0].is_ok()); + assert!(results[1].is_ok()); + assert!(results[2].is_ok()); + // 3 compound frames sent (one per file). + assert_eq!(mock.sent_count(), 3); + } + + #[tokio::test] + async fn delete_files_batch_partial_failure() { + let mock = Arc::new(MockTransport::new()); + + let err_body = crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + + // File 1: success + let file_id = FileId { + persistent: 1, + volatile: 100, + }; + let create_resp = build_create_response(file_id, 0); + let close_resp = build_close_response(); + mock.queue_response(build_compound_response_frame(&[create_resp, close_resp])); + + // File 2: CREATE fails (not found) -- cascaded failure + let mut create_hdr = Header::new_request(Command::Create); + create_hdr.flags.set_response(); + create_hdr.credits = 32; + create_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let create_err = pack_message(&create_hdr, &err_body); + + let mut close_hdr = Header::new_request(Command::Close); + close_hdr.flags.set_response(); + close_hdr.credits = 32; + close_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let close_err = pack_message(&close_hdr, &err_body); + mock.queue_response(build_compound_response_frame(&[create_err, close_err])); + + // File 3: success + let file_id = FileId { + persistent: 3, + volatile: 102, + }; + let create_resp = build_create_response(file_id, 0); + let close_resp = build_close_response(); + mock.queue_response(build_compound_response_frame(&[create_resp, close_resp])); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let results = tree + .delete_files(&mut conn, &["a.txt", "missing.txt", "c.txt"]) + .await; + + assert_eq!(results.len(), 3); + assert!(results[0].is_ok()); + assert!(results[1].is_err()); + assert_eq!( + results[1].as_ref().unwrap_err().status(), + Some(NtStatus::OBJECT_NAME_NOT_FOUND) + ); + assert!(results[2].is_ok()); + } + + #[tokio::test] + async fn delete_files_batch_close_failure_issues_cleanup() { + let mock = Arc::new(MockTransport::new()); + + let err_body = crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + + // File 1: CREATE succeeds, CLOSE fails + let file_id = FileId { + persistent: 0xAA, + volatile: 0xBB, + }; + let create_resp = build_create_response(file_id, 0); + + let mut close_hdr = Header::new_request(Command::Close); + close_hdr.flags.set_response(); + close_hdr.credits = 32; + close_hdr.status = NtStatus::UNSUCCESSFUL; + let close_fail = pack_message(&close_hdr, &err_body); + mock.queue_response(build_compound_response_frame(&[create_resp, close_fail])); + + // File 2: success + let file_id2 = FileId { + persistent: 0xCC, + volatile: 0xDD, + }; + let create_resp2 = build_create_response(file_id2, 0); + let close_resp2 = build_close_response(); + mock.queue_response(build_compound_response_frame(&[create_resp2, close_resp2])); + + // Queue response for the standalone CLOSE cleanup of file 1. + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let results = tree.delete_files(&mut conn, &["leaky.txt", "ok.txt"]).await; + + assert_eq!(results.len(), 2); + assert!(results[0].is_err()); + assert!(results[1].is_ok()); + // 2 compound frames + 1 standalone CLOSE = 3 messages sent. + assert_eq!(mock.sent_count(), 3); + } + + #[tokio::test] + async fn delete_files_empty_returns_empty() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let results = tree.delete_files(&mut conn, &[]).await; + assert!(results.is_empty()); + assert_eq!(mock.sent_count(), 0); + } + + // ── Pipelined read tests ──────────────────────────────────────── + + fn build_read_response_with_msg_id( + status: NtStatus, + msg_id: MessageId, + data: Vec, + ) -> Vec { + let mut h = Header::new_request(Command::Read); + h.flags.set_response(); + h.credits = 32; + h.status = status; + h.message_id = msg_id; + + if status == NtStatus::END_OF_FILE { + use crate::msg::header::ErrorResponse; + let body = ErrorResponse { + error_context_count: 0, + error_data: vec![], + }; + return pack_message(&h, &body); + } + + let body = ReadResponse { + data_offset: 0x50, + data_remaining: 0, + flags: 0, + data, + }; + + pack_message(&h, &body) + } + + fn build_write_response_with_msg_id(msg_id: MessageId, count: u32) -> Vec { + use crate::msg::write::WriteResponse; + let mut h = Header::new_request(Command::Write); + h.flags.set_response(); + h.credits = 32; + h.message_id = msg_id; + + let body = WriteResponse { + count, + remaining: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + }; + + pack_message(&h, &body) + } + + #[tokio::test] + async fn pipelined_read_four_chunks() { + // File: 256 KB = 4 chunks of 64 KB. + let mock = Arc::new(MockTransport::new()); + let tree_id = TreeId(20); + let file_id = FileId { + persistent: 0x100, + volatile: 0x200, + }; + let file_size = 256 * 1024u64; + + // Build 256 KB of test data with a recognizable pattern. + let mut expected_data = vec![0u8; file_size as usize]; + for (i, byte) in expected_data.iter_mut().enumerate() { + *byte = (i % 251) as u8; // prime to avoid alignment artifacts + } + + // Queue: CREATE response. + mock.queue_response(build_create_response(file_id, file_size)); + + // Queue: 4 READ responses (in order, matching the MessageIds + // that send_request will assign). + // After CREATE, the next message_id = 1 (CREATE consumed 0). + // Actually, connection starts at next_message_id=0. But setup_connection + // doesn't call negotiate (which would consume msg_id 0). + // send_request for CREATE will use msg_id 0, then the 4 READs will + // use msg_ids 1, 2, 3, 4. + for i in 0..4 { + let offset = i * 65536; + let chunk = expected_data[offset..offset + 65536].to_vec(); + mock.queue_response(build_read_response_with_msg_id( + NtStatus::SUCCESS, + MessageId((i / 65536 + 1) as u64), // msg_ids 1..4 + chunk, + )); + } + // Fix: the message IDs. send_request increments next_message_id each time. + // After CREATE (msg_id=0), the 4 READs get msg_ids 1, 2, 3, 4. + // Let me rebuild these correctly. + // Actually I already did it wrong above. Let me clear and redo. + // The loop above computed msg_id as (i / 65536 + 1) which is always 1. + // Let me fix this. + + // Clear the mock and redo. + let mock = Arc::new(MockTransport::new()); + mock.queue_response(build_create_response(file_id, file_size)); + + for i in 0u64..4 { + let offset = (i * 65536) as usize; + let chunk = expected_data[offset..offset + 65536].to_vec(); + mock.queue_response(build_read_response_with_msg_id( + NtStatus::SUCCESS, + MessageId(i + 1), // msg_ids 1, 2, 3, 4 + chunk, + )); + } + + // Queue: CLOSE response. + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id, + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let data = tree + .read_file_pipelined(&mut conn, "big.bin") + .await + .unwrap(); + + assert_eq!(data.len(), expected_data.len()); + assert_eq!(data, expected_data); + + // 1 CREATE + 4 READs + 1 CLOSE = 6 messages sent. + assert_eq!(mock.sent_count(), 6); + } + + #[tokio::test] + async fn pipelined_read_responses_out_of_order() { + // File: 192 KB = 3 chunks of 64 KB. Responses arrive in reverse order. + let mock = Arc::new(MockTransport::new()); + let tree_id = TreeId(20); + let file_id = FileId { + persistent: 0x300, + volatile: 0x400, + }; + let file_size = 192 * 1024u64; + + let mut expected_data = vec![0u8; file_size as usize]; + for (i, byte) in expected_data.iter_mut().enumerate() { + *byte = (i % 199) as u8; + } + + mock.queue_response(build_create_response(file_id, file_size)); + + // Queue responses in REVERSE order (msg_id 3, 2, 1) to test reassembly. + for i in (0u64..3).rev() { + let offset = (i * 65536) as usize; + let chunk = expected_data[offset..offset + 65536].to_vec(); + mock.queue_response(build_read_response_with_msg_id( + NtStatus::SUCCESS, + MessageId(i + 1), + chunk, + )); + } + + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id, + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let data = tree + .read_file_pipelined(&mut conn, "reverse.bin") + .await + .unwrap(); + + assert_eq!(data.len(), expected_data.len()); + assert_eq!(data, expected_data); + } + + #[tokio::test] + async fn pipelined_read_zero_byte_file() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0x500, + volatile: 0x600, + }; + + // CREATE reports file_size=0. + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(20), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let data = tree + .read_file_pipelined(&mut conn, "empty.bin") + .await + .unwrap(); + + assert!(data.is_empty()); + // 1 CREATE + 1 CLOSE = 2 messages (no READs needed). + assert_eq!(mock.sent_count(), 2); + } + + #[tokio::test] + async fn pipelined_read_end_of_file_mid_window() { + // File claims to be 128 KB (2 chunks), but second chunk returns STATUS_END_OF_FILE. + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0x700, + volatile: 0x800, + }; + let file_size = 128 * 1024u64; + let first_chunk = vec![0xAA; 65536]; + + mock.queue_response(build_create_response(file_id, file_size)); + // First chunk succeeds. + mock.queue_response(build_read_response_with_msg_id( + NtStatus::SUCCESS, + MessageId(1), + first_chunk.clone(), + )); + // Second chunk returns END_OF_FILE. + mock.queue_response(build_read_response_with_msg_id( + NtStatus::END_OF_FILE, + MessageId(2), + vec![], + )); + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(20), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let data = tree + .read_file_pipelined(&mut conn, "truncated.bin") + .await + .unwrap(); + + // We got the full buffer since file_size was 128 KB. + // The second chunk area stays as zeros (from vec initialization). + assert_eq!(data.len(), file_size as usize); + assert_eq!(&data[..65536], &first_chunk); + } + + #[tokio::test] + async fn pipelined_read_window_sliding() { + // File: 192 KB = 3 chunks. Credits = 2, so we need 2 windows. + let file_id = FileId { + persistent: 0x900, + volatile: 0xA00, + }; + let file_size = 192 * 1024u64; + + let mut expected_data = vec![0u8; file_size as usize]; + for (i, byte) in expected_data.iter_mut().enumerate() { + *byte = (i % 173) as u8; + } + + // Build with limited credits to force window sliding. + // CREATE response grants only 2 credits (instead of default 32), + // so the pipeline can only send 2 reads per window. + let mock = Arc::new(MockTransport::new()); + + let create_resp = { + let mut h = Header::new_request(Command::Create); + h.flags.set_response(); + h.credits = 2; // Only grant 2 credits. + let body = CreateResponse { + oplock_level: OplockLevel::None, + flags: 0, + create_action: CreateAction::FileOpened, + creation_time: FileTime::ZERO, + last_access_time: FileTime::ZERO, + last_write_time: FileTime::ZERO, + change_time: FileTime::ZERO, + allocation_size: 0, + end_of_file: file_size, + file_attributes: 0, + file_id, + create_contexts: vec![], + }; + pack_message(&h, &body) + }; + mock.queue_response(create_resp); + + // Window 1: 2 READs (chunks 0, 1). Responses grant 2 credits each. + for i in 0u64..2 { + let offset = (i * 65536) as usize; + let chunk_data = expected_data[offset..offset + 65536].to_vec(); + let mut h = Header::new_request(Command::Read); + h.flags.set_response(); + h.credits = 2; // Grant 2 credits per response. + h.message_id = MessageId(i + 1); + let body = ReadResponse { + data_offset: 0x50, + data_remaining: 0, + flags: 0, + data: chunk_data, + }; + mock.queue_response(pack_message(&h, &body)); + } + + // Window 2: 1 READ (chunk 2). + { + let offset = (2 * 65536) as usize; + let chunk_data = expected_data[offset..offset + 65536].to_vec(); + let mut h = Header::new_request(Command::Read); + h.flags.set_response(); + h.credits = 2; + h.message_id = MessageId(3); + let body = ReadResponse { + data_offset: 0x50, + data_remaining: 0, + flags: 0, + data: chunk_data, + }; + mock.queue_response(pack_message(&h, &body)); + } + + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(20), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let data = tree + .read_file_pipelined(&mut conn, "sliding.bin") + .await + .unwrap(); + + assert_eq!(data.len(), expected_data.len()); + assert_eq!(data, expected_data); + // 1 CREATE + 3 READs + 1 CLOSE = 5. + assert_eq!(mock.sent_count(), 5); + } + + #[tokio::test] + async fn sliding_window_sends_immediately_after_receive() { + // File: 512 KB = 8 chunks of 64 KB. Only 4 credits available initially. + // With sliding window: 4 sends, then each receive triggers a new send. + // Total: 8 sends interleaved with 8 receives (not 2 batches of 4). + let file_id = FileId { + persistent: 0xF00, + volatile: 0xF01, + }; + let file_size = 8 * 65536u64; + + let mut expected_data = vec![0u8; file_size as usize]; + for (i, byte) in expected_data.iter_mut().enumerate() { + *byte = (i % 137) as u8; + } + + let mock = Arc::new(MockTransport::new()); + + // CREATE response grants 4 credits (not the default 32). + let create_resp = { + let mut h = Header::new_request(Command::Create); + h.flags.set_response(); + h.credits = 4; + let body = CreateResponse { + oplock_level: OplockLevel::None, + flags: 0, + create_action: CreateAction::FileOpened, + creation_time: FileTime::ZERO, + last_access_time: FileTime::ZERO, + last_write_time: FileTime::ZERO, + change_time: FileTime::ZERO, + allocation_size: 0, + end_of_file: file_size, + file_attributes: 0, + file_id, + create_contexts: vec![], + }; + pack_message(&h, &body) + }; + mock.queue_response(create_resp); + + // Queue 8 READ responses. Each grants 1 credit so the window + // stays at 1 after the initial 4 are consumed (4 - 4 + 1 per response). + // With sliding window, after initial 4 sends, each response triggers 1 more send. + for i in 0u64..8 { + let offset = (i * 65536) as usize; + let chunk_data = expected_data[offset..offset + 65536].to_vec(); + let mut h = Header::new_request(Command::Read); + h.flags.set_response(); + h.credits = 1; // Grant 1 credit per response. + h.message_id = MessageId(i + 1); + let body = ReadResponse { + data_offset: 0x50, + data_remaining: 0, + flags: 0, + data: chunk_data, + }; + mock.queue_response(pack_message(&h, &body)); + } + + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(20), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let data = tree + .read_file_pipelined(&mut conn, "sliding_test.bin") + .await + .unwrap(); + + assert_eq!(data.len(), expected_data.len()); + assert_eq!(data, expected_data); + + // 1 CREATE + 8 READs + 1 CLOSE = 10 messages sent. + assert_eq!(mock.sent_count(), 10); + } + + // ── Pipelined read with progress tests ──────────────────────────── + + #[tokio::test] + async fn read_pipelined_with_progress_reports_progress() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0xF1, + volatile: 0xF2, + }; + // 2 chunks of 65536 bytes each. + let file_size = 65536u64 * 2; + let expected_data = vec![0xABu8; file_size as usize]; + + // CREATE response with file size. + let create_resp = { + let mut h = Header::new_request(Command::Create); + h.flags.set_response(); + h.credits = 32; + let body = CreateResponse { + oplock_level: OplockLevel::None, + flags: 0, + create_action: CreateAction::FileOpened, + creation_time: FileTime::ZERO, + last_access_time: FileTime::ZERO, + last_write_time: FileTime::ZERO, + change_time: FileTime::ZERO, + allocation_size: 0, + end_of_file: file_size, + file_attributes: 0, + file_id, + create_contexts: vec![], + }; + pack_message(&h, &body) + }; + mock.queue_response(create_resp); + + // 2 READ responses. + for i in 0..2u64 { + let offset = (i * 65536) as usize; + let chunk = expected_data[offset..offset + 65536].to_vec(); + let resp = build_read_response_with_msg_id(NtStatus::SUCCESS, MessageId(i + 1), chunk); + mock.queue_response(resp); + } + + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(20), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let mut progress_reports = Vec::new(); + let data = tree + .read_file_pipelined_with_progress(&mut conn, "progress_test.bin", |p| { + progress_reports.push(p.bytes_transferred); + ControlFlow::Continue(()) + }) + .await + .unwrap(); + + assert_eq!(data.len(), file_size as usize); + // Should have received 2 progress callbacks (one per chunk). + assert_eq!(progress_reports.len(), 2); + assert_eq!(progress_reports[0], 65536); + assert_eq!(progress_reports[1], file_size); + } + + #[tokio::test] + async fn read_pipelined_with_progress_cancellation() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0xF3, + volatile: 0xF4, + }; + // 4 chunks of 65536 bytes. + let file_size = 65536u64 * 4; + + let create_resp = { + let mut h = Header::new_request(Command::Create); + h.flags.set_response(); + h.credits = 32; + let body = CreateResponse { + oplock_level: OplockLevel::None, + flags: 0, + create_action: CreateAction::FileOpened, + creation_time: FileTime::ZERO, + last_access_time: FileTime::ZERO, + last_write_time: FileTime::ZERO, + change_time: FileTime::ZERO, + allocation_size: 0, + end_of_file: file_size, + file_attributes: 0, + file_id, + create_contexts: vec![], + }; + pack_message(&h, &body) + }; + mock.queue_response(create_resp); + + // Queue all 4 READ responses (some won't be consumed due to cancellation). + for i in 0..4u64 { + let chunk = vec![0x42u8; 65536]; + let resp = build_read_response_with_msg_id(NtStatus::SUCCESS, MessageId(i + 1), chunk); + mock.queue_response(resp); + } + + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(20), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + // Cancel after the first chunk. + let result = tree + .read_file_pipelined_with_progress(&mut conn, "cancel_test.bin", |_p| { + ControlFlow::Break(()) + }) + .await; + + assert!(result.is_err()); + match result.unwrap_err() { + Error::Cancelled => {} // expected + other => panic!("expected Cancelled, got {:?}", other), + } + } + + #[tokio::test] + async fn read_pipelined_with_progress_empty_file() { + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0xF5, + volatile: 0xF6, + }; + + // CREATE response with size=0. + let create_resp = { + let mut h = Header::new_request(Command::Create); + h.flags.set_response(); + h.credits = 32; + let body = CreateResponse { + oplock_level: OplockLevel::None, + flags: 0, + create_action: CreateAction::FileOpened, + creation_time: FileTime::ZERO, + last_access_time: FileTime::ZERO, + last_write_time: FileTime::ZERO, + change_time: FileTime::ZERO, + allocation_size: 0, + end_of_file: 0, + file_attributes: 0, + file_id, + create_contexts: vec![], + }; + pack_message(&h, &body) + }; + mock.queue_response(create_resp); + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(20), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let mut progress_called = false; + let data = tree + .read_file_pipelined_with_progress(&mut conn, "empty.bin", |p| { + progress_called = true; + assert_eq!(p.bytes_transferred, 0); + assert_eq!(p.total_bytes, Some(0)); + ControlFlow::Continue(()) + }) + .await + .unwrap(); + + assert!(data.is_empty()); + assert!(progress_called); + } + + // ── Pipelined write tests ─────────────────────────────────────── + + #[tokio::test] + async fn pipelined_write_four_chunks() { + let mock = Arc::new(MockTransport::new()); + let tree_id = TreeId(20); + let file_id = FileId { + persistent: 0xB00, + volatile: 0xC00, + }; + let data_to_write = vec![0x42u8; 256 * 1024]; // 256 KB = 4 chunks + + // CREATE response. + mock.queue_response(build_create_response(file_id, 0)); + + // 4 WRITE responses. + for i in 0u64..4 { + mock.queue_response(build_write_response_with_msg_id(MessageId(i + 1), 65536)); + } + + // FLUSH + CLOSE responses. + mock.queue_response(build_flush_response()); + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id, + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let written = tree + .write_file_pipelined(&mut conn, "big_write.bin", &data_to_write) + .await + .unwrap(); + + assert_eq!(written, 256 * 1024); + // 1 CREATE + 4 WRITEs + 1 FLUSH + 1 CLOSE = 7. + assert_eq!(mock.sent_count(), 7); + + // Verify that each WRITE request contains the correct data chunk. + for i in 0..4 { + let sent = mock.sent_message(i + 1).unwrap(); // skip CREATE at index 0 + let mut cursor = ReadCursor::new(&sent); + let _header = Header::unpack(&mut cursor).unwrap(); + let req = WriteRequest::unpack(&mut cursor).unwrap(); + assert_eq!(req.data.len(), 65536); + assert_eq!(req.offset, i as u64 * 65536); + assert!(req.data.iter().all(|&b| b == 0x42)); + } + } + + #[tokio::test] + async fn pipelined_write_last_chunk_smaller() { + // 100 KB = 1 full chunk (64 KB) + 1 partial chunk (36 KB). + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0xD00, + volatile: 0xE00, + }; + let data_to_write = vec![0x55u8; 100 * 1024]; + + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_write_response_with_msg_id(MessageId(1), 65536)); + mock.queue_response(build_write_response_with_msg_id(MessageId(2), 36 * 1024)); + mock.queue_response(build_flush_response()); + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(20), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let written = tree + .write_file_pipelined(&mut conn, "partial.bin", &data_to_write) + .await + .unwrap(); + + assert_eq!(written, 65536 + 36 * 1024); + assert_eq!(mock.sent_count(), 5); // CREATE + 2 WRITEs + FLUSH + CLOSE + } + + // ── Compound request tests ────────────────────────────────────── + + /// Build a compound response frame with proper NextCommand offsets and padding. + fn build_compound_response_frame(responses: &[Vec]) -> Vec { + let mut padded: Vec> = Vec::new(); + for (i, resp) in responses.iter().enumerate() { + let mut r = resp.clone(); + let is_last = i == responses.len() - 1; + if !is_last { + // Pad to 8-byte alignment. + let remainder = r.len() % 8; + if remainder != 0 { + r.resize(r.len() + (8 - remainder), 0); + } + // Set NextCommand. + let next_cmd = r.len() as u32; + r[20..24].copy_from_slice(&next_cmd.to_le_bytes()); + } + padded.push(r); + } + let mut frame = Vec::new(); + for r in &padded { + frame.extend_from_slice(r); + } + frame + } + + #[tokio::test] + async fn read_file_compound_returns_file_data() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + // Set up tree. + mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk)); + let tree = Tree::connect(&mut conn, "share").await.unwrap(); + + // Build compound response frame: CREATE + READ + CLOSE. + let file_id = FileId { + persistent: 0x42, + volatile: 0x99, + }; + let file_data = b"Hello, compound!".to_vec(); + + let create_resp = build_create_response(file_id, file_data.len() as u64); + let read_resp = build_read_response(NtStatus::SUCCESS, file_data.clone()); + let close_resp = build_close_response(); + + let frame = build_compound_response_frame(&[create_resp, read_resp, close_resp]); + mock.queue_response(frame); + + let data = tree + .read_file_compound(&mut conn, "test.txt") + .await + .unwrap(); + + assert_eq!(data, b"Hello, compound!"); + // Should have sent one compound frame (plus the tree connect). + assert_eq!(mock.sent_count(), 2); // TreeConnect + compound + } + + #[tokio::test] + async fn read_file_compound_handles_empty_file() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk)); + let tree = Tree::connect(&mut conn, "share").await.unwrap(); + + let file_id = FileId { + persistent: 1, + volatile: 2, + }; + + // Build compound response: CREATE ok, READ returns END_OF_FILE, CLOSE ok. + let create_resp = build_create_response(file_id, 0); + + // For END_OF_FILE, we need an error response body. + let read_resp = build_read_response(NtStatus::END_OF_FILE, vec![]); + let close_resp = build_close_response(); + + let frame = build_compound_response_frame(&[create_resp, read_resp, close_resp]); + mock.queue_response(frame); + + let data = tree + .read_file_compound(&mut conn, "empty.txt") + .await + .unwrap(); + + assert!(data.is_empty()); + } + + #[tokio::test] + async fn read_file_compound_create_failure_returns_error() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk)); + let tree = Tree::connect(&mut conn, "share").await.unwrap(); + + // Build compound response where CREATE fails with OBJECT_NAME_NOT_FOUND. + // When CREATE fails, server cascades error to READ and CLOSE. + let mut create_resp_header = Header::new_request(Command::Create); + create_resp_header.flags.set_response(); + create_resp_header.credits = 32; + create_resp_header.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let create_resp = pack_message( + &create_resp_header, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let mut read_resp_header = Header::new_request(Command::Read); + read_resp_header.flags.set_response(); + read_resp_header.credits = 32; + read_resp_header.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let read_resp = pack_message( + &read_resp_header, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let mut close_resp_header = Header::new_request(Command::Close); + close_resp_header.flags.set_response(); + close_resp_header.credits = 32; + close_resp_header.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let close_resp = pack_message( + &close_resp_header, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let frame = build_compound_response_frame(&[create_resp, read_resp, close_resp]); + mock.queue_response(frame); + + let result = tree.read_file_compound(&mut conn, "nonexistent.txt").await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert_eq!(err.status(), Some(NtStatus::OBJECT_NAME_NOT_FOUND)); + } + + #[tokio::test] + async fn read_file_compound_read_failure_issues_standalone_close() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk)); + let tree = Tree::connect(&mut conn, "share").await.unwrap(); + + let file_id = FileId { + persistent: 0x42, + volatile: 0x99, + }; + + // CREATE succeeds. + let create_resp = build_create_response(file_id, 1024); + + // READ fails with INSUFFICIENT_RESOURCES. + let mut read_resp_header = Header::new_request(Command::Read); + read_resp_header.flags.set_response(); + read_resp_header.credits = 32; + read_resp_header.status = NtStatus::INSUFFICIENT_RESOURCES; + let read_resp = pack_message( + &read_resp_header, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + // CLOSE also fails (cascaded). + let mut close_resp_header = Header::new_request(Command::Close); + close_resp_header.flags.set_response(); + close_resp_header.credits = 32; + close_resp_header.status = NtStatus::INSUFFICIENT_RESOURCES; + let close_resp = pack_message( + &close_resp_header, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let frame = build_compound_response_frame(&[create_resp, read_resp, close_resp]); + mock.queue_response(frame); + + // Queue a standalone CLOSE response for the cleanup. + mock.queue_response(build_close_response()); + + let result = tree.read_file_compound(&mut conn, "problem.txt").await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert_eq!(err.status(), Some(NtStatus::INSUFFICIENT_RESOURCES)); + + // Should have sent: TreeConnect + compound + standalone CLOSE = 3. + assert_eq!(mock.sent_count(), 3); + } + + #[tokio::test] + async fn read_file_compound_sends_correct_request_structure() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk)); + let tree = Tree::connect(&mut conn, "share").await.unwrap(); + + let file_id = FileId { + persistent: 1, + volatile: 2, + }; + let create_resp = build_create_response(file_id, 5); + let read_resp = build_read_response(NtStatus::SUCCESS, vec![1, 2, 3, 4, 5]); + let close_resp = build_close_response(); + let frame = build_compound_response_frame(&[create_resp, read_resp, close_resp]); + mock.queue_response(frame); + + tree.read_file_compound(&mut conn, "verify.txt") + .await + .unwrap(); + + // The second sent message is the compound request. + let compound = mock.sent_message(1).unwrap(); + + // Verify it contains 3 headers linked by NextCommand. + let mut cursor = ReadCursor::new(&compound); + let h1 = Header::unpack(&mut cursor).unwrap(); + assert_eq!(h1.command, Command::Create); + assert!(!h1.flags.is_related()); + assert!(h1.next_command > 0); + assert_eq!(h1.tree_id, Some(TreeId(7))); + + let off2 = h1.next_command as usize; + let mut cursor2 = ReadCursor::new(&compound[off2..]); + let h2 = Header::unpack(&mut cursor2).unwrap(); + assert_eq!(h2.command, Command::Read); + assert!(h2.flags.is_related()); + assert!(h2.next_command > 0); + + // Verify READ uses sentinel FileId. + let read_parsed = ReadRequest::unpack(&mut cursor2).unwrap(); + assert_eq!(read_parsed.file_id, FileId::SENTINEL); + + let off3 = off2 + h2.next_command as usize; + let mut cursor3 = ReadCursor::new(&compound[off3..]); + let h3 = Header::unpack(&mut cursor3).unwrap(); + assert_eq!(h3.command, Command::Close); + assert!(h3.flags.is_related()); + assert_eq!(h3.next_command, 0); + + // Verify CLOSE uses sentinel FileId. + let close_parsed = CloseRequest::unpack(&mut cursor3).unwrap(); + assert_eq!(close_parsed.file_id, FileId::SENTINEL); + } + + // ── Compound write tests ──────────────────────────────────────── + + #[tokio::test] + async fn write_file_compound_returns_bytes_written() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk)); + let tree = Tree::connect(&mut conn, "share").await.unwrap(); + + let file_id = FileId { + persistent: 0x42, + volatile: 0x99, + }; + let file_data = b"Hello, compound write!"; + + let create_resp = build_create_response(file_id, 0); + let write_resp = build_write_response(file_data.len() as u32); + let flush_resp = build_flush_response(); + let close_resp = build_close_response(); + + let frame = + build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]); + mock.queue_response(frame); + + let written = tree + .write_file_compound(&mut conn, "test.txt", file_data) + .await + .unwrap(); + + assert_eq!(written, file_data.len() as u64); + // Should have sent one compound frame (plus the tree connect). + assert_eq!(mock.sent_count(), 2); // TreeConnect + compound + } + + #[tokio::test] + async fn write_file_compound_empty_file() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk)); + let tree = Tree::connect(&mut conn, "share").await.unwrap(); + + let file_id = FileId { + persistent: 1, + volatile: 2, + }; + + let create_resp = build_create_response(file_id, 0); + let write_resp = build_write_response(0); + let flush_resp = build_flush_response(); + let close_resp = build_close_response(); + + let frame = + build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]); + mock.queue_response(frame); + + let written = tree + .write_file_compound(&mut conn, "empty.txt", b"") + .await + .unwrap(); + + assert_eq!(written, 0); + } + + #[tokio::test] + async fn write_file_compound_create_failure_returns_error() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk)); + let tree = Tree::connect(&mut conn, "share").await.unwrap(); + + // Build compound response where CREATE fails. + // When CREATE fails, server cascades error to WRITE, FLUSH, and CLOSE. + let mut create_h = Header::new_request(Command::Create); + create_h.flags.set_response(); + create_h.credits = 32; + create_h.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let create_resp = pack_message( + &create_h, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let mut write_h = Header::new_request(Command::Write); + write_h.flags.set_response(); + write_h.credits = 32; + write_h.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let write_resp = pack_message( + &write_h, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let mut flush_h = Header::new_request(Command::Flush); + flush_h.flags.set_response(); + flush_h.credits = 32; + flush_h.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let flush_resp = pack_message( + &flush_h, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let mut close_h = Header::new_request(Command::Close); + close_h.flags.set_response(); + close_h.credits = 32; + close_h.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let close_resp = pack_message( + &close_h, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let frame = + build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]); + mock.queue_response(frame); + + let result = tree + .write_file_compound(&mut conn, "bad/path.txt", b"data") + .await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert_eq!(err.status(), Some(NtStatus::OBJECT_NAME_NOT_FOUND)); + } + + #[tokio::test] + async fn write_file_compound_write_failure_issues_standalone_close() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk)); + let tree = Tree::connect(&mut conn, "share").await.unwrap(); + + let file_id = FileId { + persistent: 0x42, + volatile: 0x99, + }; + + // CREATE succeeds. + let create_resp = build_create_response(file_id, 0); + + // WRITE fails with INSUFFICIENT_RESOURCES. + let mut write_h = Header::new_request(Command::Write); + write_h.flags.set_response(); + write_h.credits = 32; + write_h.status = NtStatus::INSUFFICIENT_RESOURCES; + let write_resp = pack_message( + &write_h, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + // FLUSH also fails (cascaded). + let mut flush_h = Header::new_request(Command::Flush); + flush_h.flags.set_response(); + flush_h.credits = 32; + flush_h.status = NtStatus::INSUFFICIENT_RESOURCES; + let flush_resp = pack_message( + &flush_h, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + // CLOSE also fails (cascaded). + let mut close_h = Header::new_request(Command::Close); + close_h.flags.set_response(); + close_h.credits = 32; + close_h.status = NtStatus::INSUFFICIENT_RESOURCES; + let close_resp = pack_message( + &close_h, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + + let frame = + build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]); + mock.queue_response(frame); + + // Queue a standalone CLOSE response for the cleanup. + mock.queue_response(build_close_response()); + + let result = tree + .write_file_compound(&mut conn, "problem.txt", b"data") + .await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert_eq!(err.status(), Some(NtStatus::INSUFFICIENT_RESOURCES)); + + // Should have sent: TreeConnect + compound + standalone CLOSE = 3. + assert_eq!(mock.sent_count(), 3); + } + + #[tokio::test] + async fn write_file_compound_sends_correct_request_structure() { + let mock = Arc::new(MockTransport::new()); + let mut conn = setup_connection(&mock); + + mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk)); + let tree = Tree::connect(&mut conn, "share").await.unwrap(); + + let file_id = FileId { + persistent: 1, + volatile: 2, + }; + let create_resp = build_create_response(file_id, 0); + let write_resp = build_write_response(5); + let flush_resp = build_flush_response(); + let close_resp = build_close_response(); + let frame = + build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]); + mock.queue_response(frame); + + tree.write_file_compound(&mut conn, "verify.txt", &[1, 2, 3, 4, 5]) + .await + .unwrap(); + + // The second sent message is the compound request. + let compound = mock.sent_message(1).unwrap(); + + // Verify it contains 4 headers linked by NextCommand. + let mut cursor = ReadCursor::new(&compound); + let h1 = Header::unpack(&mut cursor).unwrap(); + assert_eq!(h1.command, Command::Create); + assert!(!h1.flags.is_related()); + assert!(h1.next_command > 0); + assert_eq!(h1.tree_id, Some(TreeId(7))); + + let off2 = h1.next_command as usize; + let mut cursor2 = ReadCursor::new(&compound[off2..]); + let h2 = Header::unpack(&mut cursor2).unwrap(); + assert_eq!(h2.command, Command::Write); + assert!(h2.flags.is_related()); + assert!(h2.next_command > 0); + + // Verify WRITE uses sentinel FileId. + let write_parsed = WriteRequest::unpack(&mut cursor2).unwrap(); + assert_eq!(write_parsed.file_id, FileId::SENTINEL); + assert_eq!(write_parsed.data, vec![1, 2, 3, 4, 5]); + + let off3 = off2 + h2.next_command as usize; + let mut cursor3 = ReadCursor::new(&compound[off3..]); + let h3 = Header::unpack(&mut cursor3).unwrap(); + assert_eq!(h3.command, Command::Flush); + assert!(h3.flags.is_related()); + assert!(h3.next_command > 0); + + // Verify FLUSH uses sentinel FileId. + let flush_parsed = FlushRequest::unpack(&mut cursor3).unwrap(); + assert_eq!(flush_parsed.file_id, FileId::SENTINEL); + + let off4 = off3 + h3.next_command as usize; + let mut cursor4 = ReadCursor::new(&compound[off4..]); + let h4 = Header::unpack(&mut cursor4).unwrap(); + assert_eq!(h4.command, Command::Close); + assert!(h4.flags.is_related()); + assert_eq!(h4.next_command, 0); + + // Verify CLOSE uses sentinel FileId. + let close_parsed = CloseRequest::unpack(&mut cursor4).unwrap(); + assert_eq!(close_parsed.file_id, FileId::SENTINEL); + } + + // ── BUFFER_OVERFLOW tests ─────────────────────────────────────── + + #[tokio::test] + async fn stat_accepts_buffer_overflow_as_partial_data() { + // STATUS_BUFFER_OVERFLOW is a warning, not an error. The response + // body contains valid partial data and should be parsed. + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0xCC, + volatile: 0xDD, + }; + + // STAT = compound CREATE + QUERY_INFO(basic, BUFFER_OVERFLOW) + QUERY_INFO(standard) + CLOSE + let create_resp = build_create_response(file_id, 0); + + let basic = build_file_basic_info( + 132_000_000_000_000_000, + 132_100_000_000_000_000, + 133_000_000_000_000_000, + 133_000_000_000_000_000, + 0x20, // ARCHIVE + ); + let basic_resp = build_query_info_response_with_status(NtStatus::BUFFER_OVERFLOW, basic); + + let std_info = build_file_standard_info(4096, 1024, 1, false, false); + let std_resp = build_query_info_response(std_info); + + let close_resp = build_close_response(); + + let frame = build_compound_response_frame(&[create_resp, basic_resp, std_resp, close_resp]); + mock.queue_response(frame); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(10), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + // Should succeed despite BUFFER_OVERFLOW on the basic info query. + let info = tree.stat(&mut conn, "partial.txt").await.unwrap(); + assert_eq!(info.size, 1024); + assert!(!info.is_directory); + assert_eq!(info.created, FileTime(132_000_000_000_000_000)); + // One compound frame sent. + assert_eq!(mock.sent_count(), 1); + } + + // ── Streamed write tests ─────────────────────────────────────── + + #[tokio::test] + async fn write_file_streamed_basic() { + // Provide 3 small chunks, verify CREATE + 3 WRITEs + CLOSE. + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0xAA, + volatile: 0xBB, + }; + + let chunk1 = vec![0x01; 100]; + let chunk2 = vec![0x02; 200]; + let chunk3 = vec![0x03; 150]; + let chunks = vec![Ok(chunk1.clone()), Ok(chunk2.clone()), Ok(chunk3.clone())]; + let mut chunk_iter = chunks.into_iter(); + + // Queue: CREATE, 3x WRITE, FLUSH, CLOSE. + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_write_response(100)); + mock.queue_response(build_write_response(200)); + mock.queue_response(build_write_response(150)); + mock.queue_response(build_flush_response()); + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(30), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let mut next_chunk = + move || -> Option, std::io::Error>> { chunk_iter.next() }; + + let written = tree + .write_file_streamed(&mut conn, "streamed.bin", &mut next_chunk) + .await + .unwrap(); + + assert_eq!(written, 450); // 100 + 200 + 150 + + // Verify CREATE + 3 WRITEs + FLUSH + CLOSE = 6 messages. + assert_eq!(mock.sent_count(), 6); + + // Verify WRITE offsets and data. + // Message 0 = CREATE, 1..3 = WRITEs, 4 = FLUSH, 5 = CLOSE. + let sent1 = mock.sent_message(1).unwrap(); + let mut cursor1 = ReadCursor::new(&sent1); + let _ = Header::unpack(&mut cursor1).unwrap(); + let req1 = WriteRequest::unpack(&mut cursor1).unwrap(); + assert_eq!(req1.offset, 0); + assert_eq!(req1.data, chunk1); + + let sent2 = mock.sent_message(2).unwrap(); + let mut cursor2 = ReadCursor::new(&sent2); + let _ = Header::unpack(&mut cursor2).unwrap(); + let req2 = WriteRequest::unpack(&mut cursor2).unwrap(); + assert_eq!(req2.offset, 100); + assert_eq!(req2.data, chunk2); + + let sent3 = mock.sent_message(3).unwrap(); + let mut cursor3 = ReadCursor::new(&sent3); + let _ = Header::unpack(&mut cursor3).unwrap(); + let req3 = WriteRequest::unpack(&mut cursor3).unwrap(); + assert_eq!(req3.offset, 300); + assert_eq!(req3.data, chunk3); + + // Verify last message is CLOSE. + let sent5 = mock.sent_message(5).unwrap(); + let mut cursor5 = ReadCursor::new(&sent5); + let h5 = Header::unpack(&mut cursor5).unwrap(); + assert_eq!(h5.command, Command::Close); + } + + #[tokio::test] + async fn write_file_streamed_empty() { + // Callback returns None immediately -> CREATE + FLUSH + CLOSE (empty file). + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0xCC, + volatile: 0xDD, + }; + + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_flush_response()); + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(31), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let mut next_chunk = || -> Option, std::io::Error>> { None }; + + let written = tree + .write_file_streamed(&mut conn, "empty_stream.bin", &mut next_chunk) + .await + .unwrap(); + + assert_eq!(written, 0); + // CREATE + FLUSH + CLOSE = 3 messages. + assert_eq!(mock.sent_count(), 3); + + // Verify CREATE then FLUSH then CLOSE. + let sent0 = mock.sent_message(0).unwrap(); + let mut c0 = ReadCursor::new(&sent0); + let h0 = Header::unpack(&mut c0).unwrap(); + assert_eq!(h0.command, Command::Create); + + let sent1 = mock.sent_message(1).unwrap(); + let mut c1 = ReadCursor::new(&sent1); + let h1 = Header::unpack(&mut c1).unwrap(); + assert_eq!(h1.command, Command::Flush); + + let sent2 = mock.sent_message(2).unwrap(); + let mut c2 = ReadCursor::new(&sent2); + let h2 = Header::unpack(&mut c2).unwrap(); + assert_eq!(h2.command, Command::Close); + } + + #[tokio::test] + async fn write_file_streamed_callback_error() { + // Callback returns Ok on first call, Err on second. + // Verify: handle is closed (CLOSE sent) and error is propagated. + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0xEE, + volatile: 0xFF, + }; + + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_write_response(64)); + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(32), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let mut call_count = 0u32; + let mut next_chunk = move || -> Option, std::io::Error>> { + call_count += 1; + match call_count { + 1 => Some(Ok(vec![0x42; 64])), + 2 => Some(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "source stream broke", + ))), + _ => None, + } + }; + + let result = tree + .write_file_streamed(&mut conn, "error_stream.bin", &mut next_chunk) + .await; + + assert!(result.is_err(), "expected error from callback to propagate"); + + // Verify CLOSE was still sent (handle cleanup). + // Messages: CREATE + WRITE + CLOSE = 3. + assert_eq!(mock.sent_count(), 3); + + let sent_last = mock.sent_message(2).unwrap(); + let mut cl = ReadCursor::new(&sent_last); + let hl = Header::unpack(&mut cl).unwrap(); + assert_eq!(hl.command, Command::Close); + } + + #[tokio::test] + async fn write_file_streamed_callback_error_is_not_connection_lost() { + // A callback error is a consumer issue, not a connection failure. + // The error kind should NOT be ConnectionLost — the connection is + // still usable after write_file_streamed drains in-flight responses. + let mock = Arc::new(MockTransport::new()); + let file_id = FileId { + persistent: 0x11, + volatile: 0x22, + }; + + mock.queue_response(build_create_response(file_id, 0)); + mock.queue_response(build_write_response(64)); + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(40), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let mut call_count = 0u32; + let mut next_chunk = move || -> Option, std::io::Error>> { + call_count += 1; + match call_count { + 1 => Some(Ok(vec![0x42; 64])), + 2 => Some(Err(std::io::Error::new( + std::io::ErrorKind::Interrupted, + "user cancelled", + ))), + _ => None, + } + }; + + let err = tree + .write_file_streamed(&mut conn, "cancel_test.bin", &mut next_chunk) + .await + .unwrap_err(); + + // The error should NOT be classified as ConnectionLost. + // A callback cancellation doesn't break the SMB connection — all + // in-flight responses were drained and the handle was closed cleanly. + assert_ne!( + err.kind(), + crate::ErrorKind::ConnectionLost, + "callback error should not be classified as ConnectionLost; the connection is still healthy" + ); + } + + #[tokio::test] + async fn write_file_streamed_callback_error_connection_still_usable() { + // After a callback error in write_file_streamed, the connection should + // be in a clean state: all in-flight WRITE responses drained, handle + // CLOSEd. A subsequent operation (read_file) should work. + let mock = Arc::new(MockTransport::new()); + let write_file_id = FileId { + persistent: 0x33, + volatile: 0x44, + }; + let read_file_id = FileId { + persistent: 0x55, + volatile: 0x66, + }; + + // Phase 1: Streamed write that errors after 2 chunks. + // With max_write=65536, 2 small chunks fit in the initial window. + mock.queue_response(build_create_response(write_file_id, 0)); + mock.queue_response(build_write_response(100)); + mock.queue_response(build_write_response(200)); + mock.queue_response(build_close_response()); + + // Phase 2: A subsequent read_file (compound: CREATE+READ+CLOSE). + let read_data = b"hello from the server"; + mock.queue_response(build_compound_read_response( + read_file_id, + read_data.to_vec(), + )); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(41), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + // Streamed write: 2 chunks succeed, then callback errors. + let mut call_count = 0u32; + let mut next_chunk = move || -> Option, std::io::Error>> { + call_count += 1; + match call_count { + 1 => Some(Ok(vec![0xAA; 100])), + 2 => Some(Ok(vec![0xBB; 200])), + 3 => Some(Err(std::io::Error::new( + std::io::ErrorKind::Interrupted, + "cancelled by user", + ))), + _ => None, + } + }; + + let write_result = tree + .write_file_streamed(&mut conn, "partial.bin", &mut next_chunk) + .await; + assert!( + write_result.is_err(), + "write should fail due to callback error" + ); + + // Now verify the connection still works: read a file. + let data = tree + .read_file_compound(&mut conn, "other.txt") + .await + .unwrap(); + assert_eq!(data, read_data); + } + + /// Builds a compound CREATE+READ+CLOSE response (single transport frame) + /// for use with `read_file_compound`. + fn build_compound_read_response(file_id: FileId, data: Vec) -> Vec { + use crate::msg::read::ReadResponse; + + // CREATE response (chained: next_command points to READ response) + let mut h1 = Header::new_request(Command::Create); + h1.flags.set_response(); + h1.credits = 32; + let create_body = CreateResponse { + oplock_level: OplockLevel::None, + flags: 0u8, + create_action: CreateAction::FileOpened, + creation_time: crate::pack::FileTime(0), + last_access_time: crate::pack::FileTime(0), + last_write_time: crate::pack::FileTime(0), + change_time: crate::pack::FileTime(0), + allocation_size: 0, + end_of_file: data.len() as u64, + file_attributes: 0x80, + file_id, + create_contexts: vec![], + }; + let create_bytes = pack_message(&h1, &create_body); + + // READ response (chained: next_command points to CLOSE response) + let mut h2 = Header::new_request(Command::Read); + h2.flags.set_response(); + h2.credits = 32; + let read_body = ReadResponse { + data_offset: 0x50, + data: data.clone(), + data_remaining: 0, + flags: 0, + }; + let read_bytes = pack_message(&h2, &read_body); + + // CLOSE response (last in chain) + let close_bytes = build_close_response(); + + // Patch next_command offsets for compounding + let mut frame = Vec::new(); + + let mut create_buf = create_bytes; + let create_len = create_buf.len(); + // Align to 8 bytes + let padded_create_len = (create_len + 7) & !7; + create_buf.resize(padded_create_len, 0); + // Set NextCommand in header (offset 20, 4 bytes LE) + let next_cmd = padded_create_len as u32; + create_buf[20..24].copy_from_slice(&next_cmd.to_le_bytes()); + // Set RELATED_OPERATIONS flag on subsequent headers? No — compound READ + // uses the same file_id from CREATE, but read_file_compound sends + // FileId::SENTINEL which gets filled in by the server. For mock, + // we just need the responses to parse correctly. + frame.extend_from_slice(&create_buf); + + let mut read_buf = read_bytes; + let read_len = read_buf.len(); + let padded_read_len = (read_len + 7) & !7; + read_buf.resize(padded_read_len, 0); + let next_cmd2 = padded_read_len as u32; + read_buf[20..24].copy_from_slice(&next_cmd2.to_le_bytes()); + frame.extend_from_slice(&read_buf); + + frame.extend_from_slice(&close_bytes); + + frame + } + + // ── Tree::download (streaming via &mut Connection) ───────────────────── + + /// Happy path: `Tree::download` returns a `FileDownload` that yields + /// all chunks of a small file in order and closes the handle cleanly. + #[tokio::test] + async fn tree_download_streams_small_file() { + let mock = Arc::new(MockTransport::new()); + + let file_id = FileId { + persistent: 0xA1, + volatile: 0xB2, + }; + let payload = b"streaming hello from Tree::download".to_vec(); + + // CREATE (open_file) — server returns handle + size. + mock.queue_response(build_create_response(file_id, payload.len() as u64)); + // Single READ covering the whole file (payload fits in one + // max_read_size=65536 chunk). + mock.queue_response(build_read_response(NtStatus::SUCCESS, payload.clone())); + // CLOSE after the last chunk. + mock.queue_response(build_close_response()); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(11), + share_name: "share".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let mut download = tree + .download(&mut conn, "hello.txt") + .await + .expect("download"); + assert_eq!(download.size(), payload.len() as u64); + + let mut received = Vec::new(); + while let Some(chunk) = download.next_chunk().await { + let bytes = chunk.expect("chunk"); + received.extend_from_slice(&bytes); + } + assert_eq!(received, payload); + + // CREATE + READ + CLOSE = 3 messages on the wire. + assert_eq!(mock.sent_count(), 3); + mock.assert_fully_consumed(); + } + + /// Error path: if CREATE fails, `Tree::download` surfaces the NTSTATUS + /// as `Error::Protocol` without ever constructing a `FileDownload`. + #[tokio::test] + async fn tree_download_create_failure_returns_protocol_error() { + let mock = Arc::new(MockTransport::new()); + + let mut create_hdr = Header::new_request(Command::Create); + create_hdr.flags.set_response(); + create_hdr.credits = 32; + create_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND; + let create_err = pack_message( + &create_hdr, + &crate::msg::header::ErrorResponse { + error_context_count: 0, + error_data: vec![], + }, + ); + mock.queue_response(create_err); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(12), + share_name: "share".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let result = tree.download(&mut conn, "missing.txt").await; + let err = result.err().expect("expected error"); + assert_eq!(err.status(), Some(NtStatus::OBJECT_NAME_NOT_FOUND)); + } + + /// Dropping a `FileDownload` mid-stream (before draining all chunks) + /// must not panic. The `Drop` impl logs a warning; the handle may leak + /// on the server, but the client stays healthy. + #[tokio::test] + async fn tree_download_drop_mid_stream_does_not_panic() { + let mock = Arc::new(MockTransport::new()); + + let file_id = FileId { + persistent: 0xC3, + volatile: 0xD4, + }; + // 3x max_read_size payload so at least one READ remains unsent + // after the caller drops early. + let total = 3 * 65536usize; + mock.queue_response(build_create_response(file_id, total as u64)); + // Queue one READ response; we'll consume only that one and drop. + mock.queue_response(build_read_response(NtStatus::SUCCESS, vec![0xAB; 65536])); + + let mut conn = setup_connection(&mock); + let tree = Tree { + tree_id: TreeId(13), + share_name: "share".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }; + + let mut download = tree.download(&mut conn, "big.bin").await.expect("download"); + + let first = download + .next_chunk() + .await + .expect("first chunk exists") + .expect("first chunk ok"); + assert_eq!(first.len(), 65536); + + // Drop mid-stream -- must not panic. + drop(download); + } + + /// Two `Tree::download` futures on cloned `Connection`s must both + /// complete with correct data, proving the headline reason for adding + /// `Tree::download`: concurrent downloads on one SMB session. + /// + /// The mock transport routes responses via the Phase 3 receiver task + /// (shared across all `Connection::clone()`s), so msg-id demux is what + /// wires each CREATE/READ/CLOSE to the right waiter. We queue the + /// responses AFTER the sends land (just like + /// `concurrent_execute_on_one_connection_all_succeed` in + /// `connection.rs`) so `enable_auto_rewrite_msg_id` can stamp them in + /// FIFO send order. Both downloads use the same payload, so it doesn't + /// matter which task's READ lands in which slot — whoever gets routed + /// their msg_id wins the correct bytes. + /// + /// Gotcha/Why: if the two tasks fetched DIFFERENT payloads, the FIFO + /// auto-rewrite in `MockTransport` would mis-pair sends and responses + /// unless we hand-serialized each phase across tasks, which would + /// defeat the concurrency the test is meant to prove. Keeping the + /// payloads identical lets both downloads race freely while we still + /// assert correctness of the whole pipeline. + #[tokio::test(flavor = "multi_thread")] + async fn tree_download_concurrent_on_cloned_connections() { + use std::time::{Duration, Instant}; + + let mock = Arc::new(MockTransport::new()); + mock.enable_auto_rewrite_msg_id(); + + let params = crate::client::connection::NegotiatedParams { + dialect: crate::types::Dialect::Smb2_0_2, + max_read_size: 65536, + max_write_size: 65536, + max_transact_size: 65536, + server_guid: crate::pack::Guid::ZERO, + signing_required: false, + capabilities: crate::types::flags::Capabilities::default(), + gmac_negotiated: false, + cipher: None, + compression_supported: false, + }; + let mut conn_primary = crate::client::connection::Connection::from_transport( + Box::new(mock.clone()), + Box::new(mock.clone()), + "test-server", + ); + conn_primary.set_test_params(params); + conn_primary.set_session_id(crate::types::SessionId(0x1234)); + let mut conn_secondary = conn_primary.clone(); + + let tree = Arc::new(Tree { + tree_id: TreeId(14), + share_name: "share".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + }); + + let payload = b"shared-body-for-both-readers".to_vec(); + + let file_id_1 = FileId { + persistent: 0x0A, + volatile: 0x1A, + }; + let file_id_2 = FileId { + persistent: 0x0B, + volatile: 0x1B, + }; + + let tree_a = Arc::clone(&tree); + let payload_a = payload.clone(); + let handle_a = tokio::spawn(async move { + let mut dl = tree_a + .download(&mut conn_primary, "same.txt") + .await + .expect("download a"); + let mut buf = Vec::new(); + while let Some(c) = dl.next_chunk().await { + buf.extend_from_slice(&c.expect("chunk a")); + } + assert_eq!(buf, payload_a); + }); + + let tree_b = Arc::clone(&tree); + let payload_b = payload.clone(); + let handle_b = tokio::spawn(async move { + let mut dl = tree_b + .download(&mut conn_secondary, "same.txt") + .await + .expect("download b"); + let mut buf = Vec::new(); + while let Some(c) = dl.next_chunk().await { + buf.extend_from_slice(&c.expect("chunk b")); + } + assert_eq!(buf, payload_b); + }); + + // Wait for both CREATE sends before queuing responses. + let deadline = Instant::now() + Duration::from_secs(5); + while mock.sent_count() < 2 { + if Instant::now() > deadline { + panic!("CREATE sends did not land: {}", mock.sent_count()); + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + mock.queue_response(build_create_response(file_id_1, payload.len() as u64)); + mock.queue_response(build_create_response(file_id_2, payload.len() as u64)); + + // Wait for both READs, then answer with identical payloads. + let deadline = Instant::now() + Duration::from_secs(5); + while mock.sent_count() < 4 { + if Instant::now() > deadline { + panic!("READ sends did not land: {}", mock.sent_count()); + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + mock.queue_response(build_read_response(NtStatus::SUCCESS, payload.clone())); + mock.queue_response(build_read_response(NtStatus::SUCCESS, payload.clone())); + + // Both downloads issue a CLOSE after the final chunk. + let deadline = Instant::now() + Duration::from_secs(5); + while mock.sent_count() < 6 { + if Instant::now() > deadline { + panic!("CLOSE sends did not land: {}", mock.sent_count()); + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + mock.queue_response(build_close_response()); + mock.queue_response(build_close_response()); + + handle_a.await.expect("task a panicked"); + handle_b.await.expect("task b panicked"); + + assert_eq!(mock.sent_count(), 6); // 2 CREATE + 2 READ + 2 CLOSE + } +} diff --git a/vendor/smb2/src/client/watcher.rs b/vendor/smb2/src/client/watcher.rs new file mode 100644 index 0000000..ba1b8fc --- /dev/null +++ b/vendor/smb2/src/client/watcher.rs @@ -0,0 +1,780 @@ +//! Directory change notification via SMB2 CHANGE_NOTIFY. +//! +//! The [`Watcher`] type registers for change notifications on a directory +//! and returns [`FileNotifyEvent`] entries describing changes as they happen. +//! The server holds the request until a change occurs, making this a long-poll +//! operation. + +use log::debug; + +use crate::client::connection::{await_frame, Connection, Frame}; +use crate::client::tree::Tree; +use crate::error::Result; +use crate::msg::change_notify::{ + ChangeNotifyRequest, ChangeNotifyResponse, FILE_NOTIFY_CHANGE_ATTRIBUTES, + FILE_NOTIFY_CHANGE_CREATION, FILE_NOTIFY_CHANGE_DIR_NAME, FILE_NOTIFY_CHANGE_FILE_NAME, + FILE_NOTIFY_CHANGE_LAST_WRITE, FILE_NOTIFY_CHANGE_SIZE, SMB2_WATCH_TREE, +}; +use crate::pack::{ReadCursor, Unpack}; +use crate::types::status::NtStatus; +use crate::types::{Command, FileId}; +use crate::Error; +use tokio::sync::oneshot; + +/// Default completion filter: watch for most common changes. +const DEFAULT_COMPLETION_FILTER: u32 = FILE_NOTIFY_CHANGE_FILE_NAME + | FILE_NOTIFY_CHANGE_DIR_NAME + | FILE_NOTIFY_CHANGE_ATTRIBUTES + | FILE_NOTIFY_CHANGE_SIZE + | FILE_NOTIFY_CHANGE_LAST_WRITE + | FILE_NOTIFY_CHANGE_CREATION; + +/// Default output buffer length for CHANGE_NOTIFY responses (64 KB). +const OUTPUT_BUFFER_LENGTH: u32 = 65536; + +/// The type of change that occurred on a file or directory. +/// +/// These correspond to the `Action` field in `FILE_NOTIFY_INFORMATION` +/// (MS-FSCC section 2.4.42). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FileNotifyAction { + /// A file was added to the directory. + Added, + /// A file was removed from the directory. + Removed, + /// A file was modified. + Modified, + /// A file was renamed (this is the old name). + RenamedOldName, + /// A file was renamed (this is the new name). + RenamedNewName, +} + +impl FileNotifyAction { + /// Parse an action value from the wire format. + fn from_u32(value: u32) -> Result { + match value { + 0x0000_0001 => Ok(FileNotifyAction::Added), + 0x0000_0002 => Ok(FileNotifyAction::Removed), + 0x0000_0003 => Ok(FileNotifyAction::Modified), + 0x0000_0004 => Ok(FileNotifyAction::RenamedOldName), + 0x0000_0005 => Ok(FileNotifyAction::RenamedNewName), + other => Err(Error::invalid_data(format!( + "unknown FILE_NOTIFY_INFORMATION action: {other:#010X}" + ))), + } + } +} + +impl std::fmt::Display for FileNotifyAction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FileNotifyAction::Added => write!(f, "added"), + FileNotifyAction::Removed => write!(f, "removed"), + FileNotifyAction::Modified => write!(f, "modified"), + FileNotifyAction::RenamedOldName => write!(f, "renamed (old name)"), + FileNotifyAction::RenamedNewName => write!(f, "renamed (new name)"), + } + } +} + +/// A single file change notification. +/// +/// Represents one `FILE_NOTIFY_INFORMATION` entry from the server. +#[derive(Debug, Clone)] +pub struct FileNotifyEvent { + /// What kind of change occurred. + pub action: FileNotifyAction, + /// The relative file name within the watched directory. + pub filename: String, +} + +/// Watches a directory for changes via SMB2 CHANGE_NOTIFY. +/// +/// The server holds the request until something changes, then responds +/// with one or more [`FileNotifyEvent`] entries. Each call to +/// [`next_events()`](Watcher::next_events) blocks until the server +/// reports a change. +/// +/// ```no_run +/// # async fn example(client: &mut smb2::SmbClient, share: &smb2::Tree) -> Result<(), smb2::Error> { +/// let mut watcher = client.watch(&share, "_test/", true).await?; +/// loop { +/// let events = watcher.next_events().await?; +/// for event in &events { +/// println!("{}: {}", event.filename, event.action); +/// } +/// } +/// # Ok(()) +/// # } +/// ``` +/// +/// **Pipelining**: `Watcher` keeps one CHANGE_NOTIFY request pre-issued on +/// the wire at all times after the first call to +/// [`next_events`](Self::next_events). The wire never sits idle between +/// consecutive responses, so server-side events that arrive while the +/// consumer is processing the previous batch are still delivered to an +/// outstanding request — they don't fall in a response→re-arm gap where +/// strict servers (older Samba, NAS firmware) drop them silently. +/// +/// The watcher owns a cloned [`Connection`] (cheap `Arc::clone`, all +/// clones multiplex over the same SMB session), so the caller doesn't +/// need a second `SmbClient` to perform other operations while watching. +pub struct Watcher { + tree: Tree, + conn: Connection, + file_id: FileId, + recursive: bool, + /// In-flight CHANGE_NOTIFY response receiver. Populated lazily on the + /// first `next_events()` call and re-populated before awaiting each + /// response, so there is always exactly one outstanding request on + /// the wire from that point on. + pending: Option>>, +} + +impl Watcher { + /// Create a new watcher (called by `Tree::watch`). + pub(crate) fn new(tree: Tree, conn: Connection, file_id: FileId, recursive: bool) -> Self { + Watcher { + tree, + conn, + file_id, + recursive, + pending: None, + } + } + + /// Wait for the next batch of change events. + /// + /// Dispatches a CHANGE_NOTIFY request (if one isn't already pre-issued + /// from the previous call), then — before awaiting the response — + /// dispatches the *next* CHANGE_NOTIFY. This keeps the wire + /// continuously armed: from the moment the first call returns until + /// the watcher is dropped, the server always has an outstanding + /// request to deliver events into. Closes the response→re-arm loss + /// window that strict servers (older Samba, NAS firmware) drop events + /// through. + /// + /// The server holds each request until changes occur, so this call + /// may block for a long time. + /// + /// Returns `Ok(events)` with one or more events when changes are detected. + /// + /// # Errors + /// + /// Returns `Error::Protocol` with `STATUS_NOTIFY_ENUM_DIR` if too many + /// changes occurred and the server could not fit them in the response + /// buffer. In this case, the caller should re-scan the directory and + /// keep watching — by the time control returns, the pipelined-next + /// request is already on the wire so no events arriving during the + /// re-scan get lost. + pub async fn next_events(&mut self) -> Result> { + // Cold start: no request has been issued yet. Dispatch the first. + if self.pending.is_none() { + let rx = self.dispatch_next().await?; + self.pending = Some(rx); + } + // Take the currently in-flight receiver, then immediately + // pre-issue the next request before awaiting this one. The + // `dispatch` call below `.await`s only the transport.send(), so + // when it returns, the next CHANGE_NOTIFY is on the wire and the + // server has somewhere to put new events even while we process + // the response for the previous one. + let in_flight = self.pending.take().expect("pending populated above"); + let next_rx = self.dispatch_next().await?; + self.pending = Some(next_rx); + + let frame = await_frame(in_flight).await?; + + if frame.header.status == NtStatus::NOTIFY_ENUM_DIR { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::ChangeNotify, + }); + } + + if frame.header.status != NtStatus::SUCCESS { + return Err(Error::Protocol { + status: frame.header.status, + command: Command::ChangeNotify, + }); + } + + let mut cursor = ReadCursor::new(&frame.body); + let resp = ChangeNotifyResponse::unpack(&mut cursor)?; + + let events = parse_notify_information(&resp.output_data)?; + debug!("watcher: received {} change event(s)", events.len()); + Ok(events) + } + + /// Build a CHANGE_NOTIFY request and dispatch it on the cloned + /// connection, returning the response receiver. `Connection::dispatch` + /// awaits only up to and including `transport.send()`, so when this + /// returns the request is on the wire — the caller can rely on the + /// "outstanding on the wire" invariant for whatever comes next. + async fn dispatch_next(&self) -> Result>> { + let flags = if self.recursive { SMB2_WATCH_TREE } else { 0 }; + let req = ChangeNotifyRequest { + flags, + output_buffer_length: OUTPUT_BUFFER_LENGTH, + file_id: self.file_id, + completion_filter: DEFAULT_COMPLETION_FILTER, + }; + self.conn + .dispatch(Command::ChangeNotify, &req, Some(self.tree.tree_id)) + .await + } + + /// Close the directory handle. + /// + /// Drops the pre-issued CHANGE_NOTIFY receiver (the `Connection` + /// receiver task discards the late response silently when it + /// arrives — same contract `Connection::execute` already documents), + /// then issues a CLOSE on the file handle. If `close` is not called + /// explicitly, the `Drop` impl drops the pre-issued receiver but the + /// server-side handle leaks until the session ends (there is no + /// async drop in Rust). + pub async fn close(mut self) -> Result<()> { + self.pending.take(); + self.tree.close_handle(&mut self.conn, self.file_id).await + } +} + +impl Drop for Watcher { + fn drop(&mut self) { + // The pre-issued response receiver (if any) drops with the + // Watcher. The `Connection` receiver task discards the late + // frame silently when it arrives, matching the contract on + // `Connection::execute`. The directory handle itself leaks + // server-side until the session ends — the docstring on `close` + // already warns about this. + } +} + +/// Parse a chain of FILE_NOTIFY_INFORMATION entries from the response buffer. +/// +/// Each entry has: +/// - `NextEntryOffset` (u32): offset to next entry, 0 for last +/// - `Action` (u32): the change type +/// - `FileNameLength` (u32): length of filename in bytes (UTF-16LE) +/// - `FileName` (variable): UTF-16LE, NOT null-terminated +/// +/// Entries are 4-byte aligned. +fn parse_notify_information(data: &[u8]) -> Result> { + let mut events = Vec::new(); + let mut offset = 0usize; + + if data.is_empty() { + return Ok(events); + } + + loop { + // Need at least 12 bytes for the fixed fields. + if offset + 12 > data.len() { + return Err(Error::invalid_data( + "FILE_NOTIFY_INFORMATION truncated: not enough bytes for fixed fields", + )); + } + + let next_entry_offset = + u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; + let action_raw = u32::from_le_bytes(data[offset + 4..offset + 8].try_into().unwrap()); + let filename_length = + u32::from_le_bytes(data[offset + 8..offset + 12].try_into().unwrap()) as usize; + + // Filename starts right after the 12-byte fixed header. + let filename_start = offset + 12; + let filename_end = filename_start + filename_length; + + if filename_end > data.len() { + return Err(Error::invalid_data(format!( + "FILE_NOTIFY_INFORMATION filename extends beyond buffer: \ + need {} bytes at offset {}, buffer is {} bytes", + filename_length, + filename_start, + data.len() + ))); + } + + let filename_bytes = &data[filename_start..filename_end]; + + // Decode UTF-16LE filename. + let filename = decode_utf16le(filename_bytes)?; + let action = FileNotifyAction::from_u32(action_raw)?; + + events.push(FileNotifyEvent { action, filename }); + + if next_entry_offset == 0 { + break; + } + + offset += next_entry_offset; + } + + Ok(events) +} + +/// Decode a UTF-16LE byte slice into a Rust String. +fn decode_utf16le(bytes: &[u8]) -> Result { + if bytes.len() % 2 != 0 { + return Err(Error::invalid_data("UTF-16LE filename has odd byte count")); + } + + let u16s: Vec = bytes + .chunks_exact(2) + .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])) + .collect(); + + String::from_utf16(&u16s) + .map_err(|e| Error::invalid_data(format!("invalid UTF-16LE filename: {e}"))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_single_notify_entry() { + // Build a single FILE_NOTIFY_INFORMATION entry. + let filename = "test.txt"; + let utf16: Vec = filename.encode_utf16().collect(); + let filename_bytes: Vec = utf16.iter().flat_map(|c| c.to_le_bytes()).collect(); + let filename_len = filename_bytes.len() as u32; + + let mut data = Vec::new(); + // NextEntryOffset = 0 (last entry) + data.extend_from_slice(&0u32.to_le_bytes()); + // Action = FILE_ACTION_ADDED (0x00000001) + data.extend_from_slice(&1u32.to_le_bytes()); + // FileNameLength + data.extend_from_slice(&filename_len.to_le_bytes()); + // FileName (UTF-16LE) + data.extend_from_slice(&filename_bytes); + + let events = parse_notify_information(&data).unwrap(); + assert_eq!(events.len(), 1); + assert_eq!(events[0].action, FileNotifyAction::Added); + assert_eq!(events[0].filename, "test.txt"); + } + + #[test] + fn parse_multiple_notify_entries() { + // Build two FILE_NOTIFY_INFORMATION entries. + let build_entry = |name: &str, action: u32, is_last: bool| -> Vec { + let utf16: Vec = name.encode_utf16().collect(); + let filename_bytes: Vec = utf16.iter().flat_map(|c| c.to_le_bytes()).collect(); + let filename_len = filename_bytes.len() as u32; + + let mut entry = Vec::new(); + // Fixed header is 12 bytes + filename. Align to 4 bytes. + let entry_size = 12 + filename_bytes.len(); + let aligned_size = (entry_size + 3) & !3; + + let next_offset = if is_last { 0u32 } else { aligned_size as u32 }; + entry.extend_from_slice(&next_offset.to_le_bytes()); + entry.extend_from_slice(&action.to_le_bytes()); + entry.extend_from_slice(&filename_len.to_le_bytes()); + entry.extend_from_slice(&filename_bytes); + + // Pad to 4-byte alignment. + while entry.len() < aligned_size { + entry.push(0); + } + + entry + }; + + let mut data = Vec::new(); + data.extend_from_slice(&build_entry("added.txt", 1, false)); + data.extend_from_slice(&build_entry("removed.txt", 2, true)); + + let events = parse_notify_information(&data).unwrap(); + assert_eq!(events.len(), 2); + assert_eq!(events[0].action, FileNotifyAction::Added); + assert_eq!(events[0].filename, "added.txt"); + assert_eq!(events[1].action, FileNotifyAction::Removed); + assert_eq!(events[1].filename, "removed.txt"); + } + + #[test] + fn parse_empty_buffer_returns_no_events() { + let events = parse_notify_information(&[]).unwrap(); + assert!(events.is_empty()); + } + + #[test] + fn parse_truncated_buffer_returns_error() { + // Only 8 bytes, need at least 12 for fixed fields. + let data = vec![0u8; 8]; + let result = parse_notify_information(&data); + assert!(result.is_err()); + } + + #[test] + fn decode_utf16le_basic() { + let input = "hello"; + let utf16: Vec = input.encode_utf16().collect(); + let bytes: Vec = utf16.iter().flat_map(|c| c.to_le_bytes()).collect(); + let result = decode_utf16le(&bytes).unwrap(); + assert_eq!(result, "hello"); + } + + #[test] + fn decode_utf16le_non_ascii() { + let input = "photos/\u{00E9}t\u{00E9}"; + let utf16: Vec = input.encode_utf16().collect(); + let bytes: Vec = utf16.iter().flat_map(|c| c.to_le_bytes()).collect(); + let result = decode_utf16le(&bytes).unwrap(); + assert_eq!(result, input); + } + + #[test] + fn decode_utf16le_odd_bytes_is_error() { + let result = decode_utf16le(&[0x41, 0x00, 0x42]); + assert!(result.is_err()); + } + + #[test] + fn file_notify_action_display() { + assert_eq!(format!("{}", FileNotifyAction::Added), "added"); + assert_eq!(format!("{}", FileNotifyAction::Removed), "removed"); + assert_eq!(format!("{}", FileNotifyAction::Modified), "modified"); + assert_eq!( + format!("{}", FileNotifyAction::RenamedOldName), + "renamed (old name)" + ); + assert_eq!( + format!("{}", FileNotifyAction::RenamedNewName), + "renamed (new name)" + ); + } + + #[test] + fn file_notify_action_from_u32_unknown_is_error() { + let result = FileNotifyAction::from_u32(0x9999); + assert!(result.is_err()); + } +} + +/// Loss-window tests using a strict-server simulator. +/// +/// These probe the architectural property the watcher contract should +/// guarantee: every event the server observes is eventually delivered +/// to the consumer, even when the server drops events that arrive +/// while no `CHANGE_NOTIFY` request is outstanding (the naspi / older +/// Samba behavior that triggered cmdr's field reproduction). +/// +/// **TDD-red on `main`**: `LossySim` drops events when no request is +/// outstanding; current `next_events()` issues one CHANGE_NOTIFY per +/// call, so there's always a gap between response delivery and the +/// next request. Events pushed during that gap are dropped, and the +/// test fails. The pipelined-watcher fix (always keep one CHANGE_NOTIFY +/// pre-issued on the wire) closes the gap, the simulator never drops, +/// and the test passes. +#[cfg(test)] +mod loss_window_tests { + use super::*; + use crate::client::connection::{pack_message, Connection, NegotiatedParams}; + use crate::client::tree::Tree; + use crate::msg::change_notify::ChangeNotifyResponse; + use crate::msg::header::Header; + use crate::pack::Guid; + use crate::transport::{TransportReceive, TransportSend}; + use crate::types::flags::Capabilities; + use crate::types::{Command, Dialect, MessageId, SessionId, TreeId}; + use async_trait::async_trait; + use std::collections::VecDeque; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::{Arc, Mutex}; + use std::time::Duration; + use tokio::sync::Notify; + + /// Simulates a CHANGE_NOTIFY server that DROPS events that arrive + /// while no request is outstanding. Models naspi / older Samba + /// firmware (the server side of cmdr's 9-files → 4-events field + /// reproduction). Forgiving servers like Docker Samba buffer + /// generously and won't trigger this; the simulator's job is to + /// surface the architectural bug regardless of how forgiving any + /// real server happens to be. + struct LossySim { + /// Outstanding CHANGE_NOTIFY request msg_ids (FIFO). + outstanding: Mutex>, + /// Events the server has observed but not yet delivered. + pending_events: Mutex>, + /// Response queue read by `receive()`. + responses: Mutex>>, + /// Count of events the server saw with no request outstanding. + dropped: Mutex, + send_notify: Notify, + recv_notify: Notify, + closed: AtomicBool, + } + + impl LossySim { + fn new() -> Self { + Self { + outstanding: Mutex::new(VecDeque::new()), + pending_events: Mutex::new(Vec::new()), + responses: Mutex::new(VecDeque::new()), + dropped: Mutex::new(0), + send_notify: Notify::new(), + recv_notify: Notify::new(), + closed: AtomicBool::new(false), + } + } + + /// Block until at least one CHANGE_NOTIFY request is outstanding. + async fn wait_outstanding(&self) { + loop { + if !self.outstanding.lock().unwrap().is_empty() { + return; + } + if self.closed.load(Ordering::Acquire) { + return; + } + self.send_notify.notified().await; + } + } + + /// Push an event. If a CHANGE_NOTIFY request is outstanding, buffer + /// the event for the next `deliver_pending()`. Else, drop silently + /// and bump the dropped counter. + fn push_event(&self, name: &str) { + let outstanding = !self.outstanding.lock().unwrap().is_empty(); + if outstanding { + self.pending_events + .lock() + .unwrap() + .push((name.to_string(), 1 /* FILE_ACTION_ADDED */)); + } else { + *self.dropped.lock().unwrap() += 1; + } + } + + /// Wrap all buffered events into a single CHANGE_NOTIFY response, + /// consuming one outstanding msg_id. + fn deliver_pending(&self) { + let msg_id = self.outstanding.lock().unwrap().pop_front(); + let events = std::mem::take(&mut *self.pending_events.lock().unwrap()); + if let Some(id) = msg_id { + let resp = build_response(id, &events); + self.responses.lock().unwrap().push_back(resp); + self.recv_notify.notify_one(); + } + } + + fn dropped_count(&self) -> usize { + *self.dropped.lock().unwrap() + } + + fn close(&self) { + self.closed.store(true, Ordering::Release); + self.recv_notify.notify_waiters(); + self.send_notify.notify_waiters(); + } + } + + #[async_trait] + impl TransportSend for LossySim { + async fn send(&self, data: &[u8]) -> crate::error::Result<()> { + if let Some(msg_id) = extract_change_notify_msg_id(data) { + self.outstanding.lock().unwrap().push_back(msg_id); + self.send_notify.notify_waiters(); + } + Ok(()) + } + } + + #[async_trait] + impl TransportReceive for LossySim { + async fn receive(&self) -> crate::error::Result> { + loop { + if let Some(data) = self.responses.lock().unwrap().pop_front() { + return Ok(data); + } + if self.closed.load(Ordering::Acquire) { + return Err(crate::Error::Disconnected); + } + self.recv_notify.notified().await; + } + } + } + + /// Pull `MessageId` out of a request frame, but only for CHANGE_NOTIFY. + /// Non-CHANGE_NOTIFY sends are ignored by the simulator (the test + /// pre-configures the connection so no other requests should hit this + /// transport — but if any do, we won't track them). + fn extract_change_notify_msg_id(data: &[u8]) -> Option { + const HEADER_MIN: usize = 64; + if data.len() < HEADER_MIN || &data[0..4] != b"\xFESMB" { + return None; + } + let cmd = u16::from_le_bytes([data[12], data[13]]); + if cmd != Command::ChangeNotify as u16 { + return None; + } + Some(u64::from_le_bytes(data[24..32].try_into().unwrap())) + } + + /// Pack a CHANGE_NOTIFY response carrying the given (name, action) pairs. + fn build_response(msg_id: u64, events: &[(String, u32)]) -> Vec { + let mut output_data = Vec::new(); + for (i, (name, action)) in events.iter().enumerate() { + let is_last = i == events.len() - 1; + let utf16: Vec = name.encode_utf16().collect(); + let filename_bytes: Vec = utf16.iter().flat_map(|c| c.to_le_bytes()).collect(); + let filename_len = filename_bytes.len() as u32; + let entry_size = 12 + filename_bytes.len(); + let aligned_size = (entry_size + 3) & !3; + let next_offset = if is_last { 0u32 } else { aligned_size as u32 }; + let start = output_data.len(); + output_data.extend_from_slice(&next_offset.to_le_bytes()); + output_data.extend_from_slice(&action.to_le_bytes()); + output_data.extend_from_slice(&filename_len.to_le_bytes()); + output_data.extend_from_slice(&filename_bytes); + while output_data.len() - start < aligned_size { + output_data.push(0); + } + } + let mut h = Header::new_request(Command::ChangeNotify); + h.flags.set_response(); + h.message_id = MessageId(msg_id); + h.credits = 32; + let body = ChangeNotifyResponse { output_data }; + pack_message(&h, &body) + } + + fn setup_connection(sim: &Arc) -> Connection { + let mut conn = + Connection::from_transport(Box::new(sim.clone()), Box::new(sim.clone()), "test-server"); + conn.set_test_params(NegotiatedParams { + dialect: Dialect::Smb2_0_2, + max_read_size: 65536, + max_write_size: 65536, + max_transact_size: 65536, + server_guid: Guid::ZERO, + signing_required: false, + capabilities: Capabilities::default(), + gmac_negotiated: false, + cipher: None, + compression_supported: false, + }); + conn.set_session_id(SessionId(0x1234)); + conn + } + + fn test_tree() -> Tree { + Tree { + tree_id: TreeId(1), + share_name: "test".to_string(), + server: "test-server".to_string(), + is_dfs: false, + encrypt_data: false, + } + } + + /// Cycle, repeated N times: + /// 1. wait for outstanding (watcher armed) + /// 2. push event A → buffered + /// 3. deliver_pending → response queued, msg_id consumed + /// 4. push GAP event → on `main`, no outstanding → DROPPED; + /// on the pipelined-watcher fix, the next request is already + /// issued → buffered. + /// + /// Final flush: one more wait_outstanding + push + deliver to make + /// sure any buffered gap events on the fix path get out. + /// + /// On `main`: `dropped_count() > 0`, `delivered.len() < expected`. + /// On the fix: `dropped_count() == 0`, all events delivered. + #[tokio::test] + async fn watcher_does_not_lose_events_between_consecutive_requests() { + let _ = env_logger::try_init(); + + const N_CYCLES: usize = 5; + + let sim = Arc::new(LossySim::new()); + let conn = setup_connection(&sim); + let tree = test_tree(); + + let scenario_sim = sim.clone(); + let scenario = tokio::spawn(async move { + let sim = scenario_sim; + for round in 0..N_CYCLES { + sim.wait_outstanding().await; + sim.push_event(&format!("a_{round:02}")); + sim.deliver_pending(); + // Inline push (no .await) — outstanding queue was just + // emptied by deliver_pending. On `main`, no request has + // been re-issued yet, so this lands in the "drop" branch. + // On the fix, a pre-issued request is still outstanding, + // so it lands in the "buffer" branch. + sim.push_event(&format!("gap_{round:02}")); + // Models "time passes between server-side events". Real + // workloads have at least a syscall worth of latency + // between events, which is enough for the watcher task + // to wake up, process the previous response, and + // re-dispatch. The pipelining fix only guarantees one + // outstanding through the response-processing window, + // not through arbitrary back-to-back synchronous + // delivers within a single scheduler quantum. + tokio::task::yield_now().await; + } + // Flush: drive one more cycle to push any buffered gap events + // out the door for the fix path. + sim.wait_outstanding().await; + sim.push_event("flush_marker"); + sim.deliver_pending(); + // Brief grace period for the watcher to drain the response, + // then close so its next next_events() returns Disconnected + // and the consumer loop exits. + tokio::time::sleep(Duration::from_millis(50)).await; + sim.close(); + }); + + let mut watcher = Watcher::new( + tree, + conn, + crate::types::FileId { + persistent: 0x1111, + volatile: 0x2222, + }, + true, + ); + let mut delivered: Vec = Vec::new(); + while let Ok(events) = watcher.next_events().await { + for e in &events { + delivered.push(e.filename.clone()); + } + } + scenario.await.unwrap(); + + let dropped = sim.dropped_count(); + // `a_*` events always land in the outstanding window. `flush_marker` + // ditto. `gap_*` events expose the bug: dropped today, delivered + // after the fix. + let expected_min = N_CYCLES /* a_* */ + 1 /* flush_marker */; + let expected_max = expected_min + N_CYCLES /* gap_* */; + + assert!( + delivered.len() >= expected_min, + "watcher dropped 'a_*' or 'flush_marker' events: got {:?}", + delivered + ); + assert_eq!( + dropped, 0, + "{} server-side event(s) arrived with no outstanding CHANGE_NOTIFY \ + request and were dropped. The pipelined-watcher fix should keep \ + one CHANGE_NOTIFY request continuously outstanding so no event \ + ever lands in the drop branch. Delivered to consumer: {:?}", + dropped, delivered + ); + assert_eq!( + delivered.len(), + expected_max, + "expected every 'a_*', 'gap_*', and 'flush_marker' event delivered; \ + got {:?}", + delivered + ); + } +} diff --git a/vendor/smb2/src/crypto/CLAUDE.md b/vendor/smb2/src/crypto/CLAUDE.md new file mode 100644 index 0000000..be70a59 --- /dev/null +++ b/vendor/smb2/src/crypto/CLAUDE.md @@ -0,0 +1,55 @@ +# Crypto -- signing, encryption, key derivation, compression + +Handles all cryptographic operations. Most users don't touch this directly -- `Session::setup` and `Connection` use it automatically. + +## Key files + +| File | Purpose | +|---|---| +| `signing.rs` | Sign/verify messages. Three algorithms: HMAC-SHA256, AES-CMAC, AES-GMAC | +| `encryption.rs` | Encrypt/decrypt messages. Four ciphers: AES-128/256-CCM, AES-128/256-GCM | +| `kdf.rs` | SP800-108 KDF + `PreauthHasher` (SHA-512 running hash) | +| `compression.rs` | LZ4 compression for SMB 3.1.1 | + +## Signing algorithms + +| Algorithm | Dialect | Key size | +|---|---|---| +| HMAC-SHA256 (truncated to 16 bytes) | SMB 2.0.2, 2.1 | any | +| AES-128-CMAC | SMB 3.0, 3.0.2, 3.1.1 (fallback) | 16 bytes | +| AES-128-GMAC | SMB 3.1.1 (with `SMB2_SIGNING_CAPABILITIES`) | 16 bytes | + +GMAC is AES-128-GCM with empty plaintext. The auth tag IS the signature. The 12-byte nonce encodes `MessageId` (bytes 0-7), a role bit (byte 8 bit 0: 0=client, 1=server), and a cancel flag (byte 8 bit 1). + +## Encryption + +Four ciphers, negotiated during NEGOTIATE: +- AES-128-CCM (11-byte nonce) -- SMB 3.0+ +- AES-128-GCM (12-byte nonce) -- SMB 3.0+ +- AES-256-CCM (11-byte nonce) -- SMB 3.1.1 +- AES-256-GCM (12-byte nonce) -- SMB 3.1.1 + +Nonces come from a `NonceGenerator` with a monotonic u64 counter. Nonce reuse breaks GCM catastrophically -- the counter must never reset within a session. + +AAD is the TRANSFORM_HEADER bytes 20..52 (Nonce + OriginalMessageSize + Reserved + Flags + SessionId). The auth tag goes into the Signature field at bytes 4..20. + +## Key derivation (SP800-108) + +`derive_session_keys` produces three keys (signing, encryption, decryption) from the NTLM session key using HMAC-SHA256 in counter mode. + +- **SMB 3.0/3.0.2**: Fixed ASCII label/context pairs (for example, `"SMB2AESCMAC\0"` / `"SmbSign\0"`) +- **SMB 3.1.1**: New labels (`"SMBSigningKey\0"`) with preauth hash (64-byte SHA-512) as context + +`PreauthHasher` computes `SHA-512(prev_hash || message_bytes)` incrementally over negotiate and session-setup wire bytes. Cloned per session (spec requires per-session hash). + +## Key decisions + +- **Labels include `\0` terminator**: Matches smb-rs and the spec's Label field definitions. The double-null (label `\0` + separator `0x00`) is correct. +- **GMAC uses AES-128, not AES-256**: Despite the signing algorithm name containing "256", the actual GMAC implementation uses AES-128-GCM. The "256" in the spec refers to the GMAC algorithm ID, not the key size. Signing keys are always 16 bytes. + +## Gotchas + +- **GMAC nonce has a role bit**: Client signs with role=0, server with role=1. Verify uses role=1 (server). Same message+key produces different signatures for client vs server. +- **Signing and encryption are mutually exclusive on the wire**: When encryption is active, the signature field is zeroed (AEAD provides auth). Never sign AND encrypt. +- **Nonce counter must not be reused**: `NonceGenerator` panics on u64 overflow (unreachable in practice). Each session gets its own generator. +- **HMAC-SHA256 for signing accepts any key length**: Unlike CMAC/GMAC which require exactly 16 bytes. HMAC pads/hashes the key internally. diff --git a/vendor/smb2/src/crypto/compression.rs b/vendor/smb2/src/crypto/compression.rs new file mode 100644 index 0000000..24a8b9f --- /dev/null +++ b/vendor/smb2/src/crypto/compression.rs @@ -0,0 +1,286 @@ +//! SMB2 LZ4 compression for unchained mode (MS-SMB2 section 3.1.4.4). +//! +//! In unchained mode, the `CompressionTransformHeader` has `Flags = 0x0000`. +//! The `Offset` field indicates where compressed data starts relative to the +//! original message. Bytes before the offset are sent uncompressed (the +//! "uncompressed prefix"), while bytes from the offset onward are +//! LZ4-compressed. +//! +//! This allows the SMB2 header to remain uncompressed for routing while the +//! payload is compressed. + +/// Maximum decompressed size we allow (16 MB). Prevents decompression bombs. +const MAX_DECOMPRESSED_SIZE: u32 = 16 * 1024 * 1024; + +/// The result of compressing an SMB2 message (unchained mode). +#[derive(Debug, Clone)] +pub struct CompressedMessage { + /// The original uncompressed size of the compressed portion. + pub original_size: u32, + /// Bytes before the compression offset (sent as-is). + pub uncompressed_prefix: Vec, + /// The LZ4-compressed data. + pub compressed_data: Vec, + /// The offset that was used (same as input offset). + pub offset: u32, +} + +/// Compress an SMB2 message using LZ4 (unchained mode). +/// +/// `offset` indicates where compression starts in the original message. +/// Bytes before `offset` are kept as-is (uncompressed prefix). +/// Bytes from `offset` onward are LZ4-compressed. +/// +/// Returns `None` if compression doesn't reduce the size (not worth it), +/// or if there is nothing to compress (offset >= message length). +pub fn compress_message(message: &[u8], offset: usize) -> Option { + // Nothing to compress if offset is at or beyond the end. + if offset >= message.len() { + return None; + } + + let prefix = &message[..offset]; + let to_compress = &message[offset..]; + + let compressed = lz4_flex::block::compress(to_compress); + + // Only use compression if it actually reduces size. + if compressed.len() >= to_compress.len() { + return None; + } + + Some(CompressedMessage { + original_size: to_compress.len() as u32, + uncompressed_prefix: prefix.to_vec(), + compressed_data: compressed, + offset: offset as u32, + }) +} + +/// Decompress an SMB2 message (unchained mode). +/// +/// `uncompressed_prefix` is the data before the compression offset. +/// `compressed_data` is the LZ4-compressed portion. +/// `original_size` is the expected decompressed size of the compressed portion. +/// +/// Returns the full reconstructed message (prefix + decompressed data). +pub fn decompress_message( + uncompressed_prefix: &[u8], + compressed_data: &[u8], + original_size: u32, +) -> Result, crate::Error> { + // Validate original_size to prevent decompression bombs. + if original_size > MAX_DECOMPRESSED_SIZE { + return Err(crate::Error::invalid_data(format!( + "decompressed size {} exceeds maximum allowed size {}", + original_size, MAX_DECOMPRESSED_SIZE + ))); + } + + let decompressed = lz4_flex::block::decompress(compressed_data, original_size as usize) + .map_err(|e| crate::Error::invalid_data(format!("LZ4 decompression failed: {e}")))?; + + let mut result = Vec::with_capacity(uncompressed_prefix.len() + decompressed.len()); + result.extend_from_slice(uncompressed_prefix); + result.extend_from_slice(&decompressed); + Ok(result) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn compress_and_decompress_roundtrip() { + // Compressible data: repeated pattern. + let message: Vec = b"ABCDEFGH".iter().copied().cycle().take(1024).collect(); + + let compressed = compress_message(&message, 0).expect("should compress"); + assert!(compressed.compressed_data.len() < message.len()); + assert_eq!(compressed.original_size, message.len() as u32); + assert!(compressed.uncompressed_prefix.is_empty()); + assert_eq!(compressed.offset, 0); + + let decompressed = decompress_message( + &compressed.uncompressed_prefix, + &compressed.compressed_data, + compressed.original_size, + ) + .expect("should decompress"); + + assert_eq!(decompressed, message); + } + + #[test] + fn compress_with_offset_preserves_prefix() { + // Simulate a 64-byte SMB2 header + compressible payload. + let mut message = vec![0xFE; 64]; // "header" bytes + let payload: Vec = b"HelloWorld".iter().copied().cycle().take(2048).collect(); + message.extend_from_slice(&payload); + + let compressed = compress_message(&message, 64).expect("should compress"); + assert_eq!(compressed.offset, 64); + assert_eq!(compressed.uncompressed_prefix, &message[..64]); + assert_eq!(compressed.original_size, payload.len() as u32); + assert!(compressed.compressed_data.len() < payload.len()); + + let decompressed = decompress_message( + &compressed.uncompressed_prefix, + &compressed.compressed_data, + compressed.original_size, + ) + .expect("should decompress"); + + assert_eq!(decompressed, message); + } + + #[test] + fn compress_with_offset_zero_compresses_entire_message() { + let message: Vec = vec![42u8; 4096]; + + let compressed = compress_message(&message, 0).expect("should compress"); + assert_eq!(compressed.offset, 0); + assert!(compressed.uncompressed_prefix.is_empty()); + assert_eq!(compressed.original_size, 4096); + + let decompressed = decompress_message( + &compressed.uncompressed_prefix, + &compressed.compressed_data, + compressed.original_size, + ) + .expect("should decompress"); + + assert_eq!(decompressed, message); + } + + #[test] + fn compress_empty_message_returns_none() { + let message: &[u8] = &[]; + assert!(compress_message(message, 0).is_none()); + } + + #[test] + fn compress_offset_at_end_returns_none() { + let message = b"short"; + assert!(compress_message(message, 5).is_none()); + assert!(compress_message(message, 100).is_none()); + } + + #[test] + fn incompressible_data_returns_none() { + // Random-ish bytes that LZ4 cannot compress (will likely grow). + let mut message = Vec::with_capacity(256); + for i in 0u16..256 { + // Use a simple PRNG-like pattern that doesn't compress well. + message.push(((i.wrapping_mul(137).wrapping_add(53)) & 0xFF) as u8); + } + + // Small incompressible data should return None. + assert!( + compress_message(&message, 0).is_none(), + "incompressible data should return None" + ); + } + + #[test] + fn large_message_compresses_well() { + // 1 MB of repeated pattern -- should compress very well. + let message: Vec = b"SMB2 compression test data! " + .iter() + .copied() + .cycle() + .take(1024 * 1024) + .collect(); + + let compressed = compress_message(&message, 0).expect("should compress large message"); + + // LZ4 should achieve at least 4:1 on highly repetitive data. + let ratio = message.len() as f64 / compressed.compressed_data.len() as f64; + assert!( + ratio > 4.0, + "compression ratio {ratio:.1} is too low for repetitive data" + ); + + let decompressed = decompress_message( + &compressed.uncompressed_prefix, + &compressed.compressed_data, + compressed.original_size, + ) + .expect("should decompress"); + + assert_eq!(decompressed.len(), message.len()); + assert_eq!(decompressed, message); + } + + #[test] + fn decompress_with_wrong_original_size_fails() { + let message: Vec = vec![0xAA; 1024]; + let compressed = compress_message(&message, 0).expect("should compress"); + + // Use a wrong (smaller) original_size -- decompression should fail + // because LZ4 validates the output size. + let result = decompress_message(&[], &compressed.compressed_data, 512); + assert!(result.is_err(), "wrong original_size should cause an error"); + } + + #[test] + fn decompress_rejects_oversized_original_size() { + // Attempt to decompress with original_size exceeding 16 MB limit. + let bogus_compressed = vec![0u8; 10]; + let result = decompress_message(&[], &bogus_compressed, MAX_DECOMPRESSED_SIZE + 1); + assert!(result.is_err()); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("exceeds maximum"), + "error should mention size limit, got: {err_msg}" + ); + } + + #[test] + fn decompress_with_exact_max_size_is_allowed() { + // original_size == MAX_DECOMPRESSED_SIZE should not be rejected + // by the size check (it will fail on actual decompression since the + // data is bogus, but that's a different error). + let bogus_compressed = vec![0u8; 10]; + let result = decompress_message(&[], &bogus_compressed, MAX_DECOMPRESSED_SIZE); + + // Should fail on decompression, not on size validation. + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("decompression failed"), + "should fail on decompression, not size check, got: {err_msg}" + ); + } + + #[test] + fn decompress_corrupt_data_fails() { + let corrupt = vec![0xFF, 0xFE, 0xFD, 0xFC, 0xFB]; + let result = decompress_message(&[], &corrupt, 1024); + assert!(result.is_err()); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("decompression failed"), + "error should mention decompression failure, got: {err_msg}" + ); + } + + #[test] + fn decompress_preserves_prefix_in_output() { + let prefix = b"PREFIX_DATA"; + let payload: Vec = vec![0x42; 2048]; + let compressed_payload = compress_message(&payload, 0).expect("should compress payload"); + + let result = decompress_message( + prefix, + &compressed_payload.compressed_data, + compressed_payload.original_size, + ) + .expect("should decompress"); + + assert_eq!(&result[..prefix.len()], prefix); + assert_eq!(&result[prefix.len()..], &payload); + } +} diff --git a/vendor/smb2/src/crypto/encryption.rs b/vendor/smb2/src/crypto/encryption.rs new file mode 100644 index 0000000..d529657 --- /dev/null +++ b/vendor/smb2/src/crypto/encryption.rs @@ -0,0 +1,591 @@ +//! SMB2/3 message encryption and decryption. +//! +//! Implements AES-128-CCM, AES-128-GCM, AES-256-CCM, and AES-256-GCM +//! as specified in MS-SMB2 sections 3.1.4.3 (encrypting) and 3.1.5.1 +//! (decrypting). Nonces are generated from a monotonically increasing +//! per-session counter to prevent catastrophic nonce reuse in AES-GCM. + +use aes::{Aes128, Aes256}; +use aes_gcm::aead::{array::Array, inout::InOutBuf, AeadInOut}; +use aes_gcm::KeyInit; +use ccm::consts::{U11, U16}; + +use crate::msg::transform::{TransformHeader, SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED}; +use crate::pack::{Pack, WriteCursor}; +use crate::types::SessionId; +use crate::Error; + +/// Offset in the serialized TRANSFORM_HEADER where the AAD begins. +/// +/// The AAD is "the SMB2 TRANSFORM_HEADER, excluding the ProtocolId and +/// Signature fields" (MS-SMB2 section 3.1.4.3). ProtocolId is 4 bytes +/// and Signature is 16 bytes, so the AAD starts at offset 20 (the Nonce +/// field) and extends to the end of the 52-byte header. +const AAD_OFFSET: usize = 20; + +/// Total size of the TRANSFORM_HEADER in bytes. +const HEADER_SIZE: usize = TransformHeader::SIZE; // 52 + +// ── CCM type aliases ───────────────────────────────────────────────── + +/// AES-128-CCM with 16-byte tag and 11-byte nonce (SMB 3.0+). +type Aes128Ccm = ccm::Ccm; + +/// AES-256-CCM with 16-byte tag and 11-byte nonce (SMB 3.1.1). +type Aes256Ccm = ccm::Ccm; + +// ── Cipher enum ────────────────────────────────────────────────────── + +/// Encryption cipher, determined during negotiation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub enum Cipher { + /// AES-128-CCM (SMB 3.0+) -- 11-byte nonce. + Aes128Ccm, + /// AES-128-GCM (SMB 3.0+) -- 12-byte nonce. + Aes128Gcm, + /// AES-256-CCM (SMB 3.1.1) -- 11-byte nonce. + Aes256Ccm, + /// AES-256-GCM (SMB 3.1.1) -- 12-byte nonce. + Aes256Gcm, +} + +impl Cipher { + /// Returns the number of nonce bytes actually used by this cipher. + pub fn nonce_len(self) -> usize { + match self { + Cipher::Aes128Ccm | Cipher::Aes256Ccm => 11, + Cipher::Aes128Gcm | Cipher::Aes256Gcm => 12, + } + } + + /// Returns the expected key length in bytes. + fn key_len(self) -> usize { + match self { + Cipher::Aes128Ccm | Cipher::Aes128Gcm => 16, + Cipher::Aes256Ccm | Cipher::Aes256Gcm => 32, + } + } +} + +// ── Nonce generator ────────────────────────────────────────────────── + +/// Monotonically increasing nonce generator. +/// +/// Each session gets its own nonce generator. The counter MUST NOT +/// be reused -- nonce reuse breaks AES-GCM catastrophically. +pub struct NonceGenerator { + counter: u64, +} + +impl NonceGenerator { + /// Create a new nonce generator starting at counter 0. + pub fn new() -> Self { + Self { counter: 0 } + } + + /// Generate the next nonce for the given cipher. + /// + /// Returns the full 16-byte nonce field for the TRANSFORM_HEADER. + /// - CCM: 8-byte LE counter in bytes 0..8, zeros in bytes 8..16 + /// (the cipher uses the first 11 bytes as the nonce). + /// - GCM: 8-byte LE counter in bytes 0..8, zeros in bytes 8..16 + /// (the cipher uses the first 12 bytes as the nonce). + /// + /// # Panics + /// + /// Panics if the counter overflows `u64::MAX`. In practice this + /// can never happen (2^64 messages at line speed would take millennia). + pub fn next(&mut self, _cipher: Cipher) -> [u8; 16] { + let count = self.counter; + self.counter = self.counter.checked_add(1).expect("nonce counter overflow"); + let mut nonce = [0u8; 16]; + nonce[..8].copy_from_slice(&count.to_le_bytes()); + nonce + } +} + +impl Default for NonceGenerator { + fn default() -> Self { + Self::new() + } +} + +// ── Encrypt ────────────────────────────────────────────────────────── + +/// Encrypt an SMB2 message. +/// +/// Returns `(transform_header_bytes, encrypted_message)`. The 52-byte +/// transform header includes the protocol ID, auth tag (in the Signature +/// field), nonce, original message size, flags, and session ID. The +/// encrypted message replaces the plaintext. +pub fn encrypt_message( + plaintext: &[u8], + key: &[u8], + cipher: Cipher, + nonce: &[u8; 16], + session_id: u64, +) -> Result<(Vec, Vec), Error> { + if key.len() != cipher.key_len() { + return Err(Error::invalid_data(format!( + "encryption key length mismatch: expected {}, got {}", + cipher.key_len(), + key.len() + ))); + } + + // Build the TRANSFORM_HEADER with a zeroed signature (will be filled + // with the auth tag after encryption). + let header = TransformHeader { + signature: [0u8; 16], + nonce: *nonce, + original_message_size: plaintext.len() as u32, + flags: SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED, + session_id: SessionId(session_id), + }; + + let mut header_bytes = { + let mut w = WriteCursor::new(); + header.pack(&mut w); + w.into_inner() + }; + + // AAD = header bytes 20..52 (Nonce + OriginalMessageSize + Reserved + Flags + SessionId) + let aad = &header_bytes[AAD_OFFSET..HEADER_SIZE]; + + // Encrypt and get the auth tag. + let mut buffer = plaintext.to_vec(); + let nonce_slice = &nonce[..cipher.nonce_len()]; + + let tag = encrypt_raw(cipher, key, nonce_slice, aad, &mut buffer)?; + + // Write the 16-byte auth tag into the Signature field (bytes 4..20). + header_bytes[4..20].copy_from_slice(&tag); + + Ok((header_bytes, buffer)) +} + +// ── Decrypt ────────────────────────────────────────────────────────── + +/// Decrypt an SMB2 message. +/// +/// `transform_header` is the 52-byte TRANSFORM_HEADER (as received on +/// the wire). `ciphertext` is the encrypted message data that follows +/// the header. Returns the decrypted plaintext. +pub fn decrypt_message( + transform_header: &[u8], + ciphertext: &[u8], + key: &[u8], + cipher: Cipher, +) -> Result, Error> { + if transform_header.len() != HEADER_SIZE { + return Err(Error::invalid_data(format!( + "transform header must be {} bytes, got {}", + HEADER_SIZE, + transform_header.len() + ))); + } + if key.len() != cipher.key_len() { + return Err(Error::invalid_data(format!( + "decryption key length mismatch: expected {}, got {}", + cipher.key_len(), + key.len() + ))); + } + + // Extract auth tag (Signature) from bytes 4..20. + let mut tag = [0u8; 16]; + tag.copy_from_slice(&transform_header[4..20]); + + // Extract nonce from bytes 20..36. + let nonce = &transform_header[20..20 + cipher.nonce_len()]; + + // AAD = header bytes 20..52. + let aad = &transform_header[AAD_OFFSET..HEADER_SIZE]; + + let mut buffer = ciphertext.to_vec(); + decrypt_raw(cipher, key, nonce, aad, &tag, &mut buffer)?; + + Ok(buffer) +} + +// ── Raw encrypt/decrypt helpers ────────────────────────────────────── + +/// Copy an auth tag array into a fixed-size `[u8; 16]` array. +fn tag_to_array(tag: Array) -> [u8; 16] { + let mut arr = [0u8; 16]; + arr.copy_from_slice(tag.as_slice()); + arr +} + +/// Encrypt `buffer` in place and return the 16-byte auth tag. +fn encrypt_raw( + cipher: Cipher, + key: &[u8], + nonce: &[u8], + aad: &[u8], + buffer: &mut [u8], +) -> Result<[u8; 16], Error> { + let map_err = |_| Error::invalid_data("encryption failed"); + let buf = InOutBuf::from(buffer); + + let tag = match cipher { + Cipher::Aes128Ccm => { + let c = Aes128Ccm::new(key.try_into().expect("key length validated")); + let n = nonce.try_into().expect("nonce length validated"); + c.encrypt_inout_detached(n, aad, buf) + .map(tag_to_array) + .map_err(map_err)? + } + Cipher::Aes128Gcm => { + let c = aes_gcm::Aes128Gcm::new(key.try_into().expect("key length validated")); + let n = nonce.try_into().expect("nonce length validated"); + c.encrypt_inout_detached(n, aad, buf) + .map(tag_to_array) + .map_err(map_err)? + } + Cipher::Aes256Ccm => { + let c = Aes256Ccm::new(key.try_into().expect("key length validated")); + let n = nonce.try_into().expect("nonce length validated"); + c.encrypt_inout_detached(n, aad, buf) + .map(tag_to_array) + .map_err(map_err)? + } + Cipher::Aes256Gcm => { + let c = aes_gcm::Aes256Gcm::new(key.try_into().expect("key length validated")); + let n = nonce.try_into().expect("nonce length validated"); + c.encrypt_inout_detached(n, aad, buf) + .map(tag_to_array) + .map_err(map_err)? + } + }; + + Ok(tag) +} + +/// Decrypt `buffer` in place, verifying the 16-byte auth tag. +fn decrypt_raw( + cipher: Cipher, + key: &[u8], + nonce: &[u8], + aad: &[u8], + tag: &[u8; 16], + buffer: &mut [u8], +) -> Result<(), Error> { + let map_err = |_| Error::invalid_data("decryption failed: authentication tag mismatch"); + let buf = InOutBuf::from(buffer); + let t: &Array = tag.into(); + + match cipher { + Cipher::Aes128Ccm => { + let c = Aes128Ccm::new(key.try_into().expect("key length validated")); + let n = nonce.try_into().expect("nonce length validated"); + c.decrypt_inout_detached(n, aad, buf, t).map_err(map_err) + } + Cipher::Aes128Gcm => { + let c = aes_gcm::Aes128Gcm::new(key.try_into().expect("key length validated")); + let n = nonce.try_into().expect("nonce length validated"); + c.decrypt_inout_detached(n, aad, buf, t).map_err(map_err) + } + Cipher::Aes256Ccm => { + let c = Aes256Ccm::new(key.try_into().expect("key length validated")); + let n = nonce.try_into().expect("nonce length validated"); + c.decrypt_inout_detached(n, aad, buf, t).map_err(map_err) + } + Cipher::Aes256Gcm => { + let c = aes_gcm::Aes256Gcm::new(key.try_into().expect("key length validated")); + let n = nonce.try_into().expect("nonce length validated"); + c.decrypt_inout_detached(n, aad, buf, t).map_err(map_err) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::msg::transform::TRANSFORM_PROTOCOL_ID; + + // ── Helper ──────────────────────────────────────────────────────── + + fn test_key(cipher: Cipher) -> Vec { + vec![0x42; cipher.key_len()] + } + + // ── Encrypt-then-decrypt roundtrip (one per cipher) ────────────── + + #[test] + fn roundtrip_aes128_ccm() { + roundtrip_cipher(Cipher::Aes128Ccm); + } + + #[test] + fn roundtrip_aes128_gcm() { + roundtrip_cipher(Cipher::Aes128Gcm); + } + + #[test] + fn roundtrip_aes256_ccm() { + roundtrip_cipher(Cipher::Aes256Ccm); + } + + #[test] + fn roundtrip_aes256_gcm() { + roundtrip_cipher(Cipher::Aes256Gcm); + } + + fn roundtrip_cipher(cipher: Cipher) { + let key = test_key(cipher); + let plaintext = b"Hello, SMB2 encryption roundtrip!"; + let session_id = 0xDEAD_BEEF_CAFE_FACE; + + let mut nonce_gen = NonceGenerator::new(); + let nonce = nonce_gen.next(cipher); + + let (header, ciphertext) = + encrypt_message(plaintext, &key, cipher, &nonce, session_id).unwrap(); + + // Ciphertext must differ from plaintext. + assert_ne!(&ciphertext[..], &plaintext[..]); + + let decrypted = decrypt_message(&header, &ciphertext, &key, cipher).unwrap(); + assert_eq!(decrypted, plaintext); + } + + // ── Nonce generator monotonically increases ────────────────────── + + #[test] + fn nonce_generator_monotonic() { + let mut gen = NonceGenerator::new(); + let mut prev = [0u8; 16]; // counter 0 hasn't been generated yet + + for i in 0u64..100 { + let nonce = gen.next(Cipher::Aes128Gcm); + // Extract the 8-byte LE counter from the nonce. + let counter = u64::from_le_bytes(nonce[..8].try_into().unwrap()); + assert_eq!(counter, i, "counter should equal {i}"); + + if i > 0 { + assert_ne!(nonce, prev, "each nonce must be unique"); + } + prev = nonce; + } + } + + // ── Nonce format for GCM ───────────────────────────────────────── + + #[test] + fn nonce_format_gcm() { + let mut gen = NonceGenerator::new(); + // Advance to counter = 7 to have a non-trivial value. + for _ in 0..7 { + gen.next(Cipher::Aes128Gcm); + } + let nonce = gen.next(Cipher::Aes128Gcm); // counter = 7 + + // First 8 bytes: LE counter (7). + assert_eq!( + u64::from_le_bytes(nonce[..8].try_into().unwrap()), + 7, + "counter value" + ); + // Bytes 8..12: zeros (padding to 12-byte GCM nonce). + assert_eq!(nonce[8..12], [0, 0, 0, 0], "GCM nonce padding (8..12)"); + // Bytes 12..16: zeros (unused portion of the 16-byte field). + assert_eq!(nonce[12..16], [0, 0, 0, 0], "unused nonce bytes (12..16)"); + } + + // ── Nonce format for CCM ───────────────────────────────────────── + + #[test] + fn nonce_format_ccm() { + let mut gen = NonceGenerator::new(); + // Advance to counter = 5. + for _ in 0..5 { + gen.next(Cipher::Aes128Ccm); + } + let nonce = gen.next(Cipher::Aes128Ccm); // counter = 5 + + // First 8 bytes: LE counter (5). + assert_eq!( + u64::from_le_bytes(nonce[..8].try_into().unwrap()), + 5, + "counter value" + ); + // Bytes 8..11: zeros (padding to 11-byte CCM nonce). + assert_eq!(nonce[8..11], [0, 0, 0], "CCM nonce padding (8..11)"); + // Bytes 11..16: zeros (unused portion of the 16-byte field). + assert_eq!( + nonce[11..16], + [0, 0, 0, 0, 0], + "unused nonce bytes (11..16)" + ); + } + + // ── Tampered ciphertext fails decryption ───────────────────────── + + #[test] + fn tampered_ciphertext_fails() { + let cipher = Cipher::Aes128Gcm; + let key = test_key(cipher); + let plaintext = b"Do not tamper with me!"; + let session_id = 42; + + let mut gen = NonceGenerator::new(); + let nonce = gen.next(cipher); + + let (header, mut ciphertext) = + encrypt_message(plaintext, &key, cipher, &nonce, session_id).unwrap(); + + // Flip a byte in the ciphertext. + ciphertext[0] ^= 0xFF; + + let result = decrypt_message(&header, &ciphertext, &key, cipher); + assert!(result.is_err(), "tampered ciphertext must fail decryption"); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("tag mismatch") || err.contains("decryption failed"), + "error was: {err}" + ); + } + + // ── Wrong key fails decryption ─────────────────────────────────── + + #[test] + fn wrong_key_fails() { + let cipher = Cipher::Aes256Gcm; + let key = test_key(cipher); + let wrong_key = vec![0x99; cipher.key_len()]; + let plaintext = b"Secret message"; + let session_id = 100; + + let mut gen = NonceGenerator::new(); + let nonce = gen.next(cipher); + + let (header, ciphertext) = + encrypt_message(plaintext, &key, cipher, &nonce, session_id).unwrap(); + + let result = decrypt_message(&header, &ciphertext, &wrong_key, cipher); + assert!(result.is_err(), "wrong key must fail decryption"); + } + + // ── AAD includes correct TRANSFORM_HEADER bytes (offset 20-51) ── + + #[test] + fn aad_is_correct_header_region() { + // Verify the AAD constants match the spec. + assert_eq!(AAD_OFFSET, 20, "AAD starts at byte 20"); + assert_eq!( + HEADER_SIZE - AAD_OFFSET, + 32, + "AAD is 32 bytes (Nonce + OrigMsgSize + Reserved + Flags + SessionId)" + ); + assert_eq!(HEADER_SIZE, 52, "TRANSFORM_HEADER is 52 bytes"); + + // Build a header and verify the AAD region contains the expected fields. + let mut nonce = [0u8; 16]; + nonce[0] = 0xAA; + nonce[7] = 0xBB; + + let header = TransformHeader { + signature: [0xFF; 16], + nonce, + original_message_size: 1024, + flags: SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED, + session_id: SessionId(0x0123_4567_89AB_CDEF), + }; + + let mut w = WriteCursor::new(); + header.pack(&mut w); + let bytes = w.into_inner(); + + let aad = &bytes[AAD_OFFSET..HEADER_SIZE]; + assert_eq!(aad.len(), 32); + + // First 16 bytes of AAD should be the nonce. + assert_eq!(aad[0], 0xAA, "nonce byte 0"); + assert_eq!(aad[7], 0xBB, "nonce byte 7"); + + // Bytes 16..20 of AAD should be OriginalMessageSize (1024 LE). + assert_eq!( + u32::from_le_bytes(aad[16..20].try_into().unwrap()), + 1024, + "OriginalMessageSize" + ); + + // Bytes 20..22 of AAD should be Reserved (0). + assert_eq!(aad[20..22], [0, 0], "Reserved"); + + // Bytes 22..24 of AAD should be Flags (0x0001). + assert_eq!( + u16::from_le_bytes(aad[22..24].try_into().unwrap()), + SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED, + "Flags" + ); + + // Bytes 24..32 of AAD should be SessionId. + assert_eq!( + u64::from_le_bytes(aad[24..32].try_into().unwrap()), + 0x0123_4567_89AB_CDEF, + "SessionId" + ); + } + + // ── Transform header has correct protocol ID ───────────────────── + + #[test] + fn transform_header_protocol_id() { + let cipher = Cipher::Aes128Gcm; + let key = test_key(cipher); + let plaintext = b"test"; + let session_id = 1; + + let mut gen = NonceGenerator::new(); + let nonce = gen.next(cipher); + + let (header, _) = encrypt_message(plaintext, &key, cipher, &nonce, session_id).unwrap(); + + // First 4 bytes must be 0xFD 'S' 'M' 'B'. + assert_eq!(&header[..4], &TRANSFORM_PROTOCOL_ID); + assert_eq!(header[0], 0xFD, "protocol ID first byte must be 0xFD"); + assert_eq!(header[1], b'S'); + assert_eq!(header[2], b'M'); + assert_eq!(header[3], b'B'); + } + + // ── Auth tag (signature) is at bytes 4..20 ────────────────────── + + #[test] + fn signature_position_in_header() { + let cipher = Cipher::Aes256Ccm; + let key = test_key(cipher); + let plaintext = b"Check signature position"; + let session_id = 99; + + let mut gen = NonceGenerator::new(); + let nonce = gen.next(cipher); + + let (header, _) = encrypt_message(plaintext, &key, cipher, &nonce, session_id).unwrap(); + + // The signature (auth tag) lives at bytes 4..20. + let signature = &header[4..20]; + + // It should NOT be all zeros (that would mean we forgot to write it). + assert_ne!( + signature, &[0u8; 16], + "signature must not be all zeros after encryption" + ); + + // Verify that using this tag allows successful decryption + // (already covered by roundtrip tests, but this confirms the + // position explicitly). + let decrypted = decrypt_message(&header, &header[..0], &key, cipher); + // This will fail because we passed empty ciphertext, but that's + // not the point -- the roundtrip tests cover correctness. + // Instead, let's verify the tag by a proper roundtrip. + drop(decrypted); + + let (header2, ct2) = encrypt_message(plaintext, &key, cipher, &nonce, session_id).unwrap(); + let result = decrypt_message(&header2, &ct2, &key, cipher).unwrap(); + assert_eq!(result, plaintext); + } +} diff --git a/vendor/smb2/src/crypto/kdf.rs b/vendor/smb2/src/crypto/kdf.rs new file mode 100644 index 0000000..5f31fe8 --- /dev/null +++ b/vendor/smb2/src/crypto/kdf.rs @@ -0,0 +1,525 @@ +//! SP800-108 key derivation and preauthentication integrity hashing for SMB2/3. +//! +//! SMB 3.x uses NIST SP800-108 KDF in counter mode with HMAC-SHA256 as the PRF +//! to derive signing, encryption, and decryption keys from the session key. +//! +//! SMB 3.1.1 additionally requires a preauthentication integrity hash (SHA-512) +//! computed over the raw wire bytes of NEGOTIATE and SESSION_SETUP exchanges, +//! which feeds into the KDF as the "context" parameter. + +use crate::types::Dialect; +use digest::{Digest, KeyInit}; +use hmac::{Hmac, Mac}; +use sha2::{Sha256, Sha512}; + +type HmacSha256 = Hmac; + +/// Derive a key using SP800-108 KDF in counter mode with HMAC-SHA256. +/// +/// This implements the algorithm from NIST SP800-108 section 5.1 as required +/// by MS-SMB2 section 3.1.4.2. The counter width ('r') is 32 bits, and the +/// PRF is HMAC-SHA256. +/// +/// # Arguments +/// +/// * `key` - The key to derive from (the session key from authentication). +/// * `label` - Label string (including null terminator). +/// * `context` - Context string or preauth hash (including null terminator for +/// string contexts). +/// * `key_length_bits` - Desired output key length in bits (128 or 256). +pub fn sp800_108_kdf(key: &[u8], label: &[u8], context: &[u8], key_length_bits: u32) -> Vec { + let iterations = key_length_bits.div_ceil(256); + let mut result = Vec::with_capacity((iterations * 32) as usize); + + for i in 1..=iterations { + let mut mac = HmacSha256::new_from_slice(key).expect("HMAC-SHA256 accepts any key length"); + + // counter (32-bit big-endian) + mac.update(&i.to_be_bytes()); + // label + mac.update(label); + // separator byte 0x00 + mac.update(&[0x00]); + // context + mac.update(context); + // L = key length in bits (32-bit big-endian) + mac.update(&key_length_bits.to_be_bytes()); + + result.extend_from_slice(&mac.finalize().into_bytes()); + } + + result.truncate((key_length_bits / 8) as usize); + result +} + +/// Derived session keys for signing, encryption, and decryption. +#[derive(Debug, Clone)] +pub struct DerivedKeys { + /// Key used to sign outgoing messages. + pub signing_key: Vec, + /// Key used to encrypt outgoing messages. + pub encryption_key: Vec, + /// Key used to decrypt incoming messages. + pub decryption_key: Vec, +} + +/// Derive session keys for the given dialect. +/// +/// For SMB 3.0 and 3.0.2, the context is a fixed ASCII string. +/// For SMB 3.1.1, the context is the preauthentication integrity hash value +/// (64 bytes from SHA-512). +/// +/// # Panics +/// +/// Panics if `dialect` is SMB 3.1.1 and `preauth_hash` is `None`. +/// Panics if `dialect` is not in the SMB 3.x family. +pub fn derive_session_keys( + session_key: &[u8], + dialect: Dialect, + preauth_hash: Option<&[u8; 64]>, + key_length_bits: u32, +) -> DerivedKeys { + assert!( + matches!( + dialect, + Dialect::Smb3_0 | Dialect::Smb3_0_2 | Dialect::Smb3_1_1 + ), + "Key derivation is only applicable for the SMB 3.x dialect family" + ); + + let (signing_label, signing_context): (&[u8], &[u8]); + let (enc_label, enc_context): (&[u8], &[u8]); + let (dec_label, dec_context): (&[u8], &[u8]); + + if dialect == Dialect::Smb3_1_1 { + let hash = preauth_hash + .expect("SMB 3.1.1 requires a preauthentication integrity hash for key derivation"); + // SMB 3.1.1 labels include null terminator (matches smb-rs and + // the MS-SMB2 spec's Label field definitions) + signing_label = b"SMBSigningKey\0"; + signing_context = hash.as_slice(); + enc_label = b"SMBC2SCipherKey\0"; + enc_context = hash.as_slice(); + dec_label = b"SMBS2CCipherKey\0"; + dec_context = hash.as_slice(); + } else { + // SMB 3.0 and 3.0.2 + signing_label = b"SMB2AESCMAC\0"; + signing_context = b"SmbSign\0"; + enc_label = b"SMB2AESCCM\0"; + enc_context = b"ServerIn \0"; + dec_label = b"SMB2AESCCM\0"; + dec_context = b"ServerOut\0"; + } + + DerivedKeys { + signing_key: sp800_108_kdf(session_key, signing_label, signing_context, key_length_bits), + encryption_key: sp800_108_kdf(session_key, enc_label, enc_context, key_length_bits), + decryption_key: sp800_108_kdf(session_key, dec_label, dec_context, key_length_bits), + } +} + +/// Running hash over negotiate and session-setup exchange bytes. +/// +/// Used as the "context" parameter to the KDF for SMB 3.1.1. The hash +/// algorithm is SHA-512, producing a 64-byte value. +/// +/// The hash is computed incrementally: +/// 1. Initialize with 64 zero bytes +/// 2. `update()` with negotiate request raw bytes +/// 3. `update()` with negotiate response raw bytes +/// 4. (Clone for session hash) +/// 5. `update()` with session setup request raw bytes +/// 6. `update()` with session setup response raw bytes +/// 7. Repeat 5-6 for each SESSION_SETUP round-trip +/// +/// Each `update()` computes: `hash = SHA-512(previous_hash || message_bytes)` +pub struct PreauthHasher { + hash: [u8; 64], +} + +impl PreauthHasher { + /// Create a new hasher initialized with 64 zero bytes. + pub fn new() -> Self { + Self { hash: [0u8; 64] } + } + + /// Update the hash with a message's raw wire bytes. + /// + /// Computes `hash = SHA-512(previous_hash || message_bytes)`. + pub fn update(&mut self, message_bytes: &[u8]) { + let mut hasher = Sha512::new(); + hasher.update(self.hash); + hasher.update(message_bytes); + self.hash.copy_from_slice(&hasher.finalize()); + } + + /// Get the current hash value (64 bytes). + pub fn value(&self) -> &[u8; 64] { + &self.hash + } +} + +impl Default for PreauthHasher { + fn default() -> Self { + Self::new() + } +} + +impl Clone for PreauthHasher { + fn clone(&self) -> Self { + Self { hash: self.hash } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ======================================================================== + // SP800-108 KDF tests + // ======================================================================== + + #[test] + fn kdf_128_bit_output_is_16_bytes() { + let key = [0xAA; 16]; + let result = sp800_108_kdf(&key, b"label\0", b"context\0", 128); + assert_eq!(result.len(), 16); + } + + #[test] + fn kdf_256_bit_output_is_32_bytes() { + let key = [0xBB; 16]; + let result = sp800_108_kdf(&key, b"label\0", b"context\0", 256); + assert_eq!(result.len(), 32); + } + + #[test] + fn kdf_is_deterministic() { + let key = [0x42; 16]; + let label = b"TestLabel\0"; + let context = b"TestContext\0"; + let r1 = sp800_108_kdf(&key, label, context, 128); + let r2 = sp800_108_kdf(&key, label, context, 128); + assert_eq!(r1, r2); + } + + #[test] + fn kdf_different_labels_produce_different_keys() { + let key = [0x42; 16]; + let context = b"ctx\0"; + let k1 = sp800_108_kdf(&key, b"LabelA\0", context, 128); + let k2 = sp800_108_kdf(&key, b"LabelB\0", context, 128); + assert_ne!(k1, k2); + } + + #[test] + fn kdf_different_contexts_produce_different_keys() { + let key = [0x42; 16]; + let label = b"label\0"; + let k1 = sp800_108_kdf(&key, label, b"ContextA\0", 128); + let k2 = sp800_108_kdf(&key, label, b"ContextB\0", 128); + assert_ne!(k1, k2); + } + + #[test] + fn kdf_different_session_keys_produce_different_derived_keys() { + let label = b"SMB2AESCMAC\0"; + let context = b"SmbSign\0"; + let k1 = sp800_108_kdf(&[0x11; 16], label, context, 128); + let k2 = sp800_108_kdf(&[0x22; 16], label, context, 128); + assert_ne!(k1, k2); + } + + /// Verify KDF output against a manually computed value. + /// + /// For a single iteration (128-bit output), the KDF computes: + /// HMAC-SHA256(key, 0x00000001 || label || 0x00 || context || 0x00000080) + /// and takes the first 16 bytes. + #[test] + fn kdf_known_vector_single_iteration() { + let key = [0x00u8; 16]; + let label = b"SMB2AESCMAC\0"; + let context = b"SmbSign\0"; + + // Manually compute the expected value. + let mut mac = HmacSha256::new_from_slice(&key).unwrap(); + mac.update(&1u32.to_be_bytes()); // counter = 1 + mac.update(label); // label + mac.update(&[0x00]); // separator + mac.update(context); // context + mac.update(&128u32.to_be_bytes()); // L = 128 + let full = mac.finalize().into_bytes(); + let expected = &full[..16]; + + let result = sp800_108_kdf(&key, label, context, 128); + assert_eq!(result.as_slice(), expected); + } + + /// Verify that 256-bit KDF uses two iterations and concatenates correctly. + #[test] + fn kdf_known_vector_two_iterations() { + let key = [0xFFu8; 16]; + let label = b"TestLabel\0"; + let context = b"TestCtx\0"; + + // Compute iteration 1 + let mut mac1 = HmacSha256::new_from_slice(&key).unwrap(); + mac1.update(&1u32.to_be_bytes()); + mac1.update(label); + mac1.update(&[0x00]); + mac1.update(context); + mac1.update(&256u32.to_be_bytes()); + let block1 = mac1.finalize().into_bytes(); + + // 256 bits = 32 bytes = exactly one HMAC-SHA256 block, so only one + // iteration is needed. But let's verify with the formula: + // ceil(256 / 256) = 1 iteration. So 256-bit also needs just one. + let result = sp800_108_kdf(&key, label, context, 256); + assert_eq!(result.len(), 32); + assert_eq!(result.as_slice(), block1.as_slice()); + } + + // ======================================================================== + // derive_session_keys tests + // ======================================================================== + + #[test] + fn derive_keys_smb3_0_uses_legacy_labels() { + let session_key = [0x42; 16]; + let keys = derive_session_keys(&session_key, Dialect::Smb3_0, None, 128); + + // Verify each key matches what we'd get calling KDF directly with the + // SMB 3.0 label/context pairs. + assert_eq!( + keys.signing_key, + sp800_108_kdf(&session_key, b"SMB2AESCMAC\0", b"SmbSign\0", 128) + ); + assert_eq!( + keys.encryption_key, + sp800_108_kdf(&session_key, b"SMB2AESCCM\0", b"ServerIn \0", 128) + ); + assert_eq!( + keys.decryption_key, + sp800_108_kdf(&session_key, b"SMB2AESCCM\0", b"ServerOut\0", 128) + ); + } + + #[test] + fn derive_keys_smb3_0_2_uses_legacy_labels() { + let session_key = [0x42; 16]; + let keys = derive_session_keys(&session_key, Dialect::Smb3_0_2, None, 128); + + assert_eq!( + keys.signing_key, + sp800_108_kdf(&session_key, b"SMB2AESCMAC\0", b"SmbSign\0", 128) + ); + assert_eq!( + keys.encryption_key, + sp800_108_kdf(&session_key, b"SMB2AESCCM\0", b"ServerIn \0", 128) + ); + assert_eq!( + keys.decryption_key, + sp800_108_kdf(&session_key, b"SMB2AESCCM\0", b"ServerOut\0", 128) + ); + } + + #[test] + fn derive_keys_smb3_1_1_uses_new_labels_with_preauth_hash() { + let session_key = [0x42; 16]; + let preauth_hash = [0xAB; 64]; + let keys = derive_session_keys(&session_key, Dialect::Smb3_1_1, Some(&preauth_hash), 128); + + assert_eq!( + keys.signing_key, + sp800_108_kdf(&session_key, b"SMBSigningKey\0", &preauth_hash, 128) + ); + assert_eq!( + keys.encryption_key, + sp800_108_kdf(&session_key, b"SMBC2SCipherKey\0", &preauth_hash, 128) + ); + assert_eq!( + keys.decryption_key, + sp800_108_kdf(&session_key, b"SMBS2CCipherKey\0", &preauth_hash, 128) + ); + } + + #[test] + fn derive_keys_smb3_1_1_256_bit() { + let session_key = [0x42; 16]; + let preauth_hash = [0xCD; 64]; + let keys = derive_session_keys(&session_key, Dialect::Smb3_1_1, Some(&preauth_hash), 256); + + assert_eq!(keys.signing_key.len(), 32); + assert_eq!(keys.encryption_key.len(), 32); + assert_eq!(keys.decryption_key.len(), 32); + } + + #[test] + fn derive_keys_all_three_are_different() { + let session_key = [0x42; 16]; + let keys = derive_session_keys(&session_key, Dialect::Smb3_0, None, 128); + + assert_ne!(keys.signing_key, keys.encryption_key); + assert_ne!(keys.signing_key, keys.decryption_key); + assert_ne!(keys.encryption_key, keys.decryption_key); + } + + #[test] + #[should_panic(expected = "preauthentication integrity hash")] + fn derive_keys_smb3_1_1_panics_without_preauth_hash() { + let session_key = [0x42; 16]; + derive_session_keys(&session_key, Dialect::Smb3_1_1, None, 128); + } + + #[test] + #[should_panic(expected = "SMB 3.x dialect family")] + fn derive_keys_panics_for_smb2() { + let session_key = [0x42; 16]; + derive_session_keys(&session_key, Dialect::Smb2_0_2, None, 128); + } + + // ======================================================================== + // PreauthHasher tests + // ======================================================================== + + #[test] + fn preauth_hasher_starts_with_64_zero_bytes() { + let hasher = PreauthHasher::new(); + assert_eq!(hasher.value(), &[0u8; 64]); + } + + #[test] + fn preauth_hasher_default_equals_new() { + let h1 = PreauthHasher::new(); + let h2 = PreauthHasher::default(); + assert_eq!(h1.value(), h2.value()); + } + + #[test] + fn preauth_hasher_update_changes_hash() { + let mut hasher = PreauthHasher::new(); + let initial = *hasher.value(); + hasher.update(b"negotiate request bytes"); + assert_ne!(hasher.value(), &initial); + } + + #[test] + fn preauth_hasher_two_updates_differ_from_one() { + let mut hasher1 = PreauthHasher::new(); + hasher1.update(b"message1"); + + let mut hasher2 = PreauthHasher::new(); + hasher2.update(b"message1"); + hasher2.update(b"message2"); + + assert_ne!(hasher1.value(), hasher2.value()); + } + + #[test] + fn preauth_hasher_is_deterministic() { + let mut h1 = PreauthHasher::new(); + h1.update(b"negotiate request"); + h1.update(b"negotiate response"); + + let mut h2 = PreauthHasher::new(); + h2.update(b"negotiate request"); + h2.update(b"negotiate response"); + + assert_eq!(h1.value(), h2.value()); + } + + #[test] + fn preauth_hasher_empty_update_changes_hash() { + // SHA-512(64_zeros || empty) != 64_zeros + let mut hasher = PreauthHasher::new(); + let initial = *hasher.value(); + hasher.update(b""); + assert_ne!(hasher.value(), &initial); + } + + #[test] + fn preauth_hasher_known_value() { + // Verify against direct SHA-512 computation. + let mut hasher = PreauthHasher::new(); + hasher.update(b"test"); + + let mut expected_hasher = Sha512::new(); + expected_hasher.update([0u8; 64]); + expected_hasher.update(b"test"); + let expected = expected_hasher.finalize(); + + assert_eq!(hasher.value().as_slice(), expected.as_slice()); + } + + #[test] + fn preauth_hasher_chained_known_value() { + // Two updates: hash1 = SHA-512(zeros || msg1), hash2 = SHA-512(hash1 || msg2) + let mut hasher = PreauthHasher::new(); + hasher.update(b"negotiate"); + hasher.update(b"response"); + + // Compute manually + let mut h = Sha512::new(); + h.update([0u8; 64]); + h.update(b"negotiate"); + let hash1: [u8; 64] = h.finalize().into(); + + let mut h2 = Sha512::new(); + h2.update(hash1); + h2.update(b"response"); + let hash2: [u8; 64] = h2.finalize().into(); + + assert_eq!(hasher.value(), &hash2); + } + + #[test] + fn preauth_hasher_clone_is_independent() { + let mut hasher = PreauthHasher::new(); + hasher.update(b"negotiate request"); + hasher.update(b"negotiate response"); + + // Clone for session hash (spec step 4) + let mut session_hasher = hasher.clone(); + session_hasher.update(b"session setup request"); + + // Original should not be affected + assert_ne!(hasher.value(), session_hasher.value()); + } + + #[test] + fn preauth_hasher_output_is_64_bytes() { + let mut hasher = PreauthHasher::new(); + hasher.update(b"some data"); + assert_eq!(hasher.value().len(), 64); + } + + /// Full end-to-end test: preauth hash feeds into KDF for SMB 3.1.1. + #[test] + fn preauth_hash_feeds_into_kdf() { + // Simulate the protocol flow + let mut conn_hasher = PreauthHasher::new(); + conn_hasher.update(b"negotiate request bytes"); + conn_hasher.update(b"negotiate response bytes"); + + let mut session_hasher = conn_hasher.clone(); + session_hasher.update(b"session setup request bytes"); + session_hasher.update(b"session setup response bytes"); + + let session_key = [0x42; 16]; + let keys = derive_session_keys( + &session_key, + Dialect::Smb3_1_1, + Some(session_hasher.value()), + 128, + ); + + // Keys should all be 16 bytes and different from each other + assert_eq!(keys.signing_key.len(), 16); + assert_eq!(keys.encryption_key.len(), 16); + assert_eq!(keys.decryption_key.len(), 16); + assert_ne!(keys.signing_key, keys.encryption_key); + assert_ne!(keys.signing_key, keys.decryption_key); + assert_ne!(keys.encryption_key, keys.decryption_key); + } +} diff --git a/vendor/smb2/src/crypto/mod.rs b/vendor/smb2/src/crypto/mod.rs new file mode 100644 index 0000000..f0f9389 --- /dev/null +++ b/vendor/smb2/src/crypto/mod.rs @@ -0,0 +1,9 @@ +//! Cryptographic operations for SMB2/3: signing, encryption, key derivation, and compression. +//! +//! Most users don't need this module directly -- [`SmbClient`](crate::SmbClient) +//! handles signing and encryption automatically. + +pub mod compression; +pub mod encryption; +pub mod kdf; +pub mod signing; diff --git a/vendor/smb2/src/crypto/signing.rs b/vendor/smb2/src/crypto/signing.rs new file mode 100644 index 0000000..cd11d78 --- /dev/null +++ b/vendor/smb2/src/crypto/signing.rs @@ -0,0 +1,789 @@ +//! SMB2 message signing and signature verification. +//! +//! Supports three signing algorithms, selected by negotiated dialect: +//! - **HMAC-SHA256** (SMB 2.0.2, 2.1): 32-byte hash truncated to 16 bytes. +//! - **AES-128-CMAC** (SMB 3.0, 3.0.2): 16-byte MAC. +//! - **AES-256-GMAC** (SMB 3.1.1 with `SMB2_SIGNING_CAPABILITIES`): AES-256-GCM +//! with empty plaintext; the 16-byte auth tag is the signature. +//! +//! Reference: MS-SMB2 sections 3.1.4.1 (signing) and 3.1.5.1 (verification). + +use log::{debug, error, trace}; + +use crate::types::Dialect; +use crate::Error; + +/// Offset of the 16-byte Signature field within the SMB2 header. +const SIGNATURE_OFFSET: usize = 48; +/// Length of the Signature field. +const SIGNATURE_LEN: usize = 16; +/// Minimum message length (full SMB2 header). +const MIN_MESSAGE_LEN: usize = 64; + +/// Signing algorithm, determined by negotiated dialect and capabilities. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub enum SigningAlgorithm { + /// HMAC-SHA256 truncated to 16 bytes (SMB 2.0.2, 2.1). + HmacSha256, + /// AES-128-CMAC (SMB 3.0, 3.0.2). + AesCmac, + /// AES-256-GMAC with MessageId-based nonce (SMB 3.1.1). + AesGmac, +} + +/// Select the appropriate signing algorithm for a dialect. +/// +/// For SMB 3.1.1, `gmac_negotiated` indicates whether the peer negotiated +/// `AES-256-GMAC` via `SMB2_SIGNING_CAPABILITIES`. When `false`, SMB 3.1.1 +/// falls back to AES-128-CMAC. +pub fn algorithm_for_dialect(dialect: Dialect, gmac_negotiated: bool) -> SigningAlgorithm { + match dialect { + Dialect::Smb2_0_2 | Dialect::Smb2_1 => SigningAlgorithm::HmacSha256, + Dialect::Smb3_0 | Dialect::Smb3_0_2 => SigningAlgorithm::AesCmac, + Dialect::Smb3_1_1 => { + if gmac_negotiated { + SigningAlgorithm::AesGmac + } else { + SigningAlgorithm::AesCmac + } + } + } +} + +/// Sign an SMB2 message in-place (client → server). +/// +/// Zeros the signature field (bytes 48-63), computes the signature +/// over the full message, and writes the computed signature back. +/// +/// For AES-GMAC, `message_id` and `is_cancel` are used to construct +/// the 12-byte nonce. For other algorithms these parameters are ignored. +/// +/// # Errors +/// +/// Returns [`Error::InvalidData`] if the message is shorter than 64 bytes +/// or the key length is wrong for the chosen algorithm. +pub fn sign_message( + message: &mut [u8], + key: &[u8], + algorithm: SigningAlgorithm, + message_id: u64, + is_cancel: bool, +) -> Result<(), Error> { + if message.len() < MIN_MESSAGE_LEN { + return Err(Error::invalid_data(format!( + "message too short for signing: {} bytes, need at least {}", + message.len(), + MIN_MESSAGE_LEN + ))); + } + + // Step 1: zero the signature field. + message[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN].fill(0); + + // Step 2: compute signature over the entire message. + // is_response = false: we're the client, signing an outgoing request. + let signature = compute_signature(message, key, algorithm, message_id, is_cancel, false)?; + + // Step 3: write the signature back. + message[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN].copy_from_slice(&signature); + + debug!( + "signing: signed msg_id={}, algo={:?}, sig={:02x}{:02x}{:02x}{:02x}...", + message_id, algorithm, signature[0], signature[1], signature[2], signature[3] + ); + Ok(()) +} + +/// Verify the signature on a received SMB2 message (server → client). +/// +/// Returns `Ok(())` if the signature matches, or [`Error::InvalidData`] +/// if the message is tampered or the key is wrong. +/// +/// For GMAC, the nonce role bit is set to 1 (server) automatically. +pub fn verify_signature( + message: &[u8], + key: &[u8], + algorithm: SigningAlgorithm, + message_id: u64, + is_cancel: bool, +) -> Result<(), Error> { + if message.len() < MIN_MESSAGE_LEN { + return Err(Error::invalid_data(format!( + "message too short for verification: {} bytes, need at least {}", + message.len(), + MIN_MESSAGE_LEN + ))); + } + + // Step 1: save the received signature. + let mut received_sig = [0u8; SIGNATURE_LEN]; + received_sig.copy_from_slice(&message[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN]); + + // Step 2: zero the signature field in a copy. + let mut buf = message.to_vec(); + buf[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN].fill(0); + + // Step 3: compute the expected signature. + // is_response = true: the server signed this message, so the GMAC + // nonce must have role bit = 1 (server). + let expected_sig = compute_signature(&buf, key, algorithm, message_id, is_cancel, true)?; + + // Step 4: compare. + if received_sig != expected_sig { + error!( + "signing: verification failed, msg_id={}, algo={:?}, got={:02x}{:02x}{:02x}{:02x}..., want={:02x}{:02x}{:02x}{:02x}...", + message_id, algorithm, + received_sig[0], received_sig[1], received_sig[2], received_sig[3], + expected_sig[0], expected_sig[1], expected_sig[2], expected_sig[3] + ); + return Err(Error::invalid_data("signature verification failed")); + } + + trace!( + "signing: verified msg_id={}, algo={:?}, sig={:02x}{:02x}{:02x}{:02x}...", + message_id, + algorithm, + received_sig[0], + received_sig[1], + received_sig[2], + received_sig[3] + ); + Ok(()) +} + +/// Compute a 16-byte signature over `message` using the given algorithm. +fn compute_signature( + message: &[u8], + key: &[u8], + algorithm: SigningAlgorithm, + message_id: u64, + is_cancel: bool, + is_response: bool, +) -> Result<[u8; 16], Error> { + match algorithm { + SigningAlgorithm::HmacSha256 => compute_hmac_sha256(message, key), + SigningAlgorithm::AesCmac => compute_aes_cmac(message, key), + SigningAlgorithm::AesGmac => { + compute_aes_gmac(message, key, message_id, is_cancel, is_response) + } + } +} + +/// HMAC-SHA256, truncated to 16 bytes. Key must be 16 bytes. +fn compute_hmac_sha256(message: &[u8], key: &[u8]) -> Result<[u8; 16], Error> { + use digest::KeyInit; + use hmac::{Hmac, Mac}; + use sha2::Sha256; + + type HmacSha256 = Hmac; + + let mut mac = HmacSha256::new_from_slice(key) + .map_err(|e| Error::invalid_data(format!("HMAC-SHA256 key error: {e}")))?; + mac.update(message); + let result = mac.finalize().into_bytes(); + + // Truncate 32-byte hash to first 16 bytes. + let mut sig = [0u8; 16]; + sig.copy_from_slice(&result[..16]); + Ok(sig) +} + +/// AES-128-CMAC. Key must be 16 bytes. +fn compute_aes_cmac(message: &[u8], key: &[u8]) -> Result<[u8; 16], Error> { + use aes::Aes128; + use cmac::{Cmac, Mac}; + use digest::KeyInit; + + type AesCmac = Cmac; + + let mut mac = AesCmac::new_from_slice(key) + .map_err(|e| Error::invalid_data(format!("AES-CMAC key error: {e}")))?; + mac.update(message); + let result = mac.finalize().into_bytes(); + + let mut sig = [0u8; 16]; + sig.copy_from_slice(&result); + Ok(sig) +} + +/// AES-128-GMAC (AES-128-GCM with empty plaintext). Key must be 16 bytes. +/// +/// The 12-byte nonce is constructed as (MS-SMB2 section 3.1.4.1): +/// - Bytes 0-7: `message_id` (little-endian u64) +/// - Byte 8: bit 0 = role (0=client, 1=server), bit 1 = `is_cancel` +/// - Bytes 9-11: zero +fn compute_aes_gmac( + message: &[u8], + key: &[u8], + message_id: u64, + is_cancel: bool, + is_response: bool, +) -> Result<[u8; 16], Error> { + use aes_gcm::aead::Aead; + use aes_gcm::{Aes128Gcm, KeyInit, Nonce}; + + if key.len() != 16 { + return Err(Error::invalid_data(format!( + "AES-128-GMAC requires a 16-byte key, got {} bytes", + key.len() + ))); + } + + // Build 12-byte nonce. + let mut nonce_bytes = [0u8; 12]; + nonce_bytes[0..8].copy_from_slice(&message_id.to_le_bytes()); + // Byte 8: bit 0 = role (0 = client, 1 = server), bit 1 = CANCEL flag. + let mut flags_byte: u8 = 0; + if is_response { + flags_byte |= 0x01; // server role + } + if is_cancel { + flags_byte |= 0x02; + } + nonce_bytes[8] = flags_byte; + + let cipher = Aes128Gcm::new(key.try_into().map_err(|_| { + Error::invalid_data(format!( + "AES-128-GMAC requires a 16-byte key, got {} bytes", + key.len() + )) + })?); + let nonce: &Nonce<_> = (&nonce_bytes).into(); + + // GMAC mode: encrypt empty plaintext with the message as AAD. + // The "ciphertext" is empty; the auth tag IS the signature. + use aes_gcm::aead::Payload; + let payload = Payload { + msg: &[], + aad: message, + }; + + let ciphertext = cipher + .encrypt(nonce, payload) + .map_err(|e| Error::invalid_data(format!("AES-256-GMAC encryption error: {e}")))?; + + // The output is the 16-byte auth tag (no ciphertext bytes since plaintext was empty). + if ciphertext.len() != 16 { + return Err(Error::invalid_data(format!( + "unexpected GMAC output length: expected 16, got {}", + ciphertext.len() + ))); + } + + let mut sig = [0u8; 16]; + sig.copy_from_slice(&ciphertext); + Ok(sig) +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Build a minimal 64-byte fake SMB2 message for testing. + /// The signature field (bytes 48-63) is zeroed. + fn make_test_message(body_extra: &[u8]) -> Vec { + let mut msg = vec![0u8; 64 + body_extra.len()]; + // Protocol ID + msg[0..4].copy_from_slice(&[0xFE, b'S', b'M', b'B']); + // Structure size = 64 + msg[4..6].copy_from_slice(&64u16.to_le_bytes()); + // Fill some fields so the message isn't all zeros + msg[12..14].copy_from_slice(&0x0008u16.to_le_bytes()); // Command = Read + msg[24..32].copy_from_slice(&42u64.to_le_bytes()); // MessageId = 42 + // Append body + msg[64..].copy_from_slice(body_extra); + msg + } + + // ── algorithm_for_dialect ───────────────────────────────────────── + + #[test] + fn algorithm_for_smb2_0_2_is_hmac_sha256() { + assert_eq!( + algorithm_for_dialect(Dialect::Smb2_0_2, false), + SigningAlgorithm::HmacSha256 + ); + } + + #[test] + fn algorithm_for_smb2_1_is_hmac_sha256() { + assert_eq!( + algorithm_for_dialect(Dialect::Smb2_1, false), + SigningAlgorithm::HmacSha256 + ); + } + + #[test] + fn algorithm_for_smb3_0_is_aes_cmac() { + assert_eq!( + algorithm_for_dialect(Dialect::Smb3_0, false), + SigningAlgorithm::AesCmac + ); + } + + #[test] + fn algorithm_for_smb3_0_2_is_aes_cmac() { + assert_eq!( + algorithm_for_dialect(Dialect::Smb3_0_2, false), + SigningAlgorithm::AesCmac + ); + } + + #[test] + fn algorithm_for_smb3_1_1_without_gmac_is_aes_cmac() { + assert_eq!( + algorithm_for_dialect(Dialect::Smb3_1_1, false), + SigningAlgorithm::AesCmac + ); + } + + #[test] + fn algorithm_for_smb3_1_1_with_gmac_is_aes_gmac() { + assert_eq!( + algorithm_for_dialect(Dialect::Smb3_1_1, true), + SigningAlgorithm::AesGmac + ); + } + + #[test] + fn gmac_flag_ignored_for_older_dialects() { + // Even if gmac_negotiated is true, older dialects don't use GMAC. + assert_eq!( + algorithm_for_dialect(Dialect::Smb2_0_2, true), + SigningAlgorithm::HmacSha256 + ); + assert_eq!( + algorithm_for_dialect(Dialect::Smb3_0, true), + SigningAlgorithm::AesCmac + ); + } + + // ── Message too short ───────────────────────────────────────────── + + #[test] + fn sign_rejects_message_shorter_than_64_bytes() { + let mut msg = vec![0u8; 32]; + let key = [0u8; 16]; + let result = sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("too short")); + } + + #[test] + fn verify_rejects_message_shorter_than_64_bytes() { + let msg = vec![0u8; 32]; + let key = [0u8; 16]; + let result = verify_signature(&msg, &key, SigningAlgorithm::HmacSha256, 0, false); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("too short")); + } + + // ── HMAC-SHA256 ────────────────────────────────────────────────── + + #[test] + fn hmac_sha256_sign_produces_nonzero_signature() { + let mut msg = make_test_message(b"hello world"); + let key = [0xAA; 16]; + sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap(); + + let sig = &msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN]; + assert_ne!(sig, &[0u8; 16], "signature should not be all zeros"); + } + + #[test] + fn hmac_sha256_known_signature() { + // Compute expected HMAC-SHA256 using the same process: + // zero sig field, compute HMAC, truncate to 16 bytes. + let mut msg = make_test_message(&[]); + let key = [0x01; 16]; + + // Manually compute expected value. + let mut zeroed = msg.clone(); + zeroed[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN].fill(0); + let expected = { + use digest::KeyInit; + use hmac::{Hmac, Mac}; + use sha2::Sha256; + type H = Hmac; + let mut mac = H::new_from_slice(&key).unwrap(); + mac.update(&zeroed); + let full = mac.finalize().into_bytes(); + let mut trunc = [0u8; 16]; + trunc.copy_from_slice(&full[..16]); + trunc + }; + + sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap(); + assert_eq!( + &msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN], + &expected + ); + } + + #[test] + fn hmac_sha256_sign_then_verify_roundtrip() { + let mut msg = make_test_message(b"some payload data"); + let key = [0x42; 16]; + sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap(); + verify_signature(&msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap(); + } + + #[test] + fn hmac_sha256_verify_fails_on_tampered_message() { + let mut msg = make_test_message(b"original data"); + let key = [0x42; 16]; + sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap(); + + // Flip a byte in the body. + let last = msg.len() - 1; + msg[last] ^= 0xFF; + + let result = verify_signature(&msg, &key, SigningAlgorithm::HmacSha256, 0, false); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("verification failed"),); + } + + #[test] + fn hmac_sha256_verify_fails_with_wrong_key() { + let mut msg = make_test_message(b"data"); + let key = [0x42; 16]; + sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap(); + + let wrong_key = [0x43; 16]; + let result = verify_signature(&msg, &wrong_key, SigningAlgorithm::HmacSha256, 0, false); + assert!(result.is_err()); + } + + // ── AES-128-CMAC ──────────────────────────────────────────────── + + #[test] + fn aes_cmac_sign_produces_nonzero_signature() { + let mut msg = make_test_message(b"cmac test"); + let key = [0xBB; 16]; + sign_message(&mut msg, &key, SigningAlgorithm::AesCmac, 0, false).unwrap(); + + let sig = &msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN]; + assert_ne!(sig, &[0u8; 16]); + } + + #[test] + fn aes_cmac_known_signature() { + let mut msg = make_test_message(&[]); + let key = [0x02; 16]; + + let mut zeroed = msg.clone(); + zeroed[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN].fill(0); + let expected = { + use aes::Aes128; + use cmac::{Cmac, Mac}; + use digest::KeyInit; + type C = Cmac; + let mut mac = C::new_from_slice(&key).unwrap(); + mac.update(&zeroed); + let result = mac.finalize().into_bytes(); + let mut sig = [0u8; 16]; + sig.copy_from_slice(&result); + sig + }; + + sign_message(&mut msg, &key, SigningAlgorithm::AesCmac, 0, false).unwrap(); + assert_eq!( + &msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN], + &expected + ); + } + + #[test] + fn aes_cmac_sign_then_verify_roundtrip() { + let mut msg = make_test_message(b"cmac roundtrip payload"); + let key = [0x55; 16]; + sign_message(&mut msg, &key, SigningAlgorithm::AesCmac, 0, false).unwrap(); + verify_signature(&msg, &key, SigningAlgorithm::AesCmac, 0, false).unwrap(); + } + + #[test] + fn aes_cmac_verify_fails_on_tampered_message() { + let mut msg = make_test_message(b"cmac original"); + let key = [0x55; 16]; + sign_message(&mut msg, &key, SigningAlgorithm::AesCmac, 0, false).unwrap(); + + msg[10] ^= 0xFF; + + let result = verify_signature(&msg, &key, SigningAlgorithm::AesCmac, 0, false); + assert!(result.is_err()); + } + + #[test] + fn aes_cmac_verify_fails_with_wrong_key() { + let mut msg = make_test_message(b"cmac data"); + let key = [0x55; 16]; + sign_message(&mut msg, &key, SigningAlgorithm::AesCmac, 0, false).unwrap(); + + let wrong_key = [0x56; 16]; + let result = verify_signature(&msg, &wrong_key, SigningAlgorithm::AesCmac, 0, false); + assert!(result.is_err()); + } + + // ── AES-128-GMAC ──────────────────────────────────────────────── + + #[test] + fn aes_gmac_sign_produces_nonzero_signature() { + let mut msg = make_test_message(b"gmac test"); + let key = [0xCC; 16]; + sign_message(&mut msg, &key, SigningAlgorithm::AesGmac, 1, false).unwrap(); + + let sig = &msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN]; + assert_ne!(sig, &[0u8; 16]); + } + + #[test] + fn aes_gmac_known_signature() { + let mut msg = make_test_message(&[]); + let key = [0x03; 16]; + let message_id: u64 = 7; + + let mut zeroed = msg.clone(); + zeroed[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN].fill(0); + let expected = { + use aes_gcm::aead::{Aead, Payload}; + use aes_gcm::{Aes128Gcm, KeyInit, Nonce}; + + let mut nonce_bytes = [0u8; 12]; + nonce_bytes[0..8].copy_from_slice(&message_id.to_le_bytes()); + // not cancel, client role -> byte 8 = 0 + + let cipher = Aes128Gcm::new((&key).into()); + let nonce: &Nonce<_> = (&nonce_bytes).into(); + let payload = Payload { + msg: &[], + aad: &zeroed, + }; + let ct = cipher.encrypt(nonce, payload).unwrap(); + let mut sig = [0u8; 16]; + sig.copy_from_slice(&ct); + sig + }; + + sign_message(&mut msg, &key, SigningAlgorithm::AesGmac, message_id, false).unwrap(); + assert_eq!( + &msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN], + &expected + ); + } + + #[test] + fn aes_gmac_sign_then_verify_roundtrip() { + // sign_message uses client role (is_response=false internally), + // verify_signature uses server role (is_response=true internally). + // For a self-roundtrip test, we need to test sign+verify on the + // same role. Use the internal compute_signature directly, or + // just verify that a real server flow works (sign as client, + // verify as server would compute -- but that's an integration test). + // + // For this unit test, verify that sign→verify works when the + // message has the SERVER_TO_REDIR flag set (simulating a + // response that we signed ourselves for testing). + let mut msg = make_test_message(b"gmac roundtrip payload"); + // Set SERVER_TO_REDIR flag so verify_signature uses server role bit + let flags = u32::from_le_bytes(msg[16..20].try_into().unwrap()); + let new_flags = flags | 0x0000_0001; // SERVER_TO_REDIR + msg[16..20].copy_from_slice(&new_flags.to_le_bytes()); + + let key = [0xDD; 16]; + // Sign with is_response=false (client), but verify_signature + // always uses is_response=true (server). So we need to compute + // the signature manually with is_response=true to make roundtrip work. + // Actually, let's just test that sign and verify produce consistent + // results by testing each direction independently. + + // Test: sign as client (role=0), verify we can detect tampering + sign_message(&mut msg, &key, SigningAlgorithm::AesGmac, 100, false).unwrap(); + // verify_signature uses role=1 (server), so it WON'T match client-signed. + // This is correct behavior -- client and server signatures differ. + // Instead, test that the signature is non-zero and stable. + let sig1: [u8; 16] = msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN] + .try_into() + .unwrap(); + assert_ne!(sig1, [0u8; 16]); + } + + #[test] + fn aes_gmac_verify_fails_on_tampered_message() { + let mut msg = make_test_message(b"gmac original"); + let key = [0xDD; 16]; + sign_message(&mut msg, &key, SigningAlgorithm::AesGmac, 5, false).unwrap(); + + // Tamper the message -- even though verify uses server role, + // the auth tag won't match ANY valid signature. + let last = msg.len() - 1; + msg[last] ^= 0xFF; + + let result = verify_signature(&msg, &key, SigningAlgorithm::AesGmac, 5, false); + assert!(result.is_err()); + } + + #[test] + fn aes_gmac_verify_fails_with_wrong_key() { + let mut msg = make_test_message(b"gmac data"); + let key = [0xDD; 16]; + sign_message(&mut msg, &key, SigningAlgorithm::AesGmac, 5, false).unwrap(); + + let wrong_key = [0xDE; 16]; + let result = verify_signature(&msg, &wrong_key, SigningAlgorithm::AesGmac, 5, false); + assert!(result.is_err()); + } + + #[test] + fn aes_gmac_rejects_wrong_key_length() { + let mut msg = make_test_message(&[]); + let key = [0xDD; 32]; // 32 bytes instead of 16 + let result = sign_message(&mut msg, &key, SigningAlgorithm::AesGmac, 0, false); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("16-byte key")); + } + + // ── GMAC nonce construction ───────────────────────────────────── + + #[test] + fn aes_gmac_nonce_contains_message_id() { + // Different MessageIds must produce different signatures on the same message+key. + let key = [0xEE; 16]; + + let mut msg1 = make_test_message(b"nonce test"); + sign_message(&mut msg1, &key, SigningAlgorithm::AesGmac, 1, false).unwrap(); + let sig1: [u8; 16] = msg1[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN] + .try_into() + .unwrap(); + + let mut msg2 = make_test_message(b"nonce test"); + sign_message(&mut msg2, &key, SigningAlgorithm::AesGmac, 2, false).unwrap(); + let sig2: [u8; 16] = msg2[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN] + .try_into() + .unwrap(); + + assert_ne!( + sig1, sig2, + "different MessageIds must produce different signatures" + ); + } + + #[test] + fn aes_gmac_cancel_bit_changes_signature() { + let key = [0xEE; 16]; + let message_id = 42u64; + + let mut msg_normal = make_test_message(b"cancel test"); + sign_message( + &mut msg_normal, + &key, + SigningAlgorithm::AesGmac, + message_id, + false, + ) + .unwrap(); + let sig_normal: [u8; 16] = msg_normal[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN] + .try_into() + .unwrap(); + + let mut msg_cancel = make_test_message(b"cancel test"); + sign_message( + &mut msg_cancel, + &key, + SigningAlgorithm::AesGmac, + message_id, + true, + ) + .unwrap(); + let sig_cancel: [u8; 16] = msg_cancel[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN] + .try_into() + .unwrap(); + + assert_ne!( + sig_normal, sig_cancel, + "CANCEL bit must produce a different signature" + ); + } + + #[test] + fn aes_gmac_cancel_bit_is_bit_1_of_byte_8() { + // Verify the nonce byte 8 value directly by checking that + // the CANCEL nonce has 0x02 at byte 8 (bit 1), not 0x01 (bit 0). + let message_id: u64 = 99; + + let mut nonce_normal = [0u8; 12]; + nonce_normal[0..8].copy_from_slice(&message_id.to_le_bytes()); + // is_cancel = false -> byte 8 stays 0x00 + + let mut nonce_cancel = [0u8; 12]; + nonce_cancel[0..8].copy_from_slice(&message_id.to_le_bytes()); + nonce_cancel[8] = 0x02; // bit 1 set, NOT bit 0 + + assert_eq!(nonce_normal[8], 0x00); + assert_eq!(nonce_cancel[8], 0x02); + // Bit 0 (role bit) is always 0 for client. + assert_eq!(nonce_cancel[8] & 0x01, 0x00); + } + + // ── Signature field location ──────────────────────────────────── + + #[test] + fn signature_field_is_at_bytes_48_through_63() { + let mut msg = make_test_message(&[]); + let key = [0xFF; 16]; + + // Set a marker pattern in bytes 48-63 to verify they get overwritten. + msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN].copy_from_slice(&[0xAA; 16]); + + sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap(); + + // The marker should be gone, replaced by the computed signature. + let sig = &msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN]; + assert_ne!(sig, &[0xAA; 16], "signature field must be overwritten"); + assert_ne!(sig, &[0x00; 16], "signature should not be all zeros"); + } + + #[test] + fn bytes_outside_signature_field_are_preserved() { + let body = b"preserve me"; + let mut msg = make_test_message(body); + let original_body = msg[64..].to_vec(); + let original_header_prefix = msg[0..SIGNATURE_OFFSET].to_vec(); + + let key = [0xFF; 16]; + sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap(); + + // Header bytes before signature are unchanged. + assert_eq!(&msg[0..SIGNATURE_OFFSET], &original_header_prefix); + // Body is unchanged. + assert_eq!(&msg[64..], &original_body); + } + + // ── Cross-algorithm: verify with wrong algorithm fails ────────── + + #[test] + fn verify_with_wrong_algorithm_fails() { + let mut msg = make_test_message(b"cross algo"); + let key = [0x77; 16]; + sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap(); + + let result = verify_signature(&msg, &key, SigningAlgorithm::AesCmac, 0, false); + assert!(result.is_err()); + } + + // ── GMAC: verify with wrong message_id fails ──────────────────── + + #[test] + fn aes_gmac_verify_with_wrong_message_id_fails() { + let mut msg = make_test_message(b"msg id test"); + let key = [0xDD; 16]; + sign_message(&mut msg, &key, SigningAlgorithm::AesGmac, 10, false).unwrap(); + + // verify uses server role bit, and wrong message_id -- both wrong + let result = verify_signature(&msg, &key, SigningAlgorithm::AesGmac, 11, false); + assert!(result.is_err()); + } +} diff --git a/vendor/smb2/src/error.rs b/vendor/smb2/src/error.rs new file mode 100644 index 0000000..c820539 --- /dev/null +++ b/vendor/smb2/src/error.rs @@ -0,0 +1,389 @@ +//! Error types for the SMB2 library. + +use crate::types::status::NtStatus; +use crate::types::Command; +use thiserror::Error; + +/// Top-level error type for SMB2 operations. +#[derive(Debug, Error)] +pub enum Error { + /// The data is malformed or does not match the expected format. + #[error("Invalid data: {message}")] + InvalidData { + /// Description of what went wrong. + message: String, + }, + + /// The server returned a non-success NTSTATUS. + #[error("Protocol error: {status} during {command:?}")] + Protocol { + /// The NTSTATUS code from the response header. + status: NtStatus, + /// The command that triggered the error. + command: Command, + }, + + /// Authentication failed. + #[error("Authentication failed: {message}")] + Auth { + /// Description of what went wrong. + message: String, + }, + + /// An I/O or transport error occurred. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + /// The operation timed out. + #[error("Operation timed out")] + Timeout, + + /// The connection was lost. + #[error("Disconnected from server")] + Disconnected, + + /// The path requires DFS referral resolution. + /// + /// The server returned `STATUS_PATH_NOT_COVERED`, meaning this path + /// lives on a different server via DFS. The caller can query for a + /// referral or display a helpful message. + #[error("DFS referral required for path: {path}")] + DfsReferralRequired { + /// The path that needs DFS resolution. + path: String, + }, + + /// The operation was cancelled by the caller (via progress callback). + #[error("Operation cancelled")] + Cancelled, + + /// The session expired and reauthentication failed. + /// + /// The pipeline normally handles `STATUS_NETWORK_SESSION_EXPIRED` + /// transparently by reauthenticating. This error surfaces only + /// when reauthentication itself fails. + #[error("Session expired and reauthentication failed")] + SessionExpired, +} + +impl Error { + /// Create an `InvalidData` error with the given message. + pub fn invalid_data(msg: impl Into) -> Self { + Error::InvalidData { + message: msg.into(), + } + } + + /// Returns `true` if this error is potentially transient and + /// the operation could succeed on retry. + pub fn is_retryable(&self) -> bool { + matches!( + self, + Error::Timeout + | Error::Disconnected + | Error::Protocol { + status: NtStatus::INSUFFICIENT_RESOURCES, + .. + } + | Error::Protocol { + status: NtStatus::INSUFF_SERVER_RESOURCES, + .. + } + ) + } + + /// Returns the NTSTATUS code if this is a protocol error. + pub fn status(&self) -> Option { + match self { + Error::Protocol { status, .. } => Some(*status), + _ => None, + } + } +} + +/// High-level error classification. +/// +/// Maps protocol-level NTSTATUS codes and other errors into categories +/// that consumers can match on without understanding SMB internals. +/// +/// ```no_run +/// # async fn example(client: &mut smb2::SmbClient, share: &mut smb2::Tree) -> Result<(), smb2::Error> { +/// use smb2::ErrorKind; +/// +/// match client.read_file(share, "photo.jpg").await { +/// Ok(data) => println!("read {} bytes", data.len()), +/// Err(e) => match e.kind() { +/// ErrorKind::NotFound => println!("file doesn't exist"), +/// ErrorKind::AlreadyExists => println!("name is already taken"), +/// ErrorKind::AccessDenied => println!("no permission"), +/// ErrorKind::SigningRequired => println!("server requires signing, use credentials"), +/// ErrorKind::AuthRequired => println!("server requires authentication"), +/// ErrorKind::SharingViolation => println!("file is in use by another client"), +/// ErrorKind::IsADirectory => println!("path is a directory, not a file"), +/// ErrorKind::NotADirectory => println!("path is a file, not a directory"), +/// ErrorKind::DiskFull => println!("volume is full"), +/// ErrorKind::ConnectionLost => { client.reconnect().await?; } +/// _ => return Err(e), +/// } +/// } +/// # Ok(()) +/// # } +/// ``` +/// +/// # Stability +/// +/// `ErrorKind` is `#[non_exhaustive]`: future versions may add variants for +/// status codes that currently fall through to [`ErrorKind::Other`]. Match +/// statements should always include a `_` arm. Adding a variant is treated +/// as a non-breaking change. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[non_exhaustive] +pub enum ErrorKind { + /// The server requires authentication (guest/anonymous not allowed). + AuthRequired, + /// The server requires message signing (guest sessions are unsigned). + SigningRequired, + /// Permission denied (valid credentials, but no access to this resource). + AccessDenied, + /// The file, directory, or share was not found. + NotFound, + /// A file or directory with the given name already exists. + /// + /// Returned by `Create` (and operations that wrap it, like `create_directory`) + /// when the target name is taken. Useful for callers that want to merge into + /// an existing directory or surface a friendly "name already taken" message. + AlreadyExists, + /// The file is in use by another client. + SharingViolation, + /// The target path is a directory, but the operation expected a file. + /// + /// Typically seen when calling `delete_file` against a directory entry — + /// the caller can fall back to `delete_directory` after detecting this. + IsADirectory, + /// The target path is a file, but the operation expected a directory. + /// + /// Typically seen when calling `list_directory` against a file entry. + NotADirectory, + /// The volume is full (write failed). + DiskFull, + /// The network connection was lost. + ConnectionLost, + /// The operation timed out. + TimedOut, + /// The operation was cancelled by the caller. + Cancelled, + /// The session expired (call `reconnect()`). + SessionExpired, + /// The path requires DFS referral resolution. + DfsReferral, + /// Invalid data or malformed response. + InvalidData, + /// An I/O error (transport or callback). Not necessarily a connection loss. + /// + /// Distinct from `ConnectionLost`: the connection may still be usable. + /// For example, a callback error in `write_file_streamed` produces `Io`, + /// but the connection is still in a clean state. + Io, + /// A protocol error not covered by other variants. + /// + /// Use [`Error::status()`] to get the raw NTSTATUS code. Some defined + /// `NtStatus` codes deliberately fall through here today + /// (`OBJECT_NAME_INVALID`, `DELETE_PENDING`, `INSUFFICIENT_RESOURCES`, + /// `INSUFF_SERVER_RESOURCES`, and similar) — they don't yet have a + /// dedicated `ErrorKind` because no consumer needs to branch on them. + /// Promoting one to its own variant is non-breaking. + Other, +} + +impl Error { + /// Classify this error into a high-level category. + /// + /// Consumers can match on [`ErrorKind`] without understanding raw + /// NTSTATUS codes. For the underlying status code, use [`status()`](Self::status). + pub fn kind(&self) -> ErrorKind { + match self { + Error::InvalidData { .. } => ErrorKind::InvalidData, + Error::Auth { .. } => ErrorKind::AuthRequired, + Error::Io(_) => ErrorKind::Io, + Error::Disconnected => ErrorKind::ConnectionLost, + Error::Timeout => ErrorKind::TimedOut, + Error::Cancelled => ErrorKind::Cancelled, + Error::SessionExpired => ErrorKind::SessionExpired, + Error::DfsReferralRequired { .. } => ErrorKind::DfsReferral, + Error::Protocol { status, .. } => classify_status(*status), + } + } +} + +/// Map an NTSTATUS to an ErrorKind. +fn classify_status(status: NtStatus) -> ErrorKind { + match status { + // Auth / signing + NtStatus::LOGON_FAILURE | NtStatus::ACCOUNT_DISABLED => ErrorKind::AuthRequired, + NtStatus::ACCESS_DENIED => { + // Could be signing-required or genuinely access-denied. + // Callers with NegotiatedParams context can distinguish further. + // Default to AccessDenied; SmbClient methods can upgrade to + // SigningRequired when signing_required is true. + ErrorKind::AccessDenied + } + + // Not found + NtStatus::NO_SUCH_FILE + | NtStatus::OBJECT_NAME_NOT_FOUND + | NtStatus::OBJECT_PATH_NOT_FOUND + | NtStatus::BAD_NETWORK_NAME => ErrorKind::NotFound, + + // Already exists + NtStatus::OBJECT_NAME_COLLISION => ErrorKind::AlreadyExists, + + // Wrong file type + NtStatus::FILE_IS_A_DIRECTORY => ErrorKind::IsADirectory, + NtStatus::NOT_A_DIRECTORY => ErrorKind::NotADirectory, + + // Sharing / locking + NtStatus::SHARING_VIOLATION | NtStatus::FILE_LOCK_CONFLICT => ErrorKind::SharingViolation, + + // Disk full + NtStatus::DISK_FULL => ErrorKind::DiskFull, + + // Session expired + NtStatus::NETWORK_SESSION_EXPIRED => ErrorKind::SessionExpired, + + // Connection + NtStatus::NETWORK_NAME_DELETED | NtStatus::USER_SESSION_DELETED => { + ErrorKind::ConnectionLost + } + + // DFS + NtStatus::PATH_NOT_COVERED => ErrorKind::DfsReferral, + + // Everything else + _ => ErrorKind::Other, + } +} + +/// A `Result` type alias using the crate's [`Error`](enum@Error) type. +pub type Result = std::result::Result; + +#[cfg(test)] +mod tests { + use super::*; + + /// Documents the full contract between `NtStatus` codes and `ErrorKind`. + /// + /// Every code listed here is asserted to map to its expected variant. When + /// adding a new `NtStatus` to `types/status.rs`, also add a row here — either + /// pointing at a dedicated `ErrorKind`, or `ErrorKind::Other` if there is + /// genuinely no consumer-meaningful classification yet. The companion test + /// `classify_status_no_silent_other` then guarantees the table stays in sync + /// with what `classify_status` actually does. + const STATUS_CLASSIFICATION_CONTRACT: &[(NtStatus, ErrorKind)] = &[ + // Auth / signing + (NtStatus::LOGON_FAILURE, ErrorKind::AuthRequired), + (NtStatus::ACCOUNT_DISABLED, ErrorKind::AuthRequired), + (NtStatus::ACCESS_DENIED, ErrorKind::AccessDenied), + // Not found + (NtStatus::NO_SUCH_FILE, ErrorKind::NotFound), + (NtStatus::OBJECT_NAME_NOT_FOUND, ErrorKind::NotFound), + (NtStatus::OBJECT_PATH_NOT_FOUND, ErrorKind::NotFound), + (NtStatus::BAD_NETWORK_NAME, ErrorKind::NotFound), + // Already exists + (NtStatus::OBJECT_NAME_COLLISION, ErrorKind::AlreadyExists), + // Wrong file type + (NtStatus::FILE_IS_A_DIRECTORY, ErrorKind::IsADirectory), + (NtStatus::NOT_A_DIRECTORY, ErrorKind::NotADirectory), + // Sharing / locking + (NtStatus::SHARING_VIOLATION, ErrorKind::SharingViolation), + (NtStatus::FILE_LOCK_CONFLICT, ErrorKind::SharingViolation), + // Disk + (NtStatus::DISK_FULL, ErrorKind::DiskFull), + // Connection / session + (NtStatus::NETWORK_NAME_DELETED, ErrorKind::ConnectionLost), + (NtStatus::USER_SESSION_DELETED, ErrorKind::ConnectionLost), + (NtStatus::NETWORK_SESSION_EXPIRED, ErrorKind::SessionExpired), + // DFS + (NtStatus::PATH_NOT_COVERED, ErrorKind::DfsReferral), + // Documented `Other` (no current consumer demand for a typed variant) + (NtStatus::NOT_IMPLEMENTED, ErrorKind::Other), + (NtStatus::INVALID_PARAMETER, ErrorKind::Other), + (NtStatus::DELETE_PENDING, ErrorKind::Other), + (NtStatus::INSUFFICIENT_RESOURCES, ErrorKind::Other), + (NtStatus::INSUFF_SERVER_RESOURCES, ErrorKind::Other), + ]; + + #[test] + fn classify_status_contract() { + for (status, expected) in STATUS_CLASSIFICATION_CONTRACT { + let err = Error::Protocol { + status: *status, + command: Command::Create, + }; + assert_eq!( + err.kind(), + *expected, + "{status} should classify as {expected:?}" + ); + } + } + + #[test] + fn kind_maps_non_protocol_errors() { + assert_eq!(Error::Timeout.kind(), ErrorKind::TimedOut); + assert_eq!(Error::Disconnected.kind(), ErrorKind::ConnectionLost); + assert_eq!(Error::Cancelled.kind(), ErrorKind::Cancelled); + assert_eq!(Error::SessionExpired.kind(), ErrorKind::SessionExpired); + assert_eq!(Error::invalid_data("test").kind(), ErrorKind::InvalidData); + assert_eq!( + Error::DfsReferralRequired { + path: "test".into() + } + .kind(), + ErrorKind::DfsReferral + ); + assert_eq!( + Error::Auth { + message: "test".into() + } + .kind(), + ErrorKind::AuthRequired + ); + } + + #[test] + fn kind_maps_io_error_to_io_not_connection_lost() { + // Error::Io from callback errors (like write_file_streamed cancellation) + // should NOT be ConnectionLost — the connection may still be usable. + let err = Error::Io(std::io::Error::new( + std::io::ErrorKind::Interrupted, + "cancelled", + )); + assert_eq!(err.kind(), ErrorKind::Io); + assert_ne!(err.kind(), ErrorKind::ConnectionLost); + } + + #[test] + fn kind_disconnected_is_connection_lost() { + // Error::Disconnected (transport EOF) IS a connection loss. + assert_eq!(Error::Disconnected.kind(), ErrorKind::ConnectionLost); + } + + #[test] + fn kind_maps_dfs_referral_required_to_dfs_referral() { + // The explicit DFS referral error variant should also map to DfsReferral. + let err = Error::DfsReferralRequired { + path: r"\\server\share\path".into(), + }; + assert_eq!(err.kind(), ErrorKind::DfsReferral); + } + + #[test] + fn dfs_referral_is_not_retryable() { + // DFS referrals need special handling, not generic retry. + let err = Error::Protocol { + status: NtStatus::PATH_NOT_COVERED, + command: Command::Create, + }; + assert!(!err.is_retryable()); + } +} diff --git a/vendor/smb2/src/fuzzing.rs b/vendor/smb2/src/fuzzing.rs new file mode 100644 index 0000000..4ab5e78 --- /dev/null +++ b/vendor/smb2/src/fuzzing.rs @@ -0,0 +1,185 @@ +//! Fuzzing entry points for `fuzz/` targets. +//! +//! This module is feature-gated behind `fuzzing` and only exists to give +//! `cargo-fuzz` targets stable, public access to otherwise-internal parse +//! functions. Applications must not depend on it -- it's unstable by +//! design, and enabling the feature pulls in nothing of runtime value. +//! +//! Every function here takes untrusted bytes and returns either a parsed +//! value or a clean typed error. No function here is allowed to panic on +//! bad input; that's what the fuzzer tests. +//! +//! Targets (see `fuzz/fuzz_targets/`): +//! +//! - [`fuzz_header_parse`] -- SMB2 header (`msg::header::Header`). +//! - [`fuzz_transform_header_parse`] -- encryption transform header. +//! - [`fuzz_compression_transform_header_parse`] -- compression wrapper. +//! - [`fuzz_compound_split`] -- `client::connection::split_compound`. +//! - [`fuzz_frame_parse`] -- compound split + per-sub-frame header parse, +//! which is the real receiver-loop path up to the body. +//! - [`fuzz_sub_frame_parse`] -- header + body (dispatched by `Command`). +//! - [`fuzz_negotiate_request_parse`] / [`fuzz_negotiate_response_parse`] +//! - [`fuzz_create_request_parse`] / [`fuzz_create_response_parse`] +//! -- CreateContext list lives inside these bodies. +//! - [`fuzz_query_info_response_parse`] -- opaque output buffer sharp edge. +//! - [`fuzz_dfs_referral_response_parse`] -- manual offset arithmetic, +//! obvious fuzzing target. + +use crate::msg::header::Header; +use crate::msg::transform::{CompressionTransformHeader, TransformHeader}; +use crate::pack::{ReadCursor, Unpack}; +use crate::types::Command; + +/// Fuzz the top-level SMB2 header parser. +pub fn fuzz_header_parse(data: &[u8]) { + let mut cursor = ReadCursor::new(data); + let _ = Header::unpack(&mut cursor); +} + +/// Fuzz the encryption transform header parser. +pub fn fuzz_transform_header_parse(data: &[u8]) { + let mut cursor = ReadCursor::new(data); + let _ = TransformHeader::unpack(&mut cursor); +} + +/// Fuzz the compression transform header parser. +pub fn fuzz_compression_transform_header_parse(data: &[u8]) { + let mut cursor = ReadCursor::new(data); + let _ = CompressionTransformHeader::unpack(&mut cursor); +} + +/// Fuzz the compound-frame splitter. Takes a preprocessed (already decrypted +/// and decompressed) buffer and returns the sub-frame byte slices. +pub fn fuzz_compound_split(data: &[u8]) { + let _ = crate::client::connection::split_compound(data); +} + +/// Fuzz the full receiver-loop parse path: compound split, plus parsing the +/// header of every sub-frame. Mirrors what `prepare_sub_frame` does before +/// it dispatches on `Command`. +pub fn fuzz_frame_parse(data: &[u8]) { + let subs = match crate::client::connection::split_compound(data) { + Ok(s) => s, + Err(_) => return, + }; + for sub in subs { + let mut cursor = ReadCursor::new(&sub); + let _ = Header::unpack(&mut cursor); + } +} + +/// Fuzz header + body (dispatched by `Command`). Much wider surface than +/// [`fuzz_frame_parse`] because it actually parses the response body for +/// every command type. +pub fn fuzz_sub_frame_parse(data: &[u8]) { + if data.len() < Header::SIZE { + return; + } + let mut cursor = ReadCursor::new(data); + let header = match Header::unpack(&mut cursor) { + Ok(h) => h, + Err(_) => return, + }; + + let body = &data[Header::SIZE..]; + let is_response = header.is_response(); + dispatch_body(header.command, is_response, body); +} + +fn dispatch_body(command: Command, is_response: bool, body: &[u8]) { + use crate::msg; + + // Unpack the given type from `body` and discard the result. Parse errors + // are fine (boring path); panics / UB are what libfuzzer catches. + macro_rules! try_unpack { + ($ty:ty) => {{ + let mut cursor = ReadCursor::new(body); + let _ = <$ty as Unpack>::unpack(&mut cursor); + }}; + } + + match (command, is_response) { + (Command::Negotiate, false) => try_unpack!(msg::negotiate::NegotiateRequest), + (Command::Negotiate, true) => try_unpack!(msg::negotiate::NegotiateResponse), + (Command::SessionSetup, false) => try_unpack!(msg::session_setup::SessionSetupRequest), + (Command::SessionSetup, true) => try_unpack!(msg::session_setup::SessionSetupResponse), + (Command::Logoff, false) => try_unpack!(msg::logoff::LogoffRequest), + (Command::Logoff, true) => try_unpack!(msg::logoff::LogoffResponse), + (Command::TreeConnect, false) => try_unpack!(msg::tree_connect::TreeConnectRequest), + (Command::TreeConnect, true) => try_unpack!(msg::tree_connect::TreeConnectResponse), + (Command::TreeDisconnect, false) => { + try_unpack!(msg::tree_disconnect::TreeDisconnectRequest) + } + (Command::TreeDisconnect, true) => { + try_unpack!(msg::tree_disconnect::TreeDisconnectResponse) + } + (Command::Create, false) => try_unpack!(msg::create::CreateRequest), + (Command::Create, true) => try_unpack!(msg::create::CreateResponse), + (Command::Close, false) => try_unpack!(msg::close::CloseRequest), + (Command::Close, true) => try_unpack!(msg::close::CloseResponse), + (Command::Flush, false) => try_unpack!(msg::flush::FlushRequest), + (Command::Flush, true) => try_unpack!(msg::flush::FlushResponse), + (Command::Read, false) => try_unpack!(msg::read::ReadRequest), + (Command::Read, true) => try_unpack!(msg::read::ReadResponse), + (Command::Write, false) => try_unpack!(msg::write::WriteRequest), + (Command::Write, true) => try_unpack!(msg::write::WriteResponse), + (Command::Lock, false) => try_unpack!(msg::lock::LockRequest), + (Command::Lock, true) => try_unpack!(msg::lock::LockResponse), + (Command::Ioctl, false) => try_unpack!(msg::ioctl::IoctlRequest), + (Command::Ioctl, true) => try_unpack!(msg::ioctl::IoctlResponse), + (Command::Cancel, false) => try_unpack!(msg::cancel::CancelRequest), + (Command::Echo, false) => try_unpack!(msg::echo::EchoRequest), + (Command::Echo, true) => try_unpack!(msg::echo::EchoResponse), + (Command::QueryDirectory, false) => { + try_unpack!(msg::query_directory::QueryDirectoryRequest) + } + (Command::QueryDirectory, true) => { + try_unpack!(msg::query_directory::QueryDirectoryResponse) + } + (Command::ChangeNotify, false) => try_unpack!(msg::change_notify::ChangeNotifyRequest), + (Command::ChangeNotify, true) => try_unpack!(msg::change_notify::ChangeNotifyResponse), + (Command::QueryInfo, false) => try_unpack!(msg::query_info::QueryInfoRequest), + (Command::QueryInfo, true) => try_unpack!(msg::query_info::QueryInfoResponse), + (Command::SetInfo, false) => try_unpack!(msg::set_info::SetInfoRequest), + (Command::SetInfo, true) => try_unpack!(msg::set_info::SetInfoResponse), + _ => {} + } +} + +/// Fuzz `NegotiateRequest::unpack` directly. +pub fn fuzz_negotiate_request_parse(data: &[u8]) { + let mut cursor = ReadCursor::new(data); + let _ = crate::msg::negotiate::NegotiateRequest::unpack(&mut cursor); +} + +/// Fuzz `NegotiateResponse::unpack` directly. Covers negotiate-context parsing. +pub fn fuzz_negotiate_response_parse(data: &[u8]) { + let mut cursor = ReadCursor::new(data); + let _ = crate::msg::negotiate::NegotiateResponse::unpack(&mut cursor); +} + +/// Fuzz `CreateRequest::unpack` directly. Covers create-context list parsing. +pub fn fuzz_create_request_parse(data: &[u8]) { + let mut cursor = ReadCursor::new(data); + let _ = crate::msg::create::CreateRequest::unpack(&mut cursor); +} + +/// Fuzz `CreateResponse::unpack` directly. +pub fn fuzz_create_response_parse(data: &[u8]) { + let mut cursor = ReadCursor::new(data); + let _ = crate::msg::create::CreateResponse::unpack(&mut cursor); +} + +/// Fuzz `QueryInfoResponse::unpack`, which has the tricky +/// output-buffer-offset-from-header arithmetic. +pub fn fuzz_query_info_response_parse(data: &[u8]) { + let mut cursor = ReadCursor::new(data); + let _ = crate::msg::query_info::QueryInfoResponse::unpack(&mut cursor); +} + +/// Fuzz the DFS referral response parser. Manual offset arithmetic makes +/// this a classic sharp-edge target. +pub fn fuzz_dfs_referral_response_parse(data: &[u8]) { + let mut cursor = ReadCursor::new(data); + let _ = crate::msg::dfs::RespGetDfsReferral::unpack(&mut cursor); +} diff --git a/vendor/smb2/src/lib.rs b/vendor/smb2/src/lib.rs new file mode 100644 index 0000000..e7bb84b --- /dev/null +++ b/vendor/smb2/src/lib.rs @@ -0,0 +1,99 @@ +#![forbid(unsafe_code)] +#![warn(missing_docs)] + +//! Pure-Rust SMB2/3 client library with pipelined I/O. +//! +//! No C dependencies, no FFI. Pipelined reads/writes fill the credit window +//! so downloads run ~10-25x faster than sequential SMB clients. +//! +//! # Quick start +//! +//! ```rust,no_run +//! use smb2::{SmbClient, ClientConfig}; +//! +//! # async fn example() -> Result<(), smb2::Error> { +//! let mut client = smb2::connect("192.168.1.100:445", "user", "pass").await?; +//! +//! // List shares +//! let shares = client.list_shares().await?; +//! +//! // Connect to a share +//! let mut share = client.connect_share("Documents").await?; +//! +//! // List files +//! let entries = client.list_directory(&mut share, "projects/").await?; +//! for entry in &entries { +//! println!("{} ({} bytes)", entry.name, entry.size); +//! } +//! +//! // Read a file +//! let data = client.read_file(&mut share, "report.pdf").await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Modules +//! +//! - [`client`] -- High-level API: [`SmbClient`], [`Tree`], [`Pipeline`]. +//! This is what most users need. +//! - [`error`] -- Error types and NTSTATUS mapping. +//! - [`msg`] -- Wire format message structs (advanced/internal use). +//! - [`pack`] -- Binary serialization primitives (advanced/internal use). +//! - [`transport`] -- Transport trait and TCP implementation (advanced/internal use). +//! - [`crypto`] -- Signing and encryption (advanced/internal use). +//! - [`auth`] -- NTLM authentication (advanced/internal use). +//! - [`rpc`] -- Named pipe RPC for share enumeration (advanced/internal use). +//! - [`types`] -- Protocol newtypes and flag types (advanced/internal use). + +pub mod auth; +pub mod client; +pub mod crypto; +pub mod error; +pub mod msg; +pub mod pack; +pub mod rpc; +#[cfg(feature = "testing")] +pub mod testing; +pub mod transport; +pub mod types; + +#[cfg(feature = "fuzzing")] +pub mod fuzzing; + +// ── Re-exports: the simple-case imports ──────────────────────────────── + +// Error types +pub use error::{Error, ErrorKind, Result}; + +// High-level client +pub use client::{connect, ClientConfig, SmbClient}; + +// Streaming I/O +pub use client::stream::{FileDownload, FileUpload, FileWriter, Progress}; + +// Tree and file types +pub use client::tree::{DirectoryEntry, FileInfo, FsInfo, Tree}; + +// Pipeline +pub use client::pipeline::{Op, OpResult, Pipeline}; + +// Connection-level types (useful for advanced users) +pub use client::connection::{CompoundOp, Frame, NegotiatedParams}; +pub use client::session::Session; + +// Diagnostics: snapshot tree returned by `SmbClient::diagnostics()` / +// `Connection::diagnostics()`. +pub use client::diagnostics::{ + ClientInfo, ClientMetricsSnapshot, CompressionInfo, ConnectionDiagnostics, CreditInfo, + DfsCacheEntry, Diagnostics, EncryptionInfo, MetricsSnapshot, NegotiatedSummary, + SessionDiagnostics, SigningInfo, +}; + +// File watching +pub use client::watcher::{FileNotifyAction, FileNotifyEvent, Watcher}; + +// Share enumeration +pub use rpc::srvsvc::ShareInfo; + +// Kerberos authentication +pub use auth::kerberos::{KerberosAuthenticator, KerberosCredentials}; diff --git a/vendor/smb2/src/msg/CLAUDE.md b/vendor/smb2/src/msg/CLAUDE.md new file mode 100644 index 0000000..e7b4ed0 --- /dev/null +++ b/vendor/smb2/src/msg/CLAUDE.md @@ -0,0 +1,44 @@ +# Msg -- wire format message structs + +One sub-module per SMB2 command. Each defines request and response structs with `Pack` and `Unpack` implementations. + +## Key files + +| File | Purpose | +|---|---| +| `mod.rs` | `trivial_message!` macro for 4-byte stub messages, module declarations | +| `header.rs` | 64-byte SMB2 header (sync + async variants), `PROTOCOL_ID` (`0xFE 'S' 'M' 'B'`) | +| `negotiate.rs` | Negotiate contexts (preauth integrity, encryption, signing, compression) | +| `create.rs` | CREATE request/response with create contexts | +| `transform.rs` | `TransformHeader` (encryption, protocol ID `0xFD`), `CompressionTransformHeader` (`0xFC`) | + +19 command modules total: negotiate, session_setup, logoff, tree_connect, tree_disconnect, create, close, flush, read, write, lock, ioctl, query_directory, change_notify, query_info, set_info, echo, cancel, oplock_break. Plus `dfs.rs` for DFS referral request/response wire format (used by IOCTL FSCTL_DFS_GET_REFERRALS). + +## Patterns + +- **Pack/Unpack**: All structs implement `pack(&self, &mut WriteCursor)` and `unpack(&mut ReadCursor) -> Result`. Hand-rolled, no proc macros. +- **Offset calculation**: All offsets in SMB2 are relative to the start of the SMB2 header (not the body, not the transport frame). When packing variable-length fields, compute `header_size + fixed_body_size` as the base offset. +- **StructureSize validation**: `Unpack` implementations read `StructureSize` first and return an error if it doesn't match the expected value. +- **`trivial_message!` macro**: Generates Pack/Unpack for 4-byte stub messages (StructureSize=4 + Reserved=0). Used by echo, cancel, logoff, tree_disconnect. + +## Compound messages + +Built by `Connection::send_compound`. Each sub-request's header has a `NextCommand` field pointing to the next message (8-byte aligned). The last message has `NextCommand = 0`. Related operations use `FileId::SENTINEL` (`0xFFFFFFFF:0xFFFFFFFF`) so the server substitutes the handle from the first CREATE. + +## Transform headers + +- **Encryption** (`0xFD 'S' 'M' 'B'`): 52-byte `TransformHeader` wraps encrypted message(s). Contains nonce, auth tag (signature), original message size, session ID. +- **Compression** (`0xFC 'S' 'M' 'B'`): `CompressionTransformHeader` wraps LZ4-compressed messages. Contains original and compressed sizes, algorithm ID. + +## Gotchas + +- **TCP framing is big-endian**: The 4-byte transport header (1 zero byte + 3-byte length) uses big-endian byte order. Everything inside the SMB2 message is little-endian. This is the only big-endian value in the entire protocol. +- **StructureSize is "fixed"**: The spec says StructureSize is the size of the fixed-length portion of the struct. It does NOT include variable-length buffers. It's validated on unpack. +- **`#![allow(missing_docs)]`**: This module opts out of doc requirements because wire format field names are self-documenting from the spec. +- **Manual offset arithmetic requires careful bounds**: In `dfs.rs`, `parse_referral_entry` uses `ensure_remaining(buf, pos, N)` before raw `buf[pos..]` reads. Count the fixed fields carefully -- V2's body is **18** bytes (server_type+flags+proximity+ttl + three u16 offsets), not 16. An off-by-2 here lets a malformed `entry_size` slip past the initial guard and panic on the last offset read. Fuzz-caught in 0.7.2; regression test `resp_parse_v2_short_entry_returns_clean_error`. + +## Fuzzing + +Parse entry points are exposed via the `fuzzing` feature (`smb2::fuzzing`) and exercised by the `fuzz/` crate. See +`fuzz/README.md` (if present) or run `just fuzz fuzz_header_parse 300` for a local sweep. Every new parser touching +external bytes should get a fuzz target wrapper added in `src/fuzzing.rs` and a matching `fuzz/fuzz_targets/*.rs`. diff --git a/vendor/smb2/src/msg/cancel.rs b/vendor/smb2/src/msg/cancel.rs new file mode 100644 index 0000000..48797ad --- /dev/null +++ b/vendor/smb2/src/msg/cancel.rs @@ -0,0 +1,27 @@ +//! SMB2 CANCEL request (spec section 2.2.30). +//! +//! The CANCEL request is fire-and-forget: the client sends it to cancel a +//! previously sent message, and there is no corresponding response message. +//! The MessageId of the request to cancel is set in the SMB2 header. + +super::trivial_message! { + /// SMB2 CANCEL request (spec section 2.2.30). + /// + /// Sent by the client to cancel a previously sent message on the same + /// transport connection. There is no response for this command. + /// Contains only StructureSize (2 bytes) and Reserved (2 bytes). + pub struct CancelRequest; +} + +#[cfg(test)] +mod tests { + use super::*; + + super::super::trivial_message_tests!( + CancelRequest, + cancel_request_known_bytes, + cancel_request_roundtrip, + cancel_request_wrong_structure_size, + cancel_request_too_short + ); +} diff --git a/vendor/smb2/src/msg/change_notify.rs b/vendor/smb2/src/msg/change_notify.rs new file mode 100644 index 0000000..7c88229 --- /dev/null +++ b/vendor/smb2/src/msg/change_notify.rs @@ -0,0 +1,355 @@ +//! SMB2 CHANGE_NOTIFY Request and Response (MS-SMB2 sections 2.2.35, 2.2.36). +//! +//! The CHANGE_NOTIFY request registers for change notifications on a +//! directory. The response returns FILE_NOTIFY_INFORMATION entries +//! describing the changes that occurred. + +use crate::error::Result; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::FileId; +use crate::Error; + +// ── Change Notify flags ──────────────────────────────────────────────── + +/// Watch the entire subtree (recursive). +pub const SMB2_WATCH_TREE: u16 = 0x0001; + +// ── CompletionFilter values ──────────────────────────────────────────── + +/// Notify when a file name changes. +pub const FILE_NOTIFY_CHANGE_FILE_NAME: u32 = 0x0000_0001; + +/// Notify when a directory name changes. +pub const FILE_NOTIFY_CHANGE_DIR_NAME: u32 = 0x0000_0002; + +/// Notify when file attributes change. +pub const FILE_NOTIFY_CHANGE_ATTRIBUTES: u32 = 0x0000_0004; + +/// Notify when the file size changes. +pub const FILE_NOTIFY_CHANGE_SIZE: u32 = 0x0000_0008; + +/// Notify when the last write time changes. +pub const FILE_NOTIFY_CHANGE_LAST_WRITE: u32 = 0x0000_0010; + +/// Notify when the last access time changes. +pub const FILE_NOTIFY_CHANGE_LAST_ACCESS: u32 = 0x0000_0020; + +/// Notify when the creation time changes. +pub const FILE_NOTIFY_CHANGE_CREATION: u32 = 0x0000_0040; + +/// Notify when extended attributes change. +pub const FILE_NOTIFY_CHANGE_EA: u32 = 0x0000_0080; + +/// Notify when the security descriptor changes. +pub const FILE_NOTIFY_CHANGE_SECURITY: u32 = 0x0000_0100; + +/// Notify when a stream name changes. +pub const FILE_NOTIFY_CHANGE_STREAM_NAME: u32 = 0x0000_0200; + +/// Notify when a stream size changes. +pub const FILE_NOTIFY_CHANGE_STREAM_SIZE: u32 = 0x0000_0400; + +/// Notify when stream data is written. +pub const FILE_NOTIFY_CHANGE_STREAM_WRITE: u32 = 0x0000_0800; + +// ── ChangeNotifyRequest ──────────────────────────────────────────────── + +/// SMB2 CHANGE_NOTIFY Request (MS-SMB2 section 2.2.35). +/// +/// Registers for directory change notifications. The structure is 32 bytes: +/// - StructureSize (2 bytes, must be 32) +/// - Flags (2 bytes) +/// - OutputBufferLength (4 bytes) +/// - FileId (16 bytes) +/// - CompletionFilter (4 bytes) +/// - Reserved (4 bytes) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ChangeNotifyRequest { + /// Flags controlling the notification. Use `SMB2_WATCH_TREE` for recursive. + pub flags: u16, + /// Maximum size of the output buffer for notification data. + pub output_buffer_length: u32, + /// The directory handle to watch. + pub file_id: FileId, + /// Bitmask of change types to watch for. + pub completion_filter: u32, +} + +impl ChangeNotifyRequest { + pub const STRUCTURE_SIZE: u16 = 32; +} + +impl Pack for ChangeNotifyRequest { + fn pack(&self, cursor: &mut WriteCursor) { + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // Flags (2 bytes) + cursor.write_u16_le(self.flags); + // OutputBufferLength (4 bytes) + cursor.write_u32_le(self.output_buffer_length); + // FileId (16 bytes) + cursor.write_u64_le(self.file_id.persistent); + cursor.write_u64_le(self.file_id.volatile); + // CompletionFilter (4 bytes) + cursor.write_u32_le(self.completion_filter); + // Reserved (4 bytes) + cursor.write_u32_le(0); + } +} + +impl Unpack for ChangeNotifyRequest { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid ChangeNotifyRequest structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + let flags = cursor.read_u16_le()?; + let output_buffer_length = cursor.read_u32_le()?; + let persistent = cursor.read_u64_le()?; + let volatile = cursor.read_u64_le()?; + let completion_filter = cursor.read_u32_le()?; + let _reserved = cursor.read_u32_le()?; + + Ok(ChangeNotifyRequest { + flags, + output_buffer_length, + file_id: FileId { + persistent, + volatile, + }, + completion_filter, + }) + } +} + +// ── ChangeNotifyResponse ─────────────────────────────────────────────── + +/// SMB2 CHANGE_NOTIFY Response (MS-SMB2 section 2.2.36). +/// +/// Returns FILE_NOTIFY_INFORMATION entries describing directory changes. +/// The buffer contains raw FILE_NOTIFY_INFORMATION entries; parsing those +/// is left to the caller for now. +/// +/// Layout: +/// - StructureSize (2 bytes, must be 9) +/// - OutputBufferOffset (2 bytes) +/// - OutputBufferLength (4 bytes) +/// - Buffer (variable, OutputBufferLength bytes) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ChangeNotifyResponse { + /// Raw FILE_NOTIFY_INFORMATION data. Parsing individual entries is + /// deferred to a higher layer. + pub output_data: Vec, +} + +impl ChangeNotifyResponse { + pub const STRUCTURE_SIZE: u16 = 9; + + /// Fixed header size before the variable buffer (8 bytes). + const FIXED_SIZE: u32 = 8; +} + +impl Pack for ChangeNotifyResponse { + fn pack(&self, cursor: &mut WriteCursor) { + let start = cursor.position(); + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + + let output_len = self.output_data.len() as u32; + // Offset is from the beginning of the SMB2 header per spec. + let output_offset = if output_len > 0 { + (start as u32) + Self::FIXED_SIZE + } else { + 0 + }; + + // OutputBufferOffset (2 bytes) + cursor.write_u16_le(output_offset as u16); + // OutputBufferLength (4 bytes) + cursor.write_u32_le(output_len); + // Buffer (variable) + cursor.write_bytes(&self.output_data); + } +} + +impl Unpack for ChangeNotifyResponse { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid ChangeNotifyResponse structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + let _output_buffer_offset = cursor.read_u16_le()?; + let output_buffer_length = cursor.read_u32_le()?; + + let output_data = if output_buffer_length > 0 { + cursor + .read_bytes_bounded(output_buffer_length as usize)? + .to_vec() + } else { + Vec::new() + }; + + Ok(ChangeNotifyResponse { output_data }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── ChangeNotifyRequest tests ───────────────────────────────────── + + #[test] + fn change_notify_request_roundtrip_recursive() { + let original = ChangeNotifyRequest { + flags: SMB2_WATCH_TREE, + output_buffer_length: 65536, + file_id: FileId { + persistent: 0x1122_3344_5566_7788, + volatile: 0xAABB_CCDD_EEFF_0011, + }, + completion_filter: FILE_NOTIFY_CHANGE_FILE_NAME + | FILE_NOTIFY_CHANGE_DIR_NAME + | FILE_NOTIFY_CHANGE_LAST_WRITE, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // Fixed 32 bytes, no variable data + assert_eq!(bytes.len(), 32); + + let mut r = ReadCursor::new(&bytes); + let decoded = ChangeNotifyRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.flags, SMB2_WATCH_TREE); + assert_eq!(decoded.output_buffer_length, 65536); + assert_eq!(decoded.file_id, original.file_id); + assert_eq!( + decoded.completion_filter, + FILE_NOTIFY_CHANGE_FILE_NAME + | FILE_NOTIFY_CHANGE_DIR_NAME + | FILE_NOTIFY_CHANGE_LAST_WRITE + ); + } + + #[test] + fn change_notify_request_wrong_structure_size() { + let mut buf = [0u8; 32]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = ChangeNotifyRequest::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + // ── ChangeNotifyResponse tests ──────────────────────────────────── + + #[test] + fn change_notify_response_roundtrip_with_data() { + let notify_data = vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]; + let original = ChangeNotifyResponse { + output_data: notify_data.clone(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // Fixed 8 bytes + 8 bytes data + assert_eq!(bytes.len(), 16); + + let mut r = ReadCursor::new(&bytes); + let decoded = ChangeNotifyResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.output_data, notify_data); + } + + #[test] + fn change_notify_response_roundtrip_empty() { + let original = ChangeNotifyResponse { + output_data: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(bytes.len(), 8); + + let mut r = ReadCursor::new(&bytes); + let decoded = ChangeNotifyResponse::unpack(&mut r).unwrap(); + + assert!(decoded.output_data.is_empty()); + } + + #[test] + fn change_notify_response_wrong_structure_size() { + let mut buf = [0u8; 8]; + buf[0..2].copy_from_slice(&42u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = ChangeNotifyResponse::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::{arb_bytes, arb_file_id}; + use proptest::prelude::*; + + proptest! { + #[test] + fn change_notify_request_pack_unpack( + flags in any::(), + output_buffer_length in any::(), + file_id in arb_file_id(), + completion_filter in any::(), + ) { + let original = ChangeNotifyRequest { + flags, + output_buffer_length, + file_id, + completion_filter, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = ChangeNotifyRequest::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + + #[test] + fn change_notify_response_pack_unpack(output_data in arb_bytes()) { + let original = ChangeNotifyResponse { output_data }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = ChangeNotifyResponse::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + } +} diff --git a/vendor/smb2/src/msg/close.rs b/vendor/smb2/src/msg/close.rs new file mode 100644 index 0000000..f1f1760 --- /dev/null +++ b/vendor/smb2/src/msg/close.rs @@ -0,0 +1,390 @@ +//! SMB2 CLOSE Request and Response (MS-SMB2 sections 2.2.15, 2.2.16). +//! +//! The CLOSE request closes a file handle previously opened via CREATE. +//! The response optionally returns file attributes if the +//! `SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB` flag was set. + +use crate::error::Result; +use crate::pack::{FileTime, Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::FileId; +use crate::Error; + +/// Close flag: request that the server returns file attributes in the response. +pub const SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB: u16 = 0x0001; + +/// SMB2 CLOSE Request (MS-SMB2 section 2.2.15). +/// +/// Sent by the client to close a file handle. The structure is 24 bytes: +/// - StructureSize (2 bytes, must be 24) +/// - Flags (2 bytes) +/// - Reserved (4 bytes) +/// - FileId (16 bytes) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CloseRequest { + /// Flags indicating how to process the close. + /// Use `SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB` to request attributes. + pub flags: u16, + /// The file handle to close. + pub file_id: FileId, +} + +impl CloseRequest { + pub const STRUCTURE_SIZE: u16 = 24; +} + +impl Pack for CloseRequest { + fn pack(&self, cursor: &mut WriteCursor) { + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // Flags (2 bytes) + cursor.write_u16_le(self.flags); + // Reserved (4 bytes) + cursor.write_u32_le(0); + // FileId (16 bytes): persistent + volatile + cursor.write_u64_le(self.file_id.persistent); + cursor.write_u64_le(self.file_id.volatile); + } +} + +impl Unpack for CloseRequest { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid CloseRequest structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + let flags = cursor.read_u16_le()?; + let _reserved = cursor.read_u32_le()?; + let persistent = cursor.read_u64_le()?; + let volatile = cursor.read_u64_le()?; + + Ok(CloseRequest { + flags, + file_id: FileId { + persistent, + volatile, + }, + }) + } +} + +/// SMB2 CLOSE Response (MS-SMB2 section 2.2.16). +/// +/// Sent by the server to confirm a close. The structure is 60 bytes: +/// - StructureSize (2 bytes, must be 60) +/// - Flags (2 bytes) +/// - Reserved (4 bytes) +/// - CreationTime (8 bytes) +/// - LastAccessTime (8 bytes) +/// - LastWriteTime (8 bytes) +/// - ChangeTime (8 bytes) +/// - AllocationSize (8 bytes) +/// - EndOfFile (8 bytes) +/// - FileAttributes (4 bytes) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CloseResponse { + /// Flags echoed from the request. If `SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB` + /// is set, the attribute fields below contain valid data. + pub flags: u16, + /// File creation time. + pub creation_time: FileTime, + /// Last access time. + pub last_access_time: FileTime, + /// Last write time. + pub last_write_time: FileTime, + /// Change time. + pub change_time: FileTime, + /// Size of allocated data in bytes. + pub allocation_size: u64, + /// End-of-file position in bytes. + pub end_of_file: u64, + /// File attributes (see MS-FSCC section 2.6). + pub file_attributes: u32, +} + +impl CloseResponse { + pub const STRUCTURE_SIZE: u16 = 60; +} + +impl Pack for CloseResponse { + fn pack(&self, cursor: &mut WriteCursor) { + cursor.write_u16_le(Self::STRUCTURE_SIZE); + cursor.write_u16_le(self.flags); + cursor.write_u32_le(0); // Reserved + self.creation_time.pack(cursor); + self.last_access_time.pack(cursor); + self.last_write_time.pack(cursor); + self.change_time.pack(cursor); + cursor.write_u64_le(self.allocation_size); + cursor.write_u64_le(self.end_of_file); + cursor.write_u32_le(self.file_attributes); + } +} + +impl Unpack for CloseResponse { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid CloseResponse structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + let flags = cursor.read_u16_le()?; + let _reserved = cursor.read_u32_le()?; + let creation_time = FileTime::unpack(cursor)?; + let last_access_time = FileTime::unpack(cursor)?; + let last_write_time = FileTime::unpack(cursor)?; + let change_time = FileTime::unpack(cursor)?; + let allocation_size = cursor.read_u64_le()?; + let end_of_file = cursor.read_u64_le()?; + let file_attributes = cursor.read_u32_le()?; + + Ok(CloseResponse { + flags, + creation_time, + last_access_time, + last_write_time, + change_time, + allocation_size, + end_of_file, + file_attributes, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── CloseRequest tests ───────────────────────────────────────── + + #[test] + fn close_request_roundtrip() { + let original = CloseRequest { + flags: SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB, + file_id: FileId { + persistent: 0x1122_3344_5566_7788, + volatile: 0xAABB_CCDD_EEFF_0011, + }, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // 2 + 2 + 4 + 16 = 24 bytes + assert_eq!(bytes.len(), 24); + + let mut r = ReadCursor::new(&bytes); + let decoded = CloseRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.flags, original.flags); + assert_eq!(decoded.file_id, original.file_id); + } + + #[test] + fn close_request_known_bytes() { + let mut buf = [0u8; 24]; + // StructureSize = 24 + buf[0..2].copy_from_slice(&24u16.to_le_bytes()); + // Flags = 0x0001 + buf[2..4].copy_from_slice(&1u16.to_le_bytes()); + // Reserved = 0 + buf[4..8].copy_from_slice(&0u32.to_le_bytes()); + // FileId persistent = 0x42 + buf[8..16].copy_from_slice(&0x42u64.to_le_bytes()); + // FileId volatile = 0x99 + buf[16..24].copy_from_slice(&0x99u64.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let req = CloseRequest::unpack(&mut cursor).unwrap(); + + assert_eq!(req.flags, SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB); + assert_eq!(req.file_id.persistent, 0x42); + assert_eq!(req.file_id.volatile, 0x99); + } + + #[test] + fn close_request_wrong_structure_size() { + let mut buf = [0u8; 24]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = CloseRequest::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + // ── CloseResponse tests ──────────────────────────────────────── + + #[test] + fn close_response_roundtrip() { + let original = CloseResponse { + flags: SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB, + creation_time: FileTime(0x01D8_AAAA_BBBB_CCCC), + last_access_time: FileTime(0x01D8_DDDD_EEEE_FFFF), + last_write_time: FileTime(0x01D8_1111_2222_3333), + change_time: FileTime(0x01D8_4444_5555_6666), + allocation_size: 4096, + end_of_file: 2048, + file_attributes: 0x20, // FILE_ATTRIBUTE_ARCHIVE + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // 2 + 2 + 4 + 8*6 + 4 = 60 bytes + assert_eq!(bytes.len(), 60); + + let mut r = ReadCursor::new(&bytes); + let decoded = CloseResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.flags, original.flags); + assert_eq!(decoded.creation_time, original.creation_time); + assert_eq!(decoded.last_access_time, original.last_access_time); + assert_eq!(decoded.last_write_time, original.last_write_time); + assert_eq!(decoded.change_time, original.change_time); + assert_eq!(decoded.allocation_size, original.allocation_size); + assert_eq!(decoded.end_of_file, original.end_of_file); + assert_eq!(decoded.file_attributes, original.file_attributes); + } + + #[test] + fn close_response_known_bytes() { + let mut buf = [0u8; 60]; + // StructureSize = 60 + buf[0..2].copy_from_slice(&60u16.to_le_bytes()); + // Flags = 0x0001 + buf[2..4].copy_from_slice(&1u16.to_le_bytes()); + // Reserved = 0 + buf[4..8].copy_from_slice(&0u32.to_le_bytes()); + // CreationTime = 100 + buf[8..16].copy_from_slice(&100u64.to_le_bytes()); + // LastAccessTime = 200 + buf[16..24].copy_from_slice(&200u64.to_le_bytes()); + // LastWriteTime = 300 + buf[24..32].copy_from_slice(&300u64.to_le_bytes()); + // ChangeTime = 400 + buf[32..40].copy_from_slice(&400u64.to_le_bytes()); + // AllocationSize = 8192 + buf[40..48].copy_from_slice(&8192u64.to_le_bytes()); + // EndOfFile = 1024 + buf[48..56].copy_from_slice(&1024u64.to_le_bytes()); + // FileAttributes = 0x10 (directory) + buf[56..60].copy_from_slice(&0x10u32.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let resp = CloseResponse::unpack(&mut cursor).unwrap(); + + assert_eq!(resp.flags, SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB); + assert_eq!(resp.creation_time, FileTime(100)); + assert_eq!(resp.last_access_time, FileTime(200)); + assert_eq!(resp.last_write_time, FileTime(300)); + assert_eq!(resp.change_time, FileTime(400)); + assert_eq!(resp.allocation_size, 8192); + assert_eq!(resp.end_of_file, 1024); + assert_eq!(resp.file_attributes, 0x10); + } + + #[test] + fn close_response_wrong_structure_size() { + let mut buf = [0u8; 60]; + buf[0..2].copy_from_slice(&42u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = CloseResponse::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + #[test] + fn close_response_zero_flags_has_zeroed_attributes() { + let original = CloseResponse { + flags: 0, + creation_time: FileTime::ZERO, + last_access_time: FileTime::ZERO, + last_write_time: FileTime::ZERO, + change_time: FileTime::ZERO, + allocation_size: 0, + end_of_file: 0, + file_attributes: 0, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = CloseResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.flags, 0); + assert_eq!(decoded.creation_time, FileTime::ZERO); + assert_eq!(decoded.file_attributes, 0); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::{arb_file_id, arb_file_time}; + use proptest::prelude::*; + + proptest! { + #[test] + fn close_request_pack_unpack( + flags in any::(), + file_id in arb_file_id(), + ) { + let original = CloseRequest { flags, file_id }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = CloseRequest::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + + #[test] + fn close_response_pack_unpack( + flags in any::(), + creation_time in arb_file_time(), + last_access_time in arb_file_time(), + last_write_time in arb_file_time(), + change_time in arb_file_time(), + allocation_size in any::(), + end_of_file in any::(), + file_attributes in any::(), + ) { + let original = CloseResponse { + flags, + creation_time, + last_access_time, + last_write_time, + change_time, + allocation_size, + end_of_file, + file_attributes, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = CloseResponse::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + } +} diff --git a/vendor/smb2/src/msg/create.rs b/vendor/smb2/src/msg/create.rs new file mode 100644 index 0000000..44ad2ef --- /dev/null +++ b/vendor/smb2/src/msg/create.rs @@ -0,0 +1,870 @@ +//! SMB2 CREATE request and response (spec sections 2.2.13, 2.2.14). +//! +//! The CREATE request opens or creates a file, named pipe, or printer. +//! The response carries the file handle ([`FileId`]) plus timestamps, +//! attributes, and optional create contexts. + +use crate::error::Result; +use crate::msg::header::Header; +use crate::pack::{FileTime, Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::flags::FileAccessMask; +use crate::types::{FileId, OplockLevel}; +use crate::Error; + +// ── Enums ──────────────────────────────────────────────────────────────── + +/// Impersonation level (MS-SMB2 2.2.13). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u32)] +pub enum ImpersonationLevel { + /// Anonymous impersonation. + Anonymous = 0, + /// Identification impersonation. + Identification = 1, + /// Impersonation level. + Impersonation = 2, + /// Delegate impersonation. + Delegate = 3, +} + +impl TryFrom for ImpersonationLevel { + type Error = Error; + + fn try_from(value: u32) -> Result { + match value { + 0 => Ok(Self::Anonymous), + 1 => Ok(Self::Identification), + 2 => Ok(Self::Impersonation), + 3 => Ok(Self::Delegate), + _ => Err(Error::invalid_data(format!( + "invalid ImpersonationLevel: {}", + value + ))), + } + } +} + +/// Share access flags (MS-SMB2 2.2.13). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct ShareAccess(pub u32); + +impl ShareAccess { + /// Allow other opens to read the file. + pub const FILE_SHARE_READ: u32 = 0x0000_0001; + /// Allow other opens to write the file. + pub const FILE_SHARE_WRITE: u32 = 0x0000_0002; + /// Allow other opens to delete the file. + pub const FILE_SHARE_DELETE: u32 = 0x0000_0004; +} + +/// Create disposition (MS-SMB2 2.2.13). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u32)] +pub enum CreateDisposition { + /// If the file exists, supersede it. Otherwise, create. + FileSupersede = 0, + /// If the file exists, open it. Otherwise, fail. + FileOpen = 1, + /// If the file exists, fail. Otherwise, create. + FileCreate = 2, + /// If the file exists, open it. Otherwise, create. + FileOpenIf = 3, + /// If the file exists, overwrite it. Otherwise, fail. + FileOverwrite = 4, + /// If the file exists, overwrite it. Otherwise, create. + FileOverwriteIf = 5, +} + +impl TryFrom for CreateDisposition { + type Error = Error; + + fn try_from(value: u32) -> Result { + match value { + 0 => Ok(Self::FileSupersede), + 1 => Ok(Self::FileOpen), + 2 => Ok(Self::FileCreate), + 3 => Ok(Self::FileOpenIf), + 4 => Ok(Self::FileOverwrite), + 5 => Ok(Self::FileOverwriteIf), + _ => Err(Error::invalid_data(format!( + "invalid CreateDisposition: {}", + value + ))), + } + } +} + +/// Create action returned in the response (MS-SMB2 2.2.14). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u32)] +pub enum CreateAction { + /// An existing file was superseded. + FileSuperseded = 0, + /// An existing file was opened. + FileOpened = 1, + /// A new file was created. + FileCreated = 2, + /// An existing file was overwritten. + FileOverwritten = 3, +} + +impl TryFrom for CreateAction { + type Error = Error; + + fn try_from(value: u32) -> Result { + match value { + 0 => Ok(Self::FileSuperseded), + 1 => Ok(Self::FileOpened), + 2 => Ok(Self::FileCreated), + 3 => Ok(Self::FileOverwritten), + _ => Err(Error::invalid_data(format!( + "invalid CreateAction: {}", + value + ))), + } + } +} + +// ── CreateRequest ──────────────────────────────────────────────────────── + +/// SMB2 CREATE request (spec section 2.2.13). +/// +/// Sent by the client to open or create a file on the server. +/// The buffer contains the filename encoded as UTF-16LE, optionally +/// followed by create context data. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CreateRequest { + /// Requested oplock level. + pub requested_oplock_level: OplockLevel, + /// Impersonation level. + pub impersonation_level: ImpersonationLevel, + /// Desired access rights. + pub desired_access: FileAccessMask, + /// File attributes for create/open. + pub file_attributes: u32, + /// Sharing mode. + pub share_access: ShareAccess, + /// Disposition: what to do if file exists/does not exist. + pub create_disposition: CreateDisposition, + /// Create options flags. + pub create_options: u32, + /// The filename to create or open. + pub name: String, + /// Raw create context bytes (unparsed). + pub create_contexts: Vec, +} + +impl CreateRequest { + pub const STRUCTURE_SIZE: u16 = 57; +} + +impl Pack for CreateRequest { + fn pack(&self, cursor: &mut WriteCursor) { + let start = cursor.position(); + + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // SecurityFlags (1 byte) -- must be 0 + cursor.write_u8(0); + // RequestedOplockLevel (1 byte) + cursor.write_u8(self.requested_oplock_level as u8); + // ImpersonationLevel (4 bytes) + cursor.write_u32_le(self.impersonation_level as u32); + // SmbCreateFlags (8 bytes) -- must be 0 + cursor.write_u64_le(0); + // Reserved (8 bytes) + cursor.write_u64_le(0); + // DesiredAccess (4 bytes) + cursor.write_u32_le(self.desired_access.bits()); + // FileAttributes (4 bytes) + cursor.write_u32_le(self.file_attributes); + // ShareAccess (4 bytes) + cursor.write_u32_le(self.share_access.0); + // CreateDisposition (4 bytes) + cursor.write_u32_le(self.create_disposition as u32); + // CreateOptions (4 bytes) + cursor.write_u32_le(self.create_options); + + // NameOffset (2 bytes) -- placeholder, backpatch later + let name_offset_pos = cursor.position(); + cursor.write_u16_le(0); + // NameLength (2 bytes) -- placeholder, backpatch later + let name_length_pos = cursor.position(); + cursor.write_u16_le(0); + // CreateContextsOffset (4 bytes) -- placeholder + let ctx_offset_pos = cursor.position(); + cursor.write_u32_le(0); + // CreateContextsLength (4 bytes) -- placeholder + let ctx_length_pos = cursor.position(); + cursor.write_u32_le(0); + + // Buffer: filename in UTF-16LE + // Offsets are from the beginning of the SMB2 header per spec. + let name_offset = Header::SIZE + (cursor.position() - start); + let name_start = cursor.position(); + cursor.write_utf16_le(&self.name); + let name_byte_len = cursor.position() - name_start; + + // Backpatch name offset and length + cursor.set_u16_le_at(name_offset_pos, name_offset as u16); + cursor.set_u16_le_at(name_length_pos, name_byte_len as u16); + + // Create contexts (if any) + if !self.create_contexts.is_empty() { + // Align to 8-byte boundary before create contexts + cursor.align_to(8); + let ctx_offset = Header::SIZE + (cursor.position() - start); + cursor.write_bytes(&self.create_contexts); + let ctx_len = self.create_contexts.len(); + + cursor.set_u32_le_at(ctx_offset_pos, ctx_offset as u32); + cursor.set_u32_le_at(ctx_length_pos, ctx_len as u32); + } else if name_byte_len == 0 { + // Per spec, buffer must be at least 1 byte even if name is empty + cursor.write_u8(0); + } + } +} + +impl Unpack for CreateRequest { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let start = cursor.position(); + + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid CreateRequest structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // SecurityFlags (1 byte) + let _security_flags = cursor.read_u8()?; + // RequestedOplockLevel (1 byte) + let oplock_raw = cursor.read_u8()?; + let requested_oplock_level = OplockLevel::try_from(oplock_raw)?; + // ImpersonationLevel (4 bytes) + let imp_raw = cursor.read_u32_le()?; + let impersonation_level = ImpersonationLevel::try_from(imp_raw)?; + // SmbCreateFlags (8 bytes) + let _smb_create_flags = cursor.read_u64_le()?; + // Reserved (8 bytes) + let _reserved = cursor.read_u64_le()?; + // DesiredAccess (4 bytes) + let desired_access = FileAccessMask::new(cursor.read_u32_le()?); + // FileAttributes (4 bytes) + let file_attributes = cursor.read_u32_le()?; + // ShareAccess (4 bytes) + let share_access = ShareAccess(cursor.read_u32_le()?); + // CreateDisposition (4 bytes) + let disp_raw = cursor.read_u32_le()?; + let create_disposition = CreateDisposition::try_from(disp_raw)?; + // CreateOptions (4 bytes) + let create_options = cursor.read_u32_le()?; + // NameOffset (2 bytes) + let name_offset = cursor.read_u16_le()? as usize; + // NameLength (2 bytes) + let name_length = cursor.read_u16_le()? as usize; + // CreateContextsOffset (4 bytes) + let ctx_offset = cursor.read_u32_le()? as usize; + // CreateContextsLength (4 bytes) + let ctx_length = cursor.read_u32_le()? as usize; + + // Read filename + // Offsets on the wire are from the beginning of the SMB2 header, + // so subtract Header::SIZE to get position within the body. + let name = if name_length > 0 { + let current = cursor.position(); + let body_offset = name_offset.saturating_sub(Header::SIZE); + let target = start + body_offset; + if target > current { + cursor.skip(target - current)?; + } + cursor.read_utf16_le(name_length)? + } else { + String::new() + }; + + // Read create contexts + let create_contexts = if ctx_length > 0 { + let current = cursor.position(); + let body_offset = ctx_offset.saturating_sub(Header::SIZE); + let target = start + body_offset; + if target > current { + cursor.skip(target - current)?; + } + cursor.read_bytes_bounded(ctx_length)?.to_vec() + } else { + Vec::new() + }; + + Ok(CreateRequest { + requested_oplock_level, + impersonation_level, + desired_access, + file_attributes, + share_access, + create_disposition, + create_options, + name, + create_contexts, + }) + } +} + +// ── CreateResponse ─────────────────────────────────────────────────────── + +/// SMB2 CREATE response (spec section 2.2.14). +/// +/// Returned by the server with the file handle and metadata about +/// the created or opened file. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CreateResponse { + /// Oplock level granted by the server. + pub oplock_level: OplockLevel, + /// Flags (SMB 3.x only). + pub flags: u8, + /// Action taken by the server (opened, created, etc.). + pub create_action: CreateAction, + /// Time the file was created. + pub creation_time: FileTime, + /// Time the file was last accessed. + pub last_access_time: FileTime, + /// Time the file was last written. + pub last_write_time: FileTime, + /// Time the file metadata was last changed. + pub change_time: FileTime, + /// Allocation size of the file in bytes. + pub allocation_size: u64, + /// End-of-file position (actual file size in bytes). + pub end_of_file: u64, + /// File attributes. + pub file_attributes: u32, + /// The file handle. + pub file_id: FileId, + /// Raw create context bytes from the response. + pub create_contexts: Vec, +} + +impl CreateResponse { + pub const STRUCTURE_SIZE: u16 = 89; +} + +impl Pack for CreateResponse { + fn pack(&self, cursor: &mut WriteCursor) { + let start = cursor.position(); + + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // OplockLevel (1 byte) + cursor.write_u8(self.oplock_level as u8); + // Flags (1 byte) + cursor.write_u8(self.flags); + // CreateAction (4 bytes) + cursor.write_u32_le(self.create_action as u32); + // CreationTime (8 bytes) + self.creation_time.pack(cursor); + // LastAccessTime (8 bytes) + self.last_access_time.pack(cursor); + // LastWriteTime (8 bytes) + self.last_write_time.pack(cursor); + // ChangeTime (8 bytes) + self.change_time.pack(cursor); + // AllocationSize (8 bytes) + cursor.write_u64_le(self.allocation_size); + // EndOfFile (8 bytes) + cursor.write_u64_le(self.end_of_file); + // FileAttributes (4 bytes) + cursor.write_u32_le(self.file_attributes); + // Reserved2 (4 bytes) + cursor.write_u32_le(0); + // FileId (16 bytes = persistent u64 + volatile u64) + cursor.write_u64_le(self.file_id.persistent); + cursor.write_u64_le(self.file_id.volatile); + // CreateContextsOffset (4 bytes) -- placeholder + let ctx_offset_pos = cursor.position(); + cursor.write_u32_le(0); + // CreateContextsLength (4 bytes) -- placeholder + let ctx_length_pos = cursor.position(); + cursor.write_u32_le(0); + + // Create contexts (if any) + if !self.create_contexts.is_empty() { + cursor.align_to(8); + let ctx_offset = Header::SIZE + (cursor.position() - start); + cursor.write_bytes(&self.create_contexts); + let ctx_len = self.create_contexts.len(); + + cursor.set_u32_le_at(ctx_offset_pos, ctx_offset as u32); + cursor.set_u32_le_at(ctx_length_pos, ctx_len as u32); + } + } +} + +impl Unpack for CreateResponse { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let start = cursor.position(); + + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid CreateResponse structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // OplockLevel (1 byte) + let oplock_level = OplockLevel::try_from(cursor.read_u8()?)?; + // Flags (1 byte) + let flags = cursor.read_u8()?; + // CreateAction (4 bytes) + let create_action = CreateAction::try_from(cursor.read_u32_le()?)?; + // CreationTime (8 bytes) + let creation_time = FileTime::unpack(cursor)?; + // LastAccessTime (8 bytes) + let last_access_time = FileTime::unpack(cursor)?; + // LastWriteTime (8 bytes) + let last_write_time = FileTime::unpack(cursor)?; + // ChangeTime (8 bytes) + let change_time = FileTime::unpack(cursor)?; + // AllocationSize (8 bytes) + let allocation_size = cursor.read_u64_le()?; + // EndOfFile (8 bytes) + let end_of_file = cursor.read_u64_le()?; + // FileAttributes (4 bytes) + let file_attributes = cursor.read_u32_le()?; + // Reserved2 (4 bytes) + let _reserved2 = cursor.read_u32_le()?; + // FileId (16 bytes) + let persistent = cursor.read_u64_le()?; + let volatile = cursor.read_u64_le()?; + let file_id = FileId { + persistent, + volatile, + }; + // CreateContextsOffset (4 bytes) + let ctx_offset = cursor.read_u32_le()? as usize; + // CreateContextsLength (4 bytes) + let ctx_length = cursor.read_u32_le()? as usize; + + // Read create contexts + // Offset on the wire is from beginning of SMB2 header. + let create_contexts = if ctx_length > 0 { + let current = cursor.position(); + let body_offset = ctx_offset.saturating_sub(Header::SIZE); + let target = start + body_offset; + if target > current { + cursor.skip(target - current)?; + } + cursor.read_bytes_bounded(ctx_length)?.to_vec() + } else { + Vec::new() + }; + + Ok(CreateResponse { + oplock_level, + flags, + create_action, + creation_time, + last_access_time, + last_write_time, + change_time, + allocation_size, + end_of_file, + file_attributes, + file_id, + create_contexts, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── CreateRequest tests ────────────────────────────────────────── + + #[test] + fn create_request_roundtrip_no_contexts() { + let original = CreateRequest { + requested_oplock_level: OplockLevel::Exclusive, + impersonation_level: ImpersonationLevel::Impersonation, + desired_access: FileAccessMask::new( + FileAccessMask::GENERIC_READ | FileAccessMask::FILE_READ_ATTRIBUTES, + ), + file_attributes: 0x80, // FILE_ATTRIBUTE_NORMAL + share_access: ShareAccess(ShareAccess::FILE_SHARE_READ | ShareAccess::FILE_SHARE_WRITE), + create_disposition: CreateDisposition::FileOpenIf, + create_options: 0, + name: "test\\file.txt".to_string(), + create_contexts: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = CreateRequest::unpack(&mut r).unwrap(); + + assert_eq!( + decoded.requested_oplock_level, + original.requested_oplock_level + ); + assert_eq!(decoded.impersonation_level, original.impersonation_level); + assert_eq!(decoded.desired_access, original.desired_access); + assert_eq!(decoded.file_attributes, original.file_attributes); + assert_eq!(decoded.share_access, original.share_access); + assert_eq!(decoded.create_disposition, original.create_disposition); + assert_eq!(decoded.create_options, original.create_options); + assert_eq!(decoded.name, original.name); + assert!(decoded.create_contexts.is_empty()); + } + + #[test] + fn create_request_roundtrip_with_create_contexts() { + // Simulate a raw create context blob (for example, a + // SMB2_CREATE_QUERY_MAXIMAL_ACCESS_REQUEST context). + let fake_ctx = vec![ + 0x00, 0x00, 0x00, 0x00, // NextEntryOffset = 0 (last entry) + 0x10, 0x00, // NameOffset = 16 + 0x04, 0x00, // NameLength = 4 + 0x00, 0x00, // Reserved + 0x18, 0x00, // DataOffset = 24 + 0x04, 0x00, 0x00, 0x00, // DataLength = 4 + b'M', b'x', b'A', b'c', // Name = "MxAc" + 0x00, 0x00, 0x00, 0x00, // padding + 0x01, 0x02, 0x03, 0x04, // Data (4 bytes) + ]; + + let original = CreateRequest { + requested_oplock_level: OplockLevel::Batch, + impersonation_level: ImpersonationLevel::Delegate, + desired_access: FileAccessMask::new(FileAccessMask::GENERIC_ALL), + file_attributes: 0x20, // FILE_ATTRIBUTE_ARCHIVE + share_access: ShareAccess(ShareAccess::FILE_SHARE_DELETE), + create_disposition: CreateDisposition::FileCreate, + create_options: 0x0000_0040, // FILE_NON_DIRECTORY_FILE + name: "share\\docs\\report.docx".to_string(), + create_contexts: fake_ctx.clone(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = CreateRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.requested_oplock_level, OplockLevel::Batch); + assert_eq!(decoded.impersonation_level, ImpersonationLevel::Delegate); + assert_eq!(decoded.name, "share\\docs\\report.docx"); + assert_eq!(decoded.create_contexts, fake_ctx); + } + + #[test] + fn create_request_structure_size_field() { + let req = CreateRequest { + requested_oplock_level: OplockLevel::None, + impersonation_level: ImpersonationLevel::Anonymous, + desired_access: FileAccessMask::default(), + file_attributes: 0, + share_access: ShareAccess::default(), + create_disposition: CreateDisposition::FileOpen, + create_options: 0, + name: "x".to_string(), + create_contexts: Vec::new(), + }; + + let mut w = WriteCursor::new(); + req.pack(&mut w); + let bytes = w.into_inner(); + + // First two bytes are StructureSize = 57 + assert_eq!(bytes[0], 57); + assert_eq!(bytes[1], 0); + } + + #[test] + fn create_request_wrong_structure_size() { + let mut buf = vec![0u8; 64]; + // Set wrong structure size + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + let mut cursor = ReadCursor::new(&buf); + let result = CreateRequest::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + // ── CreateResponse tests ───────────────────────────────────────── + + #[test] + fn create_response_roundtrip() { + let original = CreateResponse { + oplock_level: OplockLevel::LevelII, + flags: 0, + create_action: CreateAction::FileOpened, + creation_time: FileTime(133_485_408_000_000_000), + last_access_time: FileTime(133_485_408_100_000_000), + last_write_time: FileTime(133_485_408_200_000_000), + change_time: FileTime(133_485_408_300_000_000), + allocation_size: 4096, + end_of_file: 1234, + file_attributes: 0x20, // FILE_ATTRIBUTE_ARCHIVE + file_id: FileId { + persistent: 0x1111_2222_3333_4444, + volatile: 0x5555_6666_7777_8888, + }, + create_contexts: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = CreateResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.oplock_level, original.oplock_level); + assert_eq!(decoded.flags, original.flags); + assert_eq!(decoded.create_action, original.create_action); + assert_eq!(decoded.creation_time, original.creation_time); + assert_eq!(decoded.last_access_time, original.last_access_time); + assert_eq!(decoded.last_write_time, original.last_write_time); + assert_eq!(decoded.change_time, original.change_time); + assert_eq!(decoded.allocation_size, original.allocation_size); + assert_eq!(decoded.end_of_file, original.end_of_file); + assert_eq!(decoded.file_attributes, original.file_attributes); + assert_eq!(decoded.file_id, original.file_id); + assert!(decoded.create_contexts.is_empty()); + } + + #[test] + fn create_response_with_contexts() { + let ctx_data = vec![0xAA, 0xBB, 0xCC, 0xDD]; + let original = CreateResponse { + oplock_level: OplockLevel::None, + flags: 0x01, + create_action: CreateAction::FileCreated, + creation_time: FileTime(100), + last_access_time: FileTime(200), + last_write_time: FileTime(300), + change_time: FileTime(400), + allocation_size: 0, + end_of_file: 0, + file_attributes: 0, + file_id: FileId { + persistent: 1, + volatile: 2, + }, + create_contexts: ctx_data.clone(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = CreateResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.create_action, CreateAction::FileCreated); + assert_eq!(decoded.file_id.persistent, 1); + assert_eq!(decoded.file_id.volatile, 2); + assert_eq!(decoded.create_contexts, ctx_data); + } + + #[test] + fn create_response_structure_size_field() { + let resp = CreateResponse { + oplock_level: OplockLevel::None, + flags: 0, + create_action: CreateAction::FileOpened, + creation_time: FileTime::ZERO, + last_access_time: FileTime::ZERO, + last_write_time: FileTime::ZERO, + change_time: FileTime::ZERO, + allocation_size: 0, + end_of_file: 0, + file_attributes: 0, + file_id: FileId::default(), + create_contexts: Vec::new(), + }; + + let mut w = WriteCursor::new(); + resp.pack(&mut w); + let bytes = w.into_inner(); + + // First two bytes are StructureSize = 89 + assert_eq!(bytes[0], 89); + assert_eq!(bytes[1], 0); + } + + #[test] + fn create_response_wrong_structure_size() { + let mut buf = vec![0u8; 96]; + buf[0..2].copy_from_slice(&42u16.to_le_bytes()); + let mut cursor = ReadCursor::new(&buf); + let result = CreateResponse::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + // ── Enum conversion tests ──────────────────────────────────────── + + #[test] + fn oplock_level_roundtrip() { + for &level in &[ + OplockLevel::None, + OplockLevel::LevelII, + OplockLevel::Exclusive, + OplockLevel::Batch, + OplockLevel::Lease, + ] { + let raw = level as u8; + let decoded = OplockLevel::try_from(raw).unwrap(); + assert_eq!(decoded, level); + } + } + + #[test] + fn oplock_level_invalid() { + assert!(OplockLevel::try_from(0x42).is_err()); + } + + #[test] + fn impersonation_level_roundtrip() { + for &level in &[ + ImpersonationLevel::Anonymous, + ImpersonationLevel::Identification, + ImpersonationLevel::Impersonation, + ImpersonationLevel::Delegate, + ] { + let raw = level as u32; + let decoded = ImpersonationLevel::try_from(raw).unwrap(); + assert_eq!(decoded, level); + } + } + + #[test] + fn create_disposition_roundtrip() { + for &disp in &[ + CreateDisposition::FileSupersede, + CreateDisposition::FileOpen, + CreateDisposition::FileCreate, + CreateDisposition::FileOpenIf, + CreateDisposition::FileOverwrite, + CreateDisposition::FileOverwriteIf, + ] { + let raw = disp as u32; + let decoded = CreateDisposition::try_from(raw).unwrap(); + assert_eq!(decoded, disp); + } + } + + #[test] + fn create_action_roundtrip() { + for &action in &[ + CreateAction::FileSuperseded, + CreateAction::FileOpened, + CreateAction::FileCreated, + CreateAction::FileOverwritten, + ] { + let raw = action as u32; + let decoded = CreateAction::try_from(raw).unwrap(); + assert_eq!(decoded, action); + } + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::{ + arb_create_action, arb_create_disposition, arb_file_access_mask, arb_file_id, + arb_file_time, arb_impersonation_level, arb_oplock_level, arb_share_access, + arb_small_bytes, arb_utf16_string, + }; + use proptest::prelude::*; + + proptest! { + #[test] + fn create_request_pack_unpack( + requested_oplock_level in arb_oplock_level(), + impersonation_level in arb_impersonation_level(), + desired_access in arb_file_access_mask(), + file_attributes in any::(), + share_access in arb_share_access(), + create_disposition in arb_create_disposition(), + create_options in any::(), + name in arb_utf16_string(128), + create_contexts in arb_small_bytes(), + ) { + let original = CreateRequest { + requested_oplock_level, + impersonation_level, + desired_access, + file_attributes, + share_access, + create_disposition, + create_options, + name, + create_contexts, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = CreateRequest::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + // Note: pack may write a trailing 1-byte pad when name is empty + // and there are no create contexts. Unpack only advances through + // fields it reads, so the cursor may have 1 trailing byte in + // that corner case. That's fine for symmetry on struct contents. + } + + #[test] + fn create_response_pack_unpack( + oplock_level in arb_oplock_level(), + flags in any::(), + create_action in arb_create_action(), + creation_time in arb_file_time(), + last_access_time in arb_file_time(), + last_write_time in arb_file_time(), + change_time in arb_file_time(), + allocation_size in any::(), + end_of_file in any::(), + file_attributes in any::(), + file_id in arb_file_id(), + create_contexts in arb_small_bytes(), + ) { + let original = CreateResponse { + oplock_level, + flags, + create_action, + creation_time, + last_access_time, + last_write_time, + change_time, + allocation_size, + end_of_file, + file_attributes, + file_id, + create_contexts, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = CreateResponse::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + } + } +} diff --git a/vendor/smb2/src/msg/dfs.rs b/vendor/smb2/src/msg/dfs.rs new file mode 100644 index 0000000..d84e16d --- /dev/null +++ b/vendor/smb2/src/msg/dfs.rs @@ -0,0 +1,697 @@ +//! DFS referral request and response wire format (MS-DFSC sections 2.2.2, 2.2.4). +//! +//! These types are packed into the input/output buffers of an IOCTL request +//! with `ctl_code = FSCTL_DFS_GET_REFERRALS`. + +use crate::error::Result; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::Error; + +// ── ReqGetDfsReferral ───────────────────────────────────────────────── + +/// REQ_GET_DFS_REFERRAL (MS-DFSC 2.2.2). +/// +/// Sent as the input buffer of an `FSCTL_DFS_GET_REFERRALS` IOCTL request. +/// Contains the maximum referral version the client understands and the +/// DFS path to resolve. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ReqGetDfsReferral { + /// Highest DFS referral version understood by the client (typically 4). + pub max_referral_level: u16, + /// The DFS path to resolve (case-insensitive UNC path). + pub request_file_name: String, +} + +impl Pack for ReqGetDfsReferral { + fn pack(&self, cursor: &mut WriteCursor) { + // MaxReferralLevel (2 bytes, LE) + cursor.write_u16_le(self.max_referral_level); + // RequestFileName (null-terminated UTF-16LE) + cursor.write_utf16_le(&self.request_file_name); + // Null terminator (2 bytes) + cursor.write_u16_le(0); + } +} + +impl Unpack for ReqGetDfsReferral { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let max_referral_level = cursor.read_u16_le()?; + // Read the rest as null-terminated UTF-16LE. + let request_file_name = read_null_terminated_utf16(cursor)?; + Ok(ReqGetDfsReferral { + max_referral_level, + request_file_name, + }) + } +} + +// ── RespGetDfsReferral ──────────────────────────────────────────────── + +/// RESP_GET_DFS_REFERRAL (MS-DFSC 2.2.4). +/// +/// Returned in the output buffer of an IOCTL response for +/// `FSCTL_DFS_GET_REFERRALS`. Contains the number of bytes of the path +/// consumed by the server, header flags, and a list of referral entries. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RespGetDfsReferral { + /// Number of bytes (not characters) of the path prefix that matched. + pub path_consumed: u16, + /// Header flags (ReferralServers | StorageServers | TargetFailback). + pub header_flags: u32, + /// The list of referral entries (V2, V3, or V4). + pub entries: Vec, +} + +/// A single DFS referral entry (V2-V4 flattened). +/// +/// V1 is not supported (extremely rare in practice). Each entry describes +/// one target server/share that the client can use to access the DFS path. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DfsReferralEntry { + /// Referral entry version (2, 3, or 4). + pub version: u16, + /// Server type: 0 = non-root/link target, 1 = root target. + pub server_type: u16, + /// Referral entry flags (version-specific). + pub referral_entry_flags: u16, + /// Time-to-live in seconds for caching this referral. + pub ttl: u32, + /// The DFS path prefix that matched. + pub dfs_path: String, + /// The DFS alternate path (usually identical to dfs_path). + pub dfs_alternate_path: String, + /// The target UNC path (for example, `\\server\share`). + pub network_address: String, +} + +impl Unpack for RespGetDfsReferral { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let path_consumed = cursor.read_u16_le()?; + let number_of_referrals = cursor.read_u16_le()?; + let header_flags = cursor.read_u32_le()?; + + // The remaining data contains all referral entries followed by a + // string buffer. We need the full remaining slice to resolve + // offsets that are relative to each entry's start. + let entry_data = cursor.read_bytes(cursor.remaining())?; + + let mut entries = Vec::with_capacity(number_of_referrals as usize); + let mut offset = 0usize; + + for _ in 0..number_of_referrals { + if offset + 4 > entry_data.len() { + return Err(Error::invalid_data( + "DFS referral entry truncated (version/size header)", + )); + } + + let version = u16::from_le_bytes([entry_data[offset], entry_data[offset + 1]]); + let entry_size = + u16::from_le_bytes([entry_data[offset + 2], entry_data[offset + 3]]) as usize; + + if entry_size < 4 { + return Err(Error::invalid_data(format!( + "DFS referral entry size too small: {entry_size}" + ))); + } + + let entry_start = offset; + // The entry_size includes the version and size fields themselves. + let entry_end = entry_start + entry_size; + if entry_end > entry_data.len() { + return Err(Error::invalid_data(format!( + "DFS referral entry extends past buffer: entry_end={entry_end}, buf={}", + entry_data.len() + ))); + } + + // All strings referenced by offsets live from entry_start onward + // in the full buffer (not truncated to entry_size, because the + // strings are in the trailing string buffer). + let entry = parse_referral_entry(version, entry_data, entry_start)?; + entries.push(entry); + + offset = entry_end; + } + + Ok(RespGetDfsReferral { + path_consumed, + header_flags, + entries, + }) + } +} + +/// Parse a single referral entry starting at `entry_start` within `buf`. +/// +/// String offsets in V2/V3/V4 are relative to the start of the entry +/// (which includes the 4-byte version+size prefix). +fn parse_referral_entry(version: u16, buf: &[u8], entry_start: usize) -> Result { + // Skip version (2) + size (2) -- already read by caller. + let mut pos = entry_start + 4; + + match version { + 2 => { + // V2: server_type(2) + flags(2) + proximity(4) + ttl(4) + + // dfs_path_offset(2) + dfs_alternate_path_offset(2) + network_address_offset(2) + // = 18 bytes of fixed entry body after the 4-byte version/size prefix. + ensure_remaining(buf, pos, 18)?; + let server_type = read_u16(buf, pos); + pos += 2; + let referral_entry_flags = read_u16(buf, pos); + pos += 2; + let _proximity = read_u32(buf, pos); + pos += 4; + let ttl = read_u32(buf, pos); + pos += 4; + let dfs_path_offset = read_u16(buf, pos) as usize; + pos += 2; + let dfs_alternate_path_offset = read_u16(buf, pos) as usize; + pos += 2; + let network_address_offset = read_u16(buf, pos) as usize; + + let dfs_path = read_offset_string(buf, entry_start, dfs_path_offset)?; + let dfs_alternate_path = + read_offset_string(buf, entry_start, dfs_alternate_path_offset)?; + let network_address = read_offset_string(buf, entry_start, network_address_offset)?; + + Ok(DfsReferralEntry { + version, + server_type, + referral_entry_flags, + ttl, + dfs_path, + dfs_alternate_path, + network_address, + }) + } + 3 | 4 => { + // V3/V4 share the same layout for the common (non-NameListReferral) case. + // server_type(2) + flags(2) + ttl(4) + + // dfs_path_offset(2) + dfs_alternate_path_offset(2) + network_address_offset(2) + // V3/V4: + service_site_guid(16) when NameListReferral=0 + ensure_remaining(buf, pos, 14)?; + let server_type = read_u16(buf, pos); + pos += 2; + let referral_entry_flags = read_u16(buf, pos); + pos += 2; + let ttl = read_u32(buf, pos); + pos += 4; + let dfs_path_offset = read_u16(buf, pos) as usize; + pos += 2; + let dfs_alternate_path_offset = read_u16(buf, pos) as usize; + pos += 2; + let network_address_offset = read_u16(buf, pos) as usize; + // Skip the rest of the fixed entry (service_site_guid for V3/V4). + + let dfs_path = read_offset_string(buf, entry_start, dfs_path_offset)?; + let dfs_alternate_path = + read_offset_string(buf, entry_start, dfs_alternate_path_offset)?; + let network_address = read_offset_string(buf, entry_start, network_address_offset)?; + + Ok(DfsReferralEntry { + version, + server_type, + referral_entry_flags, + ttl, + dfs_path, + dfs_alternate_path, + network_address, + }) + } + _ => Err(Error::invalid_data(format!( + "unsupported DFS referral version: {version} (only V2-V4 are supported)" + ))), + } +} + +// ── Helper functions ────────────────────────────────────────────────── + +/// Read a null-terminated UTF-16LE string from a `ReadCursor`. +fn read_null_terminated_utf16(cursor: &mut ReadCursor<'_>) -> Result { + let mut code_units: Vec = Vec::new(); + loop { + let cu = cursor.read_u16_le()?; + if cu == 0 { + break; + } + code_units.push(cu); + } + String::from_utf16(&code_units) + .map_err(|_| Error::invalid_data("invalid UTF-16LE in DFS request file name")) +} + +/// Read a null-terminated UTF-16LE string from a raw byte buffer at a given absolute offset. +fn read_null_terminated_utf16_at(buf: &[u8], offset: usize) -> Result { + let mut code_units: Vec = Vec::new(); + let mut pos = offset; + loop { + if pos + 2 > buf.len() { + return Err(Error::invalid_data( + "DFS referral string extends past buffer", + )); + } + let cu = u16::from_le_bytes([buf[pos], buf[pos + 1]]); + pos += 2; + if cu == 0 { + break; + } + code_units.push(cu); + } + String::from_utf16(&code_units) + .map_err(|_| Error::invalid_data("invalid UTF-16LE in DFS referral string")) +} + +/// Read a null-terminated UTF-16LE string at an offset relative to an entry start. +fn read_offset_string(buf: &[u8], entry_start: usize, offset: usize) -> Result { + let abs = entry_start + offset; + read_null_terminated_utf16_at(buf, abs) +} + +/// Inline LE u16 read from a byte buffer. +fn read_u16(buf: &[u8], pos: usize) -> u16 { + u16::from_le_bytes([buf[pos], buf[pos + 1]]) +} + +/// Inline LE u32 read from a byte buffer. +fn read_u32(buf: &[u8], pos: usize) -> u32 { + u32::from_le_bytes([buf[pos], buf[pos + 1], buf[pos + 2], buf[pos + 3]]) +} + +/// Check that at least `need` bytes are available at `pos` in `buf`. +fn ensure_remaining(buf: &[u8], pos: usize, need: usize) -> Result<()> { + if pos + need > buf.len() { + Err(Error::invalid_data(format!( + "DFS referral entry truncated: need {need} bytes at offset {pos}, buf len {}", + buf.len() + ))) + } else { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── Request tests ───────────────────────────────────────────────── + + #[test] + fn req_pack_known_bytes() { + // Test vector from smb-rs: ReqGetDfsReferral { max_referral_level: 4, + // request_file_name: r"\ADC.aviv.local\dfs\Docs" } + let expected = hex_to_bytes( + "04005c004100440043002e0061007600690076002e006c006f00630061006c005c006400660073005c0044006f00630073000000", + ); + let req = ReqGetDfsReferral { + max_referral_level: 4, + request_file_name: r"\ADC.aviv.local\dfs\Docs".to_string(), + }; + let mut cursor = WriteCursor::new(); + req.pack(&mut cursor); + assert_eq!(cursor.into_inner(), expected); + } + + #[test] + fn req_pack_roundtrip() { + let original = ReqGetDfsReferral { + max_referral_level: 4, + request_file_name: r"\server\share\path".to_string(), + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = ReqGetDfsReferral::unpack(&mut r).unwrap(); + assert_eq!(decoded, original); + } + + #[test] + fn req_pack_empty_path() { + let req = ReqGetDfsReferral { + max_referral_level: 3, + request_file_name: String::new(), + }; + let mut w = WriteCursor::new(); + req.pack(&mut w); + let bytes = w.into_inner(); + // max_referral_level (2) + null terminator (2) = 4 bytes + assert_eq!(bytes.len(), 4); + assert_eq!(bytes, [0x03, 0x00, 0x00, 0x00]); + + let mut r = ReadCursor::new(&bytes); + let decoded = ReqGetDfsReferral::unpack(&mut r).unwrap(); + assert_eq!(decoded, req); + } + + #[test] + fn req_unpack_truncated() { + // Only 1 byte -- not enough for max_referral_level. + let bytes = [0x04]; + let mut r = ReadCursor::new(&bytes); + assert!(ReqGetDfsReferral::unpack(&mut r).is_err()); + } + + // ── Response tests ──────────────────────────────────────────────── + + #[test] + fn resp_parse_v4_referral() { + // Test vector from smb-rs: two V4 entries. + let hex = "300002000200000004002200000004000807000044007600\ + a800000000000000000000000000000000000400220000000000\ + 0807000022005400a8000000000000000000000000000000\ + 00005c004100440043002e0061007600690076002e006c00\ + 6f00630061006c005c006400660073005c0044006f006300\ + 730000005c004100440043002e0061007600690076002e00\ + 6c006f00630061006c005c006400660073005c0044006f00\ + 6300730000005c004100440043005c005300680061007200\ + 650073005c0044006f006300730000005c00460053005200\ + 56005c005300680061007200650073005c004d0079005300\ + 6800610072006500000000"; + let data = hex_to_bytes(hex); + let mut cursor = ReadCursor::new(&data); + let resp = RespGetDfsReferral::unpack(&mut cursor).unwrap(); + + assert_eq!(resp.path_consumed, 48); + // header_flags = 0x00000002 (StorageServers) + assert_eq!(resp.header_flags, 0x0000_0002); + assert_eq!(resp.entries.len(), 2); + + let e0 = &resp.entries[0]; + assert_eq!(e0.version, 4); + assert_eq!(e0.server_type, 0); // non-root + assert_eq!(e0.ttl, 1800); + assert_eq!(e0.dfs_path, r"\ADC.aviv.local\dfs\Docs"); + assert_eq!(e0.dfs_alternate_path, r"\ADC.aviv.local\dfs\Docs"); + assert_eq!(e0.network_address, r"\ADC\Shares\Docs"); + + let e1 = &resp.entries[1]; + assert_eq!(e1.version, 4); + assert_eq!(e1.server_type, 0); + assert_eq!(e1.ttl, 1800); + assert_eq!(e1.dfs_path, r"\ADC.aviv.local\dfs\Docs"); + assert_eq!(e1.dfs_alternate_path, r"\ADC.aviv.local\dfs\Docs"); + assert_eq!(e1.network_address, r"\FSRV\Shares\MyShare"); + } + + #[test] + fn resp_parse_v3_referral() { + // Manually constructed V3 response: one entry. + // Header: path_consumed=20, num_referrals=1, flags=0x03 + // Entry: version=3, size=34 (fixed part), server_type=1, flags=0, + // ttl=600, offsets point to strings after the entry. + let dfs_path = encode_null_utf16(r"\dom\share"); + let alt_path = encode_null_utf16(r"\dom\share"); + let net_addr = encode_null_utf16(r"\srv\share"); + + let entry_fixed_size: u16 = 34; // 4 + 2+2+4 + 2+2+2 + 16 = 34 + let dfs_path_offset = entry_fixed_size; + let alt_path_offset = dfs_path_offset + dfs_path.len() as u16; + let net_addr_offset = alt_path_offset + alt_path.len() as u16; + + let mut buf = Vec::new(); + // Response header + buf.extend_from_slice(&20u16.to_le_bytes()); // path_consumed + buf.extend_from_slice(&1u16.to_le_bytes()); // number_of_referrals + buf.extend_from_slice(&3u32.to_le_bytes()); // header_flags + + // Entry header + buf.extend_from_slice(&3u16.to_le_bytes()); // version + buf.extend_from_slice(&entry_fixed_size.to_le_bytes()); // size (fixed part) + buf.extend_from_slice(&1u16.to_le_bytes()); // server_type (root) + buf.extend_from_slice(&0u16.to_le_bytes()); // referral_entry_flags + buf.extend_from_slice(&600u32.to_le_bytes()); // ttl + buf.extend_from_slice(&dfs_path_offset.to_le_bytes()); + buf.extend_from_slice(&alt_path_offset.to_le_bytes()); + buf.extend_from_slice(&net_addr_offset.to_le_bytes()); + buf.extend_from_slice(&[0u8; 16]); // service_site_guid + + // String buffer + buf.extend_from_slice(&dfs_path); + buf.extend_from_slice(&alt_path); + buf.extend_from_slice(&net_addr); + + let mut cursor = ReadCursor::new(&buf); + let resp = RespGetDfsReferral::unpack(&mut cursor).unwrap(); + + assert_eq!(resp.path_consumed, 20); + assert_eq!(resp.header_flags, 3); + assert_eq!(resp.entries.len(), 1); + + let e = &resp.entries[0]; + assert_eq!(e.version, 3); + assert_eq!(e.server_type, 1); + assert_eq!(e.ttl, 600); + assert_eq!(e.dfs_path, r"\dom\share"); + assert_eq!(e.dfs_alternate_path, r"\dom\share"); + assert_eq!(e.network_address, r"\srv\share"); + } + + #[test] + fn resp_parse_v2_referral() { + // Manually constructed V2 response: one entry. + let dfs_path = encode_null_utf16(r"\domain\dfs"); + let alt_path = encode_null_utf16(r"\domain\dfs"); + let net_addr = encode_null_utf16(r"\server\data"); + + let entry_fixed_size: u16 = 22; // 4 + 2+2+4+4 + 2+2+2 = 22 + let dfs_path_offset = entry_fixed_size; + let alt_path_offset = dfs_path_offset + dfs_path.len() as u16; + let net_addr_offset = alt_path_offset + alt_path.len() as u16; + + let mut buf = Vec::new(); + // Response header + buf.extend_from_slice(&24u16.to_le_bytes()); // path_consumed + buf.extend_from_slice(&1u16.to_le_bytes()); // number_of_referrals + buf.extend_from_slice(&1u32.to_le_bytes()); // header_flags (ReferralServers) + + // Entry + buf.extend_from_slice(&2u16.to_le_bytes()); // version + buf.extend_from_slice(&entry_fixed_size.to_le_bytes()); // size + buf.extend_from_slice(&0u16.to_le_bytes()); // server_type + buf.extend_from_slice(&0u16.to_le_bytes()); // flags + buf.extend_from_slice(&0u32.to_le_bytes()); // proximity + buf.extend_from_slice(&300u32.to_le_bytes()); // ttl + buf.extend_from_slice(&dfs_path_offset.to_le_bytes()); + buf.extend_from_slice(&alt_path_offset.to_le_bytes()); + buf.extend_from_slice(&net_addr_offset.to_le_bytes()); + + // String buffer + buf.extend_from_slice(&dfs_path); + buf.extend_from_slice(&alt_path); + buf.extend_from_slice(&net_addr); + + let mut cursor = ReadCursor::new(&buf); + let resp = RespGetDfsReferral::unpack(&mut cursor).unwrap(); + + assert_eq!(resp.path_consumed, 24); + assert_eq!(resp.header_flags, 1); + assert_eq!(resp.entries.len(), 1); + + let e = &resp.entries[0]; + assert_eq!(e.version, 2); + assert_eq!(e.server_type, 0); + assert_eq!(e.ttl, 300); + assert_eq!(e.dfs_path, r"\domain\dfs"); + assert_eq!(e.dfs_alternate_path, r"\domain\dfs"); + assert_eq!(e.network_address, r"\server\data"); + } + + #[test] + fn resp_parse_empty() { + // Zero referral entries. + let mut buf = Vec::new(); + buf.extend_from_slice(&0u16.to_le_bytes()); // path_consumed + buf.extend_from_slice(&0u16.to_le_bytes()); // number_of_referrals + buf.extend_from_slice(&0u32.to_le_bytes()); // header_flags + + let mut cursor = ReadCursor::new(&buf); + let resp = RespGetDfsReferral::unpack(&mut cursor).unwrap(); + assert_eq!(resp.path_consumed, 0); + assert_eq!(resp.header_flags, 0); + assert!(resp.entries.is_empty()); + } + + #[test] + fn resp_parse_multiple_entries() { + // Two V2 entries with different targets. + // Layout: [entry1 fixed][entry2 fixed][strings for entry1][strings for entry2] + // Offsets are relative to each entry's start. + let dfs_path = encode_null_utf16(r"\ns\link"); + let alt_path = encode_null_utf16(r"\ns\link"); + let net_addr_1 = encode_null_utf16(r"\srv1\data"); + let net_addr_2 = encode_null_utf16(r"\srv2\data"); + + let entry_fixed_size: u16 = 22; + let total_fixed: u16 = entry_fixed_size * 2; // both entries' fixed parts + + // Entry 1 string offsets (relative to entry 1 start = 0 in entry_data). + // Strings start after both entries' fixed parts. + let e1_dfs_offset = total_fixed; // 44 + let e1_alt_offset = e1_dfs_offset + dfs_path.len() as u16; + let e1_net_offset = e1_alt_offset + alt_path.len() as u16; + let e1_strings_end = e1_net_offset + net_addr_1.len() as u16; + + // Entry 2 string offsets (relative to entry 2 start = 22 in entry_data). + let e2_dfs_offset = e1_strings_end - entry_fixed_size; // offset from entry 2 start + let e2_alt_offset = e2_dfs_offset + dfs_path.len() as u16; + let e2_net_offset = e2_alt_offset + alt_path.len() as u16; + + let mut buf = Vec::new(); + // Response header + buf.extend_from_slice(&16u16.to_le_bytes()); // path_consumed + buf.extend_from_slice(&2u16.to_le_bytes()); // number_of_referrals + buf.extend_from_slice(&0u32.to_le_bytes()); // header_flags + + // Entry 1 fixed part + buf.extend_from_slice(&2u16.to_le_bytes()); // version + buf.extend_from_slice(&entry_fixed_size.to_le_bytes()); // size + buf.extend_from_slice(&0u16.to_le_bytes()); // server_type + buf.extend_from_slice(&0u16.to_le_bytes()); // flags + buf.extend_from_slice(&0u32.to_le_bytes()); // proximity + buf.extend_from_slice(&120u32.to_le_bytes()); // ttl + buf.extend_from_slice(&e1_dfs_offset.to_le_bytes()); + buf.extend_from_slice(&e1_alt_offset.to_le_bytes()); + buf.extend_from_slice(&e1_net_offset.to_le_bytes()); + + // Entry 2 fixed part + buf.extend_from_slice(&2u16.to_le_bytes()); + buf.extend_from_slice(&entry_fixed_size.to_le_bytes()); + buf.extend_from_slice(&1u16.to_le_bytes()); // server_type = root + buf.extend_from_slice(&0u16.to_le_bytes()); + buf.extend_from_slice(&0u32.to_le_bytes()); + buf.extend_from_slice(&240u32.to_le_bytes()); + buf.extend_from_slice(&e2_dfs_offset.to_le_bytes()); + buf.extend_from_slice(&e2_alt_offset.to_le_bytes()); + buf.extend_from_slice(&e2_net_offset.to_le_bytes()); + + // String buffer for entry 1 + buf.extend_from_slice(&dfs_path); + buf.extend_from_slice(&alt_path); + buf.extend_from_slice(&net_addr_1); + + // String buffer for entry 2 + buf.extend_from_slice(&dfs_path); + buf.extend_from_slice(&alt_path); + buf.extend_from_slice(&net_addr_2); + + let mut cursor = ReadCursor::new(&buf); + let resp = RespGetDfsReferral::unpack(&mut cursor).unwrap(); + + assert_eq!(resp.entries.len(), 2); + assert_eq!(resp.entries[0].ttl, 120); + assert_eq!(resp.entries[0].network_address, r"\srv1\data"); + assert_eq!(resp.entries[1].ttl, 240); + assert_eq!(resp.entries[1].server_type, 1); + assert_eq!(resp.entries[1].network_address, r"\srv2\data"); + } + + #[test] + fn resp_parse_unsupported_version() { + let mut buf = Vec::new(); + // Response header + buf.extend_from_slice(&0u16.to_le_bytes()); + buf.extend_from_slice(&1u16.to_le_bytes()); // 1 entry + buf.extend_from_slice(&0u32.to_le_bytes()); + // Entry with version 1 (unsupported) + buf.extend_from_slice(&1u16.to_le_bytes()); // version + buf.extend_from_slice(&8u16.to_le_bytes()); // size + buf.extend_from_slice(&[0u8; 4]); // padding to reach size + + let mut cursor = ReadCursor::new(&buf); + let result = RespGetDfsReferral::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("unsupported DFS referral version"), + "error was: {err}" + ); + } + + #[test] + fn resp_parse_truncated_header() { + // Only 4 bytes -- missing header_flags. + let buf = [0x00, 0x00, 0x01, 0x00]; + let mut cursor = ReadCursor::new(&buf); + assert!(RespGetDfsReferral::unpack(&mut cursor).is_err()); + } + + /// Regression: fuzz-found crash. A V2 entry that claims `entry_size = 16` + /// used to panic inside the entry-body read. The V2 body needs 18 bytes + /// (server_type+flags+proximity+ttl + three u16 offsets), but the guard + /// only ensured 16 bytes were available, so the final offset read would + /// slip past the buffer. See fuzz target + /// `fuzz_dfs_referral_response_parse` crash + /// `a6933afd5a1ccec7166d914caed66154416a2fcb`. + #[test] + fn resp_parse_v2_short_entry_returns_clean_error() { + let crash_input: [u8; 28] = [ + 0x10, 0x00, 0x01, 0x00, 0x22, 0x23, 0x00, 0x03, // header + 0x02, 0x00, 0x10, 0x00, 0x01, 0x00, 0x22, 0x23, // v2 entry start (size=16) + 0x00, 0x03, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, // body bytes + 0x00, 0x00, 0x00, 0x00, // tail + ]; + let mut cursor = ReadCursor::new(&crash_input); + let result = RespGetDfsReferral::unpack(&mut cursor); + assert!(result.is_err(), "expected clean error, got {result:?}"); + } + + // ── Test helpers ────────────────────────────────────────────────── + + /// Decode a hex string (no spaces, no 0x prefix) into bytes. + fn hex_to_bytes(hex: &str) -> Vec { + let hex: String = hex.chars().filter(|c| !c.is_whitespace()).collect(); + (0..hex.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap()) + .collect() + } + + /// Encode a string as null-terminated UTF-16LE bytes. + fn encode_null_utf16(s: &str) -> Vec { + let mut out = Vec::new(); + for cu in s.encode_utf16() { + out.extend_from_slice(&cu.to_le_bytes()); + } + out.extend_from_slice(&[0x00, 0x00]); // null terminator + out + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::arb_utf16_string; + use proptest::prelude::*; + + /// Generate a UTF-16 string without interior null (U+0000). The encoder + /// terminates with a 0x0000 code unit, so an interior null would end + /// the string early on decode. + fn arb_utf16_no_nul(max: usize) -> impl Strategy { + arb_utf16_string(max).prop_filter("string must not contain interior U+0000", |s| { + !s.contains('\0') + }) + } + + proptest! { + #[test] + fn req_get_dfs_referral_pack_unpack( + max_referral_level in any::(), + request_file_name in arb_utf16_no_nul(128), + ) { + let original = ReqGetDfsReferral { + max_referral_level, + request_file_name, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = ReqGetDfsReferral::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + } +} diff --git a/vendor/smb2/src/msg/echo.rs b/vendor/smb2/src/msg/echo.rs new file mode 100644 index 0000000..b1b27d1 --- /dev/null +++ b/vendor/smb2/src/msg/echo.rs @@ -0,0 +1,42 @@ +//! SMB2 ECHO request and response (spec sections 2.2.28, 2.2.29). +//! +//! Echo messages are used to check whether a server is processing requests. +//! Both request and response contain only a StructureSize field and a +//! reserved field, for a total of 4 bytes each. + +super::trivial_message! { + /// SMB2 ECHO request (spec section 2.2.28). + /// + /// Sent by the client to determine whether a server is processing requests. + /// Contains only StructureSize (2 bytes) and Reserved (2 bytes). + pub struct EchoRequest; +} + +super::trivial_message! { + /// SMB2 ECHO response (spec section 2.2.29). + /// + /// Sent by the server to confirm that an ECHO request was processed. + /// Contains only StructureSize (2 bytes) and Reserved (2 bytes). + pub struct EchoResponse; +} + +#[cfg(test)] +mod tests { + use super::*; + + super::super::trivial_message_tests!( + EchoRequest, + echo_request_known_bytes, + echo_request_roundtrip, + echo_request_wrong_structure_size, + echo_request_too_short + ); + + super::super::trivial_message_tests!( + EchoResponse, + echo_response_known_bytes, + echo_response_roundtrip, + echo_response_wrong_structure_size, + echo_response_too_short + ); +} diff --git a/vendor/smb2/src/msg/flush.rs b/vendor/smb2/src/msg/flush.rs new file mode 100644 index 0000000..f173c72 --- /dev/null +++ b/vendor/smb2/src/msg/flush.rs @@ -0,0 +1,254 @@ +//! SMB2 FLUSH request and response (spec sections 2.2.17, 2.2.18). +//! +//! Flush messages request that the server flush all cached file information +//! for a specified open to persistent storage. If the open refers to a +//! named pipe, the operation completes once all written data has been +//! consumed by a reader. + +use crate::error::Result; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::FileId; +use crate::Error; + +/// SMB2 FLUSH request (spec section 2.2.17). +/// +/// Sent by the client to request that the server flush cached data for a file. +/// +/// Wire layout (24 bytes): +/// - StructureSize (2 bytes): must be 24 +/// - Reserved1 (2 bytes): must be 0 +/// - Reserved2 (4 bytes): must be 0 +/// - FileId (16 bytes): persistent (8 bytes) + volatile (8 bytes) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FlushRequest { + pub file_id: FileId, +} + +impl FlushRequest { + pub const STRUCTURE_SIZE: u16 = 24; +} + +impl Pack for FlushRequest { + fn pack(&self, cursor: &mut WriteCursor) { + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // Reserved1 (2 bytes) + cursor.write_u16_le(0); + // Reserved2 (4 bytes) + cursor.write_u32_le(0); + // FileId: Persistent (8 bytes) + Volatile (8 bytes) + cursor.write_u64_le(self.file_id.persistent); + cursor.write_u64_le(self.file_id.volatile); + } +} + +impl Unpack for FlushRequest { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid FlushRequest structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // Reserved1 (2 bytes) + let _reserved1 = cursor.read_u16_le()?; + + // Reserved2 (4 bytes) + let _reserved2 = cursor.read_u32_le()?; + + // FileId: Persistent (8 bytes) + Volatile (8 bytes) + let persistent = cursor.read_u64_le()?; + let volatile = cursor.read_u64_le()?; + + Ok(FlushRequest { + file_id: FileId { + persistent, + volatile, + }, + }) + } +} + +super::trivial_message! { + /// SMB2 FLUSH response (spec section 2.2.18). + /// + /// Sent by the server to confirm that a FLUSH request was processed. + /// Contains only StructureSize (2 bytes) and Reserved (2 bytes). + pub struct FlushResponse; +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── FlushRequest tests ───────────────────────────────────────── + + #[test] + fn flush_request_pack_produces_24_bytes() { + let req = FlushRequest { + file_id: FileId::default(), + }; + let mut cursor = WriteCursor::new(); + req.pack(&mut cursor); + let bytes = cursor.into_inner(); + assert_eq!(bytes.len(), 24); + } + + #[test] + fn flush_request_known_bytes() { + let req = FlushRequest { + file_id: FileId { + persistent: 0x0102_0304_0506_0708, + volatile: 0x090A_0B0C_0D0E_0F10, + }, + }; + let mut cursor = WriteCursor::new(); + req.pack(&mut cursor); + let bytes = cursor.into_inner(); + + #[rustfmt::skip] + let expected: [u8; 24] = [ + // StructureSize = 24 + 0x18, 0x00, + // Reserved1 = 0 + 0x00, 0x00, + // Reserved2 = 0 + 0x00, 0x00, 0x00, 0x00, + // FileId.Persistent (LE) + 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, + // FileId.Volatile (LE) + 0x10, 0x0F, 0x0E, 0x0D, 0x0C, 0x0B, 0x0A, 0x09, + ]; + assert_eq!(bytes, expected); + } + + #[test] + fn flush_request_unpack_known_bytes() { + #[rustfmt::skip] + let bytes: [u8; 24] = [ + // StructureSize = 24 + 0x18, 0x00, + // Reserved1 = 0 + 0x00, 0x00, + // Reserved2 = 0 + 0x00, 0x00, 0x00, 0x00, + // FileId.Persistent = 0xDEADBEEFCAFEBABE + 0xBE, 0xBA, 0xFE, 0xCA, 0xEF, 0xBE, 0xAD, 0xDE, + // FileId.Volatile = 0x1234567890ABCDEF + 0xEF, 0xCD, 0xAB, 0x90, 0x78, 0x56, 0x34, 0x12, + ]; + let mut cursor = ReadCursor::new(&bytes); + let req = FlushRequest::unpack(&mut cursor).unwrap(); + + assert_eq!(req.file_id.persistent, 0xDEAD_BEEF_CAFE_BABE); + assert_eq!(req.file_id.volatile, 0x1234_5678_90AB_CDEF); + assert!(cursor.is_empty()); + } + + #[test] + fn flush_request_roundtrip() { + let original = FlushRequest { + file_id: FileId { + persistent: 0xAAAA_BBBB_CCCC_DDDD, + volatile: 0x1111_2222_3333_4444, + }, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = FlushRequest::unpack(&mut r).unwrap(); + assert_eq!(decoded, original); + } + + #[test] + fn flush_request_roundtrip_sentinel_file_id() { + let original = FlushRequest { + file_id: FileId::SENTINEL, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = FlushRequest::unpack(&mut r).unwrap(); + assert_eq!(decoded, original); + } + + #[test] + fn flush_request_wrong_structure_size() { + let mut bytes = [0u8; 24]; + // Wrong structure size = 4 instead of 24 + bytes[0..2].copy_from_slice(&4u16.to_le_bytes()); + let mut cursor = ReadCursor::new(&bytes); + let result = FlushRequest::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + #[test] + fn flush_request_too_short() { + let bytes = [0x18, 0x00, 0x00, 0x00]; + let mut cursor = ReadCursor::new(&bytes); + let result = FlushRequest::unpack(&mut cursor); + assert!(result.is_err()); + } + + #[test] + fn flush_request_ignores_reserved_values() { + #[rustfmt::skip] + let bytes: [u8; 24] = [ + // StructureSize = 24 + 0x18, 0x00, + // Reserved1 = 0xFFFF (non-zero, should be ignored) + 0xFF, 0xFF, + // Reserved2 = 0xFFFFFFFF (non-zero, should be ignored) + 0xFF, 0xFF, 0xFF, 0xFF, + // FileId.Persistent = 0 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // FileId.Volatile = 0 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + let mut cursor = ReadCursor::new(&bytes); + let req = FlushRequest::unpack(&mut cursor).unwrap(); + assert_eq!(req.file_id, FileId::default()); + } + + // ── FlushResponse tests ──────────────────────────────────────── + + super::super::trivial_message_tests!( + FlushResponse, + flush_response_known_bytes, + flush_response_roundtrip, + flush_response_wrong_structure_size, + flush_response_too_short + ); +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::arb_file_id; + use proptest::prelude::*; + + proptest! { + #[test] + fn flush_request_pack_unpack(file_id in arb_file_id()) { + let original = FlushRequest { file_id }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = FlushRequest::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + } +} diff --git a/vendor/smb2/src/msg/header.rs b/vendor/smb2/src/msg/header.rs new file mode 100644 index 0000000..b1f86d1 --- /dev/null +++ b/vendor/smb2/src/msg/header.rs @@ -0,0 +1,669 @@ +//! SMB2 packet header (64 bytes) and error response. +//! +//! The SMB2 header has two variants that share the same 64-byte layout: +//! - **Sync header:** bytes 32-35 = Reserved (u32), bytes 36-39 = TreeId (u32) +//! - **Async header:** bytes 32-39 = AsyncId (u64) +//! +//! The choice is determined by the `SMB2_FLAGS_ASYNC_COMMAND` bit in the Flags field. +//! +//! Reference: MS-SMB2 sections 2.2.1, 2.2.1.1, 2.2.1.2, 2.2.2. + +use crate::error::Result; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::flags::HeaderFlags; +use crate::types::status::NtStatus; +use crate::types::{Command, CreditCharge, MessageId, SessionId, TreeId}; +use crate::Error; + +/// The 4-byte protocol identifier at the start of every SMB2 message. +pub const PROTOCOL_ID: [u8; 4] = [0xFE, b'S', b'M', b'B']; + +/// SMB2 packet header (64 bytes). +/// +/// Contains both sync and async variants. The `flags` field determines +/// which interpretation of bytes 32-39 is correct. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Header { + /// Number of credits charged for this request. + pub credit_charge: CreditCharge, + /// In responses: NtStatus. In requests before SMB 3.x: Reserved. + /// In requests for SMB 3.x: ChannelSequence (u16) + Reserved (u16). + pub status: NtStatus, + /// The command code for this packet. + pub command: Command, + /// In requests: credits requested. In responses: credits granted. + pub credits: u16, + /// Flags indicating how to process the operation. + pub flags: HeaderFlags, + /// Offset to the next command in a compound chain (0 = last/only). + pub next_command: u32, + /// Unique message identifier for request/response correlation. + pub message_id: MessageId, + /// Sync-only: tree identifier. None if async. + pub tree_id: Option, + /// Async-only: async identifier. None if sync. + pub async_id: Option, + /// Session identifier. + pub session_id: SessionId, + /// 16-byte message signature. + pub signature: [u8; 16], +} + +impl Header { + pub const STRUCTURE_SIZE: u16 = 64; + + /// Total header size in bytes. + pub const SIZE: usize = 64; + + /// Create a new request header for a given command. + pub fn new_request(command: Command) -> Self { + Self { + credit_charge: CreditCharge(0), + status: NtStatus::SUCCESS, + command, + credits: 1, + flags: HeaderFlags::default(), + next_command: 0, + message_id: MessageId::default(), + tree_id: Some(TreeId::default()), + async_id: None, + session_id: SessionId::default(), + signature: [0u8; 16], + } + } + + /// Is this a response (vs request)? + pub fn is_response(&self) -> bool { + self.flags.is_response() + } +} + +impl Pack for Header { + fn pack(&self, cursor: &mut WriteCursor) { + // ProtocolId (4 bytes) + cursor.write_bytes(&PROTOCOL_ID); + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // CreditCharge (2 bytes) + cursor.write_u16_le(self.credit_charge.0); + // Status (4 bytes) + cursor.write_u32_le(self.status.0); + // Command (2 bytes) + cursor.write_u16_le(self.command.into()); + // CreditRequest/CreditResponse (2 bytes) + cursor.write_u16_le(self.credits); + // Flags (4 bytes) + cursor.write_u32_le(self.flags.bits()); + // NextCommand (4 bytes) + cursor.write_u32_le(self.next_command); + // MessageId (8 bytes) + cursor.write_u64_le(self.message_id.0); + + // Bytes 32-39: async or sync variant + if self.flags.is_async() { + // AsyncId (8 bytes) + cursor.write_u64_le(self.async_id.unwrap_or(0)); + } else { + // Reserved (4 bytes) + cursor.write_u32_le(0); + // TreeId (4 bytes) + cursor.write_u32_le(self.tree_id.map_or(0, |t| t.0)); + } + + // SessionId (8 bytes) + cursor.write_u64_le(self.session_id.0); + // Signature (16 bytes) + cursor.write_bytes(&self.signature); + } +} + +impl Unpack for Header { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + // ProtocolId (4 bytes) + let proto = cursor.read_bytes(4)?; + if proto != PROTOCOL_ID { + return Err(Error::invalid_data(format!( + "invalid SMB2 protocol ID: expected {:02X?}, got {:02X?}", + PROTOCOL_ID, proto + ))); + } + + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Header::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid SMB2 header structure size: expected {}, got {}", + Header::STRUCTURE_SIZE, + structure_size + ))); + } + + // CreditCharge (2 bytes) + let credit_charge = CreditCharge(cursor.read_u16_le()?); + + // Status (4 bytes) + let status = NtStatus(cursor.read_u32_le()?); + + // Command (2 bytes) + let command_raw = cursor.read_u16_le()?; + let command = Command::try_from(command_raw).map_err(|_| { + Error::invalid_data(format!("invalid SMB2 command code: 0x{:04X}", command_raw)) + })?; + + // CreditRequest/CreditResponse (2 bytes) + let credits = cursor.read_u16_le()?; + + // Flags (4 bytes) + let flags = HeaderFlags::new(cursor.read_u32_le()?); + + // NextCommand (4 bytes) + let next_command = cursor.read_u32_le()?; + + // MessageId (8 bytes) + let message_id = MessageId(cursor.read_u64_le()?); + + // Bytes 32-39: async or sync variant + let (tree_id, async_id) = if flags.is_async() { + let async_id = cursor.read_u64_le()?; + (None, Some(async_id)) + } else { + let _reserved = cursor.read_u32_le()?; + let tree_id = TreeId(cursor.read_u32_le()?); + (Some(tree_id), None) + }; + + // SessionId (8 bytes) + let session_id = SessionId(cursor.read_u64_le()?); + + // Signature (16 bytes) + let sig_bytes = cursor.read_bytes(16)?; + let mut signature = [0u8; 16]; + signature.copy_from_slice(sig_bytes); + + Ok(Header { + credit_charge, + status, + command, + credits, + flags, + next_command, + message_id, + tree_id, + async_id, + session_id, + signature, + }) + } +} + +/// SMB2 ERROR Response body (spec section 2.2.2). +/// +/// Sent by the server when a request fails. The structure is: +/// - StructureSize (2 bytes, must be 9) +/// - ErrorContextCount (1 byte) +/// - Reserved (1 byte) +/// - ByteCount (4 bytes) +/// - ErrorData (variable, ByteCount bytes) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ErrorResponse { + /// Number of error contexts (SMB 3.1.1 only, otherwise 0). + pub error_context_count: u8, + /// Variable-length error data. + pub error_data: Vec, +} + +impl ErrorResponse { + pub const STRUCTURE_SIZE: u16 = 9; +} + +impl Pack for ErrorResponse { + fn pack(&self, cursor: &mut WriteCursor) { + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // ErrorContextCount (1 byte) + cursor.write_u8(self.error_context_count); + // Reserved (1 byte) + cursor.write_u8(0); + // ByteCount (4 bytes) + cursor.write_u32_le(self.error_data.len() as u32); + // ErrorData (variable) + cursor.write_bytes(&self.error_data); + } +} + +impl Unpack for ErrorResponse { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid ErrorResponse structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // ErrorContextCount (1 byte) + let error_context_count = cursor.read_u8()?; + + // Reserved (1 byte) + let _reserved = cursor.read_u8()?; + + // ByteCount (4 bytes) + let byte_count = cursor.read_u32_le()? as usize; + + // ErrorData (variable) + let error_data = if byte_count > 0 { + cursor.read_bytes_bounded(byte_count)?.to_vec() + } else { + Vec::new() + }; + + Ok(ErrorResponse { + error_context_count, + error_data, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── Header tests ──────────────────────────────────────────────── + + #[test] + fn pack_request_header_produces_64_bytes_with_correct_magic() { + let header = Header::new_request(Command::Negotiate); + let mut cursor = WriteCursor::new(); + header.pack(&mut cursor); + let bytes = cursor.into_inner(); + + assert_eq!(bytes.len(), Header::SIZE); + assert_eq!(&bytes[0..4], &PROTOCOL_ID); + } + + #[test] + fn unpack_known_64_byte_buffer() { + // Build a known buffer manually: sync Negotiate request + let mut buf = [0u8; 64]; + // ProtocolId + buf[0..4].copy_from_slice(&PROTOCOL_ID); + // StructureSize = 64 + buf[4..6].copy_from_slice(&64u16.to_le_bytes()); + // CreditCharge = 1 + buf[6..8].copy_from_slice(&1u16.to_le_bytes()); + // Status = SUCCESS (0) + buf[8..12].copy_from_slice(&0u32.to_le_bytes()); + // Command = Negotiate (0) + buf[12..14].copy_from_slice(&0u16.to_le_bytes()); + // Credits = 31 + buf[14..16].copy_from_slice(&31u16.to_le_bytes()); + // Flags = 0 (sync, request) + buf[16..20].copy_from_slice(&0u32.to_le_bytes()); + // NextCommand = 0 + buf[20..24].copy_from_slice(&0u32.to_le_bytes()); + // MessageId = 42 + buf[24..32].copy_from_slice(&42u64.to_le_bytes()); + // Reserved = 0 + buf[32..36].copy_from_slice(&0u32.to_le_bytes()); + // TreeId = 7 + buf[36..40].copy_from_slice(&7u32.to_le_bytes()); + // SessionId = 0x1234 + buf[40..48].copy_from_slice(&0x1234u64.to_le_bytes()); + // Signature = all zeros + // (already zero) + + let mut cursor = ReadCursor::new(&buf); + let header = Header::unpack(&mut cursor).unwrap(); + + assert_eq!(header.credit_charge, CreditCharge(1)); + assert_eq!(header.status, NtStatus::SUCCESS); + assert_eq!(header.command, Command::Negotiate); + assert_eq!(header.credits, 31); + assert!(!header.flags.is_async()); + assert!(!header.flags.is_response()); + assert_eq!(header.next_command, 0); + assert_eq!(header.message_id, MessageId(42)); + assert_eq!(header.tree_id, Some(TreeId(7))); + assert_eq!(header.async_id, None); + assert_eq!(header.session_id, SessionId(0x1234)); + assert_eq!(header.signature, [0u8; 16]); + } + + #[test] + fn roundtrip_sync_header() { + let original = Header { + credit_charge: CreditCharge(3), + status: NtStatus::ACCESS_DENIED, + command: Command::Read, + credits: 10, + flags: { + let mut f = HeaderFlags::default(); + f.set_response(); + f + }, + next_command: 0, + message_id: MessageId(99), + tree_id: Some(TreeId(42)), + async_id: None, + session_id: SessionId(0xDEAD_BEEF), + signature: [ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, + 0x0F, 0x10, + ], + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + assert_eq!(bytes.len(), Header::SIZE); + + let mut r = ReadCursor::new(&bytes); + let decoded = Header::unpack(&mut r).unwrap(); + + assert_eq!(decoded.credit_charge, original.credit_charge); + assert_eq!(decoded.status, original.status); + assert_eq!(decoded.command, original.command); + assert_eq!(decoded.credits, original.credits); + assert_eq!(decoded.flags.bits(), original.flags.bits()); + assert_eq!(decoded.next_command, original.next_command); + assert_eq!(decoded.message_id, original.message_id); + assert_eq!(decoded.tree_id, original.tree_id); + assert_eq!(decoded.async_id, original.async_id); + assert_eq!(decoded.session_id, original.session_id); + assert_eq!(decoded.signature, original.signature); + } + + #[test] + fn wrong_magic_bytes_returns_error() { + let mut buf = [0u8; 64]; + // Wrong magic + buf[0..4].copy_from_slice(&[0xFF, b'X', b'Y', b'Z']); + buf[4..6].copy_from_slice(&64u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = Header::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("protocol ID"), "error was: {err}"); + } + + #[test] + fn wrong_structure_size_returns_error() { + let mut buf = [0u8; 64]; + buf[0..4].copy_from_slice(&PROTOCOL_ID); + // Wrong structure size + buf[4..6].copy_from_slice(&32u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = Header::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + #[test] + fn async_header_pack_unpack() { + let mut flags = HeaderFlags::default(); + flags.set_async(); + flags.set_response(); + + let original = Header { + credit_charge: CreditCharge(0), + status: NtStatus::PENDING, + command: Command::ChangeNotify, + credits: 1, + flags, + next_command: 0, + message_id: MessageId(8), + tree_id: None, + async_id: Some(0x0000_0000_0000_0008), + session_id: SessionId(0x0000_0000_0853_27D7), + signature: [0u8; 16], + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + assert_eq!(bytes.len(), Header::SIZE); + + let mut r = ReadCursor::new(&bytes); + let decoded = Header::unpack(&mut r).unwrap(); + + assert!(decoded.flags.is_async()); + assert_eq!(decoded.async_id, Some(8)); + assert_eq!(decoded.tree_id, None); + assert_eq!(decoded.command, Command::ChangeNotify); + assert_eq!(decoded.status, NtStatus::PENDING); + assert_eq!(decoded.session_id, SessionId(0x0000_0000_0853_27D7)); + } + + #[test] + fn sync_header_has_tree_id_and_no_async_id() { + let header = Header::new_request(Command::Create); + + let mut w = WriteCursor::new(); + header.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = Header::unpack(&mut r).unwrap(); + + assert!(!decoded.flags.is_async()); + assert!(decoded.tree_id.is_some()); + assert_eq!(decoded.async_id, None); + } + + #[test] + fn signature_field_preserved() { + let sig = [ + 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, + 0x99, 0x00, + ]; + let mut header = Header::new_request(Command::Echo); + header.signature = sig; + + let mut w = WriteCursor::new(); + header.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = Header::unpack(&mut r).unwrap(); + + assert_eq!(decoded.signature, sig); + } + + #[test] + fn new_request_produces_correct_defaults() { + let header = Header::new_request(Command::Write); + + assert_eq!(header.command, Command::Write); + assert_eq!(header.credit_charge, CreditCharge(0)); + assert_eq!(header.status, NtStatus::SUCCESS); + assert_eq!(header.credits, 1); + assert!(!header.flags.is_response()); + assert!(!header.flags.is_async()); + assert_eq!(header.next_command, 0); + assert_eq!(header.message_id, MessageId(0)); + assert_eq!(header.tree_id, Some(TreeId(0))); + assert_eq!(header.async_id, None); + assert_eq!(header.session_id, SessionId(0)); + assert_eq!(header.signature, [0u8; 16]); + assert!(!header.is_response()); + } + + // ── ErrorResponse tests ───────────────────────────────────────── + + #[test] + fn error_response_pack_unpack_empty() { + let original = ErrorResponse { + error_context_count: 0, + error_data: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // StructureSize(2) + ErrorContextCount(1) + Reserved(1) + ByteCount(4) = 8 + assert_eq!(bytes.len(), 8); + + let mut r = ReadCursor::new(&bytes); + let decoded = ErrorResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.error_context_count, 0); + assert!(decoded.error_data.is_empty()); + } + + #[test] + fn error_response_pack_unpack_with_data() { + let data = vec![0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE]; + let original = ErrorResponse { + error_context_count: 1, + error_data: data.clone(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // 8 bytes fixed + 6 bytes data + assert_eq!(bytes.len(), 14); + + let mut r = ReadCursor::new(&bytes); + let decoded = ErrorResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.error_context_count, 1); + assert_eq!(decoded.error_data, data); + } + + #[test] + fn error_response_roundtrip() { + let original = ErrorResponse { + error_context_count: 2, + error_data: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = ErrorResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.error_context_count, original.error_context_count); + assert_eq!(decoded.error_data, original.error_data); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::{ + arb_command, arb_credit_charge, arb_header_flags, arb_message_id, arb_nt_status, + arb_session_id, arb_small_bytes, arb_tree_id, + }; + use proptest::prelude::*; + + /// Generate a `Header` whose `flags.is_async()` matches which of + /// `tree_id`/`async_id` is set. Any other combination wouldn't round-trip + /// (pack writes one or the other based on flags, and clears the other on + /// unpack), so we never generate it. + fn arb_header() -> impl Strategy { + ( + arb_credit_charge(), + arb_nt_status(), + arb_command(), + any::(), + arb_header_flags(), + any::(), + arb_message_id(), + any::(), + arb_tree_id(), + any::(), + arb_session_id(), + any::<[u8; 16]>(), + ) + .prop_map( + |( + credit_charge, + status, + command, + credits, + raw_flags, + next_command, + message_id, + make_async, + tree_id, + async_id, + session_id, + signature, + )| { + // Force `flags.ASYNC_COMMAND` to match `make_async` so + // the pack path and the `Option` fields agree. + let flags = if make_async { + let mut f = raw_flags; + f.set(HeaderFlags::ASYNC_COMMAND); + f + } else { + let mut f = raw_flags; + f.clear(HeaderFlags::ASYNC_COMMAND); + f + }; + let (tree_id, async_id) = if make_async { + (None, Some(async_id)) + } else { + (Some(tree_id), None) + }; + Header { + credit_charge, + status, + command, + credits, + flags, + next_command, + message_id, + tree_id, + async_id, + session_id, + signature, + } + }, + ) + } + + proptest! { + #[test] + fn header_pack_unpack(header in arb_header()) { + let mut w = WriteCursor::new(); + header.pack(&mut w); + let bytes = w.into_inner(); + prop_assert_eq!(bytes.len(), Header::SIZE); + + let mut r = ReadCursor::new(&bytes); + let decoded = Header::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, header); + prop_assert!(r.is_empty()); + } + + #[test] + fn error_response_pack_unpack( + error_context_count in any::(), + error_data in arb_small_bytes(), + ) { + let original = ErrorResponse { + error_context_count, + error_data, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = ErrorResponse::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + } +} diff --git a/vendor/smb2/src/msg/ioctl.rs b/vendor/smb2/src/msg/ioctl.rs new file mode 100644 index 0000000..da046a1 --- /dev/null +++ b/vendor/smb2/src/msg/ioctl.rs @@ -0,0 +1,479 @@ +//! SMB2 IOCTL Request and Response (MS-SMB2 sections 2.2.31, 2.2.32). +//! +//! The IOCTL request sends a control code to a server, optionally with input +//! data. The response returns output data from the control operation. + +use crate::error::Result; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::FileId; +use crate::Error; + +// ── IOCTL flags ──────────────────────────────────────────────────────── + +/// The request is a file system control (FSCTL) request. +pub const SMB2_0_IOCTL_IS_FSCTL: u32 = 0x0000_0001; + +// ── Common CtlCode values ────────────────────────────────────────────── + +/// Named pipe transceive operation. +pub const FSCTL_PIPE_TRANSCEIVE: u32 = 0x0011_C017; + +/// Server-side copy chunk (read handle). +pub const FSCTL_SRV_COPYCHUNK: u32 = 0x0014_40F2; + +/// Server-side copy chunk (write handle). +pub const FSCTL_SRV_COPYCHUNK_WRITE: u32 = 0x0014_80F2; + +/// DFS referral request. +pub const FSCTL_DFS_GET_REFERRALS: u32 = 0x0006_0194; + +/// Validate negotiate info (SMB 3.x). +pub const FSCTL_VALIDATE_NEGOTIATE_INFO: u32 = 0x0014_0204; + +// ── IoctlRequest ─────────────────────────────────────────────────────── + +/// SMB2 IOCTL Request (MS-SMB2 section 2.2.31). +/// +/// Sent by the client to issue a device or file system control command. +/// The fixed part is 56 bytes (StructureSize = 57 indicates 1 byte of +/// variable data is included in the fixed size, per SMB2 convention). +/// +/// Layout: +/// - StructureSize (2 bytes, must be 57) +/// - Reserved (2 bytes) +/// - CtlCode (4 bytes) +/// - FileId (16 bytes) +/// - InputOffset (4 bytes) +/// - InputCount (4 bytes) +/// - MaxInputResponse (4 bytes) +/// - OutputOffset (4 bytes) +/// - OutputCount (4 bytes) +/// - MaxOutputResponse (4 bytes) +/// - Flags (4 bytes) +/// - Reserved2 (4 bytes) +/// - Buffer (variable, InputCount bytes) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct IoctlRequest { + /// The control code for the operation. + pub ctl_code: u32, + /// The file handle for the operation. + pub file_id: FileId, + /// Maximum number of input bytes the server can return. + pub max_input_response: u32, + /// Maximum number of output bytes the server can return. + pub max_output_response: u32, + /// Flags for the request (for example, `SMB2_0_IOCTL_IS_FSCTL`). + pub flags: u32, + /// Input data buffer. + pub input_data: Vec, +} + +impl IoctlRequest { + pub const STRUCTURE_SIZE: u16 = 57; + + /// Fixed header size before the variable buffer (56 bytes). + const FIXED_SIZE: u32 = 56; +} + +impl Pack for IoctlRequest { + fn pack(&self, cursor: &mut WriteCursor) { + let start = cursor.position(); + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // Reserved (2 bytes) + cursor.write_u16_le(0); + // CtlCode (4 bytes) + cursor.write_u32_le(self.ctl_code); + // FileId (16 bytes) + cursor.write_u64_le(self.file_id.persistent); + cursor.write_u64_le(self.file_id.volatile); + + let input_count = self.input_data.len() as u32; + // Offset is from the beginning of the SMB2 header per spec. + // `start` is the cursor position at the beginning of the body; + // in a standalone request this equals Header::SIZE, in a compound + // it includes the preceding sub-requests. + let input_offset = if input_count > 0 { + (start as u32) + Self::FIXED_SIZE + } else { + 0 + }; + + // InputOffset (4 bytes) + cursor.write_u32_le(input_offset); + // InputCount (4 bytes) + cursor.write_u32_le(input_count); + // MaxInputResponse (4 bytes) + cursor.write_u32_le(self.max_input_response); + // OutputOffset (4 bytes) -- no output data in the request + cursor.write_u32_le(0); + // OutputCount (4 bytes) -- no output data in the request + cursor.write_u32_le(0); + // MaxOutputResponse (4 bytes) + cursor.write_u32_le(self.max_output_response); + // Flags (4 bytes) + cursor.write_u32_le(self.flags); + // Reserved2 (4 bytes) + cursor.write_u32_le(0); + // Buffer (variable) + cursor.write_bytes(&self.input_data); + } +} + +impl Unpack for IoctlRequest { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid IoctlRequest structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + let _reserved = cursor.read_u16_le()?; + let ctl_code = cursor.read_u32_le()?; + let persistent = cursor.read_u64_le()?; + let volatile = cursor.read_u64_le()?; + let _input_offset = cursor.read_u32_le()?; + let input_count = cursor.read_u32_le()?; + let max_input_response = cursor.read_u32_le()?; + let _output_offset = cursor.read_u32_le()?; + let _output_count = cursor.read_u32_le()?; + let max_output_response = cursor.read_u32_le()?; + let flags = cursor.read_u32_le()?; + let _reserved2 = cursor.read_u32_le()?; + + let input_data = if input_count > 0 { + cursor.read_bytes_bounded(input_count as usize)?.to_vec() + } else { + Vec::new() + }; + + Ok(IoctlRequest { + ctl_code, + file_id: FileId { + persistent, + volatile, + }, + max_input_response, + max_output_response, + flags, + input_data, + }) + } +} + +// ── IoctlResponse ────────────────────────────────────────────────────── + +/// SMB2 IOCTL Response (MS-SMB2 section 2.2.32). +/// +/// Sent by the server to return the results of an IOCTL operation. +/// +/// Layout: +/// - StructureSize (2 bytes, must be 49) +/// - Reserved (2 bytes) +/// - CtlCode (4 bytes) +/// - FileId (16 bytes) +/// - InputOffset (4 bytes) +/// - InputCount (4 bytes) +/// - OutputOffset (4 bytes) +/// - OutputCount (4 bytes) +/// - Flags (4 bytes) +/// - Reserved2 (4 bytes) +/// - Buffer (variable -- may contain both input and output data) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct IoctlResponse { + /// The control code echoed from the request. + pub ctl_code: u32, + /// The file handle echoed from the request. + pub file_id: FileId, + /// Flags echoed from the request. + pub flags: u32, + /// Output data buffer returned by the server. + pub output_data: Vec, +} + +impl IoctlResponse { + pub const STRUCTURE_SIZE: u16 = 49; + + /// Fixed header size before the variable buffer (48 bytes). + const FIXED_SIZE: u32 = 48; +} + +impl Pack for IoctlResponse { + fn pack(&self, cursor: &mut WriteCursor) { + let start = cursor.position(); + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // Reserved (2 bytes) + cursor.write_u16_le(0); + // CtlCode (4 bytes) + cursor.write_u32_le(self.ctl_code); + // FileId (16 bytes) + cursor.write_u64_le(self.file_id.persistent); + cursor.write_u64_le(self.file_id.volatile); + + let output_count = self.output_data.len() as u32; + // Offset is from the beginning of the SMB2 header per spec. + let output_offset = if output_count > 0 { + (start as u32) + Self::FIXED_SIZE + } else { + 0 + }; + + // InputOffset (4 bytes) -- no input data in the response + cursor.write_u32_le(0); + // InputCount (4 bytes) + cursor.write_u32_le(0); + // OutputOffset (4 bytes) + cursor.write_u32_le(output_offset); + // OutputCount (4 bytes) + cursor.write_u32_le(output_count); + // Flags (4 bytes) + cursor.write_u32_le(self.flags); + // Reserved2 (4 bytes) + cursor.write_u32_le(0); + // Buffer (variable) + cursor.write_bytes(&self.output_data); + } +} + +impl Unpack for IoctlResponse { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid IoctlResponse structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + let _reserved = cursor.read_u16_le()?; + let ctl_code = cursor.read_u32_le()?; + let persistent = cursor.read_u64_le()?; + let volatile = cursor.read_u64_le()?; + let _input_offset = cursor.read_u32_le()?; + let _input_count = cursor.read_u32_le()?; + let _output_offset = cursor.read_u32_le()?; + let output_count = cursor.read_u32_le()?; + let flags = cursor.read_u32_le()?; + let _reserved2 = cursor.read_u32_le()?; + + let output_data = if output_count > 0 { + cursor.read_bytes_bounded(output_count as usize)?.to_vec() + } else { + Vec::new() + }; + + Ok(IoctlResponse { + ctl_code, + file_id: FileId { + persistent, + volatile, + }, + flags, + output_data, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── IoctlRequest tests ──────────────────────────────────────────── + + #[test] + fn ioctl_request_roundtrip_with_input_data() { + let original = IoctlRequest { + ctl_code: FSCTL_PIPE_TRANSCEIVE, + file_id: FileId { + persistent: 0x1122_3344_5566_7788, + volatile: 0xAABB_CCDD_EEFF_0011, + }, + max_input_response: 0, + max_output_response: 4096, + flags: SMB2_0_IOCTL_IS_FSCTL, + input_data: vec![0x01, 0x02, 0x03, 0x04, 0x05], + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // Fixed 56 bytes + 5 bytes input data + assert_eq!(bytes.len(), 61); + + let mut r = ReadCursor::new(&bytes); + let decoded = IoctlRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.ctl_code, FSCTL_PIPE_TRANSCEIVE); + assert_eq!(decoded.file_id, original.file_id); + assert_eq!(decoded.max_input_response, 0); + assert_eq!(decoded.max_output_response, 4096); + assert_eq!(decoded.flags, SMB2_0_IOCTL_IS_FSCTL); + assert_eq!(decoded.input_data, vec![0x01, 0x02, 0x03, 0x04, 0x05]); + } + + #[test] + fn ioctl_request_roundtrip_no_input_data() { + let original = IoctlRequest { + ctl_code: FSCTL_VALIDATE_NEGOTIATE_INFO, + file_id: FileId::SENTINEL, + max_input_response: 0, + max_output_response: 256, + flags: SMB2_0_IOCTL_IS_FSCTL, + input_data: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(bytes.len(), 56); + + let mut r = ReadCursor::new(&bytes); + let decoded = IoctlRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.ctl_code, FSCTL_VALIDATE_NEGOTIATE_INFO); + assert_eq!(decoded.file_id, FileId::SENTINEL); + assert!(decoded.input_data.is_empty()); + } + + #[test] + fn ioctl_request_wrong_structure_size() { + let mut buf = [0u8; 56]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = IoctlRequest::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + // ── IoctlResponse tests ─────────────────────────────────────────── + + #[test] + fn ioctl_response_roundtrip_with_output_data() { + let original = IoctlResponse { + ctl_code: FSCTL_PIPE_TRANSCEIVE, + file_id: FileId { + persistent: 0x42, + volatile: 0x99, + }, + flags: SMB2_0_IOCTL_IS_FSCTL, + output_data: vec![0xDE, 0xAD, 0xBE, 0xEF], + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // Fixed 48 bytes + 4 bytes output data + assert_eq!(bytes.len(), 52); + + let mut r = ReadCursor::new(&bytes); + let decoded = IoctlResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.ctl_code, FSCTL_PIPE_TRANSCEIVE); + assert_eq!(decoded.file_id, original.file_id); + assert_eq!(decoded.flags, SMB2_0_IOCTL_IS_FSCTL); + assert_eq!(decoded.output_data, vec![0xDE, 0xAD, 0xBE, 0xEF]); + } + + #[test] + fn ioctl_response_roundtrip_no_output_data() { + let original = IoctlResponse { + ctl_code: FSCTL_SRV_COPYCHUNK, + file_id: FileId::default(), + flags: 0, + output_data: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(bytes.len(), 48); + + let mut r = ReadCursor::new(&bytes); + let decoded = IoctlResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.ctl_code, FSCTL_SRV_COPYCHUNK); + assert!(decoded.output_data.is_empty()); + } + + #[test] + fn ioctl_response_wrong_structure_size() { + let mut buf = [0u8; 48]; + buf[0..2].copy_from_slice(&42u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = IoctlResponse::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::{arb_bytes, arb_file_id}; + use proptest::prelude::*; + + proptest! { + #[test] + fn ioctl_request_pack_unpack( + ctl_code in any::(), + file_id in arb_file_id(), + max_input_response in any::(), + max_output_response in any::(), + flags in any::(), + input_data in arb_bytes(), + ) { + let original = IoctlRequest { + ctl_code, + file_id, + max_input_response, + max_output_response, + flags, + input_data, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = IoctlRequest::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + + #[test] + fn ioctl_response_pack_unpack( + ctl_code in any::(), + file_id in arb_file_id(), + flags in any::(), + output_data in arb_bytes(), + ) { + let original = IoctlResponse { + ctl_code, + file_id, + flags, + output_data, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = IoctlResponse::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + } +} diff --git a/vendor/smb2/src/msg/lock.rs b/vendor/smb2/src/msg/lock.rs new file mode 100644 index 0000000..b261d56 --- /dev/null +++ b/vendor/smb2/src/msg/lock.rs @@ -0,0 +1,445 @@ +//! SMB2 LOCK Request and Response (MS-SMB2 sections 2.2.26, 2.2.27). +//! +//! The LOCK request locks or unlocks byte ranges within a file. +//! Multiple ranges can be locked/unlocked in a single request. + +use crate::error::Result; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::FileId; +use crate::Error; + +/// Lock flag: shared lock (allows other readers). +pub const SMB2_LOCKFLAG_SHARED_LOCK: u32 = 0x0000_0001; + +/// Lock flag: exclusive lock (no other readers or writers). +pub const SMB2_LOCKFLAG_EXCLUSIVE_LOCK: u32 = 0x0000_0002; + +/// Lock flag: unlock a previously locked range. +pub const SMB2_LOCKFLAG_UNLOCK: u32 = 0x0000_0004; + +/// Lock flag: fail immediately if the lock conflicts. +pub const SMB2_LOCKFLAG_FAIL_IMMEDIATELY: u32 = 0x0000_0010; + +/// A single lock element describing a byte range to lock or unlock. +/// +/// Each element is 24 bytes on the wire: +/// - Offset (8 bytes) +/// - Length (8 bytes) +/// - Flags (4 bytes) +/// - Reserved (4 bytes) +/// +/// Reference: MS-SMB2 section 2.2.26.1. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LockElement { + /// Starting offset in bytes from where the range begins. + pub offset: u64, + /// Length of the range in bytes. + pub length: u64, + /// Flags describing how the range is locked or unlocked. + pub flags: u32, +} + +impl LockElement { + /// Wire size of a single lock element. + pub const SIZE: usize = 24; +} + +impl Pack for LockElement { + fn pack(&self, cursor: &mut WriteCursor) { + cursor.write_u64_le(self.offset); + cursor.write_u64_le(self.length); + cursor.write_u32_le(self.flags); + cursor.write_u32_le(0); // Reserved + } +} + +impl Unpack for LockElement { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let offset = cursor.read_u64_le()?; + let length = cursor.read_u64_le()?; + let flags = cursor.read_u32_le()?; + let _reserved = cursor.read_u32_le()?; + + Ok(LockElement { + offset, + length, + flags, + }) + } +} + +/// SMB2 LOCK Request (MS-SMB2 section 2.2.26). +/// +/// Sent by the client to lock or unlock byte ranges. The fixed portion +/// is 48 bytes (StructureSize=48, which includes one `LockElement`): +/// - StructureSize (2 bytes, must be 48) +/// - LockCount (2 bytes) +/// - LockSequenceNumber/Index (4 bytes) +/// - FileId (16 bytes) +/// - Locks (variable, LockCount x 24 bytes each) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LockRequest { + /// Combined lock sequence number (4 bits) and index (28 bits). + /// In SMB 2.0.2 this field is reserved (0). + pub lock_sequence: u32, + /// File handle to lock ranges on. + pub file_id: FileId, + /// Array of lock elements. Must contain at least one element. + pub locks: Vec, +} + +impl LockRequest { + pub const STRUCTURE_SIZE: u16 = 48; +} + +impl Pack for LockRequest { + fn pack(&self, cursor: &mut WriteCursor) { + cursor.write_u16_le(Self::STRUCTURE_SIZE); + cursor.write_u16_le(self.locks.len() as u16); // LockCount + cursor.write_u32_le(self.lock_sequence); + cursor.write_u64_le(self.file_id.persistent); + cursor.write_u64_le(self.file_id.volatile); + for lock in &self.locks { + lock.pack(cursor); + } + } +} + +impl Unpack for LockRequest { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid LockRequest structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + let lock_count = cursor.read_u16_le()?; + let lock_sequence = cursor.read_u32_le()?; + let persistent = cursor.read_u64_le()?; + let volatile = cursor.read_u64_le()?; + + let mut locks = Vec::with_capacity(lock_count as usize); + for _ in 0..lock_count { + locks.push(LockElement::unpack(cursor)?); + } + + Ok(LockRequest { + lock_sequence, + file_id: FileId { + persistent, + volatile, + }, + locks, + }) + } +} + +/// SMB2 LOCK Response (MS-SMB2 section 2.2.27). +/// +/// Sent by the server to confirm a lock operation. The structure is 4 bytes: +/// - StructureSize (2 bytes, must be 4) +/// - Reserved (2 bytes) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LockResponse; + +impl LockResponse { + pub const STRUCTURE_SIZE: u16 = 4; +} + +impl Pack for LockResponse { + fn pack(&self, cursor: &mut WriteCursor) { + cursor.write_u16_le(Self::STRUCTURE_SIZE); + cursor.write_u16_le(0); // Reserved + } +} + +impl Unpack for LockResponse { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid LockResponse structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + let _reserved = cursor.read_u16_le()?; + + Ok(LockResponse) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── LockElement tests ────────────────────────────────────────── + + #[test] + fn lock_element_roundtrip() { + let original = LockElement { + offset: 0x1000, + length: 0x2000, + flags: SMB2_LOCKFLAG_EXCLUSIVE_LOCK | SMB2_LOCKFLAG_FAIL_IMMEDIATELY, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(bytes.len(), LockElement::SIZE); + + let mut r = ReadCursor::new(&bytes); + let decoded = LockElement::unpack(&mut r).unwrap(); + + assert_eq!(decoded, original); + } + + // ── LockRequest tests ────────────────────────────────────────── + + #[test] + fn lock_request_single_lock_roundtrip() { + let original = LockRequest { + lock_sequence: 0, + file_id: FileId { + persistent: 0xDEAD, + volatile: 0xBEEF, + }, + locks: vec![LockElement { + offset: 0, + length: 4096, + flags: SMB2_LOCKFLAG_SHARED_LOCK, + }], + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // Fixed: 24 bytes + 1 lock element (24 bytes) = 48 bytes + assert_eq!(bytes.len(), 48); + + let mut r = ReadCursor::new(&bytes); + let decoded = LockRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.lock_sequence, original.lock_sequence); + assert_eq!(decoded.file_id, original.file_id); + assert_eq!(decoded.locks.len(), 1); + assert_eq!(decoded.locks[0], original.locks[0]); + } + + #[test] + fn lock_request_multiple_locks_roundtrip() { + let original = LockRequest { + lock_sequence: 0x1234_5678, + file_id: FileId { + persistent: 0x1111, + volatile: 0x2222, + }, + locks: vec![ + LockElement { + offset: 0, + length: 1024, + flags: SMB2_LOCKFLAG_EXCLUSIVE_LOCK | SMB2_LOCKFLAG_FAIL_IMMEDIATELY, + }, + LockElement { + offset: 4096, + length: 2048, + flags: SMB2_LOCKFLAG_SHARED_LOCK, + }, + LockElement { + offset: 8192, + length: 512, + flags: SMB2_LOCKFLAG_UNLOCK, + }, + ], + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // Fixed: 24 bytes + 3 lock elements (3 * 24) = 96 bytes + assert_eq!(bytes.len(), 96); + + let mut r = ReadCursor::new(&bytes); + let decoded = LockRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.lock_sequence, original.lock_sequence); + assert_eq!(decoded.file_id, original.file_id); + assert_eq!(decoded.locks.len(), 3); + assert_eq!(decoded.locks[0], original.locks[0]); + assert_eq!(decoded.locks[1], original.locks[1]); + assert_eq!(decoded.locks[2], original.locks[2]); + } + + #[test] + fn lock_request_known_bytes() { + let mut buf = Vec::new(); + // StructureSize = 48 + buf.extend_from_slice(&48u16.to_le_bytes()); + // LockCount = 1 + buf.extend_from_slice(&1u16.to_le_bytes()); + // LockSequence = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // FileId persistent = 0x10 + buf.extend_from_slice(&0x10u64.to_le_bytes()); + // FileId volatile = 0x20 + buf.extend_from_slice(&0x20u64.to_le_bytes()); + // LockElement: offset = 0, length = 100, flags = SHARED (1), reserved = 0 + buf.extend_from_slice(&0u64.to_le_bytes()); + buf.extend_from_slice(&100u64.to_le_bytes()); + buf.extend_from_slice(&1u32.to_le_bytes()); + buf.extend_from_slice(&0u32.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let req = LockRequest::unpack(&mut cursor).unwrap(); + + assert_eq!(req.file_id.persistent, 0x10); + assert_eq!(req.file_id.volatile, 0x20); + assert_eq!(req.locks.len(), 1); + assert_eq!(req.locks[0].offset, 0); + assert_eq!(req.locks[0].length, 100); + assert_eq!(req.locks[0].flags, SMB2_LOCKFLAG_SHARED_LOCK); + } + + #[test] + fn lock_request_wrong_structure_size() { + let mut buf = [0u8; 48]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = LockRequest::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + // ── LockResponse tests ───────────────────────────────────────── + + #[test] + fn lock_response_roundtrip() { + let original = LockResponse; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // 2 + 2 = 4 bytes + assert_eq!(bytes.len(), 4); + + let mut r = ReadCursor::new(&bytes); + let _decoded = LockResponse::unpack(&mut r).unwrap(); + } + + #[test] + fn lock_response_known_bytes() { + let mut buf = [0u8; 4]; + buf[0..2].copy_from_slice(&4u16.to_le_bytes()); + buf[2..4].copy_from_slice(&0u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let _resp = LockResponse::unpack(&mut cursor).unwrap(); + } + + #[test] + fn lock_response_wrong_structure_size() { + let mut buf = [0u8; 4]; + buf[0..2].copy_from_slice(&8u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = LockResponse::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + #[test] + fn lock_flags_combinations() { + // Verify flag constants are distinct and correct + assert_eq!(SMB2_LOCKFLAG_SHARED_LOCK, 0x01); + assert_eq!(SMB2_LOCKFLAG_EXCLUSIVE_LOCK, 0x02); + assert_eq!(SMB2_LOCKFLAG_UNLOCK, 0x04); + assert_eq!(SMB2_LOCKFLAG_FAIL_IMMEDIATELY, 0x10); + + // Shared + fail immediately + let combined = SMB2_LOCKFLAG_SHARED_LOCK | SMB2_LOCKFLAG_FAIL_IMMEDIATELY; + assert_eq!(combined, 0x11); + + // Exclusive + fail immediately + let combined = SMB2_LOCKFLAG_EXCLUSIVE_LOCK | SMB2_LOCKFLAG_FAIL_IMMEDIATELY; + assert_eq!(combined, 0x12); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::arb_file_id; + use proptest::prelude::*; + + fn arb_lock_element() -> impl Strategy { + (any::(), any::(), any::()).prop_map(|(offset, length, flags)| LockElement { + offset, + length, + flags, + }) + } + + proptest! { + #[test] + fn lock_element_pack_unpack(elem in arb_lock_element()) { + let mut w = WriteCursor::new(); + elem.pack(&mut w); + let bytes = w.into_inner(); + prop_assert_eq!(bytes.len(), LockElement::SIZE); + + let mut r = ReadCursor::new(&bytes); + let decoded = LockElement::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, elem); + prop_assert!(r.is_empty()); + } + + #[test] + fn lock_request_pack_unpack( + lock_sequence in any::(), + file_id in arb_file_id(), + // MS-SMB2: LockCount must be >= 1, so generate 1..=8. + locks in prop::collection::vec(arb_lock_element(), 1..=8), + ) { + let original = LockRequest { + lock_sequence, + file_id, + locks, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = LockRequest::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + + #[test] + fn lock_response_pack_unpack(_ in any::()) { + // LockResponse is a unit struct; there's nothing to vary, but + // running it through the proptest harness keeps the coverage + // map uniform. + let original = LockResponse; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = LockResponse::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + } +} diff --git a/vendor/smb2/src/msg/logoff.rs b/vendor/smb2/src/msg/logoff.rs new file mode 100644 index 0000000..1ed5c43 --- /dev/null +++ b/vendor/smb2/src/msg/logoff.rs @@ -0,0 +1,42 @@ +//! SMB2 LOGOFF request and response (spec sections 2.2.7, 2.2.8). +//! +//! Logoff messages request and confirm termination of a session. +//! Both request and response contain only a StructureSize field and a +//! reserved field, for a total of 4 bytes each. + +super::trivial_message! { + /// SMB2 LOGOFF request (spec section 2.2.7). + /// + /// Sent by the client to request termination of a particular session. + /// Contains only StructureSize (2 bytes) and Reserved (2 bytes). + pub struct LogoffRequest; +} + +super::trivial_message! { + /// SMB2 LOGOFF response (spec section 2.2.8). + /// + /// Sent by the server to confirm that a LOGOFF request was processed. + /// Contains only StructureSize (2 bytes) and Reserved (2 bytes). + pub struct LogoffResponse; +} + +#[cfg(test)] +mod tests { + use super::*; + + super::super::trivial_message_tests!( + LogoffRequest, + logoff_request_known_bytes, + logoff_request_roundtrip, + logoff_request_wrong_structure_size, + logoff_request_too_short + ); + + super::super::trivial_message_tests!( + LogoffResponse, + logoff_response_known_bytes, + logoff_response_roundtrip, + logoff_response_wrong_structure_size, + logoff_response_too_short + ); +} diff --git a/vendor/smb2/src/msg/mod.rs b/vendor/smb2/src/msg/mod.rs new file mode 100644 index 0000000..6da4316 --- /dev/null +++ b/vendor/smb2/src/msg/mod.rs @@ -0,0 +1,152 @@ +//! Wire format message structs for SMB2/3. +//! +//! Each sub-module corresponds to one SMB2 command type with its +//! request and response structures. +//! +//! Most users don't need this module directly -- use [`SmbClient`](crate::SmbClient) +//! for high-level file operations. + +// Wire format internals, comments would be pretty redundant. Public API docs are enforced at the crate level. +#![allow(missing_docs)] + +/// Generates a trivial 4-byte SMB2 stub message (StructureSize + Reserved). +/// +/// Many SMB2 commands (echo, cancel, logoff, tree_disconnect) have request +/// and/or response structs that are identical: 2-byte StructureSize (always 4) +/// plus 2-byte Reserved. This macro generates the struct definition and its +/// `Pack`/`Unpack` impls from a single declaration. +/// +/// # Usage +/// +/// ```ignore +/// trivial_message! { +/// /// Doc comment for the struct. +/// pub struct EchoRequest; +/// } +/// ``` +macro_rules! trivial_message { + ( + $(#[$meta:meta])* + pub struct $name:ident; + ) => { + $(#[$meta])* + #[derive(Debug, Clone, PartialEq, Eq)] + pub struct $name; + + impl $name { + pub const STRUCTURE_SIZE: u16 = 4; + } + + impl crate::pack::Pack for $name { + fn pack(&self, cursor: &mut crate::pack::WriteCursor) { + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // Reserved (2 bytes) + cursor.write_u16_le(0); + } + } + + impl crate::pack::Unpack for $name { + fn unpack(cursor: &mut crate::pack::ReadCursor<'_>) -> crate::error::Result { + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(crate::Error::invalid_data(format!( + "invalid {} structure size: expected {}, got {}", + stringify!($name), + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // Reserved (2 bytes) + let _reserved = cursor.read_u16_le()?; + + Ok($name) + } + } + }; +} + +pub(crate) use trivial_message; + +/// Generates a minimal test suite for a trivial 4-byte message type. +/// +/// Tests: known bytes, pack-unpack roundtrip, wrong structure size, and +/// truncated input. These four tests cover all interesting behavior for +/// types produced by [`trivial_message!`]. +#[cfg(test)] +macro_rules! trivial_message_tests { + ($type:ident, $known:ident, $roundtrip:ident, $wrong_size:ident, $short:ident) => { + #[test] + fn $known() { + let msg = $type; + let mut cursor = crate::pack::WriteCursor::new(); + crate::pack::Pack::pack(&msg, &mut cursor); + let bytes = cursor.into_inner(); + // StructureSize=4 (LE), Reserved=0 + assert_eq!(bytes, [0x04, 0x00, 0x00, 0x00]); + } + + #[test] + fn $roundtrip() { + let original = $type; + let mut w = crate::pack::WriteCursor::new(); + crate::pack::Pack::pack(&original, &mut w); + let bytes = w.into_inner(); + + let mut r = crate::pack::ReadCursor::new(&bytes); + let decoded = <$type as crate::pack::Unpack>::unpack(&mut r).unwrap(); + assert_eq!(decoded, original); + } + + #[test] + fn $wrong_size() { + let bytes = [0x08, 0x00, 0x00, 0x00]; + let mut cursor = crate::pack::ReadCursor::new(&bytes); + let result = <$type as crate::pack::Unpack>::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + #[test] + fn $short() { + let bytes = [0x04, 0x00]; + let mut cursor = crate::pack::ReadCursor::new(&bytes); + let result = <$type as crate::pack::Unpack>::unpack(&mut cursor); + assert!(result.is_err()); + } + }; +} + +#[cfg(test)] +pub(crate) use trivial_message_tests; + +#[cfg(test)] +pub(crate) mod roundtrip_strategies; + +pub mod cancel; +pub mod change_notify; +pub mod close; +pub mod create; +pub mod dfs; +pub mod echo; +pub mod flush; +pub mod header; +pub mod ioctl; +pub mod lock; +pub mod logoff; +pub mod negotiate; +pub mod oplock_break; +pub mod query_directory; +pub mod query_info; +pub mod read; +pub mod session_setup; +pub mod set_info; +pub mod transform; +pub mod tree_connect; +pub mod tree_disconnect; +pub mod write; + +pub use header::{ErrorResponse, Header, PROTOCOL_ID}; diff --git a/vendor/smb2/src/msg/negotiate.rs b/vendor/smb2/src/msg/negotiate.rs new file mode 100644 index 0000000..1f18b9c --- /dev/null +++ b/vendor/smb2/src/msg/negotiate.rs @@ -0,0 +1,1228 @@ +//! SMB2 NEGOTIATE request and response (spec sections 2.2.3, 2.2.4) +//! and negotiate context structures (spec section 2.2.3.1). +//! +//! Negotiate is the first exchange between client and server. The client +//! advertises which dialects and capabilities it supports, and the server +//! picks the highest mutually supported dialect and returns its own +//! capabilities. +//! +//! For SMB 3.1.1 (dialect 0x0311), both request and response carry a +//! variable-length list of negotiate contexts that negotiate features +//! such as preauthentication integrity, encryption, compression, and +//! signing algorithms. + +use crate::error::Result; +use crate::msg::header::Header; +use crate::pack::{Guid, Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::flags::{Capabilities, SecurityMode}; +use crate::types::Dialect; +use crate::Error; + +// ── Negotiate context type constants ─────────────────────────────────── + +/// Preauthentication integrity capabilities context type. +pub const NEGOTIATE_CONTEXT_PREAUTH_INTEGRITY: u16 = 0x0001; +/// Encryption capabilities context type. +pub const NEGOTIATE_CONTEXT_ENCRYPTION: u16 = 0x0002; +/// Compression capabilities context type. +pub const NEGOTIATE_CONTEXT_COMPRESSION: u16 = 0x0003; +/// Signing capabilities context type. +pub const NEGOTIATE_CONTEXT_SIGNING: u16 = 0x0008; + +// ── Hash algorithm IDs (2.2.3.1.1) ──────────────────────────────────── + +/// SHA-512 hash algorithm for preauthentication integrity. +pub const HASH_ALGORITHM_SHA512: u16 = 0x0001; + +// ── Encryption cipher IDs (2.2.3.1.2) ───────────────────────────────── + +/// AES-128-CCM cipher. +pub const CIPHER_AES_128_CCM: u16 = 0x0001; +/// AES-128-GCM cipher. +pub const CIPHER_AES_128_GCM: u16 = 0x0002; +/// AES-256-CCM cipher. +pub const CIPHER_AES_256_CCM: u16 = 0x0003; +/// AES-256-GCM cipher. +pub const CIPHER_AES_256_GCM: u16 = 0x0004; + +// ── Signing algorithm IDs (2.2.3.1.7) ───────────────────────────────── + +/// HMAC-SHA256 signing algorithm. +pub const SIGNING_HMAC_SHA256: u16 = 0x0000; +/// AES-CMAC signing algorithm. +pub const SIGNING_AES_CMAC: u16 = 0x0001; +/// AES-GMAC signing algorithm. +pub const SIGNING_AES_GMAC: u16 = 0x0002; + +// ── Compression algorithm IDs (2.2.3.1.3) ───────────────────────────── + +/// No compression. +pub const COMPRESSION_NONE: u16 = 0x0000; +/// LZNT1 compression algorithm. +pub const COMPRESSION_LZNT1: u16 = 0x0001; +/// LZ77 compression algorithm. +pub const COMPRESSION_LZ77: u16 = 0x0002; +/// LZ77+Huffman compression algorithm. +pub const COMPRESSION_LZ77_HUFFMAN: u16 = 0x0003; +/// Pattern scanning algorithm. +pub const COMPRESSION_PATTERN_V1: u16 = 0x0004; +/// LZ4 compression algorithm. +pub const COMPRESSION_LZ4: u16 = 0x0005; + +// ── Compression capability flags ─────────────────────────────────────── + +/// Chained compression is not supported. +pub const COMPRESSION_FLAG_NONE: u32 = 0x0000_0000; +/// Chained compression is supported. +pub const COMPRESSION_FLAG_CHAINED: u32 = 0x0000_0001; + +// ── NegotiateContext ─────────────────────────────────────────────────── + +/// A single negotiate context entry (spec section 2.2.3.1). +/// +/// Each context has a type, reserved field, and type-specific data. +/// The four most important types are represented as dedicated variants; +/// unknown types are stored as raw bytes. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum NegotiateContext { + /// Preauthentication integrity capabilities (type 0x0001). + PreauthIntegrity { + /// Supported hash algorithm IDs. + hash_algorithms: Vec, + /// Salt value. + salt: Vec, + }, + /// Encryption capabilities (type 0x0002). + Encryption { + /// Supported cipher IDs in preference order. + ciphers: Vec, + }, + /// Compression capabilities (type 0x0003). + Compression { + /// Compression capability flags. + flags: u32, + /// Supported compression algorithm IDs in preference order. + algorithms: Vec, + }, + /// Signing capabilities (type 0x0008). + Signing { + /// Supported signing algorithm IDs in preference order. + algorithms: Vec, + }, + /// Unknown or unsupported context type, stored as raw bytes. + Unknown { + /// The context type identifier. + context_type: u16, + /// The raw data bytes. + data: Vec, + }, +} + +/// Pack a single negotiate context's data (without the header). +fn pack_context_data(ctx: &NegotiateContext, cursor: &mut WriteCursor) { + match ctx { + NegotiateContext::PreauthIntegrity { + hash_algorithms, + salt, + } => { + // HashAlgorithmCount (2) + cursor.write_u16_le(hash_algorithms.len() as u16); + // SaltLength (2) + cursor.write_u16_le(salt.len() as u16); + // HashAlgorithms (variable) + for &alg in hash_algorithms { + cursor.write_u16_le(alg); + } + // Salt (variable) + cursor.write_bytes(salt); + } + NegotiateContext::Encryption { ciphers } => { + // CipherCount (2) + cursor.write_u16_le(ciphers.len() as u16); + // Ciphers (variable) + for &c in ciphers { + cursor.write_u16_le(c); + } + } + NegotiateContext::Compression { flags, algorithms } => { + // CompressionAlgorithmCount (2) + cursor.write_u16_le(algorithms.len() as u16); + // Padding (2) + cursor.write_u16_le(0); + // Flags (4) + cursor.write_u32_le(*flags); + // CompressionAlgorithms (variable) + for &a in algorithms { + cursor.write_u16_le(a); + } + } + NegotiateContext::Signing { algorithms } => { + // SigningAlgorithmCount (2) + cursor.write_u16_le(algorithms.len() as u16); + // SigningAlgorithms (variable) + for &a in algorithms { + cursor.write_u16_le(a); + } + } + NegotiateContext::Unknown { data, .. } => { + cursor.write_bytes(data); + } + } +} + +/// Return the context type ID for a negotiate context. +fn context_type_id(ctx: &NegotiateContext) -> u16 { + match ctx { + NegotiateContext::PreauthIntegrity { .. } => NEGOTIATE_CONTEXT_PREAUTH_INTEGRITY, + NegotiateContext::Encryption { .. } => NEGOTIATE_CONTEXT_ENCRYPTION, + NegotiateContext::Compression { .. } => NEGOTIATE_CONTEXT_COMPRESSION, + NegotiateContext::Signing { .. } => NEGOTIATE_CONTEXT_SIGNING, + NegotiateContext::Unknown { context_type, .. } => *context_type, + } +} + +/// Compute the data length of a single negotiate context (without the 8-byte header). +fn context_data_len(ctx: &NegotiateContext) -> usize { + match ctx { + NegotiateContext::PreauthIntegrity { + hash_algorithms, + salt, + } => 2 + 2 + hash_algorithms.len() * 2 + salt.len(), + NegotiateContext::Encryption { ciphers } => 2 + ciphers.len() * 2, + NegotiateContext::Compression { algorithms, .. } => 2 + 2 + 4 + algorithms.len() * 2, + NegotiateContext::Signing { algorithms } => 2 + algorithms.len() * 2, + NegotiateContext::Unknown { data, .. } => data.len(), + } +} + +/// Pack a list of negotiate contexts, each preceded by its header and +/// 8-byte aligned. +fn pack_negotiate_contexts(contexts: &[NegotiateContext], cursor: &mut WriteCursor) { + for (i, ctx) in contexts.iter().enumerate() { + // Pad to 8-byte alignment before each context (except the first, + // which should already be aligned by the caller). + if i > 0 { + cursor.align_to(8); + } + + // ContextType (2) + cursor.write_u16_le(context_type_id(ctx)); + // DataLength (2) + cursor.write_u16_le(context_data_len(ctx) as u16); + // Reserved (4) + cursor.write_u32_le(0); + // Data (variable) + pack_context_data(ctx, cursor); + } +} + +/// Unpack a single negotiate context from the cursor. +fn unpack_negotiate_context(cursor: &mut ReadCursor<'_>) -> Result { + // ContextType (2) + let context_type = cursor.read_u16_le()?; + // DataLength (2) + let data_length = cursor.read_u16_le()? as usize; + // Reserved (4) + let _reserved = cursor.read_u32_le()?; + + match context_type { + NEGOTIATE_CONTEXT_PREAUTH_INTEGRITY => { + let hash_count = cursor.read_u16_le()? as usize; + let salt_length = cursor.read_u16_le()? as usize; + let mut hash_algorithms = Vec::with_capacity(hash_count); + for _ in 0..hash_count { + hash_algorithms.push(cursor.read_u16_le()?); + } + let salt = cursor.read_bytes_bounded(salt_length)?.to_vec(); + Ok(NegotiateContext::PreauthIntegrity { + hash_algorithms, + salt, + }) + } + NEGOTIATE_CONTEXT_ENCRYPTION => { + let cipher_count = cursor.read_u16_le()? as usize; + let mut ciphers = Vec::with_capacity(cipher_count); + for _ in 0..cipher_count { + ciphers.push(cursor.read_u16_le()?); + } + Ok(NegotiateContext::Encryption { ciphers }) + } + NEGOTIATE_CONTEXT_COMPRESSION => { + let alg_count = cursor.read_u16_le()? as usize; + let _padding = cursor.read_u16_le()?; + let flags = cursor.read_u32_le()?; + let mut algorithms = Vec::with_capacity(alg_count); + for _ in 0..alg_count { + algorithms.push(cursor.read_u16_le()?); + } + Ok(NegotiateContext::Compression { flags, algorithms }) + } + NEGOTIATE_CONTEXT_SIGNING => { + let alg_count = cursor.read_u16_le()? as usize; + let mut algorithms = Vec::with_capacity(alg_count); + for _ in 0..alg_count { + algorithms.push(cursor.read_u16_le()?); + } + Ok(NegotiateContext::Signing { algorithms }) + } + _ => { + let data = cursor.read_bytes_bounded(data_length)?.to_vec(); + Ok(NegotiateContext::Unknown { context_type, data }) + } + } +} + +/// Unpack a list of negotiate contexts. +fn unpack_negotiate_contexts( + cursor: &mut ReadCursor<'_>, + count: usize, +) -> Result> { + let mut contexts = Vec::with_capacity(count); + for i in 0..count { + // Each context after the first must be 8-byte aligned. + if i > 0 { + let pos = cursor.position(); + let remainder = pos % 8; + if remainder != 0 { + cursor.skip(8 - remainder)?; + } + } + contexts.push(unpack_negotiate_context(cursor)?); + } + Ok(contexts) +} + +// ── NegotiateRequest ─────────────────────────────────────────────────── + +/// SMB2 NEGOTIATE request (spec section 2.2.3). +/// +/// Sent by the client to advertise which dialects and capabilities it +/// supports. For SMB 3.1.1, includes negotiate contexts. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NegotiateRequest { + /// Security mode indicating signing requirements. + pub security_mode: SecurityMode, + /// Client capabilities. + pub capabilities: Capabilities, + /// Client GUID for identification. + pub client_guid: Guid, + /// Supported dialect revision numbers. + pub dialects: Vec, + /// Negotiate contexts (only for SMB 3.1.1). + pub negotiate_contexts: Vec, +} + +impl NegotiateRequest { + pub const STRUCTURE_SIZE: u16 = 36; + + /// Returns `true` if the dialects list includes SMB 3.1.1. + fn has_smb311(&self) -> bool { + self.dialects.contains(&Dialect::Smb3_1_1) + } +} + +impl Pack for NegotiateRequest { + fn pack(&self, cursor: &mut WriteCursor) { + let start = cursor.position(); + + // StructureSize (2) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // DialectCount (2) + cursor.write_u16_le(self.dialects.len() as u16); + // SecurityMode (2) + cursor.write_u16_le(self.security_mode.bits()); + // Reserved (2) + cursor.write_u16_le(0); + // Capabilities (4) + cursor.write_u32_le(self.capabilities.bits()); + // ClientGuid (16) + self.client_guid.pack(cursor); + + if self.has_smb311() { + // NegotiateContextOffset (4) -- will be backpatched + let ctx_offset_pos = cursor.position(); + cursor.write_u32_le(0); // placeholder + // NegotiateContextCount (2) + cursor.write_u16_le(self.negotiate_contexts.len() as u16); + // Reserved2 (2) + cursor.write_u16_le(0); + + // Dialects array + for &d in &self.dialects { + cursor.write_u16_le(d.into()); + } + + // Pad to 8-byte alignment (from start of SMB2 header). + // The offset is measured from the header start, so we align + // (Header::SIZE + current_struct_pos). + let abs_pos = Header::SIZE + (cursor.position() - start); + let remainder = abs_pos % 8; + if remainder != 0 { + cursor.write_zeros(8 - remainder); + } + + // Backpatch NegotiateContextOffset (from header start) + let ctx_offset = Header::SIZE + (cursor.position() - start); + cursor.set_u32_le_at(ctx_offset_pos, ctx_offset as u32); + + // Write negotiate contexts + pack_negotiate_contexts(&self.negotiate_contexts, cursor); + } else { + // ClientStartTime (8 bytes, must be 0) + cursor.write_u64_le(0); + + // Dialects array + for &d in &self.dialects { + cursor.write_u16_le(d.into()); + } + } + } +} + +impl Unpack for NegotiateRequest { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let start = cursor.position(); + + // StructureSize (2) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid NegotiateRequest structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // DialectCount (2) + let dialect_count = cursor.read_u16_le()? as usize; + // SecurityMode (2) + let security_mode = SecurityMode::new(cursor.read_u16_le()?); + // Reserved (2) + let _reserved = cursor.read_u16_le()?; + // Capabilities (4) + let capabilities = Capabilities::new(cursor.read_u32_le()?); + // ClientGuid (16) + let client_guid = Guid::unpack(cursor)?; + + // Read the 8-byte field that is either (offset, count, reserved2) + // or ClientStartTime -- we need to peek at the dialects to know. + let raw_8 = cursor.read_bytes(8)?; + + // Dialects array + let mut dialects = Vec::with_capacity(dialect_count); + for _ in 0..dialect_count { + let d = cursor.read_u16_le()?; + dialects.push( + Dialect::try_from(d) + .map_err(|_| Error::invalid_data(format!("invalid dialect: 0x{:04X}", d)))?, + ); + } + + let has_311 = dialects.contains(&Dialect::Smb3_1_1); + + let negotiate_contexts = if has_311 { + // Parse the 8-byte field as (offset, count, reserved2) + let ctx_offset = u32::from_le_bytes([raw_8[0], raw_8[1], raw_8[2], raw_8[3]]) as usize; + let ctx_count = u16::from_le_bytes([raw_8[4], raw_8[5]]) as usize; + + // Skip padding to reach the negotiate context list. + let current_abs = Header::SIZE + (cursor.position() - start); + if ctx_offset > current_abs { + cursor.skip(ctx_offset - current_abs)?; + } + + unpack_negotiate_contexts(cursor, ctx_count)? + } else { + Vec::new() + }; + + Ok(NegotiateRequest { + security_mode, + capabilities, + client_guid, + dialects, + negotiate_contexts, + }) + } +} + +// ── NegotiateResponse ────────────────────────────────────────────────── + +/// SMB2 NEGOTIATE response (spec section 2.2.4). +/// +/// Sent by the server to indicate the selected dialect, server capabilities, +/// and (for SMB 3.1.1) negotiate contexts. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NegotiateResponse { + /// Server security mode. + pub security_mode: SecurityMode, + /// Selected dialect revision. + pub dialect_revision: Dialect, + /// Server GUID. + pub server_guid: Guid, + /// Server capabilities. + pub capabilities: Capabilities, + /// Maximum transact buffer size. + pub max_transact_size: u32, + /// Maximum read size. + pub max_read_size: u32, + /// Maximum write size. + pub max_write_size: u32, + /// Server system time as raw FILETIME value. + pub system_time: u64, + /// Server start time as raw FILETIME value. + pub server_start_time: u64, + /// Security buffer (GSS token). + pub security_buffer: Vec, + /// Negotiate contexts (only for SMB 3.1.1). + pub negotiate_contexts: Vec, +} + +impl NegotiateResponse { + pub const STRUCTURE_SIZE: u16 = 65; + + /// Returns `true` if the negotiated dialect is SMB 3.1.1. + fn is_smb311(&self) -> bool { + self.dialect_revision == Dialect::Smb3_1_1 + } +} + +impl Pack for NegotiateResponse { + fn pack(&self, cursor: &mut WriteCursor) { + let start = cursor.position(); + + // StructureSize (2) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // SecurityMode (2) + cursor.write_u16_le(self.security_mode.bits()); + // DialectRevision (2) + cursor.write_u16_le(self.dialect_revision.into()); + // NegotiateContextCount/Reserved (2) + if self.is_smb311() { + cursor.write_u16_le(self.negotiate_contexts.len() as u16); + } else { + cursor.write_u16_le(0); + } + // ServerGuid (16) + self.server_guid.pack(cursor); + // Capabilities (4) + cursor.write_u32_le(self.capabilities.bits()); + // MaxTransactSize (4) + cursor.write_u32_le(self.max_transact_size); + // MaxReadSize (4) + cursor.write_u32_le(self.max_read_size); + // MaxWriteSize (4) + cursor.write_u32_le(self.max_write_size); + // SystemTime (8) + cursor.write_u64_le(self.system_time); + // ServerStartTime (8) + cursor.write_u64_le(self.server_start_time); + + // SecurityBufferOffset (2) -- offset from header start to the buffer. + // Fixed part of the response struct is 64 bytes (fields above), so + // the buffer starts at Header::SIZE + 64. + let sec_buf_offset = (Header::SIZE + 64) as u16; + cursor.write_u16_le(sec_buf_offset); + // SecurityBufferLength (2) + cursor.write_u16_le(self.security_buffer.len() as u16); + + // NegotiateContextOffset/Reserved2 (4) + let ctx_offset_pos = cursor.position(); + cursor.write_u32_le(0); // placeholder (will backpatch for 3.1.1) + + // SecurityBuffer (variable) + cursor.write_bytes(&self.security_buffer); + + if self.is_smb311() && !self.negotiate_contexts.is_empty() { + // Pad to 8-byte alignment from header start + let abs_pos = Header::SIZE + (cursor.position() - start); + let remainder = abs_pos % 8; + if remainder != 0 { + cursor.write_zeros(8 - remainder); + } + + // Backpatch NegotiateContextOffset + let ctx_offset = Header::SIZE + (cursor.position() - start); + cursor.set_u32_le_at(ctx_offset_pos, ctx_offset as u32); + + // Write negotiate contexts + pack_negotiate_contexts(&self.negotiate_contexts, cursor); + } + } +} + +impl Unpack for NegotiateResponse { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let start = cursor.position(); + + // StructureSize (2) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid NegotiateResponse structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // SecurityMode (2) + let security_mode = SecurityMode::new(cursor.read_u16_le()?); + // DialectRevision (2) + let dialect_raw = cursor.read_u16_le()?; + let dialect_revision = Dialect::try_from(dialect_raw).map_err(|_| { + Error::invalid_data(format!("invalid dialect revision: 0x{:04X}", dialect_raw)) + })?; + // NegotiateContextCount/Reserved (2) + let negotiate_context_count = cursor.read_u16_le()? as usize; + // ServerGuid (16) + let server_guid = Guid::unpack(cursor)?; + // Capabilities (4) + let capabilities = Capabilities::new(cursor.read_u32_le()?); + // MaxTransactSize (4) + let max_transact_size = cursor.read_u32_le()?; + // MaxReadSize (4) + let max_read_size = cursor.read_u32_le()?; + // MaxWriteSize (4) + let max_write_size = cursor.read_u32_le()?; + // SystemTime (8) + let system_time = cursor.read_u64_le()?; + // ServerStartTime (8) + let server_start_time = cursor.read_u64_le()?; + // SecurityBufferOffset (2) + let _sec_buf_offset = cursor.read_u16_le()?; + // SecurityBufferLength (2) + let sec_buf_length = cursor.read_u16_le()? as usize; + // NegotiateContextOffset/Reserved2 (4) + let negotiate_context_offset = cursor.read_u32_le()? as usize; + + // SecurityBuffer (variable) + let security_buffer = if sec_buf_length > 0 { + cursor.read_bytes_bounded(sec_buf_length)?.to_vec() + } else { + Vec::new() + }; + + // Negotiate contexts (only for 3.1.1) + let negotiate_contexts = + if dialect_revision == Dialect::Smb3_1_1 && negotiate_context_count > 0 { + // Skip padding to reach the context list + let current_abs = Header::SIZE + (cursor.position() - start); + if negotiate_context_offset > current_abs { + cursor.skip(negotiate_context_offset - current_abs)?; + } + unpack_negotiate_contexts(cursor, negotiate_context_count)? + } else { + Vec::new() + }; + + Ok(NegotiateResponse { + security_mode, + dialect_revision, + server_guid, + capabilities, + max_transact_size, + max_read_size, + max_write_size, + system_time, + server_start_time, + security_buffer, + negotiate_contexts, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── Helpers ──────────────────────────────────────────────────── + + fn sample_guid() -> Guid { + Guid { + data1: 0x6BA7B810, + data2: 0x9DAD, + data3: 0x11D1, + data4: [0x80, 0xB4, 0x00, 0xC0, 0x4F, 0xD4, 0x30, 0xC8], + } + } + + // ── NegotiateRequest tests ───────────────────────────────────── + + #[test] + fn negotiate_request_roundtrip_without_contexts() { + let original = NegotiateRequest { + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + capabilities: Capabilities::new(Capabilities::DFS | Capabilities::LARGE_MTU), + client_guid: sample_guid(), + dialects: vec![Dialect::Smb2_0_2, Dialect::Smb2_1, Dialect::Smb3_0], + negotiate_contexts: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = NegotiateRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.security_mode.bits(), original.security_mode.bits()); + assert_eq!(decoded.capabilities.bits(), original.capabilities.bits()); + assert_eq!(decoded.client_guid, original.client_guid); + assert_eq!(decoded.dialects, original.dialects); + assert!(decoded.negotiate_contexts.is_empty()); + } + + #[test] + fn negotiate_request_roundtrip_with_contexts() { + let original = NegotiateRequest { + security_mode: SecurityMode::new( + SecurityMode::SIGNING_ENABLED | SecurityMode::SIGNING_REQUIRED, + ), + capabilities: Capabilities::new( + Capabilities::DFS + | Capabilities::LEASING + | Capabilities::LARGE_MTU + | Capabilities::ENCRYPTION, + ), + client_guid: sample_guid(), + dialects: vec![ + Dialect::Smb2_0_2, + Dialect::Smb2_1, + Dialect::Smb3_0, + Dialect::Smb3_0_2, + Dialect::Smb3_1_1, + ], + negotiate_contexts: vec![ + NegotiateContext::PreauthIntegrity { + hash_algorithms: vec![HASH_ALGORITHM_SHA512], + salt: vec![0xDE, 0xAD, 0xBE, 0xEF], + }, + NegotiateContext::Encryption { + ciphers: vec![CIPHER_AES_128_GCM, CIPHER_AES_128_CCM], + }, + NegotiateContext::Signing { + algorithms: vec![SIGNING_AES_GMAC, SIGNING_AES_CMAC], + }, + NegotiateContext::Compression { + flags: COMPRESSION_FLAG_CHAINED, + algorithms: vec![COMPRESSION_LZ77, COMPRESSION_LZNT1], + }, + ], + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = NegotiateRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.security_mode.bits(), original.security_mode.bits()); + assert_eq!(decoded.capabilities.bits(), original.capabilities.bits()); + assert_eq!(decoded.client_guid, original.client_guid); + assert_eq!(decoded.dialects, original.dialects); + assert_eq!(decoded.negotiate_contexts.len(), 4); + assert_eq!(decoded.negotiate_contexts, original.negotiate_contexts); + } + + #[test] + fn negotiate_request_structure_size_field() { + let req = NegotiateRequest { + security_mode: SecurityMode::default(), + capabilities: Capabilities::default(), + client_guid: Guid::ZERO, + dialects: vec![Dialect::Smb2_0_2], + negotiate_contexts: Vec::new(), + }; + + let mut w = WriteCursor::new(); + req.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(u16::from_le_bytes([bytes[0], bytes[1]]), 36); + } + + #[test] + fn negotiate_request_wrong_structure_size() { + let mut buf = [0u8; 48]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + let mut cursor = ReadCursor::new(&buf); + let result = NegotiateRequest::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + #[test] + fn negotiate_request_single_dialect() { + let original = NegotiateRequest { + security_mode: SecurityMode::default(), + capabilities: Capabilities::default(), + client_guid: Guid::ZERO, + dialects: vec![Dialect::Smb3_0_2], + negotiate_contexts: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = NegotiateRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.dialects, vec![Dialect::Smb3_0_2]); + } + + #[test] + fn negotiate_request_smb311_only() { + let original = NegotiateRequest { + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + capabilities: Capabilities::default(), + client_guid: sample_guid(), + dialects: vec![Dialect::Smb3_1_1], + negotiate_contexts: vec![NegotiateContext::PreauthIntegrity { + hash_algorithms: vec![HASH_ALGORITHM_SHA512], + salt: vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08], + }], + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = NegotiateRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.dialects, vec![Dialect::Smb3_1_1]); + assert_eq!(decoded.negotiate_contexts.len(), 1); + assert_eq!( + decoded.negotiate_contexts[0], + original.negotiate_contexts[0] + ); + } + + // ── NegotiateResponse tests ──────────────────────────────────── + + #[test] + fn negotiate_response_roundtrip_no_contexts() { + let original = NegotiateResponse { + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + dialect_revision: Dialect::Smb3_0, + server_guid: sample_guid(), + capabilities: Capabilities::new( + Capabilities::DFS | Capabilities::LEASING | Capabilities::LARGE_MTU, + ), + max_transact_size: 8_388_608, + max_read_size: 8_388_608, + max_write_size: 8_388_608, + system_time: 133_485_408_000_000_000, + server_start_time: 0, + security_buffer: vec![0x60, 0x28, 0x06, 0x06], + negotiate_contexts: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = NegotiateResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.security_mode.bits(), original.security_mode.bits()); + assert_eq!(decoded.dialect_revision, Dialect::Smb3_0); + assert_eq!(decoded.server_guid, original.server_guid); + assert_eq!(decoded.capabilities.bits(), original.capabilities.bits()); + assert_eq!(decoded.max_transact_size, 8_388_608); + assert_eq!(decoded.max_read_size, 8_388_608); + assert_eq!(decoded.max_write_size, 8_388_608); + assert_eq!(decoded.system_time, original.system_time); + assert_eq!(decoded.server_start_time, 0); + assert_eq!(decoded.security_buffer, original.security_buffer); + assert!(decoded.negotiate_contexts.is_empty()); + } + + #[test] + fn negotiate_response_roundtrip_with_contexts() { + let original = NegotiateResponse { + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + dialect_revision: Dialect::Smb3_1_1, + server_guid: sample_guid(), + capabilities: Capabilities::new(Capabilities::DFS | Capabilities::ENCRYPTION), + max_transact_size: 1_048_576, + max_read_size: 1_048_576, + max_write_size: 1_048_576, + system_time: 133_485_408_000_000_000, + server_start_time: 133_000_000_000_000_000, + security_buffer: vec![0x60, 0x28], + negotiate_contexts: vec![ + NegotiateContext::PreauthIntegrity { + hash_algorithms: vec![HASH_ALGORITHM_SHA512], + salt: vec![0xAA, 0xBB, 0xCC, 0xDD], + }, + NegotiateContext::Encryption { + ciphers: vec![CIPHER_AES_128_GCM], + }, + NegotiateContext::Signing { + algorithms: vec![SIGNING_AES_GMAC], + }, + ], + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = NegotiateResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.dialect_revision, Dialect::Smb3_1_1); + assert_eq!(decoded.negotiate_contexts.len(), 3); + assert_eq!(decoded.negotiate_contexts, original.negotiate_contexts); + assert_eq!(decoded.security_buffer, original.security_buffer); + } + + #[test] + fn negotiate_response_structure_size_field() { + let resp = NegotiateResponse { + security_mode: SecurityMode::default(), + dialect_revision: Dialect::Smb2_0_2, + server_guid: Guid::ZERO, + capabilities: Capabilities::default(), + max_transact_size: 0, + max_read_size: 0, + max_write_size: 0, + system_time: 0, + server_start_time: 0, + security_buffer: Vec::new(), + negotiate_contexts: Vec::new(), + }; + + let mut w = WriteCursor::new(); + resp.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(u16::from_le_bytes([bytes[0], bytes[1]]), 65); + } + + #[test] + fn negotiate_response_wrong_structure_size() { + let mut buf = [0u8; 70]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + let mut cursor = ReadCursor::new(&buf); + let result = NegotiateResponse::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + #[test] + fn negotiate_response_empty_security_buffer() { + let original = NegotiateResponse { + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + dialect_revision: Dialect::Smb2_1, + server_guid: Guid::ZERO, + capabilities: Capabilities::default(), + max_transact_size: 65536, + max_read_size: 65536, + max_write_size: 65536, + system_time: 0, + server_start_time: 0, + security_buffer: Vec::new(), + negotiate_contexts: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = NegotiateResponse::unpack(&mut r).unwrap(); + + assert!(decoded.security_buffer.is_empty()); + assert_eq!(decoded.dialect_revision, Dialect::Smb2_1); + } + + // ── Negotiate context roundtrip tests ────────────────────────── + + #[test] + fn context_preauth_integrity_roundtrip() { + let ctx = NegotiateContext::PreauthIntegrity { + hash_algorithms: vec![HASH_ALGORITHM_SHA512], + salt: vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08], + }; + + let mut w = WriteCursor::new(); + pack_negotiate_contexts(std::slice::from_ref(&ctx), &mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = unpack_negotiate_contexts(&mut r, 1).unwrap(); + + assert_eq!(decoded.len(), 1); + assert_eq!(decoded[0], ctx); + } + + #[test] + fn context_encryption_roundtrip() { + let ctx = NegotiateContext::Encryption { + ciphers: vec![ + CIPHER_AES_128_GCM, + CIPHER_AES_128_CCM, + CIPHER_AES_256_GCM, + CIPHER_AES_256_CCM, + ], + }; + + let mut w = WriteCursor::new(); + pack_negotiate_contexts(std::slice::from_ref(&ctx), &mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = unpack_negotiate_contexts(&mut r, 1).unwrap(); + + assert_eq!(decoded[0], ctx); + } + + #[test] + fn context_signing_roundtrip() { + let ctx = NegotiateContext::Signing { + algorithms: vec![SIGNING_AES_GMAC, SIGNING_AES_CMAC, SIGNING_HMAC_SHA256], + }; + + let mut w = WriteCursor::new(); + pack_negotiate_contexts(std::slice::from_ref(&ctx), &mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = unpack_negotiate_contexts(&mut r, 1).unwrap(); + + assert_eq!(decoded[0], ctx); + } + + #[test] + fn context_compression_roundtrip() { + let ctx = NegotiateContext::Compression { + flags: COMPRESSION_FLAG_CHAINED, + algorithms: vec![COMPRESSION_LZ77, COMPRESSION_LZNT1, COMPRESSION_LZ4], + }; + + let mut w = WriteCursor::new(); + pack_negotiate_contexts(std::slice::from_ref(&ctx), &mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = unpack_negotiate_contexts(&mut r, 1).unwrap(); + + assert_eq!(decoded[0], ctx); + } + + #[test] + fn context_unknown_roundtrip() { + let ctx = NegotiateContext::Unknown { + context_type: 0x00FF, + data: vec![0x01, 0x02, 0x03, 0x04], + }; + + let mut w = WriteCursor::new(); + pack_negotiate_contexts(std::slice::from_ref(&ctx), &mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = unpack_negotiate_contexts(&mut r, 1).unwrap(); + + assert_eq!(decoded[0], ctx); + } + + #[test] + fn multiple_contexts_roundtrip() { + let contexts = vec![ + NegotiateContext::PreauthIntegrity { + hash_algorithms: vec![HASH_ALGORITHM_SHA512], + salt: vec![0xAA; 32], + }, + NegotiateContext::Encryption { + ciphers: vec![CIPHER_AES_128_GCM], + }, + NegotiateContext::Compression { + flags: COMPRESSION_FLAG_NONE, + algorithms: vec![COMPRESSION_NONE], + }, + NegotiateContext::Signing { + algorithms: vec![SIGNING_HMAC_SHA256], + }, + ]; + + let mut w = WriteCursor::new(); + pack_negotiate_contexts(&contexts, &mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = unpack_negotiate_contexts(&mut r, 4).unwrap(); + + assert_eq!(decoded, contexts); + } + + #[test] + fn context_alignment_is_8_bytes() { + // A PreauthIntegrity context with a 3-byte salt creates a data section + // that isn't 8-byte aligned. The next context should be padded. + let contexts = vec![ + NegotiateContext::PreauthIntegrity { + hash_algorithms: vec![HASH_ALGORITHM_SHA512], + salt: vec![0x01, 0x02, 0x03], // 3 bytes -> total data = 2+2+2+3 = 9 + }, + NegotiateContext::Encryption { + ciphers: vec![CIPHER_AES_128_GCM], + }, + ]; + + let mut w = WriteCursor::new(); + pack_negotiate_contexts(&contexts, &mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = unpack_negotiate_contexts(&mut r, 2).unwrap(); + + assert_eq!(decoded, contexts); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::{ + arb_capabilities, arb_dialect, arb_guid, arb_security_mode, arb_small_bytes, + }; + use proptest::prelude::*; + + fn arb_u16_vec(max: usize) -> impl Strategy> { + prop::collection::vec(any::(), 0..=max) + } + + /// Generate a `NegotiateContext`. The generator avoids the Unknown variant + /// with a context type that collides with a known one, since the decoder + /// would demote it back to the typed form and roundtrip would fail. + fn arb_negotiate_context() -> impl Strategy { + let known_types = [ + NEGOTIATE_CONTEXT_PREAUTH_INTEGRITY, + NEGOTIATE_CONTEXT_ENCRYPTION, + NEGOTIATE_CONTEXT_COMPRESSION, + NEGOTIATE_CONTEXT_SIGNING, + ]; + + let preauth = (arb_u16_vec(8), prop::collection::vec(any::(), 0..=64)).prop_map( + |(hash_algorithms, salt)| NegotiateContext::PreauthIntegrity { + hash_algorithms, + salt, + }, + ); + let encryption = + arb_u16_vec(8).prop_map(|ciphers| NegotiateContext::Encryption { ciphers }); + let compression = (any::(), arb_u16_vec(8)) + .prop_map(|(flags, algorithms)| NegotiateContext::Compression { flags, algorithms }); + let signing = + arb_u16_vec(8).prop_map(|algorithms| NegotiateContext::Signing { algorithms }); + let unknown = (any::(), prop::collection::vec(any::(), 0..=64)) + .prop_filter( + "type must not collide with a known variant", + move |(t, _)| !known_types.contains(t), + ) + .prop_map(|(context_type, data)| NegotiateContext::Unknown { context_type, data }); + + prop_oneof![preauth, encryption, compression, signing, unknown] + } + + /// Generate the tuple `(dialects, negotiate_contexts)` in a mutually + /// consistent way: contexts are present iff 3.1.1 is in the dialect list. + fn arb_dialects_and_contexts() -> impl Strategy, Vec)> { + prop::collection::vec(arb_dialect(), 1..=5).prop_flat_map(|dialects| { + let has_311 = dialects.contains(&Dialect::Smb3_1_1); + let ctx_strat: BoxedStrategy> = if has_311 { + prop::collection::vec(arb_negotiate_context(), 0..=4).boxed() + } else { + Just(Vec::new()).boxed() + }; + (Just(dialects), ctx_strat) + }) + } + + proptest! { + #[test] + fn negotiate_context_list_roundtrip( + contexts in prop::collection::vec(arb_negotiate_context(), 0..=6), + ) { + let mut w = WriteCursor::new(); + pack_negotiate_contexts(&contexts, &mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = unpack_negotiate_contexts(&mut r, contexts.len()).unwrap(); + prop_assert_eq!(decoded, contexts); + } + + #[test] + fn negotiate_request_pack_unpack( + security_mode in arb_security_mode(), + capabilities in arb_capabilities(), + client_guid in arb_guid(), + (dialects, negotiate_contexts) in arb_dialects_and_contexts(), + ) { + let original = NegotiateRequest { + security_mode, + capabilities, + client_guid, + dialects, + negotiate_contexts, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = NegotiateRequest::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + } + + #[test] + fn negotiate_response_pack_unpack( + security_mode in arb_security_mode(), + server_guid in arb_guid(), + capabilities in arb_capabilities(), + max_transact_size in any::(), + max_read_size in any::(), + max_write_size in any::(), + system_time in any::(), + server_start_time in any::(), + security_buffer in arb_small_bytes(), + dialect_revision in arb_dialect(), + contexts_if_311 in prop::collection::vec(arb_negotiate_context(), 0..=4), + ) { + // Contexts only present for SMB 3.1.1, per the spec / encoder. + let negotiate_contexts = if dialect_revision == Dialect::Smb3_1_1 { + contexts_if_311 + } else { + Vec::new() + }; + let original = NegotiateResponse { + security_mode, + dialect_revision, + server_guid, + capabilities, + max_transact_size, + max_read_size, + max_write_size, + system_time, + server_start_time, + security_buffer, + negotiate_contexts, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = NegotiateResponse::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + } + } +} diff --git a/vendor/smb2/src/msg/oplock_break.rs b/vendor/smb2/src/msg/oplock_break.rs new file mode 100644 index 0000000..2a8b662 --- /dev/null +++ b/vendor/smb2/src/msg/oplock_break.rs @@ -0,0 +1,262 @@ +//! SMB2 Oplock Break Notification, Acknowledgment, and Response +//! (MS-SMB2 sections 2.2.23, 2.2.24, 2.2.25). +//! +//! All three oplock break messages share an identical 24-byte wire format: +//! - StructureSize (2 bytes, must be 24) +//! - OplockLevel (1 byte) +//! - Reserved (1 byte) +//! - Reserved2 (4 bytes) +//! - FileId (16 bytes) +//! +//! We define one shared struct and provide type aliases for each role. +//! +//! Note: Lease break notification/acknowledgment/response (sections 2.2.23.2, +//! 2.2.24.2, 2.2.25.2) use a different structure with LeaseKey, LeaseState, +//! etc. Lease break handling is deferred to a future implementation. + +use crate::error::Result; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::{FileId, OplockLevel}; +use crate::Error; + +// ── OplockBreak (shared struct) ──────────────────────────────────────── + +/// Shared wire format for oplock break notification, acknowledgment, and +/// response messages (MS-SMB2 sections 2.2.23, 2.2.24, 2.2.25). +/// +/// All three messages have an identical 24-byte layout. The message's role +/// (notification vs acknowledgment vs response) is determined by the header's +/// command code and flags, not by this structure. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OplockBreak { + /// The oplock level. + pub oplock_level: OplockLevel, + /// The file handle associated with the oplock. + pub file_id: FileId, +} + +impl OplockBreak { + pub const STRUCTURE_SIZE: u16 = 24; +} + +impl Pack for OplockBreak { + fn pack(&self, cursor: &mut WriteCursor) { + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // OplockLevel (1 byte) + cursor.write_u8(self.oplock_level as u8); + // Reserved (1 byte) + cursor.write_u8(0); + // Reserved2 (4 bytes) + cursor.write_u32_le(0); + // FileId (16 bytes) + cursor.write_u64_le(self.file_id.persistent); + cursor.write_u64_le(self.file_id.volatile); + } +} + +impl Unpack for OplockBreak { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid OplockBreak structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + let oplock_level = OplockLevel::try_from(cursor.read_u8()?)?; + let _reserved = cursor.read_u8()?; + let _reserved2 = cursor.read_u32_le()?; + let persistent = cursor.read_u64_le()?; + let volatile = cursor.read_u64_le()?; + + Ok(OplockBreak { + oplock_level, + file_id: FileId { + persistent, + volatile, + }, + }) + } +} + +/// Oplock break notification (server to client, MS-SMB2 section 2.2.23). +/// +/// Arrives with `MessageId = 0xFFFFFFFFFFFFFFFF` (unsolicited). +pub type OplockBreakNotification = OplockBreak; + +/// Oplock break acknowledgment (client to server, MS-SMB2 section 2.2.24). +pub type OplockBreakAcknowledgment = OplockBreak; + +/// Oplock break response (server to client after ack, MS-SMB2 section 2.2.25). +pub type OplockBreakResponse = OplockBreak; + +#[cfg(test)] +mod tests { + use super::*; + + // ── OplockBreakNotification tests ───────────────────────────────── + + #[test] + fn oplock_break_notification_roundtrip() { + let original = OplockBreakNotification { + oplock_level: OplockLevel::LevelII, + file_id: FileId { + persistent: 0x1122_3344_5566_7788, + volatile: 0xAABB_CCDD_EEFF_0011, + }, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // Fixed 24 bytes + assert_eq!(bytes.len(), 24); + + let mut r = ReadCursor::new(&bytes); + let decoded = OplockBreakNotification::unpack(&mut r).unwrap(); + + assert_eq!(decoded.oplock_level, OplockLevel::LevelII); + assert_eq!(decoded.file_id, original.file_id); + } + + #[test] + fn oplock_break_notification_exclusive_level() { + let original = OplockBreakNotification { + oplock_level: OplockLevel::Exclusive, + file_id: FileId { + persistent: 0x42, + volatile: 0x99, + }, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = OplockBreakNotification::unpack(&mut r).unwrap(); + + assert_eq!(decoded.oplock_level, OplockLevel::Exclusive); + assert_eq!(decoded.file_id.persistent, 0x42); + assert_eq!(decoded.file_id.volatile, 0x99); + } + + // ── OplockBreakAcknowledgment tests ─────────────────────────────── + + #[test] + fn oplock_break_acknowledgment_roundtrip() { + let original = OplockBreakAcknowledgment { + oplock_level: OplockLevel::None, + file_id: FileId { + persistent: 0xDEAD, + volatile: 0xBEEF, + }, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(bytes.len(), 24); + + let mut r = ReadCursor::new(&bytes); + let decoded = OplockBreakAcknowledgment::unpack(&mut r).unwrap(); + + assert_eq!(decoded.oplock_level, OplockLevel::None); + assert_eq!(decoded.file_id, original.file_id); + } + + // ── OplockBreakResponse tests ───────────────────────────────────── + + #[test] + fn oplock_break_response_roundtrip() { + let original = OplockBreakResponse { + oplock_level: OplockLevel::Batch, + file_id: FileId { + persistent: 0xCAFE, + volatile: 0xFACE, + }, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(bytes.len(), 24); + + let mut r = ReadCursor::new(&bytes); + let decoded = OplockBreakResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.oplock_level, OplockLevel::Batch); + assert_eq!(decoded.file_id, original.file_id); + } + + // ── Error tests ─────────────────────────────────────────────────── + + #[test] + fn oplock_break_wrong_structure_size() { + let mut buf = [0u8; 24]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = OplockBreak::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + // Roundtrip property tests live in `roundtrip_props` at file end. + + #[test] + fn oplock_break_reserved_fields_ignored() { + let mut buf = [0u8; 24]; + // StructureSize = 24 + buf[0..2].copy_from_slice(&24u16.to_le_bytes()); + // OplockLevel = LEVEL_II + buf[2] = OplockLevel::LevelII as u8; + // Reserved = 0xFF (should be ignored) + buf[3] = 0xFF; + // Reserved2 = 0xDEADBEEF (should be ignored) + buf[4..8].copy_from_slice(&0xDEAD_BEEFu32.to_le_bytes()); + // FileId persistent = 1 + buf[8..16].copy_from_slice(&1u64.to_le_bytes()); + // FileId volatile = 2 + buf[16..24].copy_from_slice(&2u64.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let decoded = OplockBreak::unpack(&mut cursor).unwrap(); + + assert_eq!(decoded.oplock_level, OplockLevel::LevelII); + assert_eq!(decoded.file_id.persistent, 1); + assert_eq!(decoded.file_id.volatile, 2); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::{arb_file_id, arb_oplock_level}; + use proptest::prelude::*; + + proptest! { + #[test] + fn oplock_break_pack_unpack( + oplock_level in arb_oplock_level(), + file_id in arb_file_id(), + ) { + let original = OplockBreak { oplock_level, file_id }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = OplockBreak::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + } +} diff --git a/vendor/smb2/src/msg/query_directory.rs b/vendor/smb2/src/msg/query_directory.rs new file mode 100644 index 0000000..1180f4a --- /dev/null +++ b/vendor/smb2/src/msg/query_directory.rs @@ -0,0 +1,476 @@ +//! SMB2 QUERY_DIRECTORY request and response (spec sections 2.2.33, 2.2.34). +//! +//! Used by the client to enumerate directory contents. The request specifies +//! a search pattern (typically `"*"`) and the response contains directory +//! entries in the requested information class format. + +use crate::error::Result; +use crate::msg::header::Header; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::FileId; +use crate::Error; + +// ── Enums / flags ──────────────────────────────────────────────────────── + +/// File information class for directory queries (MS-SMB2 2.2.33). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum FileInformationClass { + /// Basic directory information. + FileDirectoryInformation = 0x01, + /// Full directory information. + FileFullDirectoryInformation = 0x02, + /// Both short and long name information. + FileBothDirectoryInformation = 0x03, + /// File names only. + FileNamesInformation = 0x0C, + /// Both short and long name information with file IDs. + FileIdBothDirectoryInformation = 0x25, + /// Full directory information with file IDs. + FileIdFullDirectoryInformation = 0x26, +} + +impl TryFrom for FileInformationClass { + type Error = Error; + + fn try_from(value: u8) -> Result { + match value { + 0x01 => Ok(Self::FileDirectoryInformation), + 0x02 => Ok(Self::FileFullDirectoryInformation), + 0x03 => Ok(Self::FileBothDirectoryInformation), + 0x0C => Ok(Self::FileNamesInformation), + 0x25 => Ok(Self::FileIdBothDirectoryInformation), + 0x26 => Ok(Self::FileIdFullDirectoryInformation), + _ => Err(Error::invalid_data(format!( + "invalid FileInformationClass: 0x{:02X}", + value + ))), + } + } +} + +/// Query directory flags (MS-SMB2 2.2.33). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct QueryDirectoryFlags(pub u8); + +impl QueryDirectoryFlags { + /// Restart the enumeration from the beginning. + pub const RESTART_SCANS: u8 = 0x01; + /// Return only a single entry. + pub const RETURN_SINGLE_ENTRY: u8 = 0x02; + /// Resume from the specified file index. + pub const INDEX_SPECIFIED: u8 = 0x04; + /// Reopen the directory and change the search pattern. + pub const REOPEN: u8 = 0x10; +} + +// ── QueryDirectoryRequest ──────────────────────────────────────────────── + +/// SMB2 QUERY_DIRECTORY request (spec section 2.2.33). +/// +/// Sent by the client to enumerate files in a directory. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QueryDirectoryRequest { + /// The type of information to return for each directory entry. + pub file_information_class: FileInformationClass, + /// Flags controlling the query behavior. + pub flags: QueryDirectoryFlags, + /// Byte offset within the directory to resume enumeration from. + pub file_index: u32, + /// Handle to the directory being queried. + pub file_id: FileId, + /// Maximum number of bytes the server can return. + pub output_buffer_length: u32, + /// Search pattern (for example, `"*"` for all files). + pub file_name: String, +} + +impl QueryDirectoryRequest { + pub const STRUCTURE_SIZE: u16 = 33; +} + +impl Pack for QueryDirectoryRequest { + fn pack(&self, cursor: &mut WriteCursor) { + let start = cursor.position(); + + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // FileInformationClass (1 byte) + cursor.write_u8(self.file_information_class as u8); + // Flags (1 byte) + cursor.write_u8(self.flags.0); + // FileIndex (4 bytes) + cursor.write_u32_le(self.file_index); + // FileId (16 bytes) + cursor.write_u64_le(self.file_id.persistent); + cursor.write_u64_le(self.file_id.volatile); + // FileNameOffset (2 bytes) -- placeholder + let name_offset_pos = cursor.position(); + cursor.write_u16_le(0); + // FileNameLength (2 bytes) -- placeholder + let name_length_pos = cursor.position(); + cursor.write_u16_le(0); + // OutputBufferLength (4 bytes) + cursor.write_u32_le(self.output_buffer_length); + + if self.file_name.is_empty() { + // No search pattern: FileNameOffset and FileNameLength stay 0 + // per spec section 2.2.33. Write 1 padding byte to satisfy + // StructureSize=33 (32 fixed + 1 byte buffer minimum). + cursor.write_u8(0); + } else { + // Buffer: filename pattern in UTF-16LE. + // Offset is from the beginning of the SMB2 header per spec. + let name_offset = Header::SIZE + (cursor.position() - start); + let name_start = cursor.position(); + cursor.write_utf16_le(&self.file_name); + let name_byte_len = cursor.position() - name_start; + + // Backpatch + cursor.set_u16_le_at(name_offset_pos, name_offset as u16); + cursor.set_u16_le_at(name_length_pos, name_byte_len as u16); + } + } +} + +impl Unpack for QueryDirectoryRequest { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let start = cursor.position(); + + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid QueryDirectoryRequest structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // FileInformationClass (1 byte) + let info_class = FileInformationClass::try_from(cursor.read_u8()?)?; + // Flags (1 byte) + let flags = QueryDirectoryFlags(cursor.read_u8()?); + // FileIndex (4 bytes) + let file_index = cursor.read_u32_le()?; + // FileId (16 bytes) + let persistent = cursor.read_u64_le()?; + let volatile = cursor.read_u64_le()?; + let file_id = FileId { + persistent, + volatile, + }; + // FileNameOffset (2 bytes) + let name_offset = cursor.read_u16_le()? as usize; + // FileNameLength (2 bytes) + let name_length = cursor.read_u16_le()? as usize; + // OutputBufferLength (4 bytes) + let output_buffer_length = cursor.read_u32_le()?; + + // Read filename + // Offset on the wire is from beginning of SMB2 header. + let file_name = if name_length > 0 { + let current = cursor.position(); + let body_offset = name_offset.saturating_sub(Header::SIZE); + let target = start + body_offset; + if target > current { + cursor.skip(target - current)?; + } + cursor.read_utf16_le(name_length)? + } else { + String::new() + }; + + Ok(QueryDirectoryRequest { + file_information_class: info_class, + flags, + file_index, + file_id, + output_buffer_length, + file_name, + }) + } +} + +// ── QueryDirectoryResponse ─────────────────────────────────────────────── + +/// SMB2 QUERY_DIRECTORY response (spec section 2.2.34). +/// +/// Contains directory enumeration data as raw bytes. The format depends +/// on the `FileInformationClass` from the request. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QueryDirectoryResponse { + /// Raw output buffer containing directory entries. + pub output_buffer: Vec, +} + +impl QueryDirectoryResponse { + pub const STRUCTURE_SIZE: u16 = 9; +} + +impl Pack for QueryDirectoryResponse { + fn pack(&self, cursor: &mut WriteCursor) { + let start = cursor.position(); + + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // OutputBufferOffset (2 bytes) -- placeholder + let offset_pos = cursor.position(); + cursor.write_u16_le(0); + // OutputBufferLength (4 bytes) + cursor.write_u32_le(self.output_buffer.len() as u32); + + // Buffer + if !self.output_buffer.is_empty() { + // Offset is from the beginning of the SMB2 header per spec. + let buf_offset = Header::SIZE + (cursor.position() - start); + cursor.write_bytes(&self.output_buffer); + cursor.set_u16_le_at(offset_pos, buf_offset as u16); + } + } +} + +impl Unpack for QueryDirectoryResponse { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let start = cursor.position(); + + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid QueryDirectoryResponse structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // OutputBufferOffset (2 bytes) + let buf_offset = cursor.read_u16_le()? as usize; + // OutputBufferLength (4 bytes) + let buf_length = cursor.read_u32_le()? as usize; + + // Read buffer + // Offset on the wire is from beginning of SMB2 header. + let output_buffer = if buf_length > 0 { + let current = cursor.position(); + let body_offset = buf_offset.saturating_sub(Header::SIZE); + let target = start + body_offset; + if target > current { + cursor.skip(target - current)?; + } + cursor.read_bytes_bounded(buf_length)?.to_vec() + } else { + Vec::new() + }; + + Ok(QueryDirectoryResponse { output_buffer }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── QueryDirectoryRequest tests ────────────────────────────────── + + #[test] + fn query_directory_request_roundtrip_star_pattern() { + let original = QueryDirectoryRequest { + file_information_class: FileInformationClass::FileBothDirectoryInformation, + flags: QueryDirectoryFlags(QueryDirectoryFlags::RESTART_SCANS), + file_index: 0, + file_id: FileId { + persistent: 0xAAAA_BBBB_CCCC_DDDD, + volatile: 0x1111_2222_3333_4444, + }, + output_buffer_length: 65536, + file_name: "*".to_string(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = QueryDirectoryRequest::unpack(&mut r).unwrap(); + + assert_eq!( + decoded.file_information_class, + FileInformationClass::FileBothDirectoryInformation + ); + assert_eq!(decoded.flags.0, QueryDirectoryFlags::RESTART_SCANS); + assert_eq!(decoded.file_index, 0); + assert_eq!(decoded.file_id, original.file_id); + assert_eq!(decoded.output_buffer_length, 65536); + assert_eq!(decoded.file_name, "*"); + } + + #[test] + fn query_directory_request_structure_size() { + let req = QueryDirectoryRequest { + file_information_class: FileInformationClass::FileDirectoryInformation, + flags: QueryDirectoryFlags::default(), + file_index: 0, + file_id: FileId::default(), + output_buffer_length: 1024, + file_name: "*".to_string(), + }; + + let mut w = WriteCursor::new(); + req.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(bytes[0], 33); + assert_eq!(bytes[1], 0); + } + + #[test] + fn query_directory_request_wrong_structure_size() { + let mut buf = vec![0u8; 40]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + let mut cursor = ReadCursor::new(&buf); + let result = QueryDirectoryRequest::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + // ── QueryDirectoryResponse tests ───────────────────────────────── + + #[test] + fn query_directory_response_roundtrip_with_buffer() { + // Simulate raw directory entry data + let raw_entries = vec![ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, + 0x0F, 0x10, + ]; + + let original = QueryDirectoryResponse { + output_buffer: raw_entries.clone(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = QueryDirectoryResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.output_buffer, raw_entries); + } + + #[test] + fn query_directory_response_empty_buffer() { + let original = QueryDirectoryResponse { + output_buffer: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // StructureSize(2) + Offset(2) + Length(4) = 8 bytes + assert_eq!(bytes.len(), 8); + + let mut r = ReadCursor::new(&bytes); + let decoded = QueryDirectoryResponse::unpack(&mut r).unwrap(); + + assert!(decoded.output_buffer.is_empty()); + } + + #[test] + fn query_directory_response_structure_size() { + let resp = QueryDirectoryResponse { + output_buffer: vec![0xFF], + }; + + let mut w = WriteCursor::new(); + resp.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(bytes[0], 9); + assert_eq!(bytes[1], 0); + } + + #[test] + fn query_directory_response_wrong_structure_size() { + let mut buf = vec![0u8; 16]; + buf[0..2].copy_from_slice(&42u16.to_le_bytes()); + let mut cursor = ReadCursor::new(&buf); + let result = QueryDirectoryResponse::unpack(&mut cursor); + assert!(result.is_err()); + } + + // ── Enum tests ─────────────────────────────────────────────────── + + #[test] + fn file_information_class_roundtrip() { + for &class in &[ + FileInformationClass::FileDirectoryInformation, + FileInformationClass::FileFullDirectoryInformation, + FileInformationClass::FileBothDirectoryInformation, + FileInformationClass::FileNamesInformation, + FileInformationClass::FileIdFullDirectoryInformation, + FileInformationClass::FileIdBothDirectoryInformation, + ] { + let raw = class as u8; + let decoded = FileInformationClass::try_from(raw).unwrap(); + assert_eq!(decoded, class); + } + } + + #[test] + fn file_information_class_invalid() { + assert!(FileInformationClass::try_from(0xFF).is_err()); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::{ + arb_bytes, arb_file_id, arb_file_information_class, arb_utf16_string, + }; + use proptest::prelude::*; + + proptest! { + #[test] + fn query_directory_request_pack_unpack( + file_information_class in arb_file_information_class(), + flags_raw in any::(), + file_index in any::(), + file_id in arb_file_id(), + output_buffer_length in any::(), + // Search pattern is UTF-16LE on the wire. Allow empty + typical. + file_name in arb_utf16_string(128), + ) { + let original = QueryDirectoryRequest { + file_information_class, + flags: QueryDirectoryFlags(flags_raw), + file_index, + file_id, + output_buffer_length, + file_name, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = QueryDirectoryRequest::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + } + + #[test] + fn query_directory_response_pack_unpack(output_buffer in arb_bytes()) { + let original = QueryDirectoryResponse { output_buffer }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = QueryDirectoryResponse::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + } + } +} diff --git a/vendor/smb2/src/msg/query_info.rs b/vendor/smb2/src/msg/query_info.rs new file mode 100644 index 0000000..2bb61ce --- /dev/null +++ b/vendor/smb2/src/msg/query_info.rs @@ -0,0 +1,479 @@ +//! SMB2 QUERY_INFO request and response (spec sections 2.2.37, 2.2.38). +//! +//! Used to query file, filesystem, security, or quota information. +//! The response buffer is stored as raw bytes -- parsing into specific +//! information classes is deferred. + +use crate::error::Result; +use crate::msg::header::Header; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::FileId; +use crate::Error; + +// ── Enums ──────────────────────────────────────────────────────────────── + +/// Info type for query/set info operations (MS-SMB2 2.2.37). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum InfoType { + /// Query file information. + File = 0x01, + /// Query filesystem information. + Filesystem = 0x02, + /// Query security information. + Security = 0x03, + /// Query quota information. + Quota = 0x04, +} + +impl TryFrom for InfoType { + type Error = Error; + + fn try_from(value: u8) -> Result { + match value { + 0x01 => Ok(Self::File), + 0x02 => Ok(Self::Filesystem), + 0x03 => Ok(Self::Security), + 0x04 => Ok(Self::Quota), + _ => Err(Error::invalid_data(format!( + "invalid InfoType: 0x{:02X}", + value + ))), + } + } +} + +// ── QueryInfoRequest ───────────────────────────────────────────────────── + +/// SMB2 QUERY_INFO request (spec section 2.2.37). +/// +/// Sent by the client to query information about a file, filesystem, +/// security descriptor, or quota. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QueryInfoRequest { + /// The type of information being queried. + pub info_type: InfoType, + /// The file information class (interpretation depends on `info_type`). + pub file_info_class: u8, + /// Maximum number of output bytes the server may return. + pub output_buffer_length: u32, + /// Additional information flags (for example, security information flags). + pub additional_information: u32, + /// Query flags. + pub flags: u32, + /// Handle to the file or directory being queried. + pub file_id: FileId, + /// Optional input buffer (for example, for quota queries). + pub input_buffer: Vec, +} + +impl QueryInfoRequest { + pub const STRUCTURE_SIZE: u16 = 41; +} + +impl Pack for QueryInfoRequest { + fn pack(&self, cursor: &mut WriteCursor) { + let start = cursor.position(); + + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // InfoType (1 byte) + cursor.write_u8(self.info_type as u8); + // FileInfoClass (1 byte) + cursor.write_u8(self.file_info_class); + // OutputBufferLength (4 bytes) + cursor.write_u32_le(self.output_buffer_length); + // InputBufferOffset (2 bytes) -- placeholder + let input_offset_pos = cursor.position(); + cursor.write_u16_le(0); + // Reserved (2 bytes) + cursor.write_u16_le(0); + // InputBufferLength (4 bytes) + cursor.write_u32_le(self.input_buffer.len() as u32); + // AdditionalInformation (4 bytes) + cursor.write_u32_le(self.additional_information); + // Flags (4 bytes) + cursor.write_u32_le(self.flags); + // FileId (16 bytes) + cursor.write_u64_le(self.file_id.persistent); + cursor.write_u64_le(self.file_id.volatile); + + // Buffer (variable) + if !self.input_buffer.is_empty() { + // Offset is from the beginning of the SMB2 header per spec. + let buf_offset = Header::SIZE + (cursor.position() - start); + cursor.write_bytes(&self.input_buffer); + cursor.set_u16_le_at(input_offset_pos, buf_offset as u16); + } + } +} + +impl Unpack for QueryInfoRequest { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let start = cursor.position(); + + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid QueryInfoRequest structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // InfoType (1 byte) + let info_type = InfoType::try_from(cursor.read_u8()?)?; + // FileInfoClass (1 byte) + let file_info_class = cursor.read_u8()?; + // OutputBufferLength (4 bytes) + let output_buffer_length = cursor.read_u32_le()?; + // InputBufferOffset (2 bytes) + let input_offset = cursor.read_u16_le()? as usize; + // Reserved (2 bytes) + let _reserved = cursor.read_u16_le()?; + // InputBufferLength (4 bytes) + let input_length = cursor.read_u32_le()? as usize; + // AdditionalInformation (4 bytes) + let additional_information = cursor.read_u32_le()?; + // Flags (4 bytes) + let flags = cursor.read_u32_le()?; + // FileId (16 bytes) + let persistent = cursor.read_u64_le()?; + let volatile = cursor.read_u64_le()?; + let file_id = FileId { + persistent, + volatile, + }; + + // Read input buffer + // Offset on the wire is from beginning of SMB2 header. + let input_buffer = if input_length > 0 { + let current = cursor.position(); + let body_offset = input_offset.saturating_sub(Header::SIZE); + let target = start + body_offset; + if target > current { + cursor.skip(target - current)?; + } + cursor.read_bytes_bounded(input_length)?.to_vec() + } else { + Vec::new() + }; + + Ok(QueryInfoRequest { + info_type, + file_info_class, + output_buffer_length, + additional_information, + flags, + file_id, + input_buffer, + }) + } +} + +// ── QueryInfoResponse ──────────────────────────────────────────────────── + +/// SMB2 QUERY_INFO response (spec section 2.2.38). +/// +/// Contains the queried information as raw bytes. The format depends +/// on the `InfoType` and `FileInfoClass` from the request. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QueryInfoResponse { + /// Raw output buffer containing the queried information. + pub output_buffer: Vec, +} + +impl QueryInfoResponse { + pub const STRUCTURE_SIZE: u16 = 9; +} + +impl Pack for QueryInfoResponse { + fn pack(&self, cursor: &mut WriteCursor) { + let start = cursor.position(); + + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // OutputBufferOffset (2 bytes) -- placeholder + let offset_pos = cursor.position(); + cursor.write_u16_le(0); + // OutputBufferLength (4 bytes) + cursor.write_u32_le(self.output_buffer.len() as u32); + + // Buffer + if !self.output_buffer.is_empty() { + // Offset is from the beginning of the SMB2 header per spec. + let buf_offset = Header::SIZE + (cursor.position() - start); + cursor.write_bytes(&self.output_buffer); + cursor.set_u16_le_at(offset_pos, buf_offset as u16); + } + } +} + +impl Unpack for QueryInfoResponse { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let start = cursor.position(); + + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid QueryInfoResponse structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // OutputBufferOffset (2 bytes) + let buf_offset = cursor.read_u16_le()? as usize; + // OutputBufferLength (4 bytes) + let buf_length = cursor.read_u32_le()? as usize; + + // Read buffer + // Offset on the wire is from beginning of SMB2 header. + let output_buffer = if buf_length > 0 { + let current = cursor.position(); + let body_offset = buf_offset.saturating_sub(Header::SIZE); + let target = start + body_offset; + if target > current { + cursor.skip(target - current)?; + } + cursor.read_bytes_bounded(buf_length)?.to_vec() + } else { + Vec::new() + }; + + Ok(QueryInfoResponse { output_buffer }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── QueryInfoRequest tests ─────────────────────────────────────── + + #[test] + fn query_info_request_roundtrip_file_info() { + let original = QueryInfoRequest { + info_type: InfoType::File, + file_info_class: 0x12, // FileAllInformation + output_buffer_length: 4096, + additional_information: 0, + flags: 0, + file_id: FileId { + persistent: 0xDEAD_BEEF_CAFE_BABE, + volatile: 0x1234_5678_9ABC_DEF0, + }, + input_buffer: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = QueryInfoRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.info_type, InfoType::File); + assert_eq!(decoded.file_info_class, 0x12); + assert_eq!(decoded.output_buffer_length, 4096); + assert_eq!(decoded.additional_information, 0); + assert_eq!(decoded.flags, 0); + assert_eq!(decoded.file_id, original.file_id); + assert!(decoded.input_buffer.is_empty()); + } + + #[test] + fn query_info_request_with_input_buffer() { + let input = vec![0x01, 0x02, 0x03, 0x04]; + let original = QueryInfoRequest { + info_type: InfoType::Quota, + file_info_class: 0x20, + output_buffer_length: 8192, + additional_information: 0x04, // SACL_SECURITY_INFORMATION + flags: 0, + file_id: FileId { + persistent: 1, + volatile: 2, + }, + input_buffer: input.clone(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = QueryInfoRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.info_type, InfoType::Quota); + assert_eq!(decoded.input_buffer, input); + } + + #[test] + fn query_info_request_structure_size() { + let req = QueryInfoRequest { + info_type: InfoType::File, + file_info_class: 0, + output_buffer_length: 0, + additional_information: 0, + flags: 0, + file_id: FileId::default(), + input_buffer: Vec::new(), + }; + + let mut w = WriteCursor::new(); + req.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(bytes[0], 41); + assert_eq!(bytes[1], 0); + } + + #[test] + fn query_info_request_wrong_structure_size() { + let mut buf = vec![0u8; 48]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + let mut cursor = ReadCursor::new(&buf); + let result = QueryInfoRequest::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + // ── QueryInfoResponse tests ────────────────────────────────────── + + #[test] + fn query_info_response_roundtrip_with_data() { + let info_data = vec![ + 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xA0, 0xB0, 0xC0, + ]; + + let original = QueryInfoResponse { + output_buffer: info_data.clone(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = QueryInfoResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.output_buffer, info_data); + } + + #[test] + fn query_info_response_empty() { + let original = QueryInfoResponse { + output_buffer: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // StructureSize(2) + Offset(2) + Length(4) = 8 + assert_eq!(bytes.len(), 8); + + let mut r = ReadCursor::new(&bytes); + let decoded = QueryInfoResponse::unpack(&mut r).unwrap(); + + assert!(decoded.output_buffer.is_empty()); + } + + #[test] + fn query_info_response_structure_size() { + let resp = QueryInfoResponse { + output_buffer: vec![0xFF], + }; + + let mut w = WriteCursor::new(); + resp.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(bytes[0], 9); + assert_eq!(bytes[1], 0); + } + + #[test] + fn query_info_response_wrong_structure_size() { + let mut buf = vec![0u8; 16]; + buf[0..2].copy_from_slice(&42u16.to_le_bytes()); + let mut cursor = ReadCursor::new(&buf); + let result = QueryInfoResponse::unpack(&mut cursor); + assert!(result.is_err()); + } + + // ── Enum tests ─────────────────────────────────────────────────── + + #[test] + fn info_type_roundtrip() { + for &it in &[ + InfoType::File, + InfoType::Filesystem, + InfoType::Security, + InfoType::Quota, + ] { + let raw = it as u8; + let decoded = InfoType::try_from(raw).unwrap(); + assert_eq!(decoded, it); + } + } + + #[test] + fn info_type_invalid() { + assert!(InfoType::try_from(0x00).is_err()); + assert!(InfoType::try_from(0x05).is_err()); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::{arb_bytes, arb_file_id, arb_info_type}; + use proptest::prelude::*; + + proptest! { + #[test] + fn query_info_request_pack_unpack( + info_type in arb_info_type(), + file_info_class in any::(), + output_buffer_length in any::(), + additional_information in any::(), + flags in any::(), + file_id in arb_file_id(), + input_buffer in arb_bytes(), + ) { + let original = QueryInfoRequest { + info_type, + file_info_class, + output_buffer_length, + additional_information, + flags, + file_id, + input_buffer, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = QueryInfoRequest::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + } + + #[test] + fn query_info_response_pack_unpack(output_buffer in arb_bytes()) { + let original = QueryInfoResponse { output_buffer }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = QueryInfoResponse::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + } + } +} diff --git a/vendor/smb2/src/msg/read.rs b/vendor/smb2/src/msg/read.rs new file mode 100644 index 0000000..aaab974 --- /dev/null +++ b/vendor/smb2/src/msg/read.rs @@ -0,0 +1,462 @@ +//! SMB2 READ Request and Response (MS-SMB2 sections 2.2.19, 2.2.20). +//! +//! The READ request reads data from a file or named pipe. +//! The response carries the read data in a variable-length buffer. + +use crate::error::Result; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::FileId; +use crate::Error; + +/// Read flag: read data directly from underlying storage (SMB 3.0.2+). +pub const SMB2_READFLAG_READ_UNBUFFERED: u8 = 0x01; + +/// Read flag: request compressed response (SMB 3.1.1). +pub const SMB2_READFLAG_REQUEST_COMPRESSED: u8 = 0x02; + +/// Channel value: no channel information. +pub const SMB2_CHANNEL_NONE: u32 = 0x0000_0000; + +/// SMB2 READ Request (MS-SMB2 section 2.2.19). +/// +/// Sent by the client to read data from a file. The fixed portion is 49 bytes +/// (StructureSize says 49 regardless of the variable buffer length): +/// - StructureSize (2 bytes, must be 49) +/// - Padding (1 byte) +/// - Flags (1 byte) +/// - Length (4 bytes) +/// - Offset (8 bytes) +/// - FileId (16 bytes) +/// - MinimumCount (4 bytes) +/// - Channel (4 bytes) +/// - RemainingBytes (4 bytes) +/// - ReadChannelInfoOffset (2 bytes) +/// - ReadChannelInfoLength (2 bytes) +/// - Buffer (variable, typically empty for basic reads) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ReadRequest { + /// Requested data placement offset in the response. + pub padding: u8, + /// Flags for the read operation. + pub flags: u8, + /// Number of bytes to read. + pub length: u32, + /// File offset to start reading from. + pub offset: u64, + /// File handle to read from. + pub file_id: FileId, + /// Minimum number of bytes for a successful read. + pub minimum_count: u32, + /// Channel for RDMA operations (typically `SMB2_CHANNEL_NONE`). + pub channel: u32, + /// Remaining bytes in a multi-part read. + pub remaining_bytes: u32, + /// Variable-length read channel info buffer. + pub read_channel_info: Vec, +} + +impl ReadRequest { + pub const STRUCTURE_SIZE: u16 = 49; +} + +impl Pack for ReadRequest { + fn pack(&self, cursor: &mut WriteCursor) { + cursor.write_u16_le(Self::STRUCTURE_SIZE); + cursor.write_u8(self.padding); + cursor.write_u8(self.flags); + cursor.write_u32_le(self.length); + cursor.write_u64_le(self.offset); + cursor.write_u64_le(self.file_id.persistent); + cursor.write_u64_le(self.file_id.volatile); + cursor.write_u32_le(self.minimum_count); + cursor.write_u32_le(self.channel); + cursor.write_u32_le(self.remaining_bytes); + + // ReadChannelInfoOffset/Length: relative to start of SMB2 header. + // For packing the body alone, we store offset as 0 when empty. + if self.read_channel_info.is_empty() { + cursor.write_u16_le(0); + cursor.write_u16_le(0); + } else { + // Offset from the SMB2 header = header (64) + fixed body (48) = 112. + // The fixed body before Buffer is 48 bytes (StructureSize 49 minus + // the 1 byte of Buffer that's counted in StructureSize). + cursor.write_u16_le(0); // Caller must backpatch if needed + cursor.write_u16_le(self.read_channel_info.len() as u16); + } + + // Buffer: at minimum 1 byte per the StructureSize=49 contract, + // but we write the actual channel info if present. + if self.read_channel_info.is_empty() { + // Write a single padding byte so the fixed part is 49 bytes + // (StructureSize includes this 1-byte minimum buffer). + cursor.write_u8(0); + } else { + cursor.write_bytes(&self.read_channel_info); + } + } +} + +impl Unpack for ReadRequest { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid ReadRequest structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + let padding = cursor.read_u8()?; + let flags = cursor.read_u8()?; + let length = cursor.read_u32_le()?; + let offset = cursor.read_u64_le()?; + let persistent = cursor.read_u64_le()?; + let volatile = cursor.read_u64_le()?; + let minimum_count = cursor.read_u32_le()?; + let channel = cursor.read_u32_le()?; + let remaining_bytes = cursor.read_u32_le()?; + let _read_channel_info_offset = cursor.read_u16_le()?; + let read_channel_info_length = cursor.read_u16_le()?; + + // The buffer is at least 1 byte (per StructureSize=49). + // Read channel info from the buffer based on the length field. + let read_channel_info = if read_channel_info_length > 0 { + cursor + .read_bytes(read_channel_info_length as usize)? + .to_vec() + } else { + // Skip the minimum 1-byte buffer + cursor.skip(1)?; + Vec::new() + }; + + Ok(ReadRequest { + padding, + flags, + length, + offset, + file_id: FileId { + persistent, + volatile, + }, + minimum_count, + channel, + remaining_bytes, + read_channel_info, + }) + } +} + +/// SMB2 READ Response (MS-SMB2 section 2.2.20). +/// +/// Sent by the server with the requested data. The fixed portion is 17 bytes: +/// - StructureSize (2 bytes, must be 17) +/// - DataOffset (1 byte) +/// - Reserved (1 byte) +/// - DataLength (4 bytes) +/// - DataRemaining (4 bytes) +/// - Reserved2 (4 bytes) +/// - Buffer (variable, DataLength bytes) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ReadResponse { + /// Offset from the start of the SMB2 header to the data. + pub data_offset: u8, + /// Number of remaining bytes on the channel. + pub data_remaining: u32, + /// Flags/Reserved2 field (used in SMB 3.1.1, otherwise 0). + pub flags: u32, + /// The data that was read. + pub data: Vec, +} + +impl ReadResponse { + pub const STRUCTURE_SIZE: u16 = 17; +} + +impl Pack for ReadResponse { + fn pack(&self, cursor: &mut WriteCursor) { + cursor.write_u16_le(Self::STRUCTURE_SIZE); + cursor.write_u8(self.data_offset); + cursor.write_u8(0); // Reserved + cursor.write_u32_le(self.data.len() as u32); + cursor.write_u32_le(self.data_remaining); + cursor.write_u32_le(self.flags); // Reserved2/Flags + cursor.write_bytes(&self.data); + } +} + +impl Unpack for ReadResponse { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid ReadResponse structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + let data_offset = cursor.read_u8()?; + let _reserved = cursor.read_u8()?; + let data_length = cursor.read_u32_le()?; + let data_remaining = cursor.read_u32_le()?; + let flags = cursor.read_u32_le()?; + + let data = if data_length > 0 { + cursor.read_bytes_bounded(data_length as usize)?.to_vec() + } else { + Vec::new() + }; + + Ok(ReadResponse { + data_offset, + data_remaining, + flags, + data, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── ReadRequest tests ────────────────────────────────────────── + + #[test] + fn read_request_roundtrip() { + let original = ReadRequest { + padding: 0x50, + flags: SMB2_READFLAG_READ_UNBUFFERED, + length: 65536, + offset: 0x1000, + file_id: FileId { + persistent: 0xAAAA_BBBB_CCCC_DDDD, + volatile: 0x1111_2222_3333_4444, + }, + minimum_count: 1024, + channel: SMB2_CHANNEL_NONE, + remaining_bytes: 0, + read_channel_info: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // Fixed: 48 bytes + 1-byte minimum buffer = 49 bytes + assert_eq!(bytes.len(), 49); + + let mut r = ReadCursor::new(&bytes); + let decoded = ReadRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.padding, original.padding); + assert_eq!(decoded.flags, original.flags); + assert_eq!(decoded.length, original.length); + assert_eq!(decoded.offset, original.offset); + assert_eq!(decoded.file_id, original.file_id); + assert_eq!(decoded.minimum_count, original.minimum_count); + assert_eq!(decoded.channel, original.channel); + assert_eq!(decoded.remaining_bytes, original.remaining_bytes); + assert!(decoded.read_channel_info.is_empty()); + } + + #[test] + fn read_request_with_channel_info_roundtrip() { + let channel_data = vec![0xDE, 0xAD, 0xBE, 0xEF]; + let original = ReadRequest { + padding: 0, + flags: 0, + length: 4096, + offset: 0, + file_id: FileId { + persistent: 1, + volatile: 2, + }, + minimum_count: 0, + channel: 0x0000_0001, // SMB2_CHANNEL_RDMA_V1 + remaining_bytes: 4096, + read_channel_info: channel_data.clone(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // Fixed: 48 bytes + 4-byte channel info = 52 bytes + assert_eq!(bytes.len(), 52); + + let mut r = ReadCursor::new(&bytes); + let decoded = ReadRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.read_channel_info, channel_data); + assert_eq!(decoded.channel, 0x0000_0001); + } + + #[test] + fn read_request_wrong_structure_size() { + let mut buf = [0u8; 49]; + buf[0..2].copy_from_slice(&50u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = ReadRequest::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + // ── ReadResponse tests ───────────────────────────────────────── + + #[test] + fn read_response_roundtrip() { + let original = ReadResponse { + data_offset: 0x50, // typical: 64 (header) + 16 (body fixed) = 80 = 0x50 + data_remaining: 0, + flags: 0, + data: vec![0x01, 0x02, 0x03, 0x04, 0x05], + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // Fixed: 16 bytes + 5 bytes data = 21 bytes + assert_eq!(bytes.len(), 21); + + let mut r = ReadCursor::new(&bytes); + let decoded = ReadResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.data_offset, original.data_offset); + assert_eq!(decoded.data_remaining, original.data_remaining); + assert_eq!(decoded.flags, original.flags); + assert_eq!(decoded.data, original.data); + } + + #[test] + fn read_response_empty_data() { + let original = ReadResponse { + data_offset: 0, + data_remaining: 0, + flags: 0, + data: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // Fixed: 16 bytes, no data + assert_eq!(bytes.len(), 16); + + let mut r = ReadCursor::new(&bytes); + let decoded = ReadResponse::unpack(&mut r).unwrap(); + + assert!(decoded.data.is_empty()); + } + + #[test] + fn read_response_known_bytes() { + let mut buf = Vec::new(); + // StructureSize = 17 + buf.extend_from_slice(&17u16.to_le_bytes()); + // DataOffset = 0x50 + buf.push(0x50); + // Reserved = 0 + buf.push(0x00); + // DataLength = 3 + buf.extend_from_slice(&3u32.to_le_bytes()); + // DataRemaining = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // Reserved2/Flags = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // Buffer = [0xAA, 0xBB, 0xCC] + buf.extend_from_slice(&[0xAA, 0xBB, 0xCC]); + + let mut cursor = ReadCursor::new(&buf); + let resp = ReadResponse::unpack(&mut cursor).unwrap(); + + assert_eq!(resp.data_offset, 0x50); + assert_eq!(resp.data, vec![0xAA, 0xBB, 0xCC]); + assert_eq!(resp.data_remaining, 0); + assert_eq!(resp.flags, 0); + } + + #[test] + fn read_response_wrong_structure_size() { + let mut buf = [0u8; 16]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = ReadResponse::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::{arb_bytes, arb_file_id, arb_small_bytes}; + use proptest::prelude::*; + + proptest! { + #[test] + fn read_request_pack_unpack( + padding in any::(), + flags in any::(), + length in any::(), + offset in any::(), + file_id in arb_file_id(), + minimum_count in any::(), + channel in any::(), + remaining_bytes in any::(), + read_channel_info in arb_small_bytes(), + ) { + let original = ReadRequest { + padding, + flags, + length, + offset, + file_id, + minimum_count, + channel, + remaining_bytes, + read_channel_info, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = ReadRequest::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + + #[test] + fn read_response_pack_unpack( + data_offset in any::(), + data_remaining in any::(), + flags in any::(), + data in arb_bytes(), + ) { + let original = ReadResponse { + data_offset, + data_remaining, + flags, + data, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = ReadResponse::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + } +} diff --git a/vendor/smb2/src/msg/roundtrip_strategies.rs b/vendor/smb2/src/msg/roundtrip_strategies.rs new file mode 100644 index 0000000..af0dda4 --- /dev/null +++ b/vendor/smb2/src/msg/roundtrip_strategies.rs @@ -0,0 +1,250 @@ +//! Shared proptest strategies for wire-format roundtrip tests. +//! +//! Each strategy generates a value that a real encoder could emit. The goal +//! is not to stress-test the decoder against malformed input (that's fuzzing) +//! but to exercise encode/decode symmetry on well-formed inputs. +//! +//! Rules followed here: +//! - Typed enums always yield valid variants (no invalid discriminants). +//! - `Vec` lengths stay moderate (at most a few KB) to keep tests fast. +//! - Internally-dependent sizes (for example, a length field that must match a +//! sibling `Vec`) are produced via `prop_map` so generated instances are +//! always consistent. + +// Note: `#[cfg(test)]` is applied at the module declaration in `src/msg/mod.rs` +// (`#[cfg(test)] pub(crate) mod roundtrip_strategies;`). We don't repeat it +// here; clippy's `duplicated_attributes` lint rejects that. +#![allow(dead_code)] // Helpers might be unused while tests are being added. + +use proptest::prelude::*; + +use crate::pack::{FileTime, Guid}; +use crate::types::flags::{ + Capabilities, FileAccessMask, HeaderFlags, SecurityMode, ShareCapabilities, ShareFlags, +}; +use crate::types::status::NtStatus; +use crate::types::{ + Command, CreditCharge, Dialect, FileId, MessageId, OplockLevel, SessionId, TreeId, +}; + +/// Max size (in bytes) used for generated `Vec` buffers across tests. +/// Kept small so a 256-case proptest run stays well under a second. +pub const MAX_BUFFER_BYTES: usize = 1024; + +/// Moderate buffer for structs that usually carry small bodies. +pub const MAX_SMALL_BUFFER_BYTES: usize = 256; + +/// Generate a `Vec` up to `max` bytes long (including zero). +pub fn bytes_up_to(max: usize) -> impl Strategy> { + prop::collection::vec(any::(), 0..=max) +} + +/// A standard moderate-length byte buffer. +pub fn arb_bytes() -> impl Strategy> { + bytes_up_to(MAX_BUFFER_BYTES) +} + +/// A smaller byte buffer, for sub-fields or tightly-nested structures. +pub fn arb_small_bytes() -> impl Strategy> { + bytes_up_to(MAX_SMALL_BUFFER_BYTES) +} + +/// Generate a valid UTF-16-encodable String, up to `max_chars` chars. +/// +/// Excludes unpaired surrogates (U+D800..=U+DFFF) because UTF-16 decoding +/// would reject any surrogate that isn't part of a valid pair. We use the +/// BMP-minus-surrogates range plus occasional supplementary characters, so +/// both one-code-unit and two-code-unit forms are covered. +pub fn arb_utf16_string(max_chars: usize) -> impl Strategy { + prop::collection::vec( + prop::char::range('\u{0000}', '\u{D7FF}') + .prop_union(prop::char::range('\u{E000}', '\u{FFFF}')) + .or(prop::char::range('\u{1_0000}', '\u{10_FFFF}')), + 0..=max_chars, + ) + .prop_map(|chars| chars.into_iter().collect()) +} + +// ── Primitive newtype strategies ──────────────────────────────────── + +pub fn arb_session_id() -> impl Strategy { + any::().prop_map(SessionId) +} + +pub fn arb_message_id() -> impl Strategy { + any::().prop_map(MessageId) +} + +pub fn arb_tree_id() -> impl Strategy { + any::().prop_map(TreeId) +} + +pub fn arb_credit_charge() -> impl Strategy { + any::().prop_map(CreditCharge) +} + +pub fn arb_file_id() -> impl Strategy { + (any::(), any::()).prop_map(|(persistent, volatile)| FileId { + persistent, + volatile, + }) +} + +pub fn arb_file_time() -> impl Strategy { + any::().prop_map(FileTime) +} + +pub fn arb_guid() -> impl Strategy { + (any::(), any::(), any::(), any::<[u8; 8]>()).prop_map( + |(data1, data2, data3, data4)| Guid { + data1, + data2, + data3, + data4, + }, + ) +} + +pub fn arb_nt_status() -> impl Strategy { + any::().prop_map(NtStatus) +} + +// ── Flags ──────────────────────────────────────────────────────────── + +pub fn arb_header_flags() -> impl Strategy { + any::().prop_map(HeaderFlags::new) +} + +pub fn arb_security_mode() -> impl Strategy { + any::().prop_map(SecurityMode::new) +} + +pub fn arb_capabilities() -> impl Strategy { + any::().prop_map(Capabilities::new) +} + +pub fn arb_share_flags() -> impl Strategy { + any::().prop_map(ShareFlags::new) +} + +pub fn arb_share_capabilities() -> impl Strategy { + any::().prop_map(ShareCapabilities::new) +} + +pub fn arb_file_access_mask() -> impl Strategy { + any::().prop_map(FileAccessMask::new) +} + +// ── Typed enums: only valid variants ──────────────────────────────── + +pub fn arb_oplock_level() -> impl Strategy { + prop_oneof![ + Just(OplockLevel::None), + Just(OplockLevel::LevelII), + Just(OplockLevel::Exclusive), + Just(OplockLevel::Batch), + Just(OplockLevel::Lease), + ] +} + +pub fn arb_dialect() -> impl Strategy { + prop_oneof![ + Just(Dialect::Smb2_0_2), + Just(Dialect::Smb2_1), + Just(Dialect::Smb3_0), + Just(Dialect::Smb3_0_2), + Just(Dialect::Smb3_1_1), + ] +} + +pub fn arb_share_type() -> impl Strategy { + use crate::msg::tree_connect::ShareType; + prop_oneof![ + Just(ShareType::Disk), + Just(ShareType::Pipe), + Just(ShareType::Print), + ] +} + +pub fn arb_impersonation_level() -> impl Strategy { + use crate::msg::create::ImpersonationLevel; + prop_oneof![ + Just(ImpersonationLevel::Anonymous), + Just(ImpersonationLevel::Identification), + Just(ImpersonationLevel::Impersonation), + Just(ImpersonationLevel::Delegate), + ] +} + +pub fn arb_create_disposition() -> impl Strategy { + use crate::msg::create::CreateDisposition; + prop_oneof![ + Just(CreateDisposition::FileSupersede), + Just(CreateDisposition::FileOpen), + Just(CreateDisposition::FileCreate), + Just(CreateDisposition::FileOpenIf), + Just(CreateDisposition::FileOverwrite), + Just(CreateDisposition::FileOverwriteIf), + ] +} + +pub fn arb_create_action() -> impl Strategy { + use crate::msg::create::CreateAction; + prop_oneof![ + Just(CreateAction::FileSuperseded), + Just(CreateAction::FileOpened), + Just(CreateAction::FileCreated), + Just(CreateAction::FileOverwritten), + ] +} + +pub fn arb_share_access() -> impl Strategy { + any::().prop_map(crate::msg::create::ShareAccess) +} + +pub fn arb_info_type() -> impl Strategy { + use crate::msg::query_info::InfoType; + prop_oneof![ + Just(InfoType::File), + Just(InfoType::Filesystem), + Just(InfoType::Security), + Just(InfoType::Quota), + ] +} + +pub fn arb_file_information_class( +) -> impl Strategy { + use crate::msg::query_directory::FileInformationClass; + prop_oneof![ + Just(FileInformationClass::FileDirectoryInformation), + Just(FileInformationClass::FileFullDirectoryInformation), + Just(FileInformationClass::FileBothDirectoryInformation), + Just(FileInformationClass::FileNamesInformation), + Just(FileInformationClass::FileIdBothDirectoryInformation), + Just(FileInformationClass::FileIdFullDirectoryInformation), + ] +} + +pub fn arb_command() -> impl Strategy { + prop_oneof![ + Just(Command::Negotiate), + Just(Command::SessionSetup), + Just(Command::Logoff), + Just(Command::TreeConnect), + Just(Command::TreeDisconnect), + Just(Command::Create), + Just(Command::Close), + Just(Command::Flush), + Just(Command::Read), + Just(Command::Write), + Just(Command::Lock), + Just(Command::Ioctl), + Just(Command::Cancel), + Just(Command::Echo), + Just(Command::QueryDirectory), + Just(Command::ChangeNotify), + Just(Command::QueryInfo), + Just(Command::SetInfo), + Just(Command::OplockBreak), + ] +} diff --git a/vendor/smb2/src/msg/session_setup.rs b/vendor/smb2/src/msg/session_setup.rs new file mode 100644 index 0000000..fb4ec88 --- /dev/null +++ b/vendor/smb2/src/msg/session_setup.rs @@ -0,0 +1,481 @@ +//! SMB2 SESSION_SETUP request and response (spec sections 2.2.5, 2.2.6). +//! +//! Session setup messages are used to establish an authenticated session +//! between the client and the server. The request carries a security token +//! (for example, SPNEGO/NTLM) and the response carries the server's reply token +//! along with session flags. + +use crate::error::Result; +use crate::msg::header::Header; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::flags::{Capabilities, SecurityMode}; +use crate::Error; + +// ── Session setup request flags ──────────────────────────────────────── + +/// Flags for the SESSION_SETUP request (1 byte, spec section 2.2.5). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct SessionSetupRequestFlags(pub u8); + +impl SessionSetupRequestFlags { + /// Bind an existing session to a new connection (SMB 3.x only). + pub const BINDING: u8 = 0x01; + + /// Returns `true` if the binding flag is set. + #[inline] + pub fn is_binding(&self) -> bool { + self.0 & Self::BINDING != 0 + } +} + +// ── Session flags (response) ─────────────────────────────────────────── + +/// Session flags returned in the SESSION_SETUP response (spec section 2.2.6). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct SessionFlags(pub u16); + +impl SessionFlags { + /// The client has been authenticated as a guest user. + pub const IS_GUEST: u16 = 0x0001; + /// The client has been authenticated as an anonymous user. + pub const IS_NULL: u16 = 0x0002; + /// The server requires encryption of messages on this session (SMB 3.x only). + pub const ENCRYPT_DATA: u16 = 0x0004; + + /// Returns `true` if the guest flag is set. + #[inline] + pub fn is_guest(&self) -> bool { + self.0 & Self::IS_GUEST != 0 + } + + /// Returns `true` if the null session flag is set. + #[inline] + pub fn is_null(&self) -> bool { + self.0 & Self::IS_NULL != 0 + } + + /// Returns `true` if the encrypt-data flag is set. + #[inline] + pub fn encrypt_data(&self) -> bool { + self.0 & Self::ENCRYPT_DATA != 0 + } +} + +// ── SessionSetupRequest ──────────────────────────────────────────────── + +/// SMB2 SESSION_SETUP request (spec section 2.2.5). +/// +/// Sent by the client to establish an authenticated session. The security +/// buffer carries a GSS/SPNEGO token (or other auth protocol token). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionSetupRequest { + /// Flags controlling the request (for example, session binding). + pub flags: SessionSetupRequestFlags, + /// Security mode indicating signing requirements. + pub security_mode: SecurityMode, + /// Client capabilities. + pub capabilities: Capabilities, + /// Channel field (reserved, must be 0). + pub channel: u32, + /// Previously established session identifier for reconnection. + pub previous_session_id: u64, + /// Security buffer containing the authentication token. + pub security_buffer: Vec, +} + +impl SessionSetupRequest { + pub const STRUCTURE_SIZE: u16 = 25; +} + +impl Pack for SessionSetupRequest { + fn pack(&self, cursor: &mut WriteCursor) { + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // Flags (1 byte) + cursor.write_u8(self.flags.0); + // SecurityMode (1 byte) + cursor.write_u8(self.security_mode.bits() as u8); + // Capabilities (4 bytes) + cursor.write_u32_le(self.capabilities.bits()); + // Channel (4 bytes) + cursor.write_u32_le(self.channel); + + // SecurityBufferOffset (2 bytes) -- offset from start of SMB2 header + let offset = (Header::SIZE + 24) as u16; // 24 = bytes before the buffer in this struct + cursor.write_u16_le(offset); + // SecurityBufferLength (2 bytes) + cursor.write_u16_le(self.security_buffer.len() as u16); + // PreviousSessionId (8 bytes) + cursor.write_u64_le(self.previous_session_id); + // Buffer (variable) + cursor.write_bytes(&self.security_buffer); + } +} + +impl Unpack for SessionSetupRequest { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid SessionSetupRequest structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // Flags (1 byte) + let flags = SessionSetupRequestFlags(cursor.read_u8()?); + // SecurityMode (1 byte) + let security_mode = SecurityMode::new(cursor.read_u8()? as u16); + // Capabilities (4 bytes) + let capabilities = Capabilities::new(cursor.read_u32_le()?); + // Channel (4 bytes) + let channel = cursor.read_u32_le()?; + // SecurityBufferOffset (2 bytes) -- we ignore, read sequentially + let _offset = cursor.read_u16_le()?; + // SecurityBufferLength (2 bytes) + let buffer_length = cursor.read_u16_le()? as usize; + // PreviousSessionId (8 bytes) + let previous_session_id = cursor.read_u64_le()?; + // Buffer (variable) + let security_buffer = if buffer_length > 0 { + cursor.read_bytes_bounded(buffer_length)?.to_vec() + } else { + Vec::new() + }; + + Ok(SessionSetupRequest { + flags, + security_mode, + capabilities, + channel, + previous_session_id, + security_buffer, + }) + } +} + +// ── SessionSetupResponse ─────────────────────────────────────────────── + +/// SMB2 SESSION_SETUP response (spec section 2.2.6). +/// +/// Sent by the server in response to a SESSION_SETUP request. Contains +/// session flags and a security buffer with the server's auth token. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionSetupResponse { + /// Flags indicating additional information about the session. + pub session_flags: SessionFlags, + /// Security buffer containing the server's authentication token. + pub security_buffer: Vec, +} + +impl SessionSetupResponse { + pub const STRUCTURE_SIZE: u16 = 9; +} + +impl Pack for SessionSetupResponse { + fn pack(&self, cursor: &mut WriteCursor) { + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // SessionFlags (2 bytes) + cursor.write_u16_le(self.session_flags.0); + // SecurityBufferOffset (2 bytes) -- offset from start of SMB2 header + let offset = (Header::SIZE + 8) as u16; // 8 = fixed part of response struct + cursor.write_u16_le(offset); + // SecurityBufferLength (2 bytes) + cursor.write_u16_le(self.security_buffer.len() as u16); + // Buffer (variable) + cursor.write_bytes(&self.security_buffer); + } +} + +impl Unpack for SessionSetupResponse { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid SessionSetupResponse structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // SessionFlags (2 bytes) + let session_flags = SessionFlags(cursor.read_u16_le()?); + // SecurityBufferOffset (2 bytes) + let _offset = cursor.read_u16_le()?; + // SecurityBufferLength (2 bytes) + let buffer_length = cursor.read_u16_le()? as usize; + // Buffer (variable) + let security_buffer = if buffer_length > 0 { + cursor.read_bytes_bounded(buffer_length)?.to_vec() + } else { + Vec::new() + }; + + Ok(SessionSetupResponse { + session_flags, + security_buffer, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── SessionSetupRequest tests ────────────────────────────────── + + #[test] + fn session_setup_request_roundtrip() { + let token = vec![0x60, 0x28, 0x06, 0x06, 0x2b, 0x06, 0x01, 0x05]; + let original = SessionSetupRequest { + flags: SessionSetupRequestFlags(0), + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + capabilities: Capabilities::new(Capabilities::DFS), + channel: 0, + previous_session_id: 0, + security_buffer: token.clone(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = SessionSetupRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.flags, original.flags); + assert_eq!(decoded.security_mode.bits(), original.security_mode.bits()); + assert_eq!(decoded.capabilities.bits(), original.capabilities.bits()); + assert_eq!(decoded.channel, 0); + assert_eq!(decoded.previous_session_id, 0); + assert_eq!(decoded.security_buffer, token); + } + + #[test] + fn session_setup_request_with_binding_flag() { + let original = SessionSetupRequest { + flags: SessionSetupRequestFlags(SessionSetupRequestFlags::BINDING), + security_mode: SecurityMode::new( + SecurityMode::SIGNING_ENABLED | SecurityMode::SIGNING_REQUIRED, + ), + capabilities: Capabilities::default(), + channel: 0, + previous_session_id: 0xDEAD_BEEF_CAFE_BABE, + security_buffer: vec![0xAA, 0xBB], + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = SessionSetupRequest::unpack(&mut r).unwrap(); + + assert!(decoded.flags.is_binding()); + assert!(decoded.security_mode.signing_enabled()); + assert!(decoded.security_mode.signing_required()); + assert_eq!(decoded.previous_session_id, 0xDEAD_BEEF_CAFE_BABE); + assert_eq!(decoded.security_buffer, vec![0xAA, 0xBB]); + } + + #[test] + fn session_setup_request_empty_buffer() { + let original = SessionSetupRequest { + flags: SessionSetupRequestFlags(0), + security_mode: SecurityMode::default(), + capabilities: Capabilities::default(), + channel: 0, + previous_session_id: 0, + security_buffer: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = SessionSetupRequest::unpack(&mut r).unwrap(); + + assert!(decoded.security_buffer.is_empty()); + } + + #[test] + fn session_setup_request_structure_size_field() { + let req = SessionSetupRequest { + flags: SessionSetupRequestFlags(0), + security_mode: SecurityMode::default(), + capabilities: Capabilities::default(), + channel: 0, + previous_session_id: 0, + security_buffer: vec![0x01], + }; + + let mut w = WriteCursor::new(); + req.pack(&mut w); + let bytes = w.into_inner(); + + // First 2 bytes are structure size = 25 + assert_eq!(bytes[0], 25); + assert_eq!(bytes[1], 0); + } + + #[test] + fn session_setup_request_wrong_structure_size() { + let mut buf = [0u8; 26]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + let mut cursor = ReadCursor::new(&buf); + let result = SessionSetupRequest::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + // ── SessionSetupResponse tests ───────────────────────────────── + + #[test] + fn session_setup_response_roundtrip() { + let token = vec![0xA1, 0x81, 0xB0, 0x30, 0x81, 0xAD]; + let original = SessionSetupResponse { + session_flags: SessionFlags(0), + security_buffer: token.clone(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = SessionSetupResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.session_flags, original.session_flags); + assert_eq!(decoded.security_buffer, token); + } + + #[test] + fn session_setup_response_with_flags() { + let original = SessionSetupResponse { + session_flags: SessionFlags(SessionFlags::IS_GUEST | SessionFlags::ENCRYPT_DATA), + security_buffer: vec![0x01, 0x02, 0x03], + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = SessionSetupResponse::unpack(&mut r).unwrap(); + + assert!(decoded.session_flags.is_guest()); + assert!(!decoded.session_flags.is_null()); + assert!(decoded.session_flags.encrypt_data()); + assert_eq!(decoded.security_buffer, vec![0x01, 0x02, 0x03]); + } + + #[test] + fn session_setup_response_null_session() { + let original = SessionSetupResponse { + session_flags: SessionFlags(SessionFlags::IS_NULL), + security_buffer: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = SessionSetupResponse::unpack(&mut r).unwrap(); + + assert!(decoded.session_flags.is_null()); + assert!(!decoded.session_flags.is_guest()); + assert!(decoded.security_buffer.is_empty()); + } + + #[test] + fn session_setup_response_structure_size_field() { + let resp = SessionSetupResponse { + session_flags: SessionFlags(0), + security_buffer: Vec::new(), + }; + + let mut w = WriteCursor::new(); + resp.pack(&mut w); + let bytes = w.into_inner(); + + // First 2 bytes are structure size = 9 + assert_eq!(bytes[0], 9); + assert_eq!(bytes[1], 0); + } + + #[test] + fn session_setup_response_wrong_structure_size() { + let mut buf = [0u8; 10]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + let mut cursor = ReadCursor::new(&buf); + let result = SessionSetupResponse::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::{arb_capabilities, arb_small_bytes}; + use proptest::prelude::*; + + proptest! { + #[test] + fn session_setup_request_pack_unpack( + flags_raw in any::(), + // SESSION_SETUP packs SecurityMode as a single byte, so only the + // low 8 bits survive the roundtrip. Generate u8 values to avoid + // producing inputs the encoder would never emit from a real caller. + security_mode_raw in any::(), + capabilities in arb_capabilities(), + channel in any::(), + previous_session_id in any::(), + security_buffer in arb_small_bytes(), + ) { + let original = SessionSetupRequest { + flags: SessionSetupRequestFlags(flags_raw), + security_mode: SecurityMode::new(security_mode_raw as u16), + capabilities, + channel, + previous_session_id, + security_buffer, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = SessionSetupRequest::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + + #[test] + fn session_setup_response_pack_unpack( + session_flags_raw in any::(), + security_buffer in arb_small_bytes(), + ) { + let original = SessionSetupResponse { + session_flags: SessionFlags(session_flags_raw), + security_buffer, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = SessionSetupResponse::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + } +} diff --git a/vendor/smb2/src/msg/set_info.rs b/vendor/smb2/src/msg/set_info.rs new file mode 100644 index 0000000..9690794 --- /dev/null +++ b/vendor/smb2/src/msg/set_info.rs @@ -0,0 +1,328 @@ +//! SMB2 SET_INFO request and response (spec sections 2.2.39, 2.2.40). +//! +//! Used to set file, filesystem, security, or quota information. +//! The request buffer contains the information to set, stored as raw bytes. +//! The response is a minimal 2-byte structure. + +use crate::error::Result; +use crate::msg::header::Header; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::FileId; +use crate::Error; + +// Re-use InfoType from query_info +pub use super::query_info::InfoType; + +// ── SetInfoRequest ─────────────────────────────────────────────────────── + +/// SMB2 SET_INFO request (spec section 2.2.39). +/// +/// Sent by the client to set information on a file, filesystem, +/// security descriptor, or quota. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SetInfoRequest { + /// The type of information being set. + pub info_type: InfoType, + /// The file information class (interpretation depends on `info_type`). + pub file_info_class: u8, + /// Additional information flags (for example, security information flags). + pub additional_information: u32, + /// Handle to the file or directory. + pub file_id: FileId, + /// Raw buffer containing the information to set. + pub buffer: Vec, +} + +impl SetInfoRequest { + pub const STRUCTURE_SIZE: u16 = 33; +} + +impl Pack for SetInfoRequest { + fn pack(&self, cursor: &mut WriteCursor) { + let start = cursor.position(); + + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // InfoType (1 byte) + cursor.write_u8(self.info_type as u8); + // FileInfoClass (1 byte) + cursor.write_u8(self.file_info_class); + // BufferLength (4 bytes) + cursor.write_u32_le(self.buffer.len() as u32); + // BufferOffset (2 bytes) -- placeholder + let offset_pos = cursor.position(); + cursor.write_u16_le(0); + // Reserved (2 bytes) + cursor.write_u16_le(0); + // AdditionalInformation (4 bytes) + cursor.write_u32_le(self.additional_information); + // FileId (16 bytes) + cursor.write_u64_le(self.file_id.persistent); + cursor.write_u64_le(self.file_id.volatile); + + // Buffer (variable) + if !self.buffer.is_empty() { + // Offset is from the beginning of the SMB2 header per spec. + let buf_offset = Header::SIZE + (cursor.position() - start); + cursor.write_bytes(&self.buffer); + cursor.set_u16_le_at(offset_pos, buf_offset as u16); + } + } +} + +impl Unpack for SetInfoRequest { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let start = cursor.position(); + + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid SetInfoRequest structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // InfoType (1 byte) + let info_type = InfoType::try_from(cursor.read_u8()?)?; + // FileInfoClass (1 byte) + let file_info_class = cursor.read_u8()?; + // BufferLength (4 bytes) + let buffer_length = cursor.read_u32_le()? as usize; + // BufferOffset (2 bytes) + let buf_offset = cursor.read_u16_le()? as usize; + // Reserved (2 bytes) + let _reserved = cursor.read_u16_le()?; + // AdditionalInformation (4 bytes) + let additional_information = cursor.read_u32_le()?; + // FileId (16 bytes) + let persistent = cursor.read_u64_le()?; + let volatile = cursor.read_u64_le()?; + let file_id = FileId { + persistent, + volatile, + }; + + // Read buffer + // Offset on the wire is from beginning of SMB2 header. + let buffer = if buffer_length > 0 { + let current = cursor.position(); + let body_offset = buf_offset.saturating_sub(Header::SIZE); + let target = start + body_offset; + if target > current { + cursor.skip(target - current)?; + } + cursor.read_bytes_bounded(buffer_length)?.to_vec() + } else { + Vec::new() + }; + + Ok(SetInfoRequest { + info_type, + file_info_class, + additional_information, + file_id, + buffer, + }) + } +} + +// ── SetInfoResponse ────────────────────────────────────────────────────── + +/// SMB2 SET_INFO response (spec section 2.2.40). +/// +/// A minimal response indicating that the set operation succeeded. +/// Contains only the 2-byte StructureSize field. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SetInfoResponse; + +impl SetInfoResponse { + pub const STRUCTURE_SIZE: u16 = 2; +} + +impl Pack for SetInfoResponse { + fn pack(&self, cursor: &mut WriteCursor) { + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + } +} + +impl Unpack for SetInfoResponse { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid SetInfoResponse structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + Ok(SetInfoResponse) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── SetInfoRequest tests ───────────────────────────────────────── + + #[test] + fn set_info_request_roundtrip_with_buffer() { + let info_data = vec![0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x02, 0x03, 0x04]; + + let original = SetInfoRequest { + info_type: InfoType::File, + file_info_class: 0x04, // FileBasicInformation + additional_information: 0, + file_id: FileId { + persistent: 0xAAAA_BBBB_CCCC_DDDD, + volatile: 0x1111_2222_3333_4444, + }, + buffer: info_data.clone(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = SetInfoRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.info_type, InfoType::File); + assert_eq!(decoded.file_info_class, 0x04); + assert_eq!(decoded.additional_information, 0); + assert_eq!(decoded.file_id, original.file_id); + assert_eq!(decoded.buffer, info_data); + } + + #[test] + fn set_info_request_security_info() { + let sd_data = vec![0x01, 0x00, 0x04, 0x80, 0x00, 0x00, 0x00, 0x00]; + + let original = SetInfoRequest { + info_type: InfoType::Security, + file_info_class: 0, + additional_information: 0x04, // DACL_SECURITY_INFORMATION + file_id: FileId { + persistent: 42, + volatile: 99, + }, + buffer: sd_data.clone(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = SetInfoRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.info_type, InfoType::Security); + assert_eq!(decoded.additional_information, 0x04); + assert_eq!(decoded.buffer, sd_data); + } + + #[test] + fn set_info_request_structure_size() { + let req = SetInfoRequest { + info_type: InfoType::File, + file_info_class: 0, + additional_information: 0, + file_id: FileId::default(), + buffer: vec![0x01], + }; + + let mut w = WriteCursor::new(); + req.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(bytes[0], 33); + assert_eq!(bytes[1], 0); + } + + #[test] + fn set_info_request_wrong_structure_size() { + let mut buf = vec![0u8; 48]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + let mut cursor = ReadCursor::new(&buf); + let result = SetInfoRequest::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + // ── SetInfoResponse tests ──────────────────────────────────────── + + #[test] + fn set_info_response_roundtrip() { + let original = SetInfoResponse; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // Only 2 bytes + assert_eq!(bytes.len(), 2); + assert_eq!(bytes, [0x02, 0x00]); + + let mut r = ReadCursor::new(&bytes); + let decoded = SetInfoResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded, SetInfoResponse); + } + + #[test] + fn set_info_response_wrong_structure_size() { + let bytes = [0x04, 0x00]; + let mut cursor = ReadCursor::new(&bytes); + let result = SetInfoResponse::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + #[test] + fn set_info_response_too_short() { + let bytes = [0x02]; + let mut cursor = ReadCursor::new(&bytes); + let result = SetInfoResponse::unpack(&mut cursor); + assert!(result.is_err()); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::{arb_bytes, arb_file_id, arb_info_type}; + use proptest::prelude::*; + + proptest! { + #[test] + fn set_info_request_pack_unpack( + info_type in arb_info_type(), + file_info_class in any::(), + additional_information in any::(), + file_id in arb_file_id(), + buffer in arb_bytes(), + ) { + let original = SetInfoRequest { + info_type, + file_info_class, + additional_information, + file_id, + buffer, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = SetInfoRequest::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + } + } +} diff --git a/vendor/smb2/src/msg/transform.rs b/vendor/smb2/src/msg/transform.rs new file mode 100644 index 0000000..07bd8d2 --- /dev/null +++ b/vendor/smb2/src/msg/transform.rs @@ -0,0 +1,452 @@ +//! SMB2 TRANSFORM_HEADER and COMPRESSION_TRANSFORM_HEADER +//! (MS-SMB2 sections 2.2.41, 2.2.42). +//! +//! These headers wrap (encrypted or compressed) SMB2 messages. They are NOT +//! SMB2 messages themselves -- they precede the actual message data. + +use crate::error::Result; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::SessionId; +use crate::Error; + +// ── Transform header protocol IDs ────────────────────────────────────── + +/// Protocol identifier for the encryption transform header (0xFD 'S' 'M' 'B'). +/// Note: this is NOT the normal SMB2 protocol ID (0xFE). +pub const TRANSFORM_PROTOCOL_ID: [u8; 4] = [0xFD, b'S', b'M', b'B']; + +/// Protocol identifier for the compression transform header (0xFC 'S' 'M' 'B'). +pub const COMPRESSION_PROTOCOL_ID: [u8; 4] = [0xFC, b'S', b'M', b'B']; + +// ── Transform header flags ──────────────────────────────────────────── + +/// The message is encrypted. +pub const SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED: u16 = 0x0001; + +// ── CompressionAlgorithm values ──────────────────────────────────────── + +/// No compression. +pub const COMPRESSION_ALGORITHM_NONE: u16 = 0x0000; + +/// LZNT1 compression. +pub const COMPRESSION_ALGORITHM_LZNT1: u16 = 0x0001; + +/// LZ77 compression. +pub const COMPRESSION_ALGORITHM_LZ77: u16 = 0x0002; + +/// LZ77 with Huffman encoding. +pub const COMPRESSION_ALGORITHM_LZ77_HUFFMAN: u16 = 0x0003; + +/// Pattern_V1 compression. +pub const COMPRESSION_ALGORITHM_PATTERN_V1: u16 = 0x0004; + +/// LZ4 compression. +pub const COMPRESSION_ALGORITHM_LZ4: u16 = 0x0005; + +// ── Compression flags ────────────────────────────────────────────────── + +/// No compression flags. +pub const SMB2_COMPRESSION_FLAG_NONE: u16 = 0x0000; + +/// Chained compression (multiple segments). +pub const SMB2_COMPRESSION_FLAG_CHAINED: u16 = 0x0001; + +// ── TransformHeader ──────────────────────────────────────────────────── + +/// SMB2 TRANSFORM_HEADER (MS-SMB2 section 2.2.41). +/// +/// An encryption wrapper that precedes an encrypted SMB2 message. +/// The total header is 52 bytes: +/// - ProtocolId (4 bytes, must be 0xFD 'S' 'M' 'B') +/// - Signature (16 bytes) +/// - Nonce (16 bytes -- first 11 bytes used for AES-CCM, first 12 for AES-GCM) +/// - OriginalMessageSize (4 bytes) +/// - Reserved (2 bytes) +/// - Flags (2 bytes) +/// - SessionId (8 bytes) +/// +/// The encrypted message data follows immediately after this header. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TransformHeader { + /// 16-byte AES signature over the encrypted message. + pub signature: [u8; 16], + /// 16-byte nonce. Only the first 11 bytes are used for AES-CCM, + /// and the first 12 bytes for AES-GCM. The remaining bytes must be zero. + pub nonce: [u8; 16], + /// Size of the original (unencrypted) SMB2 message in bytes. + pub original_message_size: u32, + /// Flags for the transform header. Use + /// `SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED`. + pub flags: u16, + /// Session identifier for the encrypted message. + pub session_id: SessionId, +} + +impl TransformHeader { + /// Total header size in bytes (52). + pub const SIZE: usize = 52; +} + +impl Pack for TransformHeader { + fn pack(&self, cursor: &mut WriteCursor) { + // ProtocolId (4 bytes) + cursor.write_bytes(&TRANSFORM_PROTOCOL_ID); + // Signature (16 bytes) + cursor.write_bytes(&self.signature); + // Nonce (16 bytes) + cursor.write_bytes(&self.nonce); + // OriginalMessageSize (4 bytes) + cursor.write_u32_le(self.original_message_size); + // Reserved (2 bytes) + cursor.write_u16_le(0); + // Flags (2 bytes) + cursor.write_u16_le(self.flags); + // SessionId (8 bytes) + cursor.write_u64_le(self.session_id.0); + } +} + +impl Unpack for TransformHeader { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + // ProtocolId (4 bytes) + let proto = cursor.read_bytes(4)?; + if proto != TRANSFORM_PROTOCOL_ID { + return Err(Error::invalid_data(format!( + "invalid transform header protocol ID: expected {:02X?}, got {:02X?}", + TRANSFORM_PROTOCOL_ID, proto + ))); + } + + // Signature (16 bytes) + let sig_bytes = cursor.read_bytes(16)?; + let mut signature = [0u8; 16]; + signature.copy_from_slice(sig_bytes); + + // Nonce (16 bytes) + let nonce_bytes = cursor.read_bytes(16)?; + let mut nonce = [0u8; 16]; + nonce.copy_from_slice(nonce_bytes); + + // OriginalMessageSize (4 bytes) + let original_message_size = cursor.read_u32_le()?; + + // Reserved (2 bytes) + let _reserved = cursor.read_u16_le()?; + + // Flags (2 bytes) + let flags = cursor.read_u16_le()?; + + // SessionId (8 bytes) + let session_id = SessionId(cursor.read_u64_le()?); + + Ok(TransformHeader { + signature, + nonce, + original_message_size, + flags, + session_id, + }) + } +} + +// ── CompressionTransformHeader ───────────────────────────────────────── + +/// SMB2 COMPRESSION_TRANSFORM_HEADER (MS-SMB2 section 2.2.42). +/// +/// A compression wrapper that precedes a compressed SMB2 message. +/// This implements the unchained variant (Flags = 0) only. The total +/// header is 16 bytes: +/// - ProtocolId (4 bytes, must be 0xFC 'S' 'M' 'B') +/// - OriginalCompressedSegmentSize (4 bytes) +/// - CompressionAlgorithm (2 bytes) +/// - Flags (2 bytes) +/// - Offset (4 bytes) -- offset from the end of this header to the +/// start of compressed data +/// +/// Note: The chained variant (Flags = SMB2_COMPRESSION_FLAG_CHAINED) +/// interprets the last 4 bytes as Length instead of Offset. Chained +/// compression is deferred to a future implementation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CompressionTransformHeader { + /// Size of the original uncompressed data segment. + pub original_compressed_segment_size: u32, + /// The compression algorithm used. + pub compression_algorithm: u16, + /// Compression flags. Currently only unchained (0x0000) is supported. + pub flags: u16, + /// For unchained: offset from end of this header to the start of + /// compressed data. For chained: length of the original uncompressed + /// segment (chained is not yet implemented). + pub offset_or_length: u32, +} + +impl CompressionTransformHeader { + /// Total header size in bytes (16). + pub const SIZE: usize = 16; +} + +impl Pack for CompressionTransformHeader { + fn pack(&self, cursor: &mut WriteCursor) { + // ProtocolId (4 bytes) + cursor.write_bytes(&COMPRESSION_PROTOCOL_ID); + // OriginalCompressedSegmentSize (4 bytes) + cursor.write_u32_le(self.original_compressed_segment_size); + // CompressionAlgorithm (2 bytes) + cursor.write_u16_le(self.compression_algorithm); + // Flags (2 bytes) + cursor.write_u16_le(self.flags); + // Offset/Length (4 bytes) + cursor.write_u32_le(self.offset_or_length); + } +} + +impl Unpack for CompressionTransformHeader { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + // ProtocolId (4 bytes) + let proto = cursor.read_bytes(4)?; + if proto != COMPRESSION_PROTOCOL_ID { + return Err(Error::invalid_data(format!( + "invalid compression transform header protocol ID: expected {:02X?}, got {:02X?}", + COMPRESSION_PROTOCOL_ID, proto + ))); + } + + // OriginalCompressedSegmentSize (4 bytes) + let original_compressed_segment_size = cursor.read_u32_le()?; + + // CompressionAlgorithm (2 bytes) + let compression_algorithm = cursor.read_u16_le()?; + + // Flags (2 bytes) + let flags = cursor.read_u16_le()?; + + // Offset/Length (4 bytes) + let offset_or_length = cursor.read_u32_le()?; + + Ok(CompressionTransformHeader { + original_compressed_segment_size, + compression_algorithm, + flags, + offset_or_length, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── TransformHeader tests ───────────────────────────────────────── + + #[test] + fn transform_header_roundtrip() { + let mut nonce = [0u8; 16]; + nonce[0..12].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); + + let original = TransformHeader { + signature: [ + 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, + 0x99, 0x00, + ], + nonce, + original_message_size: 1024, + flags: SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED, + session_id: SessionId(0xDEAD_BEEF_CAFE_FACE), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(bytes.len(), TransformHeader::SIZE); + + let mut r = ReadCursor::new(&bytes); + let decoded = TransformHeader::unpack(&mut r).unwrap(); + + assert_eq!(decoded.signature, original.signature); + assert_eq!(decoded.nonce, original.nonce); + assert_eq!(decoded.original_message_size, 1024); + assert_eq!(decoded.flags, SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED); + assert_eq!(decoded.session_id, SessionId(0xDEAD_BEEF_CAFE_FACE)); + } + + #[test] + fn transform_header_protocol_id_is_0xfd() { + let original = TransformHeader { + signature: [0u8; 16], + nonce: [0u8; 16], + original_message_size: 0, + flags: 0, + session_id: SessionId(0), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // First 4 bytes must be 0xFD 'S' 'M' 'B', NOT 0xFE + assert_eq!(bytes[0], 0xFD); + assert_eq!(bytes[1], b'S'); + assert_eq!(bytes[2], b'M'); + assert_eq!(bytes[3], b'B'); + assert_ne!(bytes[0], 0xFE, "transform header must use 0xFD, not 0xFE"); + } + + #[test] + fn transform_header_wrong_protocol_id() { + let mut buf = [0u8; TransformHeader::SIZE]; + // Use the normal SMB2 protocol ID (0xFE) instead of 0xFD + buf[0..4].copy_from_slice(&[0xFE, b'S', b'M', b'B']); + + let mut cursor = ReadCursor::new(&buf); + let result = TransformHeader::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("protocol ID"), "error was: {err}"); + } + + // ── CompressionTransformHeader tests ────────────────────────────── + + #[test] + fn compression_transform_header_roundtrip_unchained() { + let original = CompressionTransformHeader { + original_compressed_segment_size: 4096, + compression_algorithm: COMPRESSION_ALGORITHM_LZ77, + flags: SMB2_COMPRESSION_FLAG_NONE, + offset_or_length: 64, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + assert_eq!(bytes.len(), CompressionTransformHeader::SIZE); + + let mut r = ReadCursor::new(&bytes); + let decoded = CompressionTransformHeader::unpack(&mut r).unwrap(); + + assert_eq!(decoded.original_compressed_segment_size, 4096); + assert_eq!(decoded.compression_algorithm, COMPRESSION_ALGORITHM_LZ77); + assert_eq!(decoded.flags, SMB2_COMPRESSION_FLAG_NONE); + assert_eq!(decoded.offset_or_length, 64); + } + + #[test] + fn compression_transform_header_protocol_id_is_0xfc() { + let original = CompressionTransformHeader { + original_compressed_segment_size: 0, + compression_algorithm: COMPRESSION_ALGORITHM_NONE, + flags: SMB2_COMPRESSION_FLAG_NONE, + offset_or_length: 0, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // First 4 bytes must be 0xFC 'S' 'M' 'B' + assert_eq!(bytes[0], 0xFC); + assert_eq!(bytes[1], b'S'); + assert_eq!(bytes[2], b'M'); + assert_eq!(bytes[3], b'B'); + assert_ne!( + bytes[0], 0xFE, + "compression transform header must use 0xFC, not 0xFE" + ); + } + + #[test] + fn compression_transform_header_wrong_protocol_id() { + let mut buf = [0u8; CompressionTransformHeader::SIZE]; + // Use wrong protocol ID + buf[0..4].copy_from_slice(&[0xFE, b'S', b'M', b'B']); + + let mut cursor = ReadCursor::new(&buf); + let result = CompressionTransformHeader::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("protocol ID"), "error was: {err}"); + } + + #[test] + fn compression_transform_header_lz77_huffman() { + let original = CompressionTransformHeader { + original_compressed_segment_size: 8192, + compression_algorithm: COMPRESSION_ALGORITHM_LZ77_HUFFMAN, + flags: SMB2_COMPRESSION_FLAG_NONE, + offset_or_length: 128, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = CompressionTransformHeader::unpack(&mut r).unwrap(); + + assert_eq!( + decoded.compression_algorithm, + COMPRESSION_ALGORITHM_LZ77_HUFFMAN + ); + assert_eq!(decoded.original_compressed_segment_size, 8192); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::arb_session_id; + use proptest::prelude::*; + + proptest! { + #[test] + fn transform_header_pack_unpack( + signature in any::<[u8; 16]>(), + nonce in any::<[u8; 16]>(), + original_message_size in any::(), + flags in any::(), + session_id in arb_session_id(), + ) { + let original = TransformHeader { + signature, + nonce, + original_message_size, + flags, + session_id, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + prop_assert_eq!(bytes.len(), TransformHeader::SIZE); + + let mut r = ReadCursor::new(&bytes); + let decoded = TransformHeader::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + + #[test] + fn compression_transform_header_pack_unpack( + original_compressed_segment_size in any::(), + compression_algorithm in any::(), + flags in any::(), + offset_or_length in any::(), + ) { + let original = CompressionTransformHeader { + original_compressed_segment_size, + compression_algorithm, + flags, + offset_or_length, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + prop_assert_eq!(bytes.len(), CompressionTransformHeader::SIZE); + + let mut r = ReadCursor::new(&bytes); + let decoded = CompressionTransformHeader::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + } +} diff --git a/vendor/smb2/src/msg/tree_connect.rs b/vendor/smb2/src/msg/tree_connect.rs new file mode 100644 index 0000000..f9882c1 --- /dev/null +++ b/vendor/smb2/src/msg/tree_connect.rs @@ -0,0 +1,477 @@ +//! SMB2 TREE_CONNECT request and response (spec sections 2.2.9, 2.2.10). +//! +//! Tree connect messages establish access to a share on the server. +//! The request contains a UTF-16LE encoded share path (for example, +//! `\\server\share`), and the response contains share metadata such as +//! the share type, flags, capabilities, and maximal access rights. + +use crate::error::Result; +use crate::msg::header::Header; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::flags::{ShareCapabilities, ShareFlags}; +use crate::Error; + +// ── Share type ───────────────────────────────────────────────────────── + +/// Type of share being accessed (spec section 2.2.10). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum ShareType { + /// Physical disk share. + Disk = 0x01, + /// Named pipe share. + Pipe = 0x02, + /// Printer share. + Print = 0x03, +} + +impl ShareType { + /// Try to convert a raw `u8` to a `ShareType`. + pub fn try_from_u8(val: u8) -> Result { + match val { + 0x01 => Ok(ShareType::Disk), + 0x02 => Ok(ShareType::Pipe), + 0x03 => Ok(ShareType::Print), + other => Err(Error::invalid_data(format!( + "invalid share type: 0x{:02X}", + other + ))), + } + } +} + +// ── Tree connect request flags ───────────────────────────────────────── + +/// Flags for the TREE_CONNECT request (spec section 2.2.9, SMB 3.1.1 only). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct TreeConnectRequestFlags(pub u16); + +impl TreeConnectRequestFlags { + /// Client has previously connected to the specified cluster share. + pub const CLUSTER_RECONNECT: u16 = 0x0001; + /// Client can handle synchronous share redirects. + pub const REDIRECT_TO_OWNER: u16 = 0x0002; + /// Tree connect request extension is present. + pub const EXTENSION_PRESENT: u16 = 0x0004; +} + +// ── TreeConnectRequest ───────────────────────────────────────────────── + +/// SMB2 TREE_CONNECT request (spec section 2.2.9). +/// +/// Sent by the client to request access to a particular share on the +/// server. The path is a Unicode string in the form `\\server\share`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TreeConnectRequest { + /// Flags controlling the request (SMB 3.1.1 only, otherwise 0). + pub flags: TreeConnectRequestFlags, + /// Full share path name in UTF-8 (encoded as UTF-16LE on the wire). + pub path: String, +} + +impl TreeConnectRequest { + pub const STRUCTURE_SIZE: u16 = 9; +} + +impl Pack for TreeConnectRequest { + fn pack(&self, cursor: &mut WriteCursor) { + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // Flags/Reserved (2 bytes) + cursor.write_u16_le(self.flags.0); + + // Compute path length in UTF-16LE bytes + let path_u16: Vec = self.path.encode_utf16().collect(); + let path_byte_len = path_u16.len() * 2; + + // PathOffset (2 bytes) -- offset from start of SMB2 header + let offset = (Header::SIZE + 8) as u16; // 8 = fixed part of this struct + cursor.write_u16_le(offset); + // PathLength (2 bytes) + cursor.write_u16_le(path_byte_len as u16); + // Buffer: path in UTF-16LE + cursor.write_utf16_le(&self.path); + } +} + +impl Unpack for TreeConnectRequest { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid TreeConnectRequest structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // Flags/Reserved (2 bytes) + let flags = TreeConnectRequestFlags(cursor.read_u16_le()?); + // PathOffset (2 bytes) -- we ignore, read sequentially + let _offset = cursor.read_u16_le()?; + // PathLength (2 bytes) + let path_length = cursor.read_u16_le()? as usize; + // Buffer: path in UTF-16LE + if path_length > ReadCursor::MAX_UNPACK_BUFFER { + return Err(Error::invalid_data(format!( + "buffer size {} exceeds maximum {} bytes", + path_length, + ReadCursor::MAX_UNPACK_BUFFER + ))); + } + let path = cursor.read_utf16_le(path_length)?; + + Ok(TreeConnectRequest { flags, path }) + } +} + +// ── TreeConnectResponse ──────────────────────────────────────────────── + +/// SMB2 TREE_CONNECT response (spec section 2.2.10). +/// +/// Sent by the server when a TREE_CONNECT request is processed +/// successfully. Contains share metadata. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TreeConnectResponse { + /// The type of share being accessed (disk, pipe, or print). + pub share_type: ShareType, + /// Properties for this share. + pub share_flags: ShareFlags, + /// Capabilities for this share. + pub capabilities: ShareCapabilities, + /// Maximum access rights for the connecting user. + pub maximal_access: u32, +} + +impl TreeConnectResponse { + pub const STRUCTURE_SIZE: u16 = 16; +} + +impl Pack for TreeConnectResponse { + fn pack(&self, cursor: &mut WriteCursor) { + // StructureSize (2 bytes) + cursor.write_u16_le(Self::STRUCTURE_SIZE); + // ShareType (1 byte) + cursor.write_u8(self.share_type as u8); + // Reserved (1 byte) + cursor.write_u8(0); + // ShareFlags (4 bytes) + cursor.write_u32_le(self.share_flags.bits()); + // Capabilities (4 bytes) + cursor.write_u32_le(self.capabilities.bits()); + // MaximalAccess (4 bytes) + cursor.write_u32_le(self.maximal_access); + } +} + +impl Unpack for TreeConnectResponse { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + // StructureSize (2 bytes) + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid TreeConnectResponse structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + // ShareType (1 byte) + let share_type = ShareType::try_from_u8(cursor.read_u8()?)?; + // Reserved (1 byte) + let _reserved = cursor.read_u8()?; + // ShareFlags (4 bytes) + let share_flags = ShareFlags::new(cursor.read_u32_le()?); + // Capabilities (4 bytes) + let capabilities = ShareCapabilities::new(cursor.read_u32_le()?); + // MaximalAccess (4 bytes) + let maximal_access = cursor.read_u32_le()?; + + Ok(TreeConnectResponse { + share_type, + share_flags, + capabilities, + maximal_access, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── TreeConnectRequest tests ─────────────────────────────────── + + #[test] + fn tree_connect_request_roundtrip() { + let original = TreeConnectRequest { + flags: TreeConnectRequestFlags::default(), + path: r"\\server\share".to_string(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = TreeConnectRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.flags, original.flags); + assert_eq!(decoded.path, original.path); + } + + #[test] + fn tree_connect_request_with_utf16_path() { + let path = r"\\myserver.example.com\IPC$"; + let original = TreeConnectRequest { + flags: TreeConnectRequestFlags::default(), + path: path.to_string(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = TreeConnectRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.path, path); + } + + #[test] + fn tree_connect_request_structure_size_field() { + let req = TreeConnectRequest { + flags: TreeConnectRequestFlags::default(), + path: r"\\s\d".to_string(), + }; + + let mut w = WriteCursor::new(); + req.pack(&mut w); + let bytes = w.into_inner(); + + // First 2 bytes are structure size = 9 + assert_eq!(u16::from_le_bytes([bytes[0], bytes[1]]), 9); + } + + #[test] + fn tree_connect_request_wrong_structure_size() { + let mut buf = [0u8; 20]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + let mut cursor = ReadCursor::new(&buf); + let result = TreeConnectRequest::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + #[test] + fn tree_connect_request_with_flags() { + let original = TreeConnectRequest { + flags: TreeConnectRequestFlags(TreeConnectRequestFlags::CLUSTER_RECONNECT), + path: r"\\s\d".to_string(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = TreeConnectRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.flags.0, TreeConnectRequestFlags::CLUSTER_RECONNECT); + } + + // ── TreeConnectResponse tests ────────────────────────────────── + + #[test] + fn tree_connect_response_roundtrip_disk() { + let original = TreeConnectResponse { + share_type: ShareType::Disk, + share_flags: ShareFlags::new(ShareFlags::DFS | ShareFlags::ACCESS_BASED_DIRECTORY_ENUM), + capabilities: ShareCapabilities::new(ShareCapabilities::DFS), + maximal_access: 0x001F_01FF, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = TreeConnectResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.share_type, ShareType::Disk); + assert_eq!(decoded.share_flags.bits(), original.share_flags.bits()); + assert_eq!(decoded.capabilities.bits(), original.capabilities.bits()); + assert_eq!(decoded.maximal_access, 0x001F_01FF); + } + + #[test] + fn tree_connect_response_roundtrip_pipe() { + let original = TreeConnectResponse { + share_type: ShareType::Pipe, + share_flags: ShareFlags::default(), + capabilities: ShareCapabilities::default(), + maximal_access: 0x0012_019F, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = TreeConnectResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.share_type, ShareType::Pipe); + assert_eq!(decoded.maximal_access, 0x0012_019F); + } + + #[test] + fn tree_connect_response_roundtrip_print() { + let original = TreeConnectResponse { + share_type: ShareType::Print, + share_flags: ShareFlags::new(ShareFlags::ENCRYPT_DATA), + capabilities: ShareCapabilities::new( + ShareCapabilities::CONTINUOUS_AVAILABILITY | ShareCapabilities::CLUSTER, + ), + maximal_access: 0, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = TreeConnectResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.share_type, ShareType::Print); + assert!(decoded.share_flags.contains(ShareFlags::ENCRYPT_DATA)); + assert!(decoded + .capabilities + .contains(ShareCapabilities::CONTINUOUS_AVAILABILITY)); + assert!(decoded.capabilities.contains(ShareCapabilities::CLUSTER)); + } + + #[test] + fn tree_connect_response_structure_size_field() { + let resp = TreeConnectResponse { + share_type: ShareType::Disk, + share_flags: ShareFlags::default(), + capabilities: ShareCapabilities::default(), + maximal_access: 0, + }; + + let mut w = WriteCursor::new(); + resp.pack(&mut w); + let bytes = w.into_inner(); + + // First 2 bytes are structure size = 16 + assert_eq!(u16::from_le_bytes([bytes[0], bytes[1]]), 16); + // Total packed size: 2 + 1 + 1 + 4 + 4 + 4 = 16 + assert_eq!(bytes.len(), 16); + } + + #[test] + fn tree_connect_response_wrong_structure_size() { + let mut buf = [0u8; 16]; + buf[0..2].copy_from_slice(&99u16.to_le_bytes()); + buf[2] = 0x01; // valid share type + let mut cursor = ReadCursor::new(&buf); + let result = TreeConnectResponse::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + #[test] + fn tree_connect_response_invalid_share_type() { + let mut buf = [0u8; 16]; + buf[0..2].copy_from_slice(&16u16.to_le_bytes()); + buf[2] = 0xFF; // invalid share type + let mut cursor = ReadCursor::new(&buf); + let result = TreeConnectResponse::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("share type"), "error was: {err}"); + } + + // Roundtrip property tests live in `roundtrip_props` at file end. + + #[test] + fn tree_connect_response_known_bytes() { + // Known bytes from smb-rs test: share_type=Disk, share_flags=0x00000800, + // capabilities=0, maximal_access=0x001f01ff + let bytes: Vec = vec![ + 0x10, 0x00, // StructureSize = 16 + 0x01, // ShareType = Disk + 0x00, // Reserved + 0x00, 0x08, 0x00, 0x00, // ShareFlags = 0x00000800 + 0x00, 0x00, 0x00, 0x00, // Capabilities = 0 + 0xFF, 0x01, 0x1F, 0x00, // MaximalAccess = 0x001f01ff + ]; + + let mut r = ReadCursor::new(&bytes); + let decoded = TreeConnectResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.share_type, ShareType::Disk); + assert!(decoded + .share_flags + .contains(ShareFlags::ACCESS_BASED_DIRECTORY_ENUM)); + assert_eq!(decoded.maximal_access, 0x001F_01FF); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::{ + arb_share_capabilities, arb_share_flags, arb_share_type, arb_utf16_string, + }; + use proptest::prelude::*; + + proptest! { + #[test] + fn tree_connect_request_pack_unpack( + flags_raw in any::(), + // Path is sent as UTF-16LE. Generate strings that survive that + // encoding cleanly (no unpaired surrogates). + path in arb_utf16_string(128), + ) { + let original = TreeConnectRequest { + flags: TreeConnectRequestFlags(flags_raw), + path, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = TreeConnectRequest::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + + #[test] + fn tree_connect_response_pack_unpack( + share_type in arb_share_type(), + share_flags in arb_share_flags(), + capabilities in arb_share_capabilities(), + maximal_access in any::(), + ) { + let original = TreeConnectResponse { + share_type, + share_flags, + capabilities, + maximal_access, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = TreeConnectResponse::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + } +} diff --git a/vendor/smb2/src/msg/tree_disconnect.rs b/vendor/smb2/src/msg/tree_disconnect.rs new file mode 100644 index 0000000..8bb24f8 --- /dev/null +++ b/vendor/smb2/src/msg/tree_disconnect.rs @@ -0,0 +1,43 @@ +//! SMB2 TREE_DISCONNECT request and response (spec sections 2.2.11, 2.2.12). +//! +//! Tree disconnect messages request and confirm disconnection from a share. +//! Both request and response contain only a StructureSize field and a +//! reserved field, for a total of 4 bytes each. + +super::trivial_message! { + /// SMB2 TREE_DISCONNECT request (spec section 2.2.11). + /// + /// Sent by the client to request that the tree connect specified in the + /// TreeId within the SMB2 header be disconnected. + /// Contains only StructureSize (2 bytes) and Reserved (2 bytes). + pub struct TreeDisconnectRequest; +} + +super::trivial_message! { + /// SMB2 TREE_DISCONNECT response (spec section 2.2.12). + /// + /// Sent by the server to confirm that a TREE_DISCONNECT request was processed. + /// Contains only StructureSize (2 bytes) and Reserved (2 bytes). + pub struct TreeDisconnectResponse; +} + +#[cfg(test)] +mod tests { + use super::*; + + super::super::trivial_message_tests!( + TreeDisconnectRequest, + tree_disconnect_request_known_bytes, + tree_disconnect_request_roundtrip, + tree_disconnect_request_wrong_structure_size, + tree_disconnect_request_too_short + ); + + super::super::trivial_message_tests!( + TreeDisconnectResponse, + tree_disconnect_response_known_bytes, + tree_disconnect_response_roundtrip, + tree_disconnect_response_wrong_structure_size, + tree_disconnect_response_too_short + ); +} diff --git a/vendor/smb2/src/msg/write.rs b/vendor/smb2/src/msg/write.rs new file mode 100644 index 0000000..7ca5915 --- /dev/null +++ b/vendor/smb2/src/msg/write.rs @@ -0,0 +1,446 @@ +//! SMB2 WRITE Request and Response (MS-SMB2 sections 2.2.21, 2.2.22). +//! +//! The WRITE request writes data to a file or named pipe. +//! The response reports how many bytes were written. + +use crate::error::Result; +use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::types::FileId; +use crate::Error; + +/// Write flag: server performs write-through (SMB 2.1+). +pub const SMB2_WRITEFLAG_WRITE_THROUGH: u32 = 0x0000_0001; + +/// Write flag: file buffering is not performed (SMB 3.0.2+). +pub const SMB2_WRITEFLAG_WRITE_UNBUFFERED: u32 = 0x0000_0002; + +/// SMB2 WRITE Request (MS-SMB2 section 2.2.21). +/// +/// Sent by the client to write data to a file. The fixed portion is 49 bytes +/// (StructureSize says 49 regardless of the variable buffer length): +/// - StructureSize (2 bytes, must be 49) +/// - DataOffset (2 bytes) +/// - Length (4 bytes) +/// - Offset (8 bytes) +/// - FileId (16 bytes) +/// - Channel (4 bytes) +/// - RemainingBytes (4 bytes) +/// - WriteChannelInfoOffset (2 bytes) +/// - WriteChannelInfoLength (2 bytes) +/// - Flags (4 bytes) +/// - Buffer (variable, Length bytes) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct WriteRequest { + /// Offset from the beginning of the SMB2 header to the write data. + pub data_offset: u16, + /// File offset to start writing at. + pub offset: u64, + /// File handle to write to. + pub file_id: FileId, + /// Channel for RDMA operations (typically 0 = SMB2_CHANNEL_NONE). + pub channel: u32, + /// Remaining bytes in a multi-part write. + pub remaining_bytes: u32, + /// Write channel info offset (typically 0). + pub write_channel_info_offset: u16, + /// Write channel info length (typically 0). + pub write_channel_info_length: u16, + /// Flags for the write operation. + pub flags: u32, + /// The data to write. + pub data: Vec, +} + +impl WriteRequest { + pub const STRUCTURE_SIZE: u16 = 49; +} + +impl Pack for WriteRequest { + fn pack(&self, cursor: &mut WriteCursor) { + cursor.write_u16_le(Self::STRUCTURE_SIZE); + cursor.write_u16_le(self.data_offset); + cursor.write_u32_le(self.data.len() as u32); // Length + cursor.write_u64_le(self.offset); + cursor.write_u64_le(self.file_id.persistent); + cursor.write_u64_le(self.file_id.volatile); + cursor.write_u32_le(self.channel); + cursor.write_u32_le(self.remaining_bytes); + cursor.write_u16_le(self.write_channel_info_offset); + cursor.write_u16_le(self.write_channel_info_length); + cursor.write_u32_le(self.flags); + + // Buffer: write the data (may be empty for zero-length writes). + // Per StructureSize=49 contract, at least 1 byte is implied. + if self.data.is_empty() { + cursor.write_u8(0); + } else { + cursor.write_bytes(&self.data); + } + } +} + +impl Unpack for WriteRequest { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid WriteRequest structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + let data_offset = cursor.read_u16_le()?; + let length = cursor.read_u32_le()?; + let offset = cursor.read_u64_le()?; + let persistent = cursor.read_u64_le()?; + let volatile = cursor.read_u64_le()?; + let channel = cursor.read_u32_le()?; + let remaining_bytes = cursor.read_u32_le()?; + let write_channel_info_offset = cursor.read_u16_le()?; + let write_channel_info_length = cursor.read_u16_le()?; + let flags = cursor.read_u32_le()?; + + let data = if length > 0 { + cursor.read_bytes_bounded(length as usize)?.to_vec() + } else { + // Skip the minimum 1-byte buffer + cursor.skip(1)?; + Vec::new() + }; + + Ok(WriteRequest { + data_offset, + offset, + file_id: FileId { + persistent, + volatile, + }, + channel, + remaining_bytes, + write_channel_info_offset, + write_channel_info_length, + flags, + data, + }) + } +} + +/// SMB2 WRITE Response (MS-SMB2 section 2.2.22). +/// +/// Sent by the server to confirm a write. The structure is 17 bytes: +/// - StructureSize (2 bytes, must be 17) +/// - Reserved (2 bytes) +/// - Count (4 bytes) +/// - Remaining (4 bytes) +/// - WriteChannelInfoOffset (2 bytes) +/// - WriteChannelInfoLength (2 bytes) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct WriteResponse { + /// Number of bytes written. + pub count: u32, + /// Reserved remaining field (must be 0). + pub remaining: u32, + /// Reserved write channel info offset (must be 0). + pub write_channel_info_offset: u16, + /// Reserved write channel info length (must be 0). + pub write_channel_info_length: u16, +} + +impl WriteResponse { + pub const STRUCTURE_SIZE: u16 = 17; +} + +impl Pack for WriteResponse { + fn pack(&self, cursor: &mut WriteCursor) { + cursor.write_u16_le(Self::STRUCTURE_SIZE); + cursor.write_u16_le(0); // Reserved + cursor.write_u32_le(self.count); + cursor.write_u32_le(self.remaining); + cursor.write_u16_le(self.write_channel_info_offset); + cursor.write_u16_le(self.write_channel_info_length); + } +} + +impl Unpack for WriteResponse { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let structure_size = cursor.read_u16_le()?; + if structure_size != Self::STRUCTURE_SIZE { + return Err(Error::invalid_data(format!( + "invalid WriteResponse structure size: expected {}, got {}", + Self::STRUCTURE_SIZE, + structure_size + ))); + } + + let _reserved = cursor.read_u16_le()?; + let count = cursor.read_u32_le()?; + let remaining = cursor.read_u32_le()?; + let write_channel_info_offset = cursor.read_u16_le()?; + let write_channel_info_length = cursor.read_u16_le()?; + + Ok(WriteResponse { + count, + remaining, + write_channel_info_offset, + write_channel_info_length, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── WriteRequest tests ───────────────────────────────────────── + + #[test] + fn write_request_roundtrip() { + let original = WriteRequest { + data_offset: 0x70, // 64 (header) + 48 (fixed body) = 112 = 0x70 + offset: 0x2000, + file_id: FileId { + persistent: 0xAAAA_BBBB_CCCC_DDDD, + volatile: 0x1111_2222_3333_4444, + }, + channel: 0, + remaining_bytes: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + flags: SMB2_WRITEFLAG_WRITE_THROUGH, + data: vec![0x48, 0x65, 0x6C, 0x6C, 0x6F], // "Hello" + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // Fixed: 48 bytes + 5 bytes data = 53 bytes + assert_eq!(bytes.len(), 53); + + let mut r = ReadCursor::new(&bytes); + let decoded = WriteRequest::unpack(&mut r).unwrap(); + + assert_eq!(decoded.data_offset, original.data_offset); + assert_eq!(decoded.offset, original.offset); + assert_eq!(decoded.file_id, original.file_id); + assert_eq!(decoded.channel, original.channel); + assert_eq!(decoded.remaining_bytes, original.remaining_bytes); + assert_eq!(decoded.flags, original.flags); + assert_eq!(decoded.data, original.data); + } + + #[test] + fn write_request_empty_data_roundtrip() { + let original = WriteRequest { + data_offset: 0x70, + offset: 0, + file_id: FileId { + persistent: 1, + volatile: 2, + }, + channel: 0, + remaining_bytes: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + flags: 0, + data: Vec::new(), + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // Fixed: 48 bytes + 1-byte minimum buffer = 49 bytes + assert_eq!(bytes.len(), 49); + + let mut r = ReadCursor::new(&bytes); + let decoded = WriteRequest::unpack(&mut r).unwrap(); + + assert!(decoded.data.is_empty()); + assert_eq!(decoded.file_id, original.file_id); + } + + #[test] + fn write_request_wrong_structure_size() { + let mut buf = [0u8; 49]; + buf[0..2].copy_from_slice(&48u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = WriteRequest::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } + + #[test] + fn write_request_known_bytes() { + let mut buf = Vec::new(); + // StructureSize = 49 + buf.extend_from_slice(&49u16.to_le_bytes()); + // DataOffset = 0x70 + buf.extend_from_slice(&0x70u16.to_le_bytes()); + // Length = 2 + buf.extend_from_slice(&2u32.to_le_bytes()); + // Offset = 0 + buf.extend_from_slice(&0u64.to_le_bytes()); + // FileId persistent = 0x10 + buf.extend_from_slice(&0x10u64.to_le_bytes()); + // FileId volatile = 0x20 + buf.extend_from_slice(&0x20u64.to_le_bytes()); + // Channel = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // RemainingBytes = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // WriteChannelInfoOffset = 0 + buf.extend_from_slice(&0u16.to_le_bytes()); + // WriteChannelInfoLength = 0 + buf.extend_from_slice(&0u16.to_le_bytes()); + // Flags = WRITE_THROUGH + buf.extend_from_slice(&1u32.to_le_bytes()); + // Buffer = [0xAA, 0xBB] + buf.extend_from_slice(&[0xAA, 0xBB]); + + let mut cursor = ReadCursor::new(&buf); + let req = WriteRequest::unpack(&mut cursor).unwrap(); + + assert_eq!(req.data_offset, 0x70); + assert_eq!(req.file_id.persistent, 0x10); + assert_eq!(req.file_id.volatile, 0x20); + assert_eq!(req.flags, SMB2_WRITEFLAG_WRITE_THROUGH); + assert_eq!(req.data, vec![0xAA, 0xBB]); + } + + // ── WriteResponse tests ──────────────────────────────────────── + + #[test] + fn write_response_roundtrip() { + let original = WriteResponse { + count: 65536, + remaining: 0, + write_channel_info_offset: 0, + write_channel_info_length: 0, + }; + + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + // 2 + 2 + 4 + 4 + 2 + 2 = 16 bytes + assert_eq!(bytes.len(), 16); + + let mut r = ReadCursor::new(&bytes); + let decoded = WriteResponse::unpack(&mut r).unwrap(); + + assert_eq!(decoded.count, original.count); + assert_eq!(decoded.remaining, original.remaining); + assert_eq!( + decoded.write_channel_info_offset, + original.write_channel_info_offset + ); + assert_eq!( + decoded.write_channel_info_length, + original.write_channel_info_length + ); + } + + #[test] + fn write_response_known_bytes() { + let mut buf = Vec::new(); + // StructureSize = 17 + buf.extend_from_slice(&17u16.to_le_bytes()); + // Reserved = 0 + buf.extend_from_slice(&0u16.to_le_bytes()); + // Count = 1024 + buf.extend_from_slice(&1024u32.to_le_bytes()); + // Remaining = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // WriteChannelInfoOffset = 0 + buf.extend_from_slice(&0u16.to_le_bytes()); + // WriteChannelInfoLength = 0 + buf.extend_from_slice(&0u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let resp = WriteResponse::unpack(&mut cursor).unwrap(); + + assert_eq!(resp.count, 1024); + assert_eq!(resp.remaining, 0); + } + + #[test] + fn write_response_wrong_structure_size() { + let mut buf = [0u8; 16]; + buf[0..2].copy_from_slice(&16u16.to_le_bytes()); + + let mut cursor = ReadCursor::new(&buf); + let result = WriteResponse::unpack(&mut cursor); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("structure size"), "error was: {err}"); + } +} + +#[cfg(test)] +mod roundtrip_props { + use super::*; + use crate::msg::roundtrip_strategies::{arb_bytes, arb_file_id}; + use proptest::prelude::*; + + proptest! { + #[test] + fn write_request_pack_unpack( + data_offset in any::(), + offset in any::(), + file_id in arb_file_id(), + channel in any::(), + remaining_bytes in any::(), + write_channel_info_offset in any::(), + write_channel_info_length in any::(), + flags in any::(), + data in arb_bytes(), + ) { + let original = WriteRequest { + data_offset, + offset, + file_id, + channel, + remaining_bytes, + write_channel_info_offset, + write_channel_info_length, + flags, + data, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = WriteRequest::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + + #[test] + fn write_response_pack_unpack( + count in any::(), + remaining in any::(), + write_channel_info_offset in any::(), + write_channel_info_length in any::(), + ) { + let original = WriteResponse { + count, + remaining, + write_channel_info_offset, + write_channel_info_length, + }; + let mut w = WriteCursor::new(); + original.pack(&mut w); + let bytes = w.into_inner(); + + let mut r = ReadCursor::new(&bytes); + let decoded = WriteResponse::unpack(&mut r).unwrap(); + prop_assert_eq!(decoded, original); + prop_assert!(r.is_empty()); + } + } +} diff --git a/vendor/smb2/src/pack/CLAUDE.md b/vendor/smb2/src/pack/CLAUDE.md new file mode 100644 index 0000000..3956de1 --- /dev/null +++ b/vendor/smb2/src/pack/CLAUDE.md @@ -0,0 +1,45 @@ +# Pack -- binary serialization primitives + +Cursor-based binary reader/writer for SMB2 wire format. Hand-rolled, no proc macros. + +## Key files + +| File | Purpose | +|---|---| +| `mod.rs` | `ReadCursor`, `WriteCursor`, `Pack`/`Unpack` traits, primitive read/write methods | +| `guid.rs` | GUID pack/unpack with mixed-endian layout | +| `filetime.rs` | Windows FILETIME (100ns ticks since 1601-01-01) to/from `SystemTime` | + +## Core types + +- **`ReadCursor<'a>`**: Reads from `&[u8]` with position tracking. Returns `Error` on buffer overrun (no panics). All reads are little-endian. +- **`WriteCursor`**: Writes into a growable `Vec`. Supports backpatching (`set_u16_le_at`, `set_u32_le_at`) for length fields written before their values are known. `align_to(n)` pads with zeros to n-byte boundary. +- **`Pack` trait**: `fn pack(&self, cursor: &mut WriteCursor)` -- serialize to binary. +- **`Unpack` trait**: `fn unpack(cursor: &mut ReadCursor) -> Result` -- deserialize from binary. + +## GUID mixed-endian layout + +Windows GUIDs have a mixed-endian wire format: +- `data1` (u32): little-endian +- `data2` (u16): little-endian +- `data3` (u16): little-endian +- `data4` ([u8; 8]): raw bytes (no endian conversion) + +This matches the COM/DCOM convention. Not the same as RFC 4122 UUID byte order. + +## FileTime conversion + +Windows FILETIME: 100-nanosecond intervals since 1601-01-01 00:00:00 UTC. +Unix epoch: 1970-01-01 00:00:00 UTC. +Offset: 11,644,473,600 seconds (116,444,736,000,000,000 ticks). + +## Key decisions + +- **Hand-rolled instead of proc macros**: Full control over wire format details (offsets, alignment, backpatching). Easier to debug. No build-time dependency. +- **`MAX_UNPACK_BUFFER` (16 MB)**: `read_bytes_bounded` refuses allocations larger than 16 MB. Prevents OOM from malicious packets claiming huge lengths. + +## Gotchas + +- **Everything is little-endian**: Except TCP framing (see transport module). ReadCursor/WriteCursor only do LE. +- **UTF-16LE byte length must be even**: `read_utf16_le` returns an error on odd byte counts. +- **Backpatching requires placeholder**: Write a zero first, then `set_u32_le_at` to overwrite once the real value is known. Common pattern for length-prefixed fields. diff --git a/vendor/smb2/src/pack/filetime.rs b/vendor/smb2/src/pack/filetime.rs new file mode 100644 index 0000000..7414bc7 --- /dev/null +++ b/vendor/smb2/src/pack/filetime.rs @@ -0,0 +1,175 @@ +//! Windows FILETIME type for SMB2. +//! +//! A FILETIME is a 64-bit value representing 100-nanosecond intervals +//! since 1601-01-01 00:00:00 UTC. + +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use super::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::error::Result; + +/// Difference between the Windows epoch (1601-01-01) and Unix epoch (1970-01-01) +/// in 100-nanosecond intervals. +const EPOCH_DIFF_100NS: u64 = 116_444_736_000_000_000; + +/// Windows FILETIME: 100-nanosecond intervals since 1601-01-01 00:00:00 UTC. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct FileTime( + /// The raw 100-nanosecond tick count. + pub u64, +); + +impl FileTime { + /// A zero filetime, meaning "not set" or "unknown". + pub const ZERO: Self = Self(0); + + /// Convert a [`SystemTime`] to a `FileTime`. + /// + /// Uses the Unix epoch offset (116,444,736,000,000,000 intervals of + /// 100 ns) to translate between the two epoch origins. + pub fn from_system_time(t: SystemTime) -> Self { + match t.duration_since(UNIX_EPOCH) { + Ok(dur) => { + let intervals = dur.as_nanos() / 100; + Self(intervals as u64 + EPOCH_DIFF_100NS) + } + Err(e) => { + // Time is before Unix epoch. The duration tells us how far before. + let before = e.duration(); + let intervals = before.as_nanos() / 100; + // If the pre-Unix time is still after the Windows epoch, compute it. + Self(EPOCH_DIFF_100NS.saturating_sub(intervals as u64)) + } + } + } + + /// Convert this `FileTime` to a [`SystemTime`]. + /// + /// Returns `None` if the filetime represents a date before the Unix epoch, + /// since [`SystemTime`] cannot represent dates before that. + pub fn to_system_time(self) -> Option { + if self.0 < EPOCH_DIFF_100NS { + return None; + } + let intervals_since_unix = self.0 - EPOCH_DIFF_100NS; + let nanos = (intervals_since_unix as u128) * 100; + let dur = Duration::new( + (nanos / 1_000_000_000) as u64, + (nanos % 1_000_000_000) as u32, + ); + Some(UNIX_EPOCH + dur) + } +} + +impl Pack for FileTime { + fn pack(&self, cursor: &mut WriteCursor) { + cursor.write_u64_le(self.0); + } +} + +impl Unpack for FileTime { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let val = cursor.read_u64_le()?; + Ok(Self(val)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn zero_filetime() { + assert_eq!(FileTime::ZERO, FileTime(0)); + } + + #[test] + fn pack_zero() { + let mut w = WriteCursor::new(); + FileTime::ZERO.pack(&mut w); + assert_eq!(w.as_bytes(), &[0u8; 8]); + } + + #[test] + fn unpack_zero() { + let bytes = [0u8; 8]; + let mut r = ReadCursor::new(&bytes); + let ft = FileTime::unpack(&mut r).unwrap(); + assert_eq!(ft, FileTime::ZERO); + } + + #[test] + fn known_value_2024_01_01() { + // 2024-01-01 00:00:00 UTC = FileTime(133_485_408_000_000_000) + // (Unix timestamp 1_704_067_200 * 10_000_000 + 116_444_736_000_000_000) + let expected_raw: u64 = 133_485_408_000_000_000; + let ft = FileTime(expected_raw); + + // Pack and verify roundtrip + let mut w = WriteCursor::new(); + ft.pack(&mut w); + let mut r = ReadCursor::new(w.as_bytes()); + let unpacked = FileTime::unpack(&mut r).unwrap(); + assert_eq!(unpacked, ft); + + // Verify SystemTime conversion + // 2024-01-01 00:00:00 UTC = Unix timestamp 1_704_067_200 + let st = ft.to_system_time().unwrap(); + let unix_dur = st.duration_since(UNIX_EPOCH).unwrap(); + assert_eq!(unix_dur.as_secs(), 1_704_067_200); + assert_eq!(unix_dur.subsec_nanos(), 0); + } + + #[test] + fn from_system_time_roundtrip() { + // Use a known Unix timestamp: 2024-01-01 00:00:00 UTC + let unix_secs = 1_704_067_200u64; + let st = UNIX_EPOCH + Duration::from_secs(unix_secs); + let ft = FileTime::from_system_time(st); + assert_eq!(ft.0, 133_485_408_000_000_000); + + let st2 = ft.to_system_time().unwrap(); + let dur = st2.duration_since(UNIX_EPOCH).unwrap(); + assert_eq!(dur.as_secs(), unix_secs); + } + + #[test] + fn pre_unix_epoch_returns_none() { + // A FILETIME value that represents a date before 1970-01-01 + let ft = FileTime(EPOCH_DIFF_100NS - 1); + assert!(ft.to_system_time().is_none()); + + // Zero is also before Unix epoch + assert!(FileTime::ZERO.to_system_time().is_none()); + } + + #[test] + fn unix_epoch_exactly() { + let ft = FileTime(EPOCH_DIFF_100NS); + let st = ft.to_system_time().unwrap(); + assert_eq!(st, UNIX_EPOCH); + } + + #[test] + fn from_system_time_unix_epoch() { + let ft = FileTime::from_system_time(UNIX_EPOCH); + assert_eq!(ft.0, EPOCH_DIFF_100NS); + } + + #[test] + fn pack_unpack_roundtrip() { + let ft = FileTime(133_476_576_000_000_000); + let mut w = WriteCursor::new(); + ft.pack(&mut w); + let mut r = ReadCursor::new(w.as_bytes()); + let unpacked = FileTime::unpack(&mut r).unwrap(); + assert_eq!(unpacked, ft); + } + + #[test] + fn unpack_insufficient_bytes() { + let bytes = [0u8; 4]; // need 8 + let mut r = ReadCursor::new(&bytes); + assert!(FileTime::unpack(&mut r).is_err()); + } +} diff --git a/vendor/smb2/src/pack/guid.rs b/vendor/smb2/src/pack/guid.rs new file mode 100644 index 0000000..91e5f2e --- /dev/null +++ b/vendor/smb2/src/pack/guid.rs @@ -0,0 +1,176 @@ +//! GUID (Globally Unique Identifier) type for SMB2. +//! +//! GUIDs follow the mixed-endian layout defined in MS-DTYP section 2.3.4: +//! - Bytes 0-3: `data1` (`u32`, little-endian) +//! - Bytes 4-5: `data2` (`u16`, little-endian) +//! - Bytes 6-7: `data3` (`u16`, little-endian) +//! - Bytes 8-15: `data4` (8 raw bytes, big-endian order) + +use std::fmt; + +use super::{Pack, ReadCursor, Unpack, WriteCursor}; +use crate::error::Result; + +/// A 128-bit GUID in mixed-endian wire format (MS-DTYP 2.3.4). +/// +/// With the `serde` feature on, the JSON form mirrors the in-memory +/// field shape (`{data1, data2, data3, data4}`), **not** the wire byte +/// order — the wire layout is mixed-endian and round-tripping it through +/// JSON would just be confusing. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct Guid { + /// First component (bytes 0-3, little-endian on wire). + pub data1: u32, + /// Second component (bytes 4-5, little-endian on wire). + pub data2: u16, + /// Third component (bytes 6-7, little-endian on wire). + pub data3: u16, + /// Fourth component (bytes 8-15, raw byte order on wire). + pub data4: [u8; 8], +} + +impl Guid { + /// The NULL GUID: `{00000000-0000-0000-0000-000000000000}`. + pub const ZERO: Self = Self { + data1: 0, + data2: 0, + data3: 0, + data4: [0; 8], + }; +} + +impl Pack for Guid { + fn pack(&self, cursor: &mut WriteCursor) { + cursor.write_u32_le(self.data1); + cursor.write_u16_le(self.data2); + cursor.write_u16_le(self.data3); + cursor.write_bytes(&self.data4); + } +} + +impl Unpack for Guid { + fn unpack(cursor: &mut ReadCursor<'_>) -> Result { + let data1 = cursor.read_u32_le()?; + let data2 = cursor.read_u16_le()?; + let data3 = cursor.read_u16_le()?; + let raw = cursor.read_bytes(8)?; + let mut data4 = [0u8; 8]; + data4.copy_from_slice(raw); + Ok(Self { + data1, + data2, + data3, + data4, + }) + } +} + +impl fmt::Display for Guid { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{{{:08x}-{:04x}-{:04x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}}}", + self.data1, + self.data2, + self.data3, + self.data4[0], + self.data4[1], + self.data4[2], + self.data4[3], + self.data4[4], + self.data4[5], + self.data4[6], + self.data4[7], + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn unpack_null_guid() { + let bytes = [0u8; 16]; + let mut cursor = ReadCursor::new(&bytes); + let guid = Guid::unpack(&mut cursor).unwrap(); + assert_eq!(guid, Guid::ZERO); + } + + #[test] + fn pack_null_guid() { + let mut cursor = WriteCursor::new(); + Guid::ZERO.pack(&mut cursor); + assert_eq!(cursor.as_bytes(), &[0u8; 16]); + } + + #[test] + fn roundtrip_known_guid() { + let guid = Guid { + data1: 0x6BA7B810, + data2: 0x9DAD, + data3: 0x11D1, + data4: [0x80, 0xB4, 0x00, 0xC0, 0x4F, 0xD4, 0x30, 0xC8], + }; + + let mut w = WriteCursor::new(); + guid.pack(&mut w); + let mut r = ReadCursor::new(w.as_bytes()); + let unpacked = Guid::unpack(&mut r).unwrap(); + assert_eq!(unpacked, guid); + } + + #[test] + fn display_format() { + let guid = Guid { + data1: 0x6BA7B810, + data2: 0x9DAD, + data3: 0x11D1, + data4: [0x80, 0xB4, 0x00, 0xC0, 0x4F, 0xD4, 0x30, 0xC8], + }; + assert_eq!(guid.to_string(), "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}"); + } + + #[test] + fn display_null_guid() { + assert_eq!( + Guid::ZERO.to_string(), + "{00000000-0000-0000-0000-000000000000}" + ); + } + + #[test] + fn mixed_endian_byte_ordering() { + // Build a GUID with known values and verify the wire bytes directly. + let guid = Guid { + data1: 0x04030201, + data2: 0x0605, + data3: 0x0807, + data4: [0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10], + }; + + let mut w = WriteCursor::new(); + guid.pack(&mut w); + let bytes = w.as_bytes(); + + // data1: u32 LE -> 01 02 03 04 + assert_eq!(&bytes[0..4], &[0x01, 0x02, 0x03, 0x04]); + // data2: u16 LE -> 05 06 + assert_eq!(&bytes[4..6], &[0x05, 0x06]); + // data3: u16 LE -> 07 08 + assert_eq!(&bytes[6..8], &[0x07, 0x08]); + // data4: raw bytes -> 09 0A 0B 0C 0D 0E 0F 10 + assert_eq!( + &bytes[8..16], + &[0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10] + ); + } + + #[test] + fn unpack_insufficient_bytes() { + let bytes = [0u8; 10]; // need 16 + let mut cursor = ReadCursor::new(&bytes); + assert!(Guid::unpack(&mut cursor).is_err()); + } +} diff --git a/vendor/smb2/src/pack/mod.rs b/vendor/smb2/src/pack/mod.rs new file mode 100644 index 0000000..017fa9b --- /dev/null +++ b/vendor/smb2/src/pack/mod.rs @@ -0,0 +1,649 @@ +//! Binary serialization/deserialization primitives for SMB2. +//! +//! Provides [`ReadCursor`] and [`WriteCursor`] for reading and writing +//! little-endian binary data, plus [`Pack`] and [`Unpack`] traits for +//! structured types. +//! +//! Most users don't need this module directly -- use [`SmbClient`](crate::SmbClient) +//! for high-level file operations. + +pub mod filetime; +pub mod guid; + +pub use filetime::FileTime; +pub use guid::Guid; + +use crate::error::Result; +use crate::Error; + +/// Trait for types that can serialize themselves into binary format. +pub trait Pack: Send + Sync { + /// Write this value into the cursor. + fn pack(&self, cursor: &mut WriteCursor); +} + +/// Trait for types that can deserialize themselves from binary format. +pub trait Unpack: Sized { + /// Read a value from the cursor, advancing its position. + fn unpack(cursor: &mut ReadCursor<'_>) -> Result; +} + +// --------------------------------------------------------------------------- +// ReadCursor +// --------------------------------------------------------------------------- + +/// A cursor for reading little-endian binary data from a byte slice. +/// +/// Tracks the current read position and returns errors on buffer overruns +/// rather than panicking. +pub struct ReadCursor<'a> { + data: &'a [u8], + pos: usize, +} + +impl<'a> ReadCursor<'a> { + /// Create a new read cursor starting at position 0. + pub fn new(data: &'a [u8]) -> Self { + Self { data, pos: 0 } + } + + /// Read a single byte. + pub fn read_u8(&mut self) -> Result { + self.ensure(1)?; + let val = self.data[self.pos]; + self.pos += 1; + Ok(val) + } + + /// Read a little-endian `u16`. + pub fn read_u16_le(&mut self) -> Result { + let bytes = self.read_array::<2>()?; + Ok(u16::from_le_bytes(bytes)) + } + + /// Read a little-endian `u32`. + pub fn read_u32_le(&mut self) -> Result { + let bytes = self.read_array::<4>()?; + Ok(u32::from_le_bytes(bytes)) + } + + /// Read a little-endian `u64`. + pub fn read_u64_le(&mut self) -> Result { + let bytes = self.read_array::<8>()?; + Ok(u64::from_le_bytes(bytes)) + } + + /// Read a little-endian `u128`. + pub fn read_u128_le(&mut self) -> Result { + let bytes = self.read_array::<16>()?; + Ok(u128::from_le_bytes(bytes)) + } + + /// Read exactly `n` bytes, returning a sub-slice. + pub fn read_bytes(&mut self, n: usize) -> Result<&'a [u8]> { + self.ensure(n)?; + let slice = &self.data[self.pos..self.pos + n]; + self.pos += n; + Ok(slice) + } + + /// Read `byte_len` bytes of UTF-16LE data and decode to a [`String`]. + /// + /// `byte_len` must be even (each code unit is 2 bytes). + pub fn read_utf16_le(&mut self, byte_len: usize) -> Result { + if byte_len % 2 != 0 { + return Err(Error::invalid_data(format!( + "UTF-16LE byte length must be even, got {}", + byte_len + ))); + } + let raw = self.read_bytes(byte_len)?; + let code_units: Vec = raw + .chunks_exact(2) + .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])) + .collect(); + String::from_utf16(&code_units) + .map_err(|_| Error::invalid_data("invalid UTF-16LE encoding")) + } + + /// Skip `n` bytes without reading them. + pub fn skip(&mut self, n: usize) -> Result<()> { + self.ensure(n)?; + self.pos += n; + Ok(()) + } + + /// Return the number of bytes remaining. + pub fn remaining(&self) -> usize { + self.data.len() - self.pos + } + + /// Return the current byte position. + pub fn position(&self) -> usize { + self.pos + } + + /// Return `true` if no bytes remain. + pub fn is_empty(&self) -> bool { + self.remaining() == 0 + } + + /// Maximum buffer size we'll allocate from untrusted data (16 MB). + pub const MAX_UNPACK_BUFFER: usize = 16 * 1024 * 1024; + + /// Read `n` bytes, but refuse if `n` exceeds [`Self::MAX_UNPACK_BUFFER`]. + pub fn read_bytes_bounded(&mut self, n: usize) -> Result<&'a [u8]> { + if n > Self::MAX_UNPACK_BUFFER { + return Err(Error::invalid_data(format!( + "buffer size {} exceeds maximum {} bytes", + n, + Self::MAX_UNPACK_BUFFER + ))); + } + self.read_bytes(n) + } + + // -- private helpers -- + + fn ensure(&self, n: usize) -> Result<()> { + if self.remaining() < n { + Err(Error::invalid_data(format!( + "need {} bytes but only {} remain at offset {}", + n, + self.remaining(), + self.pos + ))) + } else { + Ok(()) + } + } + + fn read_array(&mut self) -> Result<[u8; N]> { + self.ensure(N)?; + let mut arr = [0u8; N]; + arr.copy_from_slice(&self.data[self.pos..self.pos + N]); + self.pos += N; + Ok(arr) + } +} + +// --------------------------------------------------------------------------- +// WriteCursor +// --------------------------------------------------------------------------- + +/// A cursor for writing little-endian binary data into a growable buffer. +pub struct WriteCursor { + buf: Vec, +} + +impl WriteCursor { + /// Create an empty write cursor. + pub fn new() -> Self { + Self { buf: Vec::new() } + } + + /// Create a write cursor with pre-allocated capacity. + pub fn with_capacity(cap: usize) -> Self { + Self { + buf: Vec::with_capacity(cap), + } + } + + /// Write a single byte. + pub fn write_u8(&mut self, val: u8) { + self.buf.push(val); + } + + /// Write a little-endian `u16`. + pub fn write_u16_le(&mut self, val: u16) { + self.buf.extend_from_slice(&val.to_le_bytes()); + } + + /// Write a little-endian `u32`. + pub fn write_u32_le(&mut self, val: u32) { + self.buf.extend_from_slice(&val.to_le_bytes()); + } + + /// Write a little-endian `u64`. + pub fn write_u64_le(&mut self, val: u64) { + self.buf.extend_from_slice(&val.to_le_bytes()); + } + + /// Write a little-endian `u128`. + pub fn write_u128_le(&mut self, val: u128) { + self.buf.extend_from_slice(&val.to_le_bytes()); + } + + /// Write a raw byte slice. + pub fn write_bytes(&mut self, data: &[u8]) { + self.buf.extend_from_slice(data); + } + + /// Encode a string as UTF-16LE and write the bytes. + pub fn write_utf16_le(&mut self, s: &str) { + for code_unit in s.encode_utf16() { + self.buf.extend_from_slice(&code_unit.to_le_bytes()); + } + } + + /// Write `n` zero bytes. + pub fn write_zeros(&mut self, n: usize) { + self.buf.resize(self.buf.len() + n, 0); + } + + /// Pad with zero bytes until the position is a multiple of `alignment`. + /// + /// Does nothing if `alignment` is 0 or 1, or if already aligned. + pub fn align_to(&mut self, alignment: usize) { + if alignment <= 1 { + return; + } + let remainder = self.buf.len() % alignment; + if remainder != 0 { + self.write_zeros(alignment - remainder); + } + } + + /// Return the current write position (number of bytes written so far). + pub fn position(&self) -> usize { + self.buf.len() + } + + /// Overwrite a `u16` at a previous position (little-endian). + /// + /// # Panics + /// + /// Panics if `pos + 2 > self.position()`. + pub fn set_u16_le_at(&mut self, pos: usize, val: u16) { + self.buf[pos..pos + 2].copy_from_slice(&val.to_le_bytes()); + } + + /// Overwrite a `u32` at a previous position (little-endian). + /// + /// # Panics + /// + /// Panics if `pos + 4 > self.position()`. + pub fn set_u32_le_at(&mut self, pos: usize, val: u32) { + self.buf[pos..pos + 4].copy_from_slice(&val.to_le_bytes()); + } + + /// Consume the cursor and return the underlying buffer. + pub fn into_inner(self) -> Vec { + self.buf + } + + /// Return a reference to the bytes written so far. + pub fn as_bytes(&self) -> &[u8] { + &self.buf + } +} + +impl Default for WriteCursor { + fn default() -> Self { + Self::new() + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + + // -- ReadCursor tests -- + + #[test] + fn read_u8_from_known_bytes() { + let data = [0x42]; + let mut cursor = ReadCursor::new(&data); + assert_eq!(cursor.read_u8().unwrap(), 0x42); + assert!(cursor.is_empty()); + } + + #[test] + fn read_u16_le_from_known_bytes() { + let data = [0x34, 0x12]; + let mut cursor = ReadCursor::new(&data); + assert_eq!(cursor.read_u16_le().unwrap(), 0x1234); + } + + #[test] + fn read_u32_le_from_known_bytes() { + let data = [0x78, 0x56, 0x34, 0x12]; + let mut cursor = ReadCursor::new(&data); + assert_eq!(cursor.read_u32_le().unwrap(), 0x12345678); + } + + #[test] + fn read_u64_le_from_known_bytes() { + let data = [0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01]; + let mut cursor = ReadCursor::new(&data); + assert_eq!(cursor.read_u64_le().unwrap(), 0x0102030405060708); + } + + #[test] + fn read_u128_le_from_known_bytes() { + let mut data = [0u8; 16]; + data[0] = 0x01; + data[15] = 0x80; + let mut cursor = ReadCursor::new(&data); + let val = cursor.read_u128_le().unwrap(); + assert_eq!(val, 0x80000000_00000000_00000000_00000001); + } + + #[test] + fn read_past_end_returns_error() { + let data = [0x00]; + let mut cursor = ReadCursor::new(&data); + assert!(cursor.read_u16_le().is_err()); + + let empty: &[u8] = &[]; + let mut cursor = ReadCursor::new(empty); + assert!(cursor.read_u8().is_err()); + } + + #[test] + fn remaining_and_position_track_correctly() { + let data = [0x01, 0x02, 0x03, 0x04, 0x05]; + let mut cursor = ReadCursor::new(&data); + assert_eq!(cursor.position(), 0); + assert_eq!(cursor.remaining(), 5); + + cursor.read_u8().unwrap(); + assert_eq!(cursor.position(), 1); + assert_eq!(cursor.remaining(), 4); + + cursor.read_u16_le().unwrap(); + assert_eq!(cursor.position(), 3); + assert_eq!(cursor.remaining(), 2); + } + + #[test] + fn skip_advances_position() { + let data = [0x01, 0x02, 0x03, 0x04]; + let mut cursor = ReadCursor::new(&data); + cursor.skip(2).unwrap(); + assert_eq!(cursor.position(), 2); + assert_eq!(cursor.read_u8().unwrap(), 0x03); + + // Skip past end is error + assert!(cursor.skip(10).is_err()); + } + + #[test] + fn read_bytes_returns_correct_slice() { + let data = [0x0A, 0x0B, 0x0C, 0x0D]; + let mut cursor = ReadCursor::new(&data); + cursor.skip(1).unwrap(); + let slice = cursor.read_bytes(2).unwrap(); + assert_eq!(slice, &[0x0B, 0x0C]); + assert_eq!(cursor.position(), 3); + } + + #[test] + fn read_utf16_le_decodes_hello() { + // "hello" in UTF-16LE + let data = [0x68, 0x00, 0x65, 0x00, 0x6C, 0x00, 0x6C, 0x00, 0x6F, 0x00]; + let mut cursor = ReadCursor::new(&data); + let s = cursor.read_utf16_le(10).unwrap(); + assert_eq!(s, "hello"); + } + + #[test] + fn read_utf16_le_odd_byte_len_is_error() { + let data = [0x68, 0x00, 0x65]; + let mut cursor = ReadCursor::new(&data); + assert!(cursor.read_utf16_le(3).is_err()); + } + + // -- WriteCursor tests -- + + #[test] + fn write_u8_produces_correct_byte() { + let mut cursor = WriteCursor::new(); + cursor.write_u8(0xFF); + assert_eq!(cursor.as_bytes(), &[0xFF]); + } + + #[test] + fn write_u16_le_produces_correct_bytes() { + let mut cursor = WriteCursor::new(); + cursor.write_u16_le(0x1234); + assert_eq!(cursor.as_bytes(), &[0x34, 0x12]); + } + + #[test] + fn write_u32_le_produces_correct_bytes() { + let mut cursor = WriteCursor::new(); + cursor.write_u32_le(0x12345678); + assert_eq!(cursor.as_bytes(), &[0x78, 0x56, 0x34, 0x12]); + } + + #[test] + fn write_u64_le_produces_correct_bytes() { + let mut cursor = WriteCursor::new(); + cursor.write_u64_le(0x0102030405060708); + assert_eq!( + cursor.as_bytes(), + &[0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01] + ); + } + + #[test] + fn write_u128_le_produces_correct_bytes() { + let mut cursor = WriteCursor::new(); + cursor.write_u128_le(0x01); + let bytes = cursor.as_bytes(); + assert_eq!(bytes.len(), 16); + assert_eq!(bytes[0], 0x01); + assert!(bytes[1..].iter().all(|&b| b == 0)); + } + + #[test] + fn align_to_pads_correctly() { + // From position 0 -> already aligned + let mut cursor = WriteCursor::new(); + cursor.align_to(8); + assert_eq!(cursor.position(), 0); + + // From position 3 -> pad to 8 + let mut cursor = WriteCursor::new(); + cursor.write_bytes(&[0x01, 0x02, 0x03]); + cursor.align_to(8); + assert_eq!(cursor.position(), 8); + // Padding bytes should be zeros + assert_eq!(&cursor.as_bytes()[3..8], &[0, 0, 0, 0, 0]); + + // From position 8 -> already aligned + cursor.align_to(8); + assert_eq!(cursor.position(), 8); + + // From position 1 -> pad to 4 + let mut cursor = WriteCursor::new(); + cursor.write_u8(0xAA); + cursor.align_to(4); + assert_eq!(cursor.position(), 4); + } + + #[test] + fn set_u32_le_at_backpatches_correctly() { + let mut cursor = WriteCursor::new(); + cursor.write_u32_le(0); // placeholder + cursor.write_u32_le(0xDEADBEEF); + cursor.set_u32_le_at(0, 0x12345678); + assert_eq!( + cursor.as_bytes(), + &[0x78, 0x56, 0x34, 0x12, 0xEF, 0xBE, 0xAD, 0xDE] + ); + } + + #[test] + fn set_u16_le_at_backpatches_correctly() { + let mut cursor = WriteCursor::new(); + cursor.write_u16_le(0); + cursor.write_u16_le(0xBEEF); + cursor.set_u16_le_at(0, 0x1234); + assert_eq!(cursor.as_bytes(), &[0x34, 0x12, 0xEF, 0xBE]); + } + + #[test] + fn write_utf16_le_encodes_correctly() { + let mut cursor = WriteCursor::new(); + cursor.write_utf16_le("hello"); + assert_eq!( + cursor.as_bytes(), + &[0x68, 0x00, 0x65, 0x00, 0x6C, 0x00, 0x6C, 0x00, 0x6F, 0x00] + ); + } + + #[test] + fn write_zeros_produces_correct_count() { + let mut cursor = WriteCursor::new(); + cursor.write_zeros(5); + assert_eq!(cursor.as_bytes(), &[0, 0, 0, 0, 0]); + assert_eq!(cursor.position(), 5); + } + + #[test] + fn into_inner_returns_buffer() { + let mut cursor = WriteCursor::new(); + cursor.write_u8(0x42); + let buf = cursor.into_inner(); + assert_eq!(buf, vec![0x42]); + } + + #[test] + fn with_capacity_works() { + let cursor = WriteCursor::with_capacity(1024); + assert_eq!(cursor.position(), 0); + } + + // -- Roundtrip tests -- + + #[test] + fn roundtrip_u8() { + let mut w = WriteCursor::new(); + w.write_u8(0xAB); + let mut r = ReadCursor::new(w.as_bytes()); + assert_eq!(r.read_u8().unwrap(), 0xAB); + } + + #[test] + fn roundtrip_u16() { + let mut w = WriteCursor::new(); + w.write_u16_le(0xCAFE); + let mut r = ReadCursor::new(w.as_bytes()); + assert_eq!(r.read_u16_le().unwrap(), 0xCAFE); + } + + #[test] + fn roundtrip_u32() { + let mut w = WriteCursor::new(); + w.write_u32_le(0xDEADBEEF); + let mut r = ReadCursor::new(w.as_bytes()); + assert_eq!(r.read_u32_le().unwrap(), 0xDEADBEEF); + } + + #[test] + fn roundtrip_u64() { + let mut w = WriteCursor::new(); + w.write_u64_le(0x0102030405060708); + let mut r = ReadCursor::new(w.as_bytes()); + assert_eq!(r.read_u64_le().unwrap(), 0x0102030405060708); + } + + #[test] + fn roundtrip_u128() { + let val: u128 = 0x0102030405060708090A0B0C0D0E0F10; + let mut w = WriteCursor::new(); + w.write_u128_le(val); + let mut r = ReadCursor::new(w.as_bytes()); + assert_eq!(r.read_u128_le().unwrap(), val); + } + + #[test] + fn roundtrip_utf16_le() { + let mut w = WriteCursor::new(); + w.write_utf16_le("Hello, world!"); + let bytes = w.into_inner(); + let mut r = ReadCursor::new(&bytes); + let s = r.read_utf16_le(bytes.len()).unwrap(); + assert_eq!(s, "Hello, world!"); + } + + #[test] + fn roundtrip_utf16_le_emoji() { + let mut w = WriteCursor::new(); + w.write_utf16_le("\u{1F600}"); + let bytes = w.into_inner(); + let mut r = ReadCursor::new(&bytes); + let s = r.read_utf16_le(bytes.len()).unwrap(); + assert_eq!(s, "\u{1F600}"); + } + + // -- Property-based tests -- + + fn valid_utf16_string() -> impl Strategy { + prop::collection::vec( + prop::char::range('\u{0000}', '\u{D7FF}') + .prop_union(prop::char::range('\u{E000}', '\u{FFFF}')), + 0..100, + ) + .prop_map(|chars| chars.into_iter().collect()) + } + + proptest! { + #[test] + fn prop_roundtrip_u8(val: u8) { + let mut w = WriteCursor::new(); + w.write_u8(val); + let mut r = ReadCursor::new(w.as_bytes()); + prop_assert_eq!(r.read_u8().unwrap(), val); + } + + #[test] + fn prop_roundtrip_u16(val: u16) { + let mut w = WriteCursor::new(); + w.write_u16_le(val); + let mut r = ReadCursor::new(w.as_bytes()); + prop_assert_eq!(r.read_u16_le().unwrap(), val); + } + + #[test] + fn prop_roundtrip_u32(val: u32) { + let mut w = WriteCursor::new(); + w.write_u32_le(val); + let mut r = ReadCursor::new(w.as_bytes()); + prop_assert_eq!(r.read_u32_le().unwrap(), val); + } + + #[test] + fn prop_roundtrip_u64(val: u64) { + let mut w = WriteCursor::new(); + w.write_u64_le(val); + let mut r = ReadCursor::new(w.as_bytes()); + prop_assert_eq!(r.read_u64_le().unwrap(), val); + } + + #[test] + fn prop_roundtrip_u128(val: u128) { + let mut w = WriteCursor::new(); + w.write_u128_le(val); + let mut r = ReadCursor::new(w.as_bytes()); + prop_assert_eq!(r.read_u128_le().unwrap(), val); + } + + #[test] + fn prop_roundtrip_utf16_le(s in valid_utf16_string()) { + let mut w = WriteCursor::new(); + w.write_utf16_le(&s); + let bytes = w.into_inner(); + let mut r = ReadCursor::new(&bytes); + let decoded = r.read_utf16_le(bytes.len()).unwrap(); + prop_assert_eq!(decoded, s); + } + } +} diff --git a/vendor/smb2/src/rpc/CLAUDE.md b/vendor/smb2/src/rpc/CLAUDE.md new file mode 100644 index 0000000..3beb5e1 --- /dev/null +++ b/vendor/smb2/src/rpc/CLAUDE.md @@ -0,0 +1,51 @@ +# RPC -- named pipe RPC for share enumeration + +DCE/RPC over SMB2 named pipes. Used to list shares on a server via the srvsvc interface. + +## Key files + +| File | Purpose | +|---|---| +| `mod.rs` | RPC PDU building/parsing: BIND, BIND_ACK, REQUEST, RESPONSE | +| `srvsvc.rs` | NDR encoding for `NetShareEnumAll` (opnum 15), `ShareInfo` type | + +## Protocol flow + +1. Tree connect to `IPC$` +2. CREATE `srvsvc` pipe (server prepends `\pipe\`) +3. WRITE: RPC BIND (call_id=1, srvsvc UUID + NDR transfer syntax) +4. READ: RPC BIND_ACK -- verify context accepted +5. WRITE: RPC REQUEST (call_id=2, opnum=15, NDR-encoded NetShareEnumAll) +6. READ: RPC RESPONSE -- NDR-decode share list +7. CLOSE pipe +8. Tree disconnect IPC$ + +Used by `client/shares.rs` which orchestrates the full flow via `SmbClient::list_shares()`. + +## NDR encoding + +`srvsvc.rs` handles NDR (Network Data Representation) encoding/decoding: +- Conformant arrays: max_count prefix, then elements +- Conformant varying strings: max_count + offset + actual_count + UTF-16LE data +- Referent pointers: non-zero pointer ID, then deferred data +- All 4-byte aligned + +## Key decisions + +- **call_id convention**: 1 for BIND, 2 for REQUEST. Arbitrary but consistent with smb-rs. +- **Max fragment size 4280**: Default `MAX_XMIT_FRAG` / `MAX_RECV_FRAG`. Matches common implementations. + +## Response reassembly (two independent layers) + +A `NetShareEnum` reply can be split two different ways, and the client handles both. They compose: a fragment loop wrapping a buffer-overflow loop. + +- **DCE/RPC fragmentation (MS-RPCE 2.2.2.6)**: a large response may arrive as several RESPONSE PDUs, each its own pipe message, with `PFC_LAST_FRAG` set only on the last. `parse_response_fragment` returns `(stub, is_last)`; `client/shares.rs` loops reading PDUs and concatenating stubs until `is_last`, then NDR-decodes the joined stub via `srvsvc::parse_net_share_enum_all_stub`. `parse_response` is the single-fragment convenience wrapper (`parse_response_fragment(..).map(|(s, _)| s)`). +- **SMB pipe `STATUS_BUFFER_OVERFLOW` (MS-SMB2 3.3.5.10)**: a single pipe message larger than our 64 KiB read buffer comes back as overflow reads (partial data) terminated by a `SUCCESS` read. `client::shares::read_pipe_message` follows this, appending chunks until `SUCCESS`. The two phenomena are usually mutually exclusive in practice (fragments ≤ `MAX_RECV_FRAG` 4280 fit in one read; a server that ignores the frag cap sends one big PDU that overflows), but the code handles either or both. + +## Gotchas + +- **Pipe name is `srvsvc`**: The server prepends `\pipe\` automatically. Don't include it in the CREATE request. +- **Admin shares filtered out**: `list_shares` filters shares ending with `$` (IPC$, ADMIN$, C$). Only disk shares returned by default. +- **RPC version is 5.0**: Connection-oriented RPC. `PFC_FIRST_FRAG | PFC_LAST_FRAG` together mark a complete single-fragment PDU; a cleared `PFC_LAST_FRAG` means more fragments follow (see reassembly above). +- **NDR string alignment**: After each string, pad to 4-byte boundary. Missing alignment causes the server to reject the request silently. +- **Don't gate pipe reads on `SUCCESS` only**: `STATUS_BUFFER_OVERFLOW` is a warning (partial data), not a failure. Use `NtStatus::is_success_or_partial` and read again, or you truncate/error on large replies from servers that chunk them. This previously made `list_shares` fail on servers whose listing exceeded one read or one fragment. diff --git a/vendor/smb2/src/rpc/mod.rs b/vendor/smb2/src/rpc/mod.rs new file mode 100644 index 0000000..988b4a4 --- /dev/null +++ b/vendor/smb2/src/rpc/mod.rs @@ -0,0 +1,549 @@ +//! Named pipe RPC (MS-RPCE / NDR) for share enumeration. +//! +//! This module encodes and decodes DCE/RPC PDUs used over SMB2 named pipes. +//! The exchange for share enumeration is: +//! +//! 1. Open `\pipe\srvsvc` via CREATE +//! 2. Send RPC BIND request (type 11) +//! 3. Receive RPC BIND_ACK response (type 12) +//! 4. Send RPC REQUEST with NetShareEnumAll (type 0, opnum 15) +//! 5. Receive RPC RESPONSE with results (type 2) +//! 6. CLOSE the pipe +//! +//! Most users don't need this module directly -- use +//! [`SmbClient::list_shares`](crate::SmbClient::list_shares) instead. +//! The [`ShareInfo`](crate::ShareInfo) type is re-exported at the crate root. + +pub mod srvsvc; + +use crate::error::Result; +use crate::pack::guid::Guid; +use crate::pack::{Pack, ReadCursor, WriteCursor}; +use crate::Error; + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +/// RPC version 5.0 (connection-oriented). +const RPC_VERSION_MAJOR: u8 = 5; +/// RPC minor version. +const RPC_VERSION_MINOR: u8 = 0; + +/// Data representation: little-endian, ASCII character set, IEEE floating point. +const DATA_REP: [u8; 4] = [0x10, 0x00, 0x00, 0x00]; + +/// RPC PDU type: REQUEST. +const PDU_TYPE_REQUEST: u8 = 0; +/// RPC PDU type: RESPONSE. +const PDU_TYPE_RESPONSE: u8 = 2; +/// RPC PDU type: BIND. +const PDU_TYPE_BIND: u8 = 11; +/// RPC PDU type: BIND_ACK. +const PDU_TYPE_BIND_ACK: u8 = 12; + +/// Default maximum transmit fragment size. +const MAX_XMIT_FRAG: u16 = 4280; +/// Default maximum receive fragment size. +const MAX_RECV_FRAG: u16 = 4280; + +/// PFC flags: first fragment. +const PFC_FIRST_FRAG: u8 = 0x01; +/// PFC flags: last fragment. +const PFC_LAST_FRAG: u8 = 0x02; + +/// srvsvc abstract syntax UUID: `4B324FC8-1670-01D3-1278-5A47BF6EE188`. +const SRVSVC_UUID: Guid = Guid { + data1: 0x4B324FC8, + data2: 0x1670, + data3: 0x01D3, + data4: [0x12, 0x78, 0x5A, 0x47, 0xBF, 0x6E, 0xE1, 0x88], +}; +/// srvsvc abstract syntax version. +const SRVSVC_VERSION: u32 = 3; + +/// NDR transfer syntax UUID: `8A885D04-1CEB-11C9-9FE8-08002B104860`. +const NDR_UUID: Guid = Guid { + data1: 0x8A885D04, + data2: 0x1CEB, + data3: 0x11C9, + data4: [0x9F, 0xE8, 0x08, 0x00, 0x2B, 0x10, 0x48, 0x60], +}; +/// NDR transfer syntax version. +const NDR_VERSION: u32 = 2; + +// --------------------------------------------------------------------------- +// RPC PDU common header size +// --------------------------------------------------------------------------- + +/// Size of the RPC PDU common header (16 bytes). +const RPC_HEADER_SIZE: usize = 16; + +// --------------------------------------------------------------------------- +// Build functions +// --------------------------------------------------------------------------- + +/// Build an RPC BIND request for the srvsvc interface. +/// +/// The BIND PDU negotiates the presentation context, binding the srvsvc +/// abstract syntax with the NDR transfer syntax. +pub fn build_srvsvc_bind(call_id: u32) -> Vec { + let mut w = WriteCursor::with_capacity(72); + + // Common header (16 bytes) -- FragLength will be backpatched + w.write_u8(RPC_VERSION_MAJOR); + w.write_u8(RPC_VERSION_MINOR); + w.write_u8(PDU_TYPE_BIND); + w.write_u8(PFC_FIRST_FRAG | PFC_LAST_FRAG); + w.write_bytes(&DATA_REP); + let frag_len_pos = w.position(); + w.write_u16_le(0); // FragLength placeholder + w.write_u16_le(0); // AuthLength + w.write_u32_le(call_id); + + // BIND-specific fields + w.write_u16_le(MAX_XMIT_FRAG); + w.write_u16_le(MAX_RECV_FRAG); + w.write_u32_le(0); // AssocGroup + + // Presentation context list + w.write_u8(1); // NumCtxItems + w.write_bytes(&[0, 0, 0]); // Reserved + + // Context item 0 + w.write_u16_le(0); // ContextId + w.write_u8(1); // NumTransferSyntaxes + w.write_u8(0); // Reserved + + // Abstract syntax: srvsvc + SRVSVC_UUID.pack(&mut w); + w.write_u32_le(SRVSVC_VERSION); + + // Transfer syntax: NDR + NDR_UUID.pack(&mut w); + w.write_u32_le(NDR_VERSION); + + // Backpatch FragLength + let total_len = w.position(); + w.set_u16_le_at(frag_len_pos, total_len as u16); + + w.into_inner() +} + +/// Parse an RPC BIND_ACK response. +/// +/// Verifies that the server accepted the presentation context (result == 0). +/// Returns `Ok(())` on success, or an error if the bind was rejected or +/// the response is malformed. +pub fn parse_bind_ack(data: &[u8]) -> Result<()> { + let mut r = ReadCursor::new(data); + + // Common header + let version = r.read_u8()?; + let version_minor = r.read_u8()?; + if version != RPC_VERSION_MAJOR || version_minor != RPC_VERSION_MINOR { + return Err(Error::invalid_data(format!( + "unexpected RPC version {version}.{version_minor}, expected 5.0" + ))); + } + + let ptype = r.read_u8()?; + if ptype != PDU_TYPE_BIND_ACK { + return Err(Error::invalid_data(format!( + "expected BIND_ACK (type 12), got type {ptype}" + ))); + } + + let _flags = r.read_u8()?; + let _data_rep = r.read_bytes(4)?; + let _frag_length = r.read_u16_le()?; + let _auth_length = r.read_u16_le()?; + let _call_id = r.read_u32_le()?; + + // BIND_ACK specific fields + let _max_xmit_frag = r.read_u16_le()?; + let _max_recv_frag = r.read_u16_le()?; + let _assoc_group = r.read_u32_le()?; + + // Secondary address (variable length, padded to 4 bytes) + let sec_addr_len = r.read_u16_le()?; + r.skip(sec_addr_len as usize)?; + // Align to 4 bytes after secondary address (the 2-byte length + string) + let consumed = 2 + sec_addr_len as usize; + let padding = (4 - (consumed % 4)) % 4; + r.skip(padding)?; + + // Result list + let num_results = r.read_u8()?; + r.skip(3)?; // Reserved + + if num_results == 0 { + return Err(Error::invalid_data("BIND_ACK has no context results")); + } + + // Check first result + let result = r.read_u16_le()?; + if result != 0 { + let reason = r.read_u16_le()?; + return Err(Error::invalid_data(format!( + "BIND rejected: result={result}, reason={reason}" + ))); + } + + Ok(()) +} + +/// Build an RPC REQUEST PDU wrapping the given stub data. +/// +/// The caller provides the NDR-encoded stub (the operation payload) and the +/// operation number. +pub fn build_request(call_id: u32, opnum: u16, stub_data: &[u8]) -> Vec { + let mut w = WriteCursor::with_capacity(RPC_HEADER_SIZE + 8 + stub_data.len()); + + // Common header + w.write_u8(RPC_VERSION_MAJOR); + w.write_u8(RPC_VERSION_MINOR); + w.write_u8(PDU_TYPE_REQUEST); + w.write_u8(PFC_FIRST_FRAG | PFC_LAST_FRAG); + w.write_bytes(&DATA_REP); + let frag_len_pos = w.position(); + w.write_u16_le(0); // FragLength placeholder + w.write_u16_le(0); // AuthLength + w.write_u32_le(call_id); + + // REQUEST specific fields + w.write_u32_le(stub_data.len() as u32); // AllocHint + w.write_u16_le(0); // ContextId + w.write_u16_le(opnum); + + // Stub data + w.write_bytes(stub_data); + + // Backpatch FragLength + let total_len = w.position(); + w.set_u16_le_at(frag_len_pos, total_len as u16); + + w.into_inner() +} + +/// Parse a single RPC RESPONSE PDU, returning its stub data and whether it is +/// the final fragment (`PFC_LAST_FRAG` set). +/// +/// DCE/RPC servers may split a large response across several fragment PDUs, +/// clearing `PFC_LAST_FRAG` on every fragment but the last (MS-RPCE 2.2.2.6). +/// Callers reassemble by concatenating each fragment's stub until `is_last` is +/// `true`. See `client::shares` for the read-and-reassemble loop. +pub fn parse_response_fragment(data: &[u8]) -> Result<(&[u8], bool)> { + let mut r = ReadCursor::new(data); + + // Common header + let version = r.read_u8()?; + let version_minor = r.read_u8()?; + if version != RPC_VERSION_MAJOR || version_minor != RPC_VERSION_MINOR { + return Err(Error::invalid_data(format!( + "unexpected RPC version {version}.{version_minor}, expected 5.0" + ))); + } + + let ptype = r.read_u8()?; + if ptype != PDU_TYPE_RESPONSE { + return Err(Error::invalid_data(format!( + "expected RESPONSE (type 2), got type {ptype}" + ))); + } + + let flags = r.read_u8()?; + let _data_rep = r.read_bytes(4)?; + let frag_length = r.read_u16_le()? as usize; + let _auth_length = r.read_u16_le()?; + let _call_id = r.read_u32_le()?; + + // RESPONSE specific fields + let _alloc_hint = r.read_u32_le()?; + let _context_id = r.read_u16_le()?; + let _cancel_count = r.read_u8()?; + let _reserved = r.read_u8()?; + + // Stub data is the rest (up to frag_length). + let header_consumed = r.position(); + if frag_length < header_consumed { + return Err(Error::invalid_data(format!( + "RPC frag_length {frag_length} shorter than header {header_consumed}" + ))); + } + let stub_data = r.read_bytes(frag_length - header_consumed)?; + + let is_last = flags & PFC_LAST_FRAG != 0; + Ok((stub_data, is_last)) +} + +/// Parse an RPC RESPONSE PDU, returning the stub data. +/// +/// Validates the PDU header and extracts the embedded stub data for +/// further NDR decoding. Assumes a single, complete fragment; for fragmented +/// responses use [`parse_response_fragment`] and reassemble. +pub fn parse_response(data: &[u8]) -> Result<&[u8]> { + parse_response_fragment(data).map(|(stub, _is_last)| stub) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::pack::Unpack; + + #[test] + fn bind_request_has_correct_header() { + let pdu = build_srvsvc_bind(1); + + assert_eq!(pdu[0], RPC_VERSION_MAJOR, "version major"); + assert_eq!(pdu[1], RPC_VERSION_MINOR, "version minor"); + assert_eq!(pdu[2], PDU_TYPE_BIND, "packet type"); + assert_eq!(pdu[3], PFC_FIRST_FRAG | PFC_LAST_FRAG, "flags"); + + // Data representation + assert_eq!(&pdu[4..8], &DATA_REP); + + // FragLength should match actual PDU length + let frag_len = u16::from_le_bytes([pdu[8], pdu[9]]); + assert_eq!(frag_len as usize, pdu.len()); + + // AuthLength = 0 + let auth_len = u16::from_le_bytes([pdu[10], pdu[11]]); + assert_eq!(auth_len, 0); + + // CallId = 1 + let call_id = u32::from_le_bytes([pdu[12], pdu[13], pdu[14], pdu[15]]); + assert_eq!(call_id, 1); + } + + #[test] + fn bind_request_contains_srvsvc_uuid() { + let pdu = build_srvsvc_bind(1); + + // After common header (16) + MaxXmitFrag(2) + MaxRecvFrag(2) + AssocGroup(4) + + // NumCtxItems(1) + Reserved(3) + ContextId(2) + NumTransferSyntaxes(1) + Reserved(1) = 32 + let uuid_offset = 32; + + // Extract the abstract syntax UUID bytes + let mut cursor = ReadCursor::new(&pdu[uuid_offset..]); + let guid = Guid::unpack(&mut cursor).unwrap(); + assert_eq!(guid, SRVSVC_UUID); + + let version = cursor.read_u32_le().unwrap(); + assert_eq!(version, SRVSVC_VERSION); + } + + #[test] + fn bind_request_contains_ndr_transfer_syntax() { + let pdu = build_srvsvc_bind(1); + + // Transfer syntax starts after abstract syntax (UUID=16 + version=4 = 20 bytes after uuid_offset) + let transfer_offset = 32 + 20; + + let mut cursor = ReadCursor::new(&pdu[transfer_offset..]); + let guid = Guid::unpack(&mut cursor).unwrap(); + assert_eq!(guid, NDR_UUID); + + let version = cursor.read_u32_le().unwrap(); + assert_eq!(version, NDR_VERSION); + } + + #[test] + fn bind_request_total_length() { + let pdu = build_srvsvc_bind(1); + // 16 (header) + 4 (max frags) + 4 (assoc) + 4 (ctx list header) + + // 4 (ctx item header) + 20 (abstract) + 20 (transfer) = 72 + assert_eq!(pdu.len(), 72); + } + + #[test] + fn parse_valid_bind_ack() { + let ack = build_test_bind_ack(0); // result = 0 = accepted + assert!(parse_bind_ack(&ack).is_ok()); + } + + #[test] + fn parse_rejected_bind_ack() { + let ack = build_test_bind_ack(2); // result = 2 = provider_rejection + let err = parse_bind_ack(&ack).unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("rejected"), + "error should mention rejection: {msg}" + ); + } + + #[test] + fn parse_bind_ack_wrong_version() { + let mut ack = build_test_bind_ack(0); + ack[0] = 4; // wrong version + assert!(parse_bind_ack(&ack).is_err()); + } + + #[test] + fn parse_bind_ack_wrong_type() { + let mut ack = build_test_bind_ack(0); + ack[2] = PDU_TYPE_BIND; // wrong type + assert!(parse_bind_ack(&ack).is_err()); + } + + #[test] + fn request_pdu_has_correct_opnum() { + let stub = vec![0xAA, 0xBB, 0xCC]; + let pdu = build_request(1, 15, &stub); + + // OpNum is at offset 22 (header=16 + AllocHint=4 + ContextId=2) + let opnum = u16::from_le_bytes([pdu[22], pdu[23]]); + assert_eq!(opnum, 15); + } + + #[test] + fn request_pdu_has_correct_alloc_hint() { + let stub = vec![0xAA, 0xBB, 0xCC]; + let pdu = build_request(1, 15, &stub); + + let alloc_hint = u32::from_le_bytes([pdu[16], pdu[17], pdu[18], pdu[19]]); + assert_eq!(alloc_hint, 3); + } + + #[test] + fn request_pdu_contains_stub_data() { + let stub = vec![0xAA, 0xBB, 0xCC]; + let pdu = build_request(1, 15, &stub); + + // Stub starts at offset 24 (header=16 + request fields=8) + assert_eq!(&pdu[24..], &[0xAA, 0xBB, 0xCC]); + } + + #[test] + fn request_pdu_frag_length_matches() { + let stub = vec![0xAA, 0xBB, 0xCC]; + let pdu = build_request(1, 15, &stub); + + let frag_len = u16::from_le_bytes([pdu[8], pdu[9]]); + assert_eq!(frag_len as usize, pdu.len()); + } + + #[test] + fn parse_response_extracts_stub() { + let stub = b"hello stub data"; + let response_pdu = build_test_response(1, stub); + + let extracted = parse_response(&response_pdu).unwrap(); + assert_eq!(extracted, stub); + } + + #[test] + fn parse_response_wrong_version() { + let mut pdu = build_test_response(1, b"data"); + pdu[0] = 4; // wrong version + assert!(parse_response(&pdu).is_err()); + } + + #[test] + fn parse_response_fragment_reports_last_flag() { + // build_test_response sets PFC_FIRST_FRAG | PFC_LAST_FRAG. + let pdu = build_test_response(1, b"stub"); + let (stub, is_last) = parse_response_fragment(&pdu).unwrap(); + assert_eq!(stub, b"stub"); + assert!(is_last, "FIRST|LAST PDU should be the last fragment"); + + // Clear PFC_LAST_FRAG in the flags byte: now it's a non-final fragment. + let mut frag = pdu.clone(); + frag[3] &= !PFC_LAST_FRAG; + let (stub, is_last) = parse_response_fragment(&frag).unwrap(); + assert_eq!(stub, b"stub"); + assert!(!is_last, "FIRST-only PDU should not be the last fragment"); + } + + #[test] + fn parse_response_rejects_frag_length_shorter_than_header() { + let mut pdu = build_test_response(1, b"data"); + // FragLength lives at offset 8 (u16 LE); set it below the 24-byte header. + pdu[8] = 4; + pdu[9] = 0; + assert!(parse_response(&pdu).is_err()); + } + + #[test] + fn parse_response_wrong_type() { + let mut pdu = build_test_response(1, b"data"); + pdu[2] = PDU_TYPE_REQUEST; // wrong type + assert!(parse_response(&pdu).is_err()); + } + + // -- Test helpers -- + + /// Build a minimal BIND_ACK for testing. + fn build_test_bind_ack(result: u16) -> Vec { + let mut w = WriteCursor::with_capacity(64); + + // Common header + w.write_u8(RPC_VERSION_MAJOR); + w.write_u8(RPC_VERSION_MINOR); + w.write_u8(PDU_TYPE_BIND_ACK); + w.write_u8(PFC_FIRST_FRAG | PFC_LAST_FRAG); + w.write_bytes(&DATA_REP); + let frag_len_pos = w.position(); + w.write_u16_le(0); // FragLength placeholder + w.write_u16_le(0); // AuthLength + w.write_u32_le(1); // CallId + + // BIND_ACK specific + w.write_u16_le(MAX_XMIT_FRAG); + w.write_u16_le(MAX_RECV_FRAG); + w.write_u32_le(0x12345); // AssocGroup + + // Secondary address: "\pipe\srvsvc\0" (empty for simplicity -- use length 0) + w.write_u16_le(0); // SecAddrLen = 0 + w.write_bytes(&[0, 0]); // Padding to 4-byte alignment + + // Result list + w.write_u8(1); // NumResults + w.write_bytes(&[0, 0, 0]); // Reserved + + // Result entry + w.write_u16_le(result); // Result + w.write_u16_le(0); // Reason + // Transfer syntax (16 bytes UUID + 4 bytes version) + NDR_UUID.pack(&mut w); + w.write_u32_le(NDR_VERSION); + + let total_len = w.position(); + w.set_u16_le_at(frag_len_pos, total_len as u16); + + w.into_inner() + } + + /// Build a minimal RPC RESPONSE PDU wrapping the given stub data. + fn build_test_response(call_id: u32, stub: &[u8]) -> Vec { + let mut w = WriteCursor::with_capacity(RPC_HEADER_SIZE + 8 + stub.len()); + + w.write_u8(RPC_VERSION_MAJOR); + w.write_u8(RPC_VERSION_MINOR); + w.write_u8(PDU_TYPE_RESPONSE); + w.write_u8(PFC_FIRST_FRAG | PFC_LAST_FRAG); + w.write_bytes(&DATA_REP); + let frag_len_pos = w.position(); + w.write_u16_le(0); // FragLength placeholder + w.write_u16_le(0); // AuthLength + w.write_u32_le(call_id); + + // RESPONSE specific + w.write_u32_le(stub.len() as u32); // AllocHint + w.write_u16_le(0); // ContextId + w.write_u8(0); // CancelCount + w.write_u8(0); // Reserved + + w.write_bytes(stub); + + let total_len = w.position(); + w.set_u16_le_at(frag_len_pos, total_len as u16); + + w.into_inner() + } +} diff --git a/vendor/smb2/src/rpc/srvsvc.rs b/vendor/smb2/src/rpc/srvsvc.rs new file mode 100644 index 0000000..b7370cb --- /dev/null +++ b/vendor/smb2/src/rpc/srvsvc.rs @@ -0,0 +1,554 @@ +//! NetShareEnumAll NDR encoding/decoding for the srvsvc interface. +//! +//! Encodes the NetrShareEnum request (opnum 15) and decodes the response, +//! extracting share names, types, and comments. + +use crate::error::Result; +use crate::pack::{ReadCursor, WriteCursor}; +use crate::Error; + +/// Share type: disk share. +pub const STYPE_DISKTREE: u32 = 0x0000_0000; +/// Share type: printer queue. +pub const STYPE_PRINTQ: u32 = 0x0000_0001; +/// Share type: device. +pub const STYPE_DEVICE: u32 = 0x0000_0002; +/// Share type: IPC (inter-process communication). +pub const STYPE_IPC: u32 = 0x0000_0003; +/// Share type modifier: special/admin share (combined with above via OR). +pub const STYPE_SPECIAL: u32 = 0x8000_0000; + +/// Mask for the base share type (low bits). +const STYPE_BASE_MASK: u32 = 0x0000_FFFF; + +/// Information about a single network share. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ShareInfo { + /// The share name (for example, "Documents" or "IPC$"). + pub name: String, + /// The share type as a raw u32 (see `STYPE_*` constants). + pub share_type: u32, + /// An optional comment/description for the share. + pub comment: String, +} + +/// Build the NDR-encoded stub data for a NetShareEnumAll request. +/// +/// The stub is meant to be wrapped in an RPC REQUEST PDU with opnum 15. +pub fn build_net_share_enum_all_stub(server_name: &str) -> Vec { + let mut w = WriteCursor::with_capacity(128); + + // ServerName: NDR unique pointer to conformant+varying string (UTF-16LE, null-terminated) + // Referent ID (non-null pointer) + w.write_u32_le(0x0002_0000); // referent ID + + // Encode the server name as a conformant+varying NDR string + let name_utf16: Vec = server_name + .encode_utf16() + .chain(std::iter::once(0)) + .collect(); + let char_count = name_utf16.len() as u32; + + // MaxCount + w.write_u32_le(char_count); + // Offset + w.write_u32_le(0); + // ActualCount + w.write_u32_le(char_count); + // String data (UTF-16LE) + for &code_unit in &name_utf16 { + w.write_u16_le(code_unit); + } + // Align to 4 bytes after string data + w.align_to(4); + + // InfoStruct: SHARE_ENUM_STRUCT + // Level = 1 (we want SHARE_INFO_1) + w.write_u32_le(1); + + // ShareInfo union discriminant = 1 (matches level) + w.write_u32_le(1); + + // Pointer to SHARE_INFO_1_CONTAINER (unique pointer) + w.write_u32_le(0x0002_0004); // referent ID + + // SHARE_INFO_1_CONTAINER (deferred pointer data) + // EntriesRead = 0 (server fills this) + w.write_u32_le(0); + // Buffer pointer = NULL (let server allocate) + w.write_u32_le(0); + + // PreferedMaximumLength = 0xFFFFFFFF (no limit) + w.write_u32_le(0xFFFF_FFFF); + + // ResumeHandle: unique pointer to u32 + // NULL pointer (no resume) + w.write_u32_le(0); + + w.into_inner() +} + +/// Build a complete RPC REQUEST PDU for NetShareEnumAll. +/// +/// Combines the RPC REQUEST header (opnum 15) with the NDR stub data. +pub fn build_net_share_enum_all(call_id: u32, server_name: &str) -> Vec { + let stub = build_net_share_enum_all_stub(server_name); + super::build_request(call_id, 15, &stub) +} + +/// Parse the NDR stub data from a NetShareEnumAll RPC RESPONSE. +/// +/// Extracts all share entries from the response. The caller should use +/// [`filter_disk_shares`] to get only disk shares. +pub fn parse_net_share_enum_all_response(data: &[u8]) -> Result> { + // First, parse the RPC RESPONSE envelope to get the stub data + let stub = super::parse_response(data)?; + parse_net_share_enum_all_stub(stub) +} + +/// Parse the NDR stub data directly (without the RPC envelope). +/// +/// Used by the share-enumeration reassembly path, which concatenates the stub +/// of each RPC fragment before decoding. +pub(crate) fn parse_net_share_enum_all_stub(stub: &[u8]) -> Result> { + let mut r = ReadCursor::new(stub); + + // Level (u32) -- should be 1 + let level = r.read_u32_le()?; + if level != 1 { + return Err(Error::invalid_data(format!( + "expected share info level 1, got {level}" + ))); + } + + // Union discriminant (u32) -- should be 1 + let discriminant = r.read_u32_le()?; + if discriminant != 1 { + return Err(Error::invalid_data(format!( + "expected union discriminant 1, got {discriminant}" + ))); + } + + // Pointer to SHARE_INFO_1_CONTAINER + let container_ptr = r.read_u32_le()?; + if container_ptr == 0 { + return Ok(Vec::new()); + } + + // SHARE_INFO_1_CONTAINER + let count = r.read_u32_le()?; + + // Pointer to array of SHARE_INFO_1 + let array_ptr = r.read_u32_le()?; + if array_ptr == 0 || count == 0 { + return Ok(Vec::new()); + } + + // Array: MaxCount header + let max_count = r.read_u32_le()?; + if max_count < count { + return Err(Error::invalid_data(format!( + "array max_count ({max_count}) < entries ({count})" + ))); + } + + // Read the fixed-size parts of each SHARE_INFO_1 entry: + // Each entry has: name_ptr (u32), type (u32), comment_ptr (u32) + struct RawEntry { + name_ptr: u32, + share_type: u32, + comment_ptr: u32, + } + + let mut entries = Vec::with_capacity(count as usize); + for _ in 0..count { + let name_ptr = r.read_u32_le()?; + let share_type = r.read_u32_le()?; + let comment_ptr = r.read_u32_le()?; + entries.push(RawEntry { + name_ptr, + share_type, + comment_ptr, + }); + } + + // Now read the deferred pointer data (conformant+varying strings) + let mut shares = Vec::with_capacity(count as usize); + for entry in &entries { + let name = if entry.name_ptr != 0 { + read_ndr_string(&mut r)? + } else { + String::new() + }; + + let comment = if entry.comment_ptr != 0 { + read_ndr_string(&mut r)? + } else { + String::new() + }; + + shares.push(ShareInfo { + name, + share_type: entry.share_type, + comment, + }); + } + + Ok(shares) +} + +/// Read an NDR conformant+varying UTF-16LE string from the cursor. +/// +/// Format: MaxCount(u32) + Offset(u32) + ActualCount(u32) + UTF-16LE data. +/// The string is null-terminated on the wire; we strip the null. +fn read_ndr_string(r: &mut ReadCursor<'_>) -> Result { + let _max_count = r.read_u32_le()?; + let _offset = r.read_u32_le()?; + let actual_count = r.read_u32_le()?; + + if actual_count == 0 { + return Ok(String::new()); + } + + let byte_len = actual_count as usize * 2; + let s = r.read_utf16_le(byte_len)?; + + // Align to 4 bytes after reading string data + let pos = r.position(); + let padding = (4 - (pos % 4)) % 4; + if padding > 0 && r.remaining() >= padding { + r.skip(padding)?; + } + + // Strip trailing null + Ok(s.trim_end_matches('\0').to_string()) +} + +/// Filter shares, keeping only disk shares and excluding admin shares (ending with `$`). +pub fn filter_disk_shares(shares: Vec) -> Vec { + shares + .into_iter() + .filter(|s| { + let base_type = s.share_type & STYPE_BASE_MASK; + let is_disk = base_type == STYPE_DISKTREE; + let is_special = (s.share_type & STYPE_SPECIAL) != 0; + let ends_with_dollar = s.name.ends_with('$'); + is_disk && !is_special && !ends_with_dollar + }) + .collect() +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn build_request_has_opnum_15() { + let pdu = build_net_share_enum_all(1, r"\\server"); + // OpNum is at offset 22 in the RPC REQUEST PDU + let opnum = u16::from_le_bytes([pdu[22], pdu[23]]); + assert_eq!(opnum, 15); + } + + #[test] + fn build_request_stub_contains_server_name() { + let stub = build_net_share_enum_all_stub(r"\\server"); + // The server name should appear as UTF-16LE somewhere in the stub + let expected_utf16: Vec = r"\\server" + .encode_utf16() + .flat_map(|c| c.to_le_bytes()) + .collect(); + + let found = stub + .windows(expected_utf16.len()) + .any(|window| window == expected_utf16.as_slice()); + assert!(found, "stub should contain the server name in UTF-16LE"); + } + + #[test] + fn parse_response_with_three_shares() { + let response_pdu = build_test_enum_response(&[ + ("Documents", STYPE_DISKTREE, "Shared docs"), + ("IPC$", STYPE_IPC | STYPE_SPECIAL, "Remote IPC"), + ("C$", STYPE_DISKTREE | STYPE_SPECIAL, "Default share"), + ]); + + let shares = parse_net_share_enum_all_response(&response_pdu).unwrap(); + assert_eq!(shares.len(), 3); + assert_eq!(shares[0].name, "Documents"); + assert_eq!(shares[0].share_type, STYPE_DISKTREE); + assert_eq!(shares[0].comment, "Shared docs"); + assert_eq!(shares[1].name, "IPC$"); + assert_eq!(shares[2].name, "C$"); + } + + #[test] + fn filter_keeps_disk_shares() { + let shares = vec![ + ShareInfo { + name: "Documents".to_string(), + share_type: STYPE_DISKTREE, + comment: "Shared docs".to_string(), + }, + ShareInfo { + name: "Photos".to_string(), + share_type: STYPE_DISKTREE, + comment: String::new(), + }, + ]; + + let filtered = filter_disk_shares(shares); + assert_eq!(filtered.len(), 2); + } + + #[test] + fn filter_removes_ipc() { + let shares = vec![ShareInfo { + name: "IPC$".to_string(), + share_type: STYPE_IPC | STYPE_SPECIAL, + comment: "Remote IPC".to_string(), + }]; + + let filtered = filter_disk_shares(shares); + assert!(filtered.is_empty()); + } + + #[test] + fn filter_removes_admin_shares() { + let shares = vec![ + ShareInfo { + name: "C$".to_string(), + share_type: STYPE_DISKTREE | STYPE_SPECIAL, + comment: "Default share".to_string(), + }, + ShareInfo { + name: "ADMIN$".to_string(), + share_type: STYPE_DISKTREE | STYPE_SPECIAL, + comment: "Remote Admin".to_string(), + }, + ]; + + let filtered = filter_disk_shares(shares); + assert!(filtered.is_empty()); + } + + #[test] + fn filter_mixed_shares() { + let shares = vec![ + ShareInfo { + name: "Documents".to_string(), + share_type: STYPE_DISKTREE, + comment: "Shared docs".to_string(), + }, + ShareInfo { + name: "IPC$".to_string(), + share_type: STYPE_IPC | STYPE_SPECIAL, + comment: "Remote IPC".to_string(), + }, + ShareInfo { + name: "C$".to_string(), + share_type: STYPE_DISKTREE | STYPE_SPECIAL, + comment: "Default share".to_string(), + }, + ShareInfo { + name: "Photos".to_string(), + share_type: STYPE_DISKTREE, + comment: String::new(), + }, + ShareInfo { + name: "Printer".to_string(), + share_type: STYPE_PRINTQ, + comment: "Office printer".to_string(), + }, + ]; + + let filtered = filter_disk_shares(shares); + assert_eq!(filtered.len(), 2); + assert_eq!(filtered[0].name, "Documents"); + assert_eq!(filtered[1].name, "Photos"); + } + + #[test] + fn parse_empty_share_list() { + let response_pdu = build_test_enum_response(&[]); + let shares = parse_net_share_enum_all_response(&response_pdu).unwrap(); + assert!(shares.is_empty()); + } + + #[test] + fn parse_share_with_unicode_name() { + let response_pdu = build_test_enum_response(&[( + "\u{00C4}rchive", + STYPE_DISKTREE, + "Archiv f\u{00FC}r Dateien", + )]); + + let shares = parse_net_share_enum_all_response(&response_pdu).unwrap(); + assert_eq!(shares.len(), 1); + assert_eq!(shares[0].name, "\u{00C4}rchive"); + assert_eq!(shares[0].comment, "Archiv f\u{00FC}r Dateien"); + } + + #[test] + fn parse_share_with_cjk_characters() { + let response_pdu = build_test_enum_response(&[( + "\u{5171}\u{6709}", + STYPE_DISKTREE, + "\u{5171}\u{6709}\u{30D5}\u{30A9}\u{30EB}\u{30C0}", + )]); + + let shares = parse_net_share_enum_all_response(&response_pdu).unwrap(); + assert_eq!(shares.len(), 1); + assert_eq!(shares[0].name, "\u{5171}\u{6709}"); + assert_eq!( + shares[0].comment, + "\u{5171}\u{6709}\u{30D5}\u{30A9}\u{30EB}\u{30C0}" + ); + } + + #[test] + fn roundtrip_build_and_parse() { + // Build a request, then manually construct a response and parse it + let _request = build_net_share_enum_all(1, r"\\testserver"); + + let response_pdu = build_test_enum_response(&[ + ("Share1", STYPE_DISKTREE, "First share"), + ("Share2", STYPE_DISKTREE, "Second share"), + ]); + + let shares = parse_net_share_enum_all_response(&response_pdu).unwrap(); + assert_eq!(shares.len(), 2); + assert_eq!(shares[0].name, "Share1"); + assert_eq!(shares[0].comment, "First share"); + assert_eq!(shares[1].name, "Share2"); + assert_eq!(shares[1].comment, "Second share"); + } + + #[test] + fn filter_preserves_non_dollar_disk_shares_only() { + // A share named "My$hare" (dollar in middle) should be kept + let shares = vec![ShareInfo { + name: "My$hare".to_string(), + share_type: STYPE_DISKTREE, + comment: String::new(), + }]; + + let filtered = filter_disk_shares(shares); + assert_eq!(filtered.len(), 1); + assert_eq!(filtered[0].name, "My$hare"); + } + + // -- Test helpers -- + + /// Write an NDR conformant+varying UTF-16LE string into the cursor. + fn write_ndr_string(w: &mut WriteCursor, s: &str) { + let utf16: Vec = s.encode_utf16().chain(std::iter::once(0)).collect(); + let char_count = utf16.len() as u32; + + w.write_u32_le(char_count); // MaxCount + w.write_u32_le(0); // Offset + w.write_u32_le(char_count); // ActualCount + for &code_unit in &utf16 { + w.write_u16_le(code_unit); + } + w.align_to(4); + } + + /// Build a complete RPC RESPONSE PDU containing the given shares. + /// + /// This constructs valid NDR stub data wrapped in an RPC RESPONSE envelope. + fn build_test_enum_response(shares: &[(&str, u32, &str)]) -> Vec { + let stub = build_test_enum_stub(shares); + build_test_response_pdu(1, &stub) + } + + /// Build NDR stub data for a NetShareEnumAll response. + fn build_test_enum_stub(shares: &[(&str, u32, &str)]) -> Vec { + let mut w = WriteCursor::with_capacity(512); + let count = shares.len() as u32; + + // Level = 1 + w.write_u32_le(1); + // Union discriminant = 1 + w.write_u32_le(1); + + if count == 0 { + // Null container pointer + w.write_u32_le(0); + // TotalEntries + w.write_u32_le(0); + // ResumeHandle pointer (null) + w.write_u32_le(0); + // Return value (Windows error code, 0 = success) + w.write_u32_le(0); + return w.into_inner(); + } + + // Container pointer (non-null) + w.write_u32_le(0x0002_0000); + + // SHARE_INFO_1_CONTAINER + w.write_u32_le(count); // EntriesRead + w.write_u32_le(0x0002_0004); // Array pointer (non-null) + + // Array: MaxCount + w.write_u32_le(count); + + // Fixed-size entries: name_ptr, type, comment_ptr + for (i, &(_, share_type, _)) in shares.iter().enumerate() { + w.write_u32_le(0x0002_0008 + (i as u32) * 2); // name referent ID + w.write_u32_le(share_type); + w.write_u32_le(0x0002_0108 + (i as u32) * 2); // comment referent ID + } + + // Deferred string data (name then comment for each entry) + for &(name, _, comment) in shares { + write_ndr_string(&mut w, name); + write_ndr_string(&mut w, comment); + } + + // TotalEntries + w.write_u32_le(count); + // ResumeHandle pointer (null) + w.write_u32_le(0); + // Return value (0 = success) + w.write_u32_le(0); + + w.into_inner() + } + + /// Build a minimal RPC RESPONSE PDU wrapping stub data. + fn build_test_response_pdu(call_id: u32, stub: &[u8]) -> Vec { + use crate::pack::WriteCursor; + + let mut w = WriteCursor::with_capacity(24 + stub.len()); + + // Common header + w.write_u8(5); // Version + w.write_u8(0); // VersionMinor + w.write_u8(2); // PacketType = RESPONSE + w.write_u8(0x03); // Flags (first + last) + w.write_bytes(&[0x10, 0x00, 0x00, 0x00]); // DataRep + let frag_len_pos = w.position(); + w.write_u16_le(0); // FragLength placeholder + w.write_u16_le(0); // AuthLength + w.write_u32_le(call_id); + + // RESPONSE specific + w.write_u32_le(stub.len() as u32); // AllocHint + w.write_u16_le(0); // ContextId + w.write_u8(0); // CancelCount + w.write_u8(0); // Reserved + + w.write_bytes(stub); + + let total_len = w.position(); + w.set_u16_le_at(frag_len_pos, total_len as u16); + + w.into_inner() + } +} diff --git a/vendor/smb2/src/testing/CLAUDE.md b/vendor/smb2/src/testing/CLAUDE.md new file mode 100644 index 0000000..33c5fd1 --- /dev/null +++ b/vendor/smb2/src/testing/CLAUDE.md @@ -0,0 +1,48 @@ +# Testing module -- Docker-based SMB test servers + +Feature-gated (`testing` feature flag). Provides Docker-based Samba containers for consumers (apps that depend on smb2) to test their SMB integration. + +## Key files + +| File | Purpose | +|---|---| +| `mod.rs` | `TestServers`, `Error`, port constants, embedded Docker files, `write_compose_files()` | + +## Architecture + +Three-layer testing model: + +1. **Layer 1 (Rust)**: `TestServers::start()` / `start_all()` / `start_blocking()` return a struct with `*_client()` methods that connect to Docker containers. +2. **Layer 2 (E2E)**: `write_compose_files(dir)` extracts embedded Docker infrastructure to disk for non-Rust test frameworks. +3. **Layer 3 (Manual QA)**: Same compose files, run manually. + +## Embedded files + +All 35 Docker files (compose, Dockerfiles, smb.conf, scripts) are embedded via `include_str!` at compile time. At runtime, `write_compose_files()` writes them to a temp directory. Docker Compose runs from there. + +## Port scheme + +15 containers on ports 10480-10494. Each port has an env-var override (`SMB_CONSUMER_*_PORT`). The `port()` function checks the env var, falls back to the hardcoded default. + +## Profiles + +- **Minimal**: guest + auth only (2 containers, fast startup). +- **All**: all 15 containers. + +Calling a `*_client()` method for a container not in the current profile returns `Error::ContainerNotStarted`. + +## Key decisions + +| Decision | Choice | Why | +|---|---|---| +| No extra deps | `std::process::Command` for Docker | Keep the crate lean | +| Temp dir via `std::env::temp_dir()` | No `tempfile` crate | No extra deps | +| Embedded files via `include_str!` | Self-contained published crate | Consumers don't need smb2 source tree | +| Separate error type | `testing::Error` vs `smb2::Error` | Docker failures are not protocol errors | +| Best-effort cleanup in Drop | `docker compose down` | LazyLock statics never drop, so this is convenience only | + +## Gotchas + +- **LazyLock statics never drop**: `TestServers::drop()` won't run at process exit. CI should use explicit cleanup steps. +- **Flaky container has no health check**: The 5s-up/5s-down cycle means health checks would randomly fail. `wait_healthy()` skips it. +- **DFS is disabled on test clients**: Consumer containers don't set up DFS. The `connect_guest` / `connect_auth` helpers set `dfs_enabled: false`. diff --git a/vendor/smb2/src/testing/mod.rs b/vendor/smb2/src/testing/mod.rs new file mode 100644 index 0000000..0bfc6b0 --- /dev/null +++ b/vendor/smb2/src/testing/mod.rs @@ -0,0 +1,1275 @@ +//! Docker-based SMB test servers for integration testing. +//! +//! Provides [`TestServers`] for starting Samba containers on demand, +//! with factory methods that return connected [`SmbClient`] instances. +//! Enable the `testing` feature flag to use this module. +//! +//! # Three-layer testing model +//! +//! **Layer 1: Rust integration tests** -- Use [`TestServers`] to get +//! pre-connected clients in `#[tokio::test]` functions. +//! +//! **Layer 2: E2E tests** -- Use [`write_compose_files`] to extract +//! embedded Docker infrastructure, then run `docker compose up` from +//! your test framework (Playwright, Cypress, etc.). +//! +//! **Layer 3: Manual QA** -- Extract compose files once, run containers +//! manually, browse virtual servers in your app during development. +//! +//! # Example +//! +//! ```rust,no_run +//! use std::sync::LazyLock; +//! use smb2::testing::TestServers; +//! +//! static SERVERS: LazyLock = LazyLock::new(|| { +//! TestServers::start_blocking().unwrap() +//! }); +//! +//! # async fn example() { +//! let mut guest = SERVERS.guest_client().await.unwrap(); +//! let shares = guest.list_shares().await.unwrap(); +//! # } +//! ``` + +use std::fmt; +use std::fs; +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::time::Duration; + +use log::{debug, info}; + +use crate::client::{ClientConfig, SmbClient}; + +// ── Error type ────────────────────────────────────────────────────────── + +/// Errors from the test infrastructure (Docker, process, health checks). +/// +/// Separate from [`crate::Error`] because these are test-setup failures, +/// not protocol errors. +#[derive(Debug)] +pub enum Error { + /// Docker compose command failed. + Docker(std::io::Error), + /// Container didn't pass health check in time. + HealthCheckTimeout { + /// Name of the container that timed out. + container: String, + }, + /// Requested a client for a container that isn't running. + ContainerNotStarted { + /// Name of the container that was requested. + container: String, + /// Suggestion for how to fix this. + hint: String, + }, + /// SMB connection or operation failed. + Smb(crate::Error), + /// Failed to write embedded files to disk. + Io(std::io::Error), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::Docker(e) => write!(f, "docker command failed: {e}"), + Error::HealthCheckTimeout { container } => { + write!(f, "health check timed out for container: {container}") + } + Error::ContainerNotStarted { container, hint } => { + write!(f, "container not started: {container} ({hint})") + } + Error::Smb(e) => write!(f, "smb connection failed: {e}"), + Error::Io(e) => write!(f, "failed to write compose files: {e}"), + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Error::Docker(e) | Error::Io(e) => Some(e), + Error::Smb(e) => Some(e), + _ => None, + } + } +} + +/// Result type for test infrastructure operations. +pub type Result = std::result::Result; + +// ── Port constants ────────────────────────────────────────────────────── + +const DEFAULT_GUEST_PORT: u16 = 10480; +const DEFAULT_AUTH_PORT: u16 = 10481; +const DEFAULT_BOTH_PORT: u16 = 10482; +const DEFAULT_50SHARES_PORT: u16 = 10483; +const DEFAULT_UNICODE_PORT: u16 = 10484; +const DEFAULT_LONGNAMES_PORT: u16 = 10485; +const DEFAULT_DEEPNEST_PORT: u16 = 10486; +const DEFAULT_MANYFILES_PORT: u16 = 10487; +const DEFAULT_READONLY_PORT: u16 = 10488; +const DEFAULT_WINDOWS_PORT: u16 = 10489; +const DEFAULT_SYNOLOGY_PORT: u16 = 10490; +const DEFAULT_LINUX_PORT: u16 = 10491; +const DEFAULT_FLAKY_PORT: u16 = 10492; +const DEFAULT_SLOW_PORT: u16 = 10493; +const DEFAULT_MAXREADSIZE_PORT: u16 = 10494; + +/// Resolve a port from an environment variable, falling back to a default. +fn port(env_var: &str, default: u16) -> u16 { + std::env::var(env_var) + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(default) +} + +/// Port for the guest-access container. +pub fn guest_port() -> u16 { + port("SMB_CONSUMER_GUEST_PORT", DEFAULT_GUEST_PORT) +} + +/// Port for the auth-required container. +pub fn auth_port() -> u16 { + port("SMB_CONSUMER_AUTH_PORT", DEFAULT_AUTH_PORT) +} + +/// Port for the mixed auth container. +pub fn both_port() -> u16 { + port("SMB_CONSUMER_BOTH_PORT", DEFAULT_BOTH_PORT) +} + +/// Port for the 50-shares container. +pub fn many_shares_port() -> u16 { + port("SMB_CONSUMER_50SHARES_PORT", DEFAULT_50SHARES_PORT) +} + +/// Port for the unicode container. +pub fn unicode_port() -> u16 { + port("SMB_CONSUMER_UNICODE_PORT", DEFAULT_UNICODE_PORT) +} + +/// Port for the long-names container. +pub fn longnames_port() -> u16 { + port("SMB_CONSUMER_LONGNAMES_PORT", DEFAULT_LONGNAMES_PORT) +} + +/// Port for the deep-nesting container. +pub fn deepnest_port() -> u16 { + port("SMB_CONSUMER_DEEPNEST_PORT", DEFAULT_DEEPNEST_PORT) +} + +/// Port for the many-files container. +pub fn manyfiles_port() -> u16 { + port("SMB_CONSUMER_MANYFILES_PORT", DEFAULT_MANYFILES_PORT) +} + +/// Port for the read-only container. +pub fn readonly_port() -> u16 { + port("SMB_CONSUMER_READONLY_PORT", DEFAULT_READONLY_PORT) +} + +/// Port for the Windows-like container. +pub fn windows_port() -> u16 { + port("SMB_CONSUMER_WINDOWS_PORT", DEFAULT_WINDOWS_PORT) +} + +/// Port for the Synology-like container. +pub fn synology_port() -> u16 { + port("SMB_CONSUMER_SYNOLOGY_PORT", DEFAULT_SYNOLOGY_PORT) +} + +/// Port for the Linux container. +pub fn linux_port() -> u16 { + port("SMB_CONSUMER_LINUX_PORT", DEFAULT_LINUX_PORT) +} + +/// Port for the flaky container. +pub fn flaky_port() -> u16 { + port("SMB_CONSUMER_FLAKY_PORT", DEFAULT_FLAKY_PORT) +} + +/// Port for the slow container. +pub fn slow_port() -> u16 { + port("SMB_CONSUMER_SLOW_PORT", DEFAULT_SLOW_PORT) +} + +/// Port for the max-read-size container. +/// +/// The server enforces `smb2 max read = 65536` and `smb2 max write = 65536`, +/// so every transfer larger than 64 KB is chunked. Consumers can target this +/// fixture to exercise the streaming write/read fallback paths without +/// needing the internal-fixture `smb-maxreadsize` container. +pub fn maxreadsize_port() -> u16 { + port("SMB_CONSUMER_MAXREADSIZE_PORT", DEFAULT_MAXREADSIZE_PORT) +} + +// ── Embedded files ────────────────────────────────────────────────────── + +// docker-compose.yml +const COMPOSE_YML: &str = include_str!("../../tests/docker/consumer/docker-compose.yml"); + +// smb-consumer-guest +const GUEST_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-guest/Dockerfile"); +const GUEST_SMB_CONF: &str = + include_str!("../../tests/docker/consumer/smb-consumer-guest/smb.conf"); + +// smb-consumer-auth +const AUTH_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-auth/Dockerfile"); +const AUTH_SMB_CONF: &str = include_str!("../../tests/docker/consumer/smb-consumer-auth/smb.conf"); + +// smb-consumer-both +const BOTH_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-both/Dockerfile"); +const BOTH_SMB_CONF: &str = include_str!("../../tests/docker/consumer/smb-consumer-both/smb.conf"); + +// smb-consumer-50shares +const SHARES50_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-50shares/Dockerfile"); +const SHARES50_SMB_CONF: &str = + include_str!("../../tests/docker/consumer/smb-consumer-50shares/smb.conf"); +const SHARES50_GENERATE_CONF: &str = + include_str!("../../tests/docker/consumer/smb-consumer-50shares/generate-conf.sh"); + +// smb-consumer-unicode +const UNICODE_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-unicode/Dockerfile"); +const UNICODE_SMB_CONF: &str = + include_str!("../../tests/docker/consumer/smb-consumer-unicode/smb.conf"); +const UNICODE_POPULATE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-unicode/populate.sh"); + +// smb-consumer-longnames +const LONGNAMES_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-longnames/Dockerfile"); +const LONGNAMES_SMB_CONF: &str = + include_str!("../../tests/docker/consumer/smb-consumer-longnames/smb.conf"); +const LONGNAMES_POPULATE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-longnames/populate.sh"); + +// smb-consumer-deepnest +const DEEPNEST_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-deepnest/Dockerfile"); +const DEEPNEST_SMB_CONF: &str = + include_str!("../../tests/docker/consumer/smb-consumer-deepnest/smb.conf"); +const DEEPNEST_POPULATE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-deepnest/populate.sh"); + +// smb-consumer-manyfiles +const MANYFILES_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-manyfiles/Dockerfile"); +const MANYFILES_SMB_CONF: &str = + include_str!("../../tests/docker/consumer/smb-consumer-manyfiles/smb.conf"); + +// smb-consumer-readonly +const READONLY_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-readonly/Dockerfile"); +const READONLY_SMB_CONF: &str = + include_str!("../../tests/docker/consumer/smb-consumer-readonly/smb.conf"); + +// smb-consumer-windows +const WINDOWS_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-windows/Dockerfile"); +const WINDOWS_SMB_CONF: &str = + include_str!("../../tests/docker/consumer/smb-consumer-windows/smb.conf"); + +// smb-consumer-synology +const SYNOLOGY_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-synology/Dockerfile"); +const SYNOLOGY_SMB_CONF: &str = + include_str!("../../tests/docker/consumer/smb-consumer-synology/smb.conf"); + +// smb-consumer-linux +const LINUX_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-linux/Dockerfile"); +const LINUX_SMB_CONF: &str = + include_str!("../../tests/docker/consumer/smb-consumer-linux/smb.conf"); + +// smb-consumer-flaky +const FLAKY_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-flaky/Dockerfile"); +const FLAKY_SMB_CONF: &str = + include_str!("../../tests/docker/consumer/smb-consumer-flaky/smb.conf"); +const FLAKY_CYCLE: &str = include_str!("../../tests/docker/consumer/smb-consumer-flaky/cycle.sh"); + +// smb-consumer-slow +const SLOW_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-slow/Dockerfile"); +const SLOW_SMB_CONF: &str = include_str!("../../tests/docker/consumer/smb-consumer-slow/smb.conf"); +const SLOW_ENTRYPOINT: &str = + include_str!("../../tests/docker/consumer/smb-consumer-slow/entrypoint.sh"); + +// smb-consumer-maxreadsize +const MAXREADSIZE_DOCKERFILE: &str = + include_str!("../../tests/docker/consumer/smb-consumer-maxreadsize/Dockerfile"); +const MAXREADSIZE_SMB_CONF: &str = + include_str!("../../tests/docker/consumer/smb-consumer-maxreadsize/smb.conf"); + +// ── Embedded file manifest ────────────────────────────────────────────── + +/// A file to write into the compose directory. +struct EmbeddedFile { + /// Path relative to the compose directory root. + relative_path: &'static str, + /// File contents. + contents: &'static str, + /// Whether the file should be executable (shell scripts). + executable: bool, +} + +/// All files needed to reproduce the consumer Docker infrastructure. +fn embedded_files() -> Vec { + vec![ + EmbeddedFile { + relative_path: "docker-compose.yml", + contents: COMPOSE_YML, + executable: false, + }, + // guest + EmbeddedFile { + relative_path: "smb-consumer-guest/Dockerfile", + contents: GUEST_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-guest/smb.conf", + contents: GUEST_SMB_CONF, + executable: false, + }, + // auth + EmbeddedFile { + relative_path: "smb-consumer-auth/Dockerfile", + contents: AUTH_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-auth/smb.conf", + contents: AUTH_SMB_CONF, + executable: false, + }, + // both + EmbeddedFile { + relative_path: "smb-consumer-both/Dockerfile", + contents: BOTH_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-both/smb.conf", + contents: BOTH_SMB_CONF, + executable: false, + }, + // 50shares + EmbeddedFile { + relative_path: "smb-consumer-50shares/Dockerfile", + contents: SHARES50_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-50shares/smb.conf", + contents: SHARES50_SMB_CONF, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-50shares/generate-conf.sh", + contents: SHARES50_GENERATE_CONF, + executable: true, + }, + // unicode + EmbeddedFile { + relative_path: "smb-consumer-unicode/Dockerfile", + contents: UNICODE_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-unicode/smb.conf", + contents: UNICODE_SMB_CONF, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-unicode/populate.sh", + contents: UNICODE_POPULATE, + executable: true, + }, + // longnames + EmbeddedFile { + relative_path: "smb-consumer-longnames/Dockerfile", + contents: LONGNAMES_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-longnames/smb.conf", + contents: LONGNAMES_SMB_CONF, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-longnames/populate.sh", + contents: LONGNAMES_POPULATE, + executable: true, + }, + // deepnest + EmbeddedFile { + relative_path: "smb-consumer-deepnest/Dockerfile", + contents: DEEPNEST_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-deepnest/smb.conf", + contents: DEEPNEST_SMB_CONF, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-deepnest/populate.sh", + contents: DEEPNEST_POPULATE, + executable: true, + }, + // manyfiles + EmbeddedFile { + relative_path: "smb-consumer-manyfiles/Dockerfile", + contents: MANYFILES_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-manyfiles/smb.conf", + contents: MANYFILES_SMB_CONF, + executable: false, + }, + // readonly + EmbeddedFile { + relative_path: "smb-consumer-readonly/Dockerfile", + contents: READONLY_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-readonly/smb.conf", + contents: READONLY_SMB_CONF, + executable: false, + }, + // windows + EmbeddedFile { + relative_path: "smb-consumer-windows/Dockerfile", + contents: WINDOWS_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-windows/smb.conf", + contents: WINDOWS_SMB_CONF, + executable: false, + }, + // synology + EmbeddedFile { + relative_path: "smb-consumer-synology/Dockerfile", + contents: SYNOLOGY_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-synology/smb.conf", + contents: SYNOLOGY_SMB_CONF, + executable: false, + }, + // linux + EmbeddedFile { + relative_path: "smb-consumer-linux/Dockerfile", + contents: LINUX_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-linux/smb.conf", + contents: LINUX_SMB_CONF, + executable: false, + }, + // flaky + EmbeddedFile { + relative_path: "smb-consumer-flaky/Dockerfile", + contents: FLAKY_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-flaky/smb.conf", + contents: FLAKY_SMB_CONF, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-flaky/cycle.sh", + contents: FLAKY_CYCLE, + executable: true, + }, + // slow + EmbeddedFile { + relative_path: "smb-consumer-slow/Dockerfile", + contents: SLOW_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-slow/smb.conf", + contents: SLOW_SMB_CONF, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-slow/entrypoint.sh", + contents: SLOW_ENTRYPOINT, + executable: true, + }, + // maxreadsize + EmbeddedFile { + relative_path: "smb-consumer-maxreadsize/Dockerfile", + contents: MAXREADSIZE_DOCKERFILE, + executable: false, + }, + EmbeddedFile { + relative_path: "smb-consumer-maxreadsize/smb.conf", + contents: MAXREADSIZE_SMB_CONF, + executable: false, + }, + ] +} + +// ── File writing ──────────────────────────────────────────────────────── + +/// Write all embedded Docker files to the given directory. +/// +/// Creates the directory structure Docker Compose expects: +/// +/// ```text +/// / +/// docker-compose.yml +/// smb-consumer-guest/ +/// Dockerfile +/// smb.conf +/// smb-consumer-auth/ +/// Dockerfile +/// smb.conf +/// ... +/// ``` +/// +/// Use this for Layer 2 (E2E tests) or Layer 3 (manual QA) where you +/// run `docker compose up` outside of Rust. +pub fn write_compose_files(dir: &Path) -> Result<()> { + let files = embedded_files(); + for file in &files { + let path = dir.join(file.relative_path); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).map_err(Error::Io)?; + } + fs::write(&path, file.contents).map_err(Error::Io)?; + + #[cfg(unix)] + if file.executable { + use std::os::unix::fs::PermissionsExt; + let perms = std::fs::Permissions::from_mode(0o755); + fs::set_permissions(&path, perms).map_err(Error::Io)?; + } + } + debug!("wrote {} embedded files to {}", files.len(), dir.display()); + Ok(()) +} + +// ── Profile ───────────────────────────────────────────────────────────── + +/// Which containers to start. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Profile { + /// Guest + auth only (fast startup). + Minimal, + /// All 14 containers. + All, +} + +impl Profile { + /// Service names for `docker compose up`. + fn services(self) -> &'static [&'static str] { + match self { + Profile::Minimal => &["smb-consumer-guest", "smb-consumer-auth"], + Profile::All => &[ + "smb-consumer-guest", + "smb-consumer-auth", + "smb-consumer-both", + "smb-consumer-50shares", + "smb-consumer-unicode", + "smb-consumer-longnames", + "smb-consumer-deepnest", + "smb-consumer-manyfiles", + "smb-consumer-readonly", + "smb-consumer-windows", + "smb-consumer-synology", + "smb-consumer-linux", + "smb-consumer-flaky", + "smb-consumer-slow", + ], + } + } +} + +// ── TestServers ───────────────────────────────────────────────────────── + +/// Docker-based SMB test servers for integration testing. +/// +/// Starts Samba containers on construction, stops on drop. Each server +/// type has a factory method returning a connected [`SmbClient`]. +/// +/// Consumers can also skip `TestServers` entirely and use the compose +/// files directly for E2E or manual testing via [`write_compose_files`]. +pub struct TestServers { + compose_dir: PathBuf, + profile: Profile, +} + +impl TestServers { + /// Start the minimal set: guest + auth containers. + /// + /// This is the fastest option (~2 seconds). Use [`start_all`](Self::start_all) + /// if you need all 14 containers. + pub async fn start() -> Result { + let servers = Self::prepare(Profile::Minimal)?; + servers.compose_up()?; + servers.wait_healthy()?; + Ok(servers) + } + + /// Start all 14 consumer containers. + pub async fn start_all() -> Result { + let servers = Self::prepare(Profile::All)?; + servers.compose_up()?; + servers.wait_healthy()?; + Ok(servers) + } + + /// Blocking version of [`start_all`](Self::start_all) for use in + /// [`LazyLock`](std::sync::LazyLock) statics. + /// + /// # Example + /// + /// ```rust,no_run + /// use std::sync::LazyLock; + /// use smb2::testing::TestServers; + /// + /// static SERVERS: LazyLock = LazyLock::new(|| { + /// TestServers::start_blocking().unwrap() + /// }); + /// ``` + pub fn start_blocking() -> Result { + let servers = Self::prepare(Profile::All)?; + servers.compose_up()?; + servers.wait_healthy()?; + Ok(servers) + } + + /// Guest-access server. No credentials needed. + pub async fn guest_client(&self) -> Result { + self.require_service("smb-consumer-guest")?; + let addr = format!("127.0.0.1:{}", guest_port()); + connect_guest(&addr).await + } + + /// Auth-required server. Needs username and password. + pub async fn auth_client(&self, user: &str, pass: &str) -> Result { + self.require_service("smb-consumer-auth")?; + let addr = format!("127.0.0.1:{}", auth_port()); + connect_auth(&addr, user, pass).await + } + + /// Mixed server, guest connection. Can access the "public" share only. + pub async fn both_client(&self) -> Result { + self.require_service("smb-consumer-both")?; + let addr = format!("127.0.0.1:{}", both_port()); + connect_guest(&addr).await + } + + /// Mixed server, authenticated connection. Can access both "public" + /// and "private" shares. + pub async fn both_client_auth(&self, user: &str, pass: &str) -> Result { + self.require_service("smb-consumer-both")?; + let addr = format!("127.0.0.1:{}", both_port()); + connect_auth(&addr, user, pass).await + } + + /// Read-only server. Writes return errors. + pub async fn readonly_client(&self) -> Result { + self.require_service("smb-consumer-readonly")?; + let addr = format!("127.0.0.1:{}", readonly_port()); + connect_guest(&addr).await + } + + /// Server with 50 shares for testing share enumeration at scale. + pub async fn many_shares_client(&self) -> Result { + self.require_service("smb-consumer-50shares")?; + let addr = format!("127.0.0.1:{}", many_shares_port()); + connect_guest(&addr).await + } + + /// Server with unicode share and file names (CJK, emoji, accented characters). + pub async fn unicode_client(&self) -> Result { + self.require_service("smb-consumer-unicode")?; + let addr = format!("127.0.0.1:{}", unicode_port()); + connect_guest(&addr).await + } + + /// Server with 200+ character filenames. Tests path truncation. + pub async fn longnames_client(&self) -> Result { + self.require_service("smb-consumer-longnames")?; + let addr = format!("127.0.0.1:{}", longnames_port()); + connect_guest(&addr).await + } + + /// Server with 50-level deep directory tree. Tests navigation overflow. + pub async fn deepnest_client(&self) -> Result { + self.require_service("smb-consumer-deepnest")?; + let addr = format!("127.0.0.1:{}", deepnest_port()); + connect_guest(&addr).await + } + + /// Server with 10,000+ files in one directory. + pub async fn many_files_client(&self) -> Result { + self.require_service("smb-consumer-manyfiles")?; + let addr = format!("127.0.0.1:{}", manyfiles_port()); + connect_guest(&addr).await + } + + /// Windows-like server (server string in smb.conf). Tests OS detection. + pub async fn windows_client(&self) -> Result { + self.require_service("smb-consumer-windows")?; + let addr = format!("127.0.0.1:{}", windows_port()); + connect_guest(&addr).await + } + + /// Synology-like server (server string in smb.conf). Tests NAS-specific UI. + pub async fn synology_client(&self) -> Result { + self.require_service("smb-consumer-synology")?; + let addr = format!("127.0.0.1:{}", synology_port()); + connect_guest(&addr).await + } + + /// Generic Linux Samba server. Most common real-world server type. + pub async fn linux_client(&self) -> Result { + self.require_service("smb-consumer-linux")?; + let addr = format!("127.0.0.1:{}", linux_port()); + connect_guest(&addr).await + } + + /// Flaky server (5 seconds up, 5 seconds down). Tests error recovery UI. + pub async fn flaky_client(&self) -> Result { + self.require_service("smb-consumer-flaky")?; + let addr = format!("127.0.0.1:{}", flaky_port()); + connect_guest(&addr).await + } + + /// Slow server (200ms latency). Tests loading states and timeouts. + pub async fn slow_client(&self) -> Result { + self.require_service("smb-consumer-slow")?; + let addr = format!("127.0.0.1:{}", slow_port()); + connect_guest(&addr).await + } + + // ── Internal helpers ──────────────────────────────────────────── + + /// Create a temp directory, write embedded files, return the struct. + fn prepare(profile: Profile) -> Result { + let compose_dir = std::env::temp_dir().join(format!("smb2-testing-{}", std::process::id())); + write_compose_files(&compose_dir)?; + info!("prepared compose files in {}", compose_dir.display()); + Ok(Self { + compose_dir, + profile, + }) + } + + /// Run `docker compose up` for the selected profile. + fn compose_up(&self) -> Result<()> { + let services = self.profile.services(); + info!("starting {} container(s)", services.len()); + + let mut cmd = Command::new("docker"); + cmd.arg("compose") + .arg("-f") + .arg(self.compose_dir.join("docker-compose.yml")) + .arg("up") + .arg("-d") + .arg("--build"); + for svc in services { + cmd.arg(svc); + } + + debug!("running: {:?}", cmd); + let output = cmd.output().map_err(Error::Docker)?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + debug!("docker compose up stderr: {stderr}"); + return Err(Error::Docker(std::io::Error::other(format!( + "docker compose up failed: {stderr}" + )))); + } + Ok(()) + } + + /// Wait for all started containers to pass Docker health checks. + fn wait_healthy(&self) -> Result<()> { + let services = self.profile.services(); + let timeout = Duration::from_secs(30); + let poll_interval = Duration::from_millis(500); + let start = std::time::Instant::now(); + + for service in services { + // Skip health check for flaky container (it intentionally cycles). + if *service == "smb-consumer-flaky" { + debug!("skipping health check for {service} (intentionally flaky)"); + continue; + } + + loop { + if start.elapsed() > timeout { + return Err(Error::HealthCheckTimeout { + container: service.to_string(), + }); + } + + let output = Command::new("docker") + .arg("compose") + .arg("-f") + .arg(self.compose_dir.join("docker-compose.yml")) + .arg("ps") + .arg("--format") + .arg("{{.Health}}") + .arg(service) + .output() + .map_err(Error::Docker)?; + + let status = String::from_utf8_lossy(&output.stdout) + .trim() + .to_lowercase(); + if status.contains("healthy") { + debug!("{service} is healthy"); + break; + } + + debug!("{service} health: {status:?}, waiting..."); + std::thread::sleep(poll_interval); + } + } + + info!("all containers healthy"); + Ok(()) + } + + /// Check that a service is part of the current profile. + fn require_service(&self, service: &str) -> Result<()> { + if self.profile.services().contains(&service) { + Ok(()) + } else { + Err(Error::ContainerNotStarted { + container: service.to_string(), + hint: "call start_all() to start all containers".to_string(), + }) + } + } + + /// Run `docker compose down` (best-effort). + fn compose_down(&self) { + debug!("stopping containers in {}", self.compose_dir.display()); + let result = Command::new("docker") + .arg("compose") + .arg("-f") + .arg(self.compose_dir.join("docker-compose.yml")) + .arg("down") + .arg("--timeout") + .arg("5") + .output(); + + match result { + Ok(output) if output.status.success() => { + info!("containers stopped"); + } + Ok(output) => { + let stderr = String::from_utf8_lossy(&output.stderr); + debug!("docker compose down stderr: {stderr}"); + } + Err(e) => { + debug!("failed to run docker compose down: {e}"); + } + } + } + + /// Clean up the temp directory (best-effort). + fn cleanup_dir(&self) { + if self.compose_dir.exists() { + if let Err(e) = fs::remove_dir_all(&self.compose_dir) { + debug!("failed to clean up {}: {e}", self.compose_dir.display()); + } + } + } +} + +impl Drop for TestServers { + fn drop(&mut self) { + self.compose_down(); + self.cleanup_dir(); + } +} + +// ── Connection helpers ────────────────────────────────────────────────── + +async fn connect_guest(addr: &str) -> Result { + SmbClient::connect(ClientConfig { + addr: addr.to_string(), + timeout: Duration::from_secs(10), + username: String::new(), + password: String::new(), + domain: String::new(), + auto_reconnect: false, + compression: true, + dfs_enabled: false, + dfs_target_overrides: std::collections::HashMap::new(), + }) + .await + .map_err(Error::Smb) +} + +async fn connect_auth(addr: &str, user: &str, pass: &str) -> Result { + SmbClient::connect(ClientConfig { + addr: addr.to_string(), + timeout: Duration::from_secs(10), + username: user.to_string(), + password: pass.to_string(), + domain: String::new(), + auto_reconnect: false, + compression: true, + dfs_enabled: false, + dfs_target_overrides: std::collections::HashMap::new(), + }) + .await + .map_err(Error::Smb) +} + +// ── Tests ─────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + // ── Port resolution ───────────────────────────────────────────── + + #[test] + fn port_returns_default_when_env_unset() { + // Use a unique env var name that won't collide with real env. + let val = port("SMB2_TEST_NONEXISTENT_PORT_12345", 9999); + assert_eq!(val, 9999); + } + + #[test] + fn port_returns_env_value_when_set() { + let key = "SMB2_TEST_PORT_OVERRIDE_CHECK"; + std::env::set_var(key, "12345"); + let val = port(key, 9999); + std::env::remove_var(key); + assert_eq!(val, 12345); + } + + #[test] + fn port_returns_default_for_non_numeric_env() { + let key = "SMB2_TEST_PORT_BAD_VALUE"; + std::env::set_var(key, "not_a_number"); + let val = port(key, 7777); + std::env::remove_var(key); + assert_eq!(val, 7777); + } + + #[test] + fn port_returns_default_for_empty_env() { + let key = "SMB2_TEST_PORT_EMPTY"; + std::env::set_var(key, ""); + let val = port(key, 5555); + std::env::remove_var(key); + assert_eq!(val, 5555); + } + + // ── Default port values ───────────────────────────────────────── + + #[test] + fn default_ports_are_in_consumer_range() { + let ports = [ + DEFAULT_GUEST_PORT, + DEFAULT_AUTH_PORT, + DEFAULT_BOTH_PORT, + DEFAULT_50SHARES_PORT, + DEFAULT_UNICODE_PORT, + DEFAULT_LONGNAMES_PORT, + DEFAULT_DEEPNEST_PORT, + DEFAULT_MANYFILES_PORT, + DEFAULT_READONLY_PORT, + DEFAULT_WINDOWS_PORT, + DEFAULT_SYNOLOGY_PORT, + DEFAULT_LINUX_PORT, + DEFAULT_FLAKY_PORT, + DEFAULT_SLOW_PORT, + ]; + for p in ports { + assert!( + (10480..=10493).contains(&p), + "port {p} outside expected range 10480-10493" + ); + } + } + + #[test] + fn default_ports_are_unique() { + let ports = [ + DEFAULT_GUEST_PORT, + DEFAULT_AUTH_PORT, + DEFAULT_BOTH_PORT, + DEFAULT_50SHARES_PORT, + DEFAULT_UNICODE_PORT, + DEFAULT_LONGNAMES_PORT, + DEFAULT_DEEPNEST_PORT, + DEFAULT_MANYFILES_PORT, + DEFAULT_READONLY_PORT, + DEFAULT_WINDOWS_PORT, + DEFAULT_SYNOLOGY_PORT, + DEFAULT_LINUX_PORT, + DEFAULT_FLAKY_PORT, + DEFAULT_SLOW_PORT, + ]; + let mut seen = std::collections::HashSet::new(); + for p in ports { + assert!(seen.insert(p), "duplicate port: {p}"); + } + } + + // ── Error formatting ──────────────────────────────────────────── + + #[test] + fn error_display_docker() { + let err = Error::Docker(std::io::Error::new( + std::io::ErrorKind::NotFound, + "docker not found", + )); + let msg = err.to_string(); + assert!(msg.contains("docker command failed"), "got: {msg}"); + assert!(msg.contains("docker not found"), "got: {msg}"); + } + + #[test] + fn error_display_health_check_timeout() { + let err = Error::HealthCheckTimeout { + container: "smb-consumer-guest".to_string(), + }; + let msg = err.to_string(); + assert!(msg.contains("health check timed out"), "got: {msg}"); + assert!(msg.contains("smb-consumer-guest"), "got: {msg}"); + } + + #[test] + fn error_display_container_not_started() { + let err = Error::ContainerNotStarted { + container: "smb-consumer-unicode".to_string(), + hint: "call start_all()".to_string(), + }; + let msg = err.to_string(); + assert!(msg.contains("container not started"), "got: {msg}"); + assert!(msg.contains("smb-consumer-unicode"), "got: {msg}"); + assert!(msg.contains("start_all()"), "got: {msg}"); + } + + #[test] + fn error_display_io() { + let err = Error::Io(std::io::Error::new( + std::io::ErrorKind::PermissionDenied, + "permission denied", + )); + let msg = err.to_string(); + assert!(msg.contains("write compose files"), "got: {msg}"); + } + + #[test] + fn error_debug_is_implemented() { + let err = Error::HealthCheckTimeout { + container: "test".to_string(), + }; + // Just verify Debug doesn't panic. + let _ = format!("{err:?}"); + } + + // ── write_compose_files ───────────────────────────────────────── + + #[test] + fn write_compose_files_creates_expected_structure() { + let dir = std::env::temp_dir().join(format!("smb2-test-write-{}", std::process::id())); + // Clean up from any previous run. + let _ = fs::remove_dir_all(&dir); + + write_compose_files(&dir).unwrap(); + + // Verify top-level compose file. + assert!(dir.join("docker-compose.yml").exists()); + + // Verify all 15 container directories exist with Dockerfiles. + let containers = [ + "smb-consumer-guest", + "smb-consumer-auth", + "smb-consumer-both", + "smb-consumer-50shares", + "smb-consumer-unicode", + "smb-consumer-longnames", + "smb-consumer-deepnest", + "smb-consumer-manyfiles", + "smb-consumer-readonly", + "smb-consumer-windows", + "smb-consumer-synology", + "smb-consumer-linux", + "smb-consumer-flaky", + "smb-consumer-slow", + "smb-consumer-maxreadsize", + ]; + for name in containers { + let dockerfile = dir.join(name).join("Dockerfile"); + assert!(dockerfile.exists(), "missing Dockerfile for {name}"); + let smb_conf = dir.join(name).join("smb.conf"); + assert!(smb_conf.exists(), "missing smb.conf for {name}"); + } + + // Verify extra scripts exist. + assert!(dir.join("smb-consumer-50shares/generate-conf.sh").exists()); + assert!(dir.join("smb-consumer-unicode/populate.sh").exists()); + assert!(dir.join("smb-consumer-longnames/populate.sh").exists()); + assert!(dir.join("smb-consumer-deepnest/populate.sh").exists()); + assert!(dir.join("smb-consumer-flaky/cycle.sh").exists()); + assert!(dir.join("smb-consumer-slow/entrypoint.sh").exists()); + + // Clean up. + let _ = fs::remove_dir_all(&dir); + } + + #[test] + fn write_compose_files_content_matches_embedded() { + let dir = std::env::temp_dir().join(format!("smb2-test-content-{}", std::process::id())); + let _ = fs::remove_dir_all(&dir); + + write_compose_files(&dir).unwrap(); + + let compose = fs::read_to_string(dir.join("docker-compose.yml")).unwrap(); + assert!( + compose.contains("smb-consumer-guest"), + "compose file should reference guest service" + ); + assert!( + compose.contains("10480"), + "compose file should contain default guest port" + ); + + let guest_conf = fs::read_to_string(dir.join("smb-consumer-guest/smb.conf")).unwrap(); + assert!( + guest_conf.contains("[public]"), + "guest smb.conf should have [public] share" + ); + + let _ = fs::remove_dir_all(&dir); + } + + #[cfg(unix)] + #[test] + fn write_compose_files_scripts_are_executable() { + use std::os::unix::fs::PermissionsExt; + + let dir = std::env::temp_dir().join(format!("smb2-test-exec-{}", std::process::id())); + let _ = fs::remove_dir_all(&dir); + + write_compose_files(&dir).unwrap(); + + let scripts = [ + "smb-consumer-50shares/generate-conf.sh", + "smb-consumer-unicode/populate.sh", + "smb-consumer-longnames/populate.sh", + "smb-consumer-deepnest/populate.sh", + "smb-consumer-flaky/cycle.sh", + "smb-consumer-slow/entrypoint.sh", + ]; + for script in scripts { + let path = dir.join(script); + let mode = fs::metadata(&path).unwrap().permissions().mode(); + assert!( + mode & 0o111 != 0, + "{script} should be executable (mode: {mode:#o})" + ); + } + + let _ = fs::remove_dir_all(&dir); + } + + // ── Profile / require_service ─────────────────────────────────── + + #[test] + fn minimal_profile_includes_guest_and_auth() { + let services = Profile::Minimal.services(); + assert!(services.contains(&"smb-consumer-guest")); + assert!(services.contains(&"smb-consumer-auth")); + assert_eq!(services.len(), 2); + } + + #[test] + fn all_profile_includes_14_services() { + let services = Profile::All.services(); + assert_eq!(services.len(), 14); + } + + #[test] + fn require_service_ok_for_minimal_profile() { + let servers = TestServers { + compose_dir: PathBuf::from("/tmp/fake"), + profile: Profile::Minimal, + }; + assert!(servers.require_service("smb-consumer-guest").is_ok()); + assert!(servers.require_service("smb-consumer-auth").is_ok()); + } + + #[test] + fn require_service_fails_for_non_minimal_container() { + let servers = TestServers { + compose_dir: PathBuf::from("/tmp/fake"), + profile: Profile::Minimal, + }; + let err = servers.require_service("smb-consumer-unicode").unwrap_err(); + match err { + Error::ContainerNotStarted { container, hint } => { + assert_eq!(container, "smb-consumer-unicode"); + assert!(hint.contains("start_all()")); + } + other => panic!("expected ContainerNotStarted, got: {other:?}"), + } + } + + #[test] + fn require_service_ok_for_all_profile() { + let servers = TestServers { + compose_dir: PathBuf::from("/tmp/fake"), + profile: Profile::All, + }; + // Should succeed for every container. + for svc in Profile::All.services() { + assert!( + servers.require_service(svc).is_ok(), + "require_service failed for {svc}" + ); + } + } + + // ── Embedded file count ───────────────────────────────────────── + + #[test] + fn embedded_files_count() { + let files = embedded_files(); + // 1 compose + 14 containers * (Dockerfile + smb.conf) = 29 + // + 6 extra scripts = 35 + assert_eq!(files.len(), 35, "expected 35 embedded files"); + } + + #[test] + fn embedded_files_no_empty_contents() { + for file in embedded_files() { + assert!( + !file.contents.is_empty(), + "embedded file {} has empty contents", + file.relative_path + ); + } + } +} diff --git a/vendor/smb2/src/transport/CLAUDE.md b/vendor/smb2/src/transport/CLAUDE.md new file mode 100644 index 0000000..52a9169 --- /dev/null +++ b/vendor/smb2/src/transport/CLAUDE.md @@ -0,0 +1,53 @@ +# Transport -- send/receive abstraction + +Split transport traits for SMB2 message I/O. Two implementations: TCP and mock. + +## Key files + +| File | Purpose | +|---|---| +| `mod.rs` | `TransportSend`, `TransportReceive`, `Transport` traits | +| `tcp.rs` | `TcpTransport` -- direct TCP to port 445, handles framing | +| `mock.rs` | `MockTransport` -- FIFO response queue for testing | + +## Split traits + +`TransportSend` and `TransportReceive` are separate traits. This avoids deadlock in the pipeline's `tokio::select!` loop where one task sends requests while another concurrently reads responses on the same connection. A single `Transport` trait would require `&mut self` for both directions, making concurrent send+receive impossible without `Arc`. + +The blanket impl `Transport` combines both halves. `Connection` stores `Box` and `Box` separately. + +## TCP framing + +``` +[0x00] [length: 3 bytes, big-endian] [SMB2 message(s)] +``` + +- First byte must be `0x00` +- Next 3 bytes: message length in big-endian (network byte order) +- Maximum frame size: 16 MB +- This is the ONLY big-endian value in SMB2 + +`TcpTransport::send` prepends the 4-byte header. `TcpTransport::receive` reads the header, then `read_exact` for the payload. + +## Who reads the transport + +`TransportReceive::receive()` is called by exactly one owner: the background receiver task spawned by `Connection::from_transport` (Phase 2 actor refactor). No other code path calls `receive()` in production. This is the invariant that makes per-`MessageId` routing sound — there's a single serialized read of the wire, then demux to per-request `oneshot::Sender`s. See `src/client/CLAUDE.md` § "Connection internals: receiver task + `oneshot` routing". + +`TransportSend::send()` is called from the caller thread (the one holding `&mut Connection`). `TcpTransport`'s internal Mutex on the write half serializes sends — relevant for Phase 3 once `Connection` becomes `Clone`. + +## MockTransport + +Used by all unit tests. Stores sent messages for inspection and returns queued responses in FIFO order. Thread-safe via `std::sync::Mutex`. + +Phase 2 changed `receive()` from "return `Err(Disconnected)` immediately when the queue is empty" to "block on `tokio::sync::Notify` until data is queued or `close()` is called". Required because the Connection's receiver task calls `receive()` in a loop — a premature `Disconnected` would kill the task while a test was still setting up responses. + +- `queue_response(data)` / `queue_responses(vec)` push to the queue and call `notify_one()`. `notify_one` stores a permit if no receiver is parked, so the next `.notified().await` returns immediately. +- `close()` sets an atomic `closed` flag and calls BOTH `notify_one()` (covers the wake-loss race where `receive()` is between `closed.load()` and `.notified().await`) and `notify_waiters()` (wakes already-parked waiters). +- External consumers using `MockTransport` in their own tests must call `close()` to get an explicit end-of-stream; the implicit "empty queue = disconnected" behavior is gone. + +## Gotchas + +- **Partial TCP reads**: Always use `read_exact` to read the full frame. TCP can deliver partial data in any `read()` call. +- **16 MB max frame**: Reject frames larger than 16 MB to prevent OOM from malicious servers. +- **Frame may contain multiple messages**: Compound responses arrive in a single frame. The Connection's receiver task splits them by `NextCommand` offsets and routes each sub-response by `MessageId` independently. +- **`MockTransport::close()` wake-loss**: `notify_waiters()` alone only wakes already-parked waiters; if `close()` fires between `receive()`'s `closed.load()` check and its `notified().await`, the signal is lost. `close()` therefore also calls `notify_one()` to store a permit — next `.notified().await` returns immediately and the loop re-observes `closed=true`. Noticed via code review after Phase 2. diff --git a/vendor/smb2/src/transport/mock.rs b/vendor/smb2/src/transport/mock.rs new file mode 100644 index 0000000..52337d0 --- /dev/null +++ b/vendor/smb2/src/transport/mock.rs @@ -0,0 +1,507 @@ +//! Mock transport for testing. +//! +//! Provides a [`MockTransport`] that queues canned responses and records +//! sent messages, enabling test-driven development of higher layers +//! without needing a real SMB server. + +use async_trait::async_trait; +use std::collections::VecDeque; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Mutex; + +use tokio::sync::Notify; + +use crate::error::{Error, Result}; +use crate::transport::{TransportReceive, TransportSend}; + +/// A mock transport that queues responses and records sent messages. +/// +/// Use this in tests to simulate server conversations without a real +/// network connection. Responses are returned in FIFO order. +/// +/// `receive()` awaits on an internal `Notify` when the queue is empty, +/// so the background receiver task doesn't exit prematurely between +/// `queue_response` calls. Explicit disconnect is triggered by calling +/// [`Self::close`]. +pub struct MockTransport { + /// Responses to return on `receive()`, in order. + responses: Mutex>>, + /// Messages that were sent, for assertions. + sent: Mutex>>, + /// How many times `receive()` was called successfully (returning Ok). + receive_count: Mutex, + /// Wakes receivers when a response is queued or `close()` is called. + notify: Notify, + /// Set by `close()` to signal end-of-stream. + closed: AtomicBool, + /// When `true`, `receive()` rewrites each response sub-frame's + /// `MessageId` to match the `MessageId` of the next pending sent request + /// (and consumes it). See [`Self::enable_auto_rewrite_msg_id`]. + auto_rewrite: AtomicBool, + /// FIFO of `MessageId`s observed in `send()` that haven't yet been + /// consumed by a `receive()` rewrite. Only used when `auto_rewrite` + /// is on. + pending_sent_msg_ids: Mutex>, + /// Signaled whenever a new send is recorded or a close happens — used + /// by `receive()` in auto-rewrite mode to wait for a sent msg_id to + /// pair with a queued response. + send_notify: Notify, +} + +impl MockTransport { + /// Create a new mock with no queued responses. + pub fn new() -> Self { + Self { + responses: Mutex::new(VecDeque::new()), + sent: Mutex::new(Vec::new()), + receive_count: Mutex::new(0), + notify: Notify::new(), + closed: AtomicBool::new(false), + auto_rewrite: AtomicBool::new(false), + pending_sent_msg_ids: Mutex::new(VecDeque::new()), + send_notify: Notify::new(), + } + } + + /// Enable msg_id rewriting: when `true`, `receive()` rewrites each + /// response sub-frame's `MessageId` in-place to match the `MessageId` + /// of the next request recorded by `send()` (FIFO pairing). + /// + /// Without this, canned response builders hardcode `MessageId(0)` and + /// won't match the caller's allocated msg_ids — the receiver task + /// drops them as orphans and every caller hangs. This mode is the + /// test-fixture replacement for the pre-Phase-3 orphan-filter-off + /// path. Compound responses (multiple sub-frames chained via + /// `NextCommand`) each consume one sent msg_id in order. + /// + /// The receive side blocks until both a queued response and a sent + /// msg_id are available, so tests can queue responses before or + /// after the caller sends. + pub fn enable_auto_rewrite_msg_id(&self) { + self.auto_rewrite.store(true, Ordering::Release); + } + + /// Queue a response to be returned by the next `receive()` call. + pub fn queue_response(&self, data: Vec) { + self.responses.lock().unwrap().push_back(data); + self.notify.notify_one(); + } + + /// Queue multiple responses to be returned in order. + pub fn queue_responses(&self, responses: Vec>) { + let mut guard = self.responses.lock().unwrap(); + let count = responses.len(); + for r in responses { + guard.push_back(r); + } + drop(guard); + for _ in 0..count { + self.notify.notify_one(); + } + } + + /// Signal end-of-stream: after all queued responses are drained, + /// `receive()` returns `Err(Error::Disconnected)`. + pub fn close(&self) { + self.closed.store(true, Ordering::Release); + // Use `notify_one` (stores a permit for the next `notified().await`) + // in addition to `notify_waiters` (wakes currently-parked waiters). + // `notify_waiters` alone loses the signal if `close()` fires + // between `receive()`'s `closed.load()` check and its + // `notified().await` — no waiter is parked yet, so nothing gets + // woken. The stored permit from `notify_one` covers that gap. + self.notify.notify_one(); + self.notify.notify_waiters(); + // Same treatment for the send-notification used by auto-rewrite: + // close should wake a receive that's blocked waiting for a paired + // sent msg_id so it observes `closed` and bails out. + self.send_notify.notify_one(); + self.send_notify.notify_waiters(); + } + + /// Get all messages that were sent. + pub fn sent_messages(&self) -> Vec> { + self.sent.lock().unwrap().clone() + } + + /// Get the nth sent message, or `None` if out of bounds. + pub fn sent_message(&self, n: usize) -> Option> { + self.sent.lock().unwrap().get(n).cloned() + } + + /// How many messages have been sent. + pub fn sent_count(&self) -> usize { + self.sent.lock().unwrap().len() + } + + /// Clear all recorded sent messages. + pub fn clear_sent(&self) { + self.sent.lock().unwrap().clear(); + } + + /// How many times `receive()` was called successfully (returned Ok). + pub fn received_count(&self) -> usize { + *self.receive_count.lock().unwrap() + } + + /// How many responses are still queued and unread. + /// + /// Useful in tests that want to assert the code-under-test consumed + /// every response it was expected to, without leaking any to a + /// later test or leaving stale state that could mask a bug. + pub fn pending_responses(&self) -> usize { + self.responses.lock().unwrap().len() + } + + /// Assert that every queued response has been consumed. + /// + /// Panics with a descriptive message if any responses remain in the + /// queue. Use at the end of a test to catch the "caller forgot to + /// receive" pattern that produces response-pipe pollution in + /// real usage. + #[track_caller] + pub fn assert_fully_consumed(&self) { + let remaining = self.pending_responses(); + assert_eq!( + remaining, 0, + "MockTransport has {} queued response(s) the code-under-test never read. \ + This usually means a caller sent a request but never received its response, \ + which in real usage leaves an orphan on the wire and corrupts the next op.", + remaining + ); + } +} + +impl Default for MockTransport { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl TransportSend for MockTransport { + async fn send(&self, data: &[u8]) -> Result<()> { + // In auto-rewrite mode, capture the MessageId of each sub-frame + // so `receive()` can rewrite a queued response to match. + if self.auto_rewrite.load(Ordering::Acquire) { + for msg_id in extract_msg_ids(data) { + self.pending_sent_msg_ids.lock().unwrap().push_back(msg_id); + self.send_notify.notify_one(); + } + } + self.sent.lock().unwrap().push(data.to_vec()); + Ok(()) + } +} + +#[async_trait] +impl TransportReceive for MockTransport { + async fn receive(&self) -> Result> { + loop { + let auto = self.auto_rewrite.load(Ordering::Acquire); + // Wait for a queued response first (auto mode and plain mode + // both need one to exist). + let has_response = !self.responses.lock().unwrap().is_empty(); + if !has_response { + if self.closed.load(Ordering::Acquire) { + return Err(Error::Disconnected); + } + self.notify.notified().await; + continue; + } + + if auto { + // We have a response; peek its sub-frame count and wait + // for at least that many sent msg_ids to be queued + // (one consumed per sub-frame, even ones that already + // have non-zero msg_ids, so pairing stays 1:1). + let needed = { + let guard = self.responses.lock().unwrap(); + match guard.front() { + Some(frame) => count_sub_frames(frame), + None => continue, + } + }; + if needed > 0 { + loop { + let have = self.pending_sent_msg_ids.lock().unwrap().len(); + if have >= needed { + break; + } + if self.closed.load(Ordering::Acquire) { + return Err(Error::Disconnected); + } + self.send_notify.notified().await; + } + } + // Consume one response and `needed` sent msg_ids, + // rewriting each sub-frame's zero msg_id to match the + // corresponding sent msg_id. + let mut data = match self.responses.lock().unwrap().pop_front() { + Some(d) => d, + None => continue, + }; + let mut ids = self.pending_sent_msg_ids.lock().unwrap(); + rewrite_msg_ids(&mut data, &mut ids); + drop(ids); + *self.receive_count.lock().unwrap() += 1; + return Ok(data); + } + + // Plain mode: just pop and return. + let data = match self.responses.lock().unwrap().pop_front() { + Some(d) => d, + None => continue, + }; + *self.receive_count.lock().unwrap() += 1; + return Ok(data); + } + } +} + +/// Extract `MessageId`s from a packed SMB2 request frame (possibly compound). +/// Returns one msg_id per sub-frame, following `NextCommand` offsets. +/// Returns an empty Vec if the data isn't a recognizable SMB2 frame — +/// e.g. when `send()` is used with arbitrary bytes in transport-level tests. +fn extract_msg_ids(data: &[u8]) -> Vec { + const HEADER_MIN: usize = 64; + if data.len() < HEADER_MIN { + return Vec::new(); + } + // Not an SMB2 header — skip (non-SMB2 tests call send with arbitrary bytes). + if &data[0..4] != b"\xFESMB" { + return Vec::new(); + } + let mut ids = Vec::new(); + let mut offset = 0usize; + loop { + if offset + HEADER_MIN > data.len() { + break; + } + let msg_id = + u64::from_le_bytes(data[offset + 24..offset + 32].try_into().unwrap_or([0; 8])); + ids.push(msg_id); + let next = u32::from_le_bytes(data[offset + 20..offset + 24].try_into().unwrap_or([0; 4])); + if next == 0 { + break; + } + offset += next as usize; + } + ids +} + +/// Count sub-frames in a packed SMB2 response frame by walking +/// `NextCommand` offsets. Returns 0 for non-SMB2 frames, otherwise the +/// total sub-frame count. `rewrite_msg_ids` consumes one sent msg_id +/// per sub-frame (even those with already-set msg_ids) to keep +/// send→receive pairing strictly 1:1 and avoid queue drift in tests +/// that hardcode some but not all msg_ids. +fn count_sub_frames(data: &[u8]) -> usize { + const HEADER_MIN: usize = 64; + if data.len() < HEADER_MIN || &data[0..4] != b"\xFESMB" { + return 0; + } + let mut count = 0usize; + let mut offset = 0usize; + loop { + if offset + HEADER_MIN > data.len() { + break; + } + count += 1; + let next = u32::from_le_bytes(data[offset + 20..offset + 24].try_into().unwrap_or([0; 4])); + if next == 0 { + break; + } + offset += next as usize; + } + count +} + +/// Rewrite each sub-frame's `MessageId` in-place, consuming one id from +/// `ids` per sub-frame in FIFO order. Sub-frames whose msg_id is +/// already non-zero keep their hardcoded id (so tests exercising out-of- +/// order routing still work) but STILL consume one id from the queue +/// to keep send→receive pairing 1:1. +fn rewrite_msg_ids(data: &mut [u8], ids: &mut VecDeque) { + const HEADER_MIN: usize = 64; + if data.len() < HEADER_MIN || &data[0..4] != b"\xFESMB" { + return; + } + let mut offset = 0usize; + loop { + if offset + HEADER_MIN > data.len() { + break; + } + let existing = + u64::from_le_bytes(data[offset + 24..offset + 32].try_into().unwrap_or([0; 8])); + let consumed = ids.pop_front(); + if existing == 0 { + if let Some(id) = consumed { + data[offset + 24..offset + 32].copy_from_slice(&id.to_le_bytes()); + } else { + break; + } + } + let next = u32::from_le_bytes(data[offset + 20..offset + 24].try_into().unwrap_or([0; 4])); + if next == 0 { + break; + } + offset += next as usize; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn queue_response_and_receive_it() { + let mock = MockTransport::new(); + let data = vec![0x01, 0x02, 0x03]; + mock.queue_response(data.clone()); + + let received = mock.receive().await.unwrap(); + assert_eq!(received, data); + } + + #[tokio::test] + async fn queue_multiple_responses_received_in_order() { + let mock = MockTransport::new(); + mock.queue_responses(vec![vec![0x01], vec![0x02, 0x03], vec![0x04, 0x05, 0x06]]); + + assert_eq!(mock.receive().await.unwrap(), vec![0x01]); + assert_eq!(mock.receive().await.unwrap(), vec![0x02, 0x03]); + assert_eq!(mock.receive().await.unwrap(), vec![0x04, 0x05, 0x06]); + } + + #[tokio::test] + async fn close_causes_receive_to_return_disconnected() { + let mock = MockTransport::new(); + mock.close(); + + let result = mock.receive().await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + matches!(err, Error::Disconnected), + "expected Disconnected, got: {err}" + ); + } + + #[tokio::test] + async fn send_records_message() { + let mock = MockTransport::new(); + let msg = vec![0xAA, 0xBB, 0xCC]; + + mock.send(&msg).await.unwrap(); + + let sent = mock.sent_messages(); + assert_eq!(sent.len(), 1); + assert_eq!(sent[0], msg); + } + + #[tokio::test] + async fn sent_count_tracks_correctly() { + let mock = MockTransport::new(); + assert_eq!(mock.sent_count(), 0); + + mock.send(&[0x01]).await.unwrap(); + assert_eq!(mock.sent_count(), 1); + + mock.send(&[0x02]).await.unwrap(); + assert_eq!(mock.sent_count(), 2); + + mock.send(&[0x03]).await.unwrap(); + assert_eq!(mock.sent_count(), 3); + } + + #[tokio::test] + async fn sent_message_returns_nth() { + let mock = MockTransport::new(); + mock.send(&[0x0A]).await.unwrap(); + mock.send(&[0x0B]).await.unwrap(); + mock.send(&[0x0C]).await.unwrap(); + + assert_eq!(mock.sent_message(0), Some(vec![0x0A])); + assert_eq!(mock.sent_message(1), Some(vec![0x0B])); + assert_eq!(mock.sent_message(2), Some(vec![0x0C])); + assert_eq!(mock.sent_message(3), None); + } + + #[tokio::test] + async fn clear_sent_removes_all_recorded_messages() { + let mock = MockTransport::new(); + mock.send(&[0x01]).await.unwrap(); + mock.send(&[0x02]).await.unwrap(); + assert_eq!(mock.sent_count(), 2); + + mock.clear_sent(); + assert_eq!(mock.sent_count(), 0); + assert!(mock.sent_messages().is_empty()); + } + + #[tokio::test] + async fn interleaved_send_and_receive() { + let mock = MockTransport::new(); + mock.queue_responses(vec![vec![0xF1], vec![0xF2], vec![0xF3]]); + + // Send a request, receive a response, repeat. + mock.send(&[0x01]).await.unwrap(); + assert_eq!(mock.receive().await.unwrap(), vec![0xF1]); + + mock.send(&[0x02]).await.unwrap(); + assert_eq!(mock.receive().await.unwrap(), vec![0xF2]); + + mock.send(&[0x03]).await.unwrap(); + assert_eq!(mock.receive().await.unwrap(), vec![0xF3]); + + // No more responses. Close to cause Disconnected. + mock.close(); + assert!(mock.receive().await.is_err()); + + // All three sends recorded. + assert_eq!(mock.sent_count(), 3); + } + + #[tokio::test] + async fn concurrent_send_and_receive() { + use std::sync::Arc; + + let mock = Arc::new(MockTransport::new()); + mock.queue_responses(vec![vec![0xAA]; 10]); + + let send_mock = Arc::clone(&mock); + let send_task = tokio::spawn(async move { + for i in 0..10u8 { + send_mock.send(&[i]).await.unwrap(); + } + }); + + let recv_mock = Arc::clone(&mock); + let recv_task = tokio::spawn(async move { + let mut received = Vec::new(); + for _ in 0..10 { + received.push(recv_mock.receive().await.unwrap()); + } + received + }); + + send_task.await.unwrap(); + let received = recv_task.await.unwrap(); + + assert_eq!(received.len(), 10); + assert_eq!(mock.sent_count(), 10); + } + + #[tokio::test] + async fn empty_message_can_be_sent_and_received() { + let mock = MockTransport::new(); + mock.queue_response(vec![]); + + mock.send(&[]).await.unwrap(); + let received = mock.receive().await.unwrap(); + + assert!(received.is_empty()); + assert_eq!(mock.sent_message(0), Some(vec![])); + } +} diff --git a/vendor/smb2/src/transport/mod.rs b/vendor/smb2/src/transport/mod.rs new file mode 100644 index 0000000..492ab13 --- /dev/null +++ b/vendor/smb2/src/transport/mod.rs @@ -0,0 +1,215 @@ +//! Transport abstraction for sending and receiving SMB2 messages. +//! +//! The transport layer handles framing (TCP's 4-byte length-prefix header) +//! and provides split send/receive traits to avoid deadlocks in the +//! pipeline's `tokio::select!` loop. +//! +//! Two implementations are provided: +//! - [`TcpTransport`] -- direct TCP connection to an SMB server (port 445) +//! - [`MockTransport`] -- canned responses for testing +//! +//! Most users don't need this module directly -- use [`SmbClient`](crate::SmbClient) +//! which handles transport setup internally. + +pub mod mock; +pub mod tcp; + +pub use mock::MockTransport; +pub use tcp::TcpTransport; + +use crate::error::Result; +use async_trait::async_trait; + +/// Send half of a transport connection. +#[async_trait] +pub trait TransportSend: Send + Sync { + /// Send a complete SMB2 message (the implementation adds framing). + async fn send(&self, data: &[u8]) -> Result<()>; +} + +/// Receive half of a transport connection. +#[async_trait] +pub trait TransportReceive: Send + Sync { + /// Receive one complete SMB2 transport frame. + /// + /// The implementation handles the TCP framing (4-byte header: + /// 1 zero byte + 3-byte big-endian length). The returned buffer + /// contains the SMB2 message(s) without the framing header. + /// + /// The buffer may contain multiple compounded responses linked + /// by NextCommand in the SMB2 headers -- the caller must split them. + async fn receive(&self) -> Result>; +} + +/// A combined transport that can both send and receive. +pub trait Transport: TransportSend + TransportReceive {} + +// Blanket implementation: anything that implements both halves is a Transport. +impl Transport for T {} + +#[cfg(test)] +mod tests { + use super::*; + use crate::msg::header::{Header, PROTOCOL_ID}; + use crate::msg::negotiate::{ + NegotiateContext, NegotiateRequest, NegotiateResponse, HASH_ALGORITHM_SHA512, + }; + use crate::pack::{Guid, Pack, ReadCursor, Unpack, WriteCursor}; + use crate::types::flags::{Capabilities, SecurityMode}; + use crate::types::{Command, Dialect}; + + /// Pack a header + body into raw SMB2 message bytes (no transport framing). + fn pack_message(header: &Header, body: &dyn Pack) -> Vec { + let mut cursor = WriteCursor::new(); + header.pack(&mut cursor); + body.pack(&mut cursor); + cursor.into_inner() + } + + #[tokio::test] + async fn cross_module_negotiate_via_mock_transport() { + // Build a NegotiateRequest, send it through MockTransport, + // receive a canned NegotiateResponse, and verify unpacking. + + let mock = MockTransport::new(); + + // Build a negotiate request. + let req_header = Header::new_request(Command::Negotiate); + let req_body = NegotiateRequest { + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + capabilities: Capabilities::default(), + client_guid: Guid { + data1: 0xDEAD_BEEF, + data2: 0xCAFE, + data3: 0xF00D, + data4: [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08], + }, + dialects: vec![Dialect::Smb2_0_2, Dialect::Smb2_1, Dialect::Smb3_1_1], + negotiate_contexts: vec![NegotiateContext::PreauthIntegrity { + hash_algorithms: vec![HASH_ALGORITHM_SHA512], + salt: vec![0xAA; 32], + }], + }; + let req_msg = pack_message(&req_header, &req_body); + + // Build a canned NegotiateResponse. + let resp_header = { + let mut h = Header::new_request(Command::Negotiate); + h.flags.set_response(); + h.credits = 1; + h + }; + let resp_body = NegotiateResponse { + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + dialect_revision: Dialect::Smb3_1_1, + server_guid: Guid { + data1: 0x1111_2222, + data2: 0x3333, + data3: 0x4444, + data4: [0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC], + }, + capabilities: Capabilities::new(Capabilities::DFS | Capabilities::LEASING), + max_transact_size: 65536, + max_read_size: 65536, + max_write_size: 65536, + system_time: 132_000_000_000_000_000, + server_start_time: 131_000_000_000_000_000, + security_buffer: vec![0x60, 0x00], // minimal placeholder + negotiate_contexts: vec![NegotiateContext::PreauthIntegrity { + hash_algorithms: vec![HASH_ALGORITHM_SHA512], + salt: vec![0xBB; 32], + }], + }; + let resp_msg = pack_message(&resp_header, &resp_body); + + // Queue the canned response. + mock.queue_response(resp_msg); + + // Send the request through the mock. + mock.send(&req_msg).await.unwrap(); + + // Receive the canned response. + let received = mock.receive().await.unwrap(); + + // Unpack and verify. + let mut cursor = ReadCursor::new(&received); + let hdr = Header::unpack(&mut cursor).unwrap(); + assert!(hdr.is_response()); + assert_eq!(hdr.command, Command::Negotiate); + + let body = NegotiateResponse::unpack(&mut cursor).unwrap(); + assert_eq!(body.dialect_revision, Dialect::Smb3_1_1); + assert_eq!(body.max_read_size, 65536); + assert!(body.security_mode.signing_enabled()); + + // Verify the request was recorded. + assert_eq!(mock.sent_count(), 1); + let sent = mock.sent_message(0).unwrap(); + + // Verify we can unpack what was sent. + let mut cursor = ReadCursor::new(&sent); + let sent_hdr = Header::unpack(&mut cursor).unwrap(); + assert_eq!(sent_hdr.command, Command::Negotiate); + assert!(!sent_hdr.is_response()); + + let sent_body = NegotiateRequest::unpack(&mut cursor).unwrap(); + assert_eq!(sent_body.dialects.len(), 3); + assert!(sent_body.dialects.contains(&Dialect::Smb3_1_1)); + } + + #[tokio::test] + #[ignore] // Requires NAS at 192.168.1.111 + async fn negotiate_via_tcp_transport() { + use std::time::Duration; + + let transport = TcpTransport::connect("192.168.1.111:445", Duration::from_secs(5)) + .await + .expect("failed to connect to NAS"); + + // Build a negotiate request. + let header = Header::new_request(Command::Negotiate); + let request = NegotiateRequest { + security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED), + capabilities: Capabilities::new( + Capabilities::DFS | Capabilities::LEASING | Capabilities::LARGE_MTU, + ), + client_guid: Guid { + data1: 0xDEAD_BEEF, + data2: 0xCAFE, + data3: 0xF00D, + data4: [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08], + }, + dialects: vec![ + Dialect::Smb2_0_2, + Dialect::Smb2_1, + Dialect::Smb3_0, + Dialect::Smb3_0_2, + Dialect::Smb3_1_1, + ], + negotiate_contexts: vec![NegotiateContext::PreauthIntegrity { + hash_algorithms: vec![HASH_ALGORITHM_SHA512], + salt: vec![0xAA; 32], + }], + }; + + let msg = pack_message(&header, &request); + + // Send through transport (framing added automatically). + transport.send(&msg).await.unwrap(); + + // Receive response (framing stripped automatically). + let resp_bytes = transport.receive().await.unwrap(); + + // Verify we got a valid response. + assert!(resp_bytes[0..4] == PROTOCOL_ID); + + let mut cursor = ReadCursor::new(&resp_bytes); + let resp_header = Header::unpack(&mut cursor).unwrap(); + assert!(resp_header.is_response()); + assert_eq!(resp_header.command, Command::Negotiate); + + let resp_body = NegotiateResponse::unpack(&mut cursor).unwrap(); + assert!(Dialect::ALL.contains(&resp_body.dialect_revision)); + assert!(resp_body.max_read_size >= 65536); + } +} diff --git a/vendor/smb2/src/transport/tcp.rs b/vendor/smb2/src/transport/tcp.rs new file mode 100644 index 0000000..56ba3be --- /dev/null +++ b/vendor/smb2/src/transport/tcp.rs @@ -0,0 +1,485 @@ +//! Direct TCP transport for SMB2 (port 445). +//! +//! Implements the SMB2 transport framing defined in MS-SMB2 section 2.1: +//! each message is preceded by a 4-byte header consisting of 1 zero byte +//! followed by 3 bytes of big-endian length. This is the ONLY big-endian +//! encoding in the entire SMB2 protocol. + +use async_trait::async_trait; +use log::{debug, error, trace}; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio::net::{TcpStream, ToSocketAddrs}; +use tokio::sync::Mutex; + +use crate::error::{Error, Result}; +use crate::transport::{TransportReceive, TransportSend}; + +/// Maximum frame size we accept (16 MB). +/// +/// Prevents denial-of-service from corrupt or malicious length fields. +/// Real SMB2 messages are typically much smaller (the largest negotiated +/// MaxReadSize/MaxWriteSize is usually 8 MB). +const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024; + +/// Direct TCP transport for SMB2. +/// +/// Wraps a TCP connection and handles the 4-byte framing header. +/// The connection is split into independent read and write halves +/// so that send and receive can proceed concurrently without contention +/// (required by the pipeline's `tokio::select!` loop). +#[derive(Debug)] +pub struct TcpTransport { + /// The read half of the TCP connection, behind a mutex for `&self` access. + reader: Mutex, + /// The write half of the TCP connection, behind a mutex for `&self` access. + writer: Mutex, +} + +impl TcpTransport { + /// Connect to an SMB server over TCP. + /// + /// Applies the given timeout to the connection attempt. Once connected, + /// the socket is split into independent read/write halves. + pub async fn connect(addr: impl ToSocketAddrs, timeout: Duration) -> Result { + let stream = tokio::time::timeout(timeout, TcpStream::connect(addr)) + .await + .map_err(|_| Error::Timeout)? + .map_err(Error::Io)?; + + // Disable Nagle's algorithm for lower latency on small messages. + stream.set_nodelay(true).map_err(Error::Io)?; + + debug!("tcp: connected, nodelay=true"); + let (reader, writer) = stream.into_split(); + + Ok(Self { + reader: Mutex::new(reader), + writer: Mutex::new(writer), + }) + } +} + +#[async_trait] +impl TransportSend for TcpTransport { + async fn send(&self, data: &[u8]) -> Result<()> { + let len = data.len(); + if len > MAX_FRAME_SIZE { + return Err(Error::invalid_data(format!( + "message size {} exceeds maximum frame size {}", + len, MAX_FRAME_SIZE + ))); + } + + // Build the 4-byte framing header: 0x00 + 3-byte BE length. + let mut frame_header = [0u8; 4]; + frame_header[0] = 0x00; + frame_header[1] = (len >> 16) as u8; + frame_header[2] = (len >> 8) as u8; + frame_header[3] = len as u8; + + let mut writer = self.writer.lock().await; + writer.write_all(&frame_header).await.map_err(Error::Io)?; + writer.write_all(data).await.map_err(Error::Io)?; + writer.flush().await.map_err(Error::Io)?; + + trace!("tcp: sent frame, len={}", len); + Ok(()) + } +} + +#[async_trait] +impl TransportReceive for TcpTransport { + async fn receive(&self) -> Result> { + let mut reader = self.reader.lock().await; + + // Read the 4-byte framing header. + let mut frame_header = [0u8; 4]; + reader.read_exact(&mut frame_header).await.map_err(|e| { + if e.kind() == std::io::ErrorKind::UnexpectedEof { + Error::Disconnected + } else { + Error::Io(e) + } + })?; + + // Validate the first byte is 0x00. + if frame_header[0] != 0x00 { + error!("tcp: invalid frame, first byte=0x{:02X}", frame_header[0]); + return Err(Error::invalid_data(format!( + "invalid transport frame: first byte must be 0x00, got 0x{:02X}", + frame_header[0] + ))); + } + + // Extract the 3-byte big-endian length. + let msg_len = ((frame_header[1] as usize) << 16) + | ((frame_header[2] as usize) << 8) + | (frame_header[3] as usize); + + // Validate against the maximum frame size. + if msg_len > MAX_FRAME_SIZE { + return Err(Error::invalid_data(format!( + "frame length {} exceeds maximum {}", + msg_len, MAX_FRAME_SIZE + ))); + } + + trace!("tcp: receiving frame, len={}", msg_len); + + // Read the message body. + let mut buf = vec![0u8; msg_len]; + reader.read_exact(&mut buf).await.map_err(|e| { + if e.kind() == std::io::ErrorKind::UnexpectedEof { + Error::Disconnected + } else { + Error::Io(e) + } + })?; + + trace!("tcp: received frame, len={}", msg_len); + Ok(buf) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Build a framed message (4-byte header + payload). + fn frame_message(payload: &[u8]) -> Vec { + let len = payload.len(); + let mut frame = Vec::with_capacity(4 + len); + frame.push(0x00); + frame.push((len >> 16) as u8); + frame.push((len >> 8) as u8); + frame.push(len as u8); + frame.extend_from_slice(payload); + frame + } + + // ── Send framing tests ────────────────────────────────────────── + + #[test] + fn frame_header_format_small_message() { + let payload = vec![0xFE, 0x53, 0x4D, 0x42]; // "SMB2 magic" + let framed = frame_message(&payload); + + // Header: [0x00, 0x00, 0x00, 0x04] + assert_eq!(framed[0], 0x00, "first byte must be 0x00"); + assert_eq!(framed[1], 0x00, "length high byte"); + assert_eq!(framed[2], 0x00, "length mid byte"); + assert_eq!(framed[3], 0x04, "length low byte = 4"); + assert_eq!(&framed[4..], &payload); + } + + #[test] + fn frame_header_format_medium_message() { + // 300 bytes -> 0x00, 0x00, 0x01, 0x2C + let payload = vec![0xAA; 300]; + let framed = frame_message(&payload); + + assert_eq!(framed[0], 0x00); + assert_eq!(framed[1], 0x00); + assert_eq!(framed[2], 0x01); + assert_eq!(framed[3], 0x2C); + assert_eq!(framed.len(), 304); + } + + #[test] + fn frame_header_format_large_message() { + // 0x010203 = 66051 bytes + let payload = vec![0xBB; 66051]; + let framed = frame_message(&payload); + + assert_eq!(framed[0], 0x00); + assert_eq!(framed[1], 0x01); + assert_eq!(framed[2], 0x02); + assert_eq!(framed[3], 0x03); + } + + #[test] + fn frame_header_empty_payload() { + let framed = frame_message(&[]); + assert_eq!(framed, vec![0x00, 0x00, 0x00, 0x00]); + } + + // ── Receive framing tests (using tokio_test-style mock streams) ── + + /// A helper that creates a pair of connected streams via a TCP listener + /// on localhost, then writes data to one side and reads from the other. + async fn receive_from_bytes(data: &[u8]) -> Result> { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let data = data.to_vec(); + let writer_task = tokio::spawn(async move { + let mut stream = TcpStream::connect(addr).await.unwrap(); + stream.write_all(&data).await.unwrap(); + stream.shutdown().await.unwrap(); + }); + + let (stream, _) = listener.accept().await.unwrap(); + let (reader, writer) = stream.into_split(); + let transport = TcpTransport { + reader: Mutex::new(reader), + writer: Mutex::new(writer), + }; + + let result = transport.receive().await; + writer_task.await.unwrap(); + result + } + + #[tokio::test] + async fn receive_valid_frame() { + let payload = vec![0xFE, 0x53, 0x4D, 0x42, 0x01, 0x02]; + let framed = frame_message(&payload); + + let received = receive_from_bytes(&framed).await.unwrap(); + assert_eq!(received, payload); + } + + #[tokio::test] + async fn receive_empty_payload() { + let framed = frame_message(&[]); + let received = receive_from_bytes(&framed).await.unwrap(); + assert!(received.is_empty()); + } + + #[tokio::test] + async fn receive_first_byte_not_zero_returns_error() { + // First byte is 0x01 instead of 0x00. + let data = vec![0x01, 0x00, 0x00, 0x04, 0xAA, 0xBB, 0xCC, 0xDD]; + + let result = receive_from_bytes(&data).await; + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("first byte must be 0x00"), + "unexpected error: {err}" + ); + } + + #[tokio::test] + async fn receive_length_exceeds_max_returns_error() { + // Length = 0xFFFFFF = 16777215 > MAX_FRAME_SIZE (16 * 1024 * 1024 = 16777216) + // Wait, 0xFFFFFF = 16777215 < 16777216. Let's use a length just over. + // MAX_FRAME_SIZE = 16 * 1024 * 1024 = 16_777_216 + // We need > 16_777_216, but max 3-byte value is 16_777_215. + // So 3 bytes can't exceed 16 MB. But the spec says 16 MB is the max. + // Let's set MAX_FRAME_SIZE to slightly less, or test at the boundary. + // Actually MAX_FRAME_SIZE = 16 * 1024 * 1024 = 16_777_216. + // Max 3-byte value = 0xFFFFFF = 16_777_215 which is < MAX_FRAME_SIZE. + // So a 3-byte length can never exceed our MAX_FRAME_SIZE. + // This test verifies that the max 3-byte value IS accepted (no error). + // But what if someone sends a broken frame? The first byte check + // catches that. For the length check specifically, we'd need a + // smaller MAX_FRAME_SIZE to exercise the branch. For now, let's test + // with an internal test. The important thing is the check exists. + + // Actually, the more realistic concern is a malicious server sending + // large values. 0xFFFFFF = ~16 MB is fine by our limit. Let's verify + // the boundary: 0xFFFFFF should be accepted because 16_777_215 < 16_777_216. + // We can't test > MAX_FRAME_SIZE with only 3 bytes, but the check + // is there for defense-in-depth (the first byte could be non-zero + // and interpreted as part of length if we didn't validate it). + + // Let's test a frame with length 0xFFFFFF but not enough payload data, + // which should return Disconnected (not a crash from huge allocation). + let data = vec![0x00, 0xFF, 0xFF, 0xFF]; // Length = 16_777_215 bytes, no payload. + + let result = receive_from_bytes(&data).await; + assert!(result.is_err()); + // Should get Disconnected because the payload read fails. + let err = result.unwrap_err(); + assert!( + matches!(err, Error::Disconnected), + "expected Disconnected for truncated large frame, got: {err}" + ); + } + + #[tokio::test] + async fn receive_disconnected_on_eof() { + // Empty data = immediate EOF. + let result = receive_from_bytes(&[]).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + matches!(err, Error::Disconnected), + "expected Disconnected, got: {err}" + ); + } + + #[tokio::test] + async fn receive_partial_header_returns_disconnected() { + // Only 2 bytes of the 4-byte header. + let result = receive_from_bytes(&[0x00, 0x00]).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + matches!(err, Error::Disconnected), + "expected Disconnected for partial header, got: {err}" + ); + } + + #[tokio::test] + async fn receive_partial_payload_returns_disconnected() { + // Header says 10 bytes, but only 3 bytes of payload follow. + let data = vec![0x00, 0x00, 0x00, 0x0A, 0x01, 0x02, 0x03]; + + let result = receive_from_bytes(&data).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + matches!(err, Error::Disconnected), + "expected Disconnected for truncated payload, got: {err}" + ); + } + + #[tokio::test] + async fn send_and_receive_roundtrip() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let send_task = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (reader, writer) = stream.into_split(); + let transport = TcpTransport { + reader: Mutex::new(reader), + writer: Mutex::new(writer), + }; + + let payload = vec![0xFE, 0x53, 0x4D, 0x42, 0xDE, 0xAD]; + transport.send(&payload).await.unwrap(); + }); + + let (stream, _) = listener.accept().await.unwrap(); + let (reader, writer) = stream.into_split(); + let recv_transport = TcpTransport { + reader: Mutex::new(reader), + writer: Mutex::new(writer), + }; + + let received = recv_transport.receive().await.unwrap(); + assert_eq!(received, vec![0xFE, 0x53, 0x4D, 0x42, 0xDE, 0xAD]); + + send_task.await.unwrap(); + } + + #[tokio::test] + async fn send_and_receive_multiple_messages() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let send_task = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (reader, writer) = stream.into_split(); + let transport = TcpTransport { + reader: Mutex::new(reader), + writer: Mutex::new(writer), + }; + + transport.send(&[0x01, 0x02]).await.unwrap(); + transport.send(&[0x03, 0x04, 0x05]).await.unwrap(); + transport.send(&[0x06]).await.unwrap(); + }); + + let (stream, _) = listener.accept().await.unwrap(); + let (reader, writer) = stream.into_split(); + let recv_transport = TcpTransport { + reader: Mutex::new(reader), + writer: Mutex::new(writer), + }; + + assert_eq!(recv_transport.receive().await.unwrap(), vec![0x01, 0x02]); + assert_eq!( + recv_transport.receive().await.unwrap(), + vec![0x03, 0x04, 0x05] + ); + assert_eq!(recv_transport.receive().await.unwrap(), vec![0x06]); + + send_task.await.unwrap(); + } + + #[tokio::test] + async fn partial_reads_are_handled_by_read_exact() { + // This test exercises the read_exact behavior by sending data + // through a real TCP connection. Under the hood, TCP may deliver + // data in arbitrary chunk sizes, especially with Nagle disabled. + // While we can't force byte-at-a-time delivery reliably, we + // verify correctness with a larger payload that's more likely + // to arrive in multiple reads. + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let payload: Vec = (0..=255).cycle().take(8192).collect(); + let payload_clone = payload.clone(); + + let send_task = tokio::spawn(async move { + let stream = TcpStream::connect(addr).await.unwrap(); + let (reader, writer) = stream.into_split(); + let transport = TcpTransport { + reader: Mutex::new(reader), + writer: Mutex::new(writer), + }; + + transport.send(&payload_clone).await.unwrap(); + }); + + let (stream, _) = listener.accept().await.unwrap(); + let (reader, writer) = stream.into_split(); + let recv_transport = TcpTransport { + reader: Mutex::new(reader), + writer: Mutex::new(writer), + }; + + let received = recv_transport.receive().await.unwrap(); + assert_eq!(received.len(), payload.len()); + assert_eq!(received, payload); + + send_task.await.unwrap(); + } + + #[tokio::test] + async fn connect_with_timeout() { + // Connect to localhost listener with a generous timeout. + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let transport = TcpTransport::connect(addr, Duration::from_secs(5)) + .await + .unwrap(); + + // Accept the connection on the server side. + let (server_stream, _) = listener.accept().await.unwrap(); + let (server_reader, mut server_writer) = server_stream.into_split(); + drop(server_reader); + + // Send a framed message from the "server" side. + let payload = vec![0xDE, 0xAD, 0xBE, 0xEF]; + let mut frame = vec![0x00, 0x00, 0x00, 0x04]; + frame.extend_from_slice(&payload); + server_writer.write_all(&frame).await.unwrap(); + server_writer.flush().await.unwrap(); + + // Receive through the transport. + let received = transport.receive().await.unwrap(); + assert_eq!(received, payload); + } + + #[tokio::test] + async fn connect_timeout_fires() { + // Try to connect to a non-routable address. This should time out. + // 192.0.2.1 is a TEST-NET address (RFC 5737) that should be unreachable. + let result = TcpTransport::connect("192.0.2.1:445", Duration::from_millis(100)).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + // Could be Timeout or Io depending on OS behavior. + assert!( + matches!(err, Error::Timeout | Error::Io(_)), + "expected Timeout or Io error, got: {err}" + ); + } +} diff --git a/vendor/smb2/src/types/CLAUDE.md b/vendor/smb2/src/types/CLAUDE.md new file mode 100644 index 0000000..728ee90 --- /dev/null +++ b/vendor/smb2/src/types/CLAUDE.md @@ -0,0 +1,38 @@ +# Types -- protocol newtypes and enums + +Zero-cost newtype wrappers for protocol IDs, command/dialect enums, and bitflag types. + +## Key files + +| File | Purpose | +|---|---| +| `mod.rs` | `SessionId`, `TreeId`, `FileId`, `MessageId`, `CreditCharge`, `Command`, `Dialect`, `OplockLevel` | +| `flags.rs` | Bitflag types: `HeaderFlags`, `Capabilities`, `SecurityMode`, `FileAccessMask`, etc. | +| `status.rs` | `NtStatus` enum (from MS-ERREF) with severity/facility helpers | + +## Newtype IDs + +All protocol IDs are newtypes over their raw integer: +- `SessionId(u64)` -- has `NONE` sentinel (0) +- `MessageId(u64)` -- has `UNSOLICITED` sentinel (0xFFFFFFFFFFFFFFFF) for oplock breaks +- `TreeId(u32)` +- `CreditCharge(u16)` +- `FileId { persistent: u64, volatile: u64 }` -- has `SENTINEL` (all-F's) for compound related requests + +All implement `Debug`, `Clone`, `Copy`, `PartialEq`, `Eq`, `Hash`, `Display`. + +## Command and Dialect enums + +- `Command`: 19 variants (Negotiate through OplockBreak), `repr(u16)`, uses `num_enum` for `TryFrom`/`Into` +- `Dialect`: 5 variants (2.0.2 through 3.1.1), `repr(u16)`, ordered (`PartialOrd`/`Ord`). `Dialect::ALL` is a sorted slice. + +## Key decisions + +- **Newtypes over raw u32/u64**: Prevents accidentally passing a TreeId where a SessionId is expected. Zero runtime cost. +- **`num_enum` for command/dialect**: Avoids manual match arms for TryFrom. Compile-time checked exhaustive conversions. + +## Gotchas + +- **`MORE_PROCESSING_REQUIRED` has error severity bits but isn't an error**: `NtStatus` severity is encoded in bits 30-31. `MORE_PROCESSING_REQUIRED` (0xC0000016) has severity=3 (error), but it's a normal part of the session setup flow. Use `is_more_processing_required()` instead of checking `is_error()`. +- **`STATUS_BUFFER_OVERFLOW` is a warning, not an error**: Returns valid partial data. Don't discard the response body. +- **FileId::SENTINEL vs FileId::default()**: SENTINEL is all-F's (used in compound requests). Default is all-zeros (unused). Don't mix them up. diff --git a/vendor/smb2/src/types/flags.rs b/vendor/smb2/src/types/flags.rs new file mode 100644 index 0000000..de4ea6f --- /dev/null +++ b/vendor/smb2/src/types/flags.rs @@ -0,0 +1,465 @@ +//! Bitflag types for SMB2/3 protocol fields. + +use std::ops::{BitAnd, BitOr, BitOrAssign}; + +// ── Macro to reduce boilerplate for flag types ────────────────────────── + +macro_rules! impl_flags { + ($name:ident, $inner:ty) => { + impl $name { + /// Create a new flags value from a raw integer. + #[inline] + pub const fn new(raw: $inner) -> Self { + Self(raw) + } + + /// Return the raw bits. + #[inline] + pub const fn bits(&self) -> $inner { + self.0 + } + + /// Check whether a particular flag bit is set. + #[inline] + pub const fn contains(&self, flag: $inner) -> bool { + self.0 & flag == flag + } + + /// Set a flag bit. + #[inline] + pub fn set(&mut self, flag: $inner) { + self.0 |= flag; + } + + /// Clear a flag bit. + #[inline] + pub fn clear(&mut self, flag: $inner) { + self.0 &= !flag; + } + } + + impl BitOr for $name { + type Output = Self; + #[inline] + fn bitor(self, rhs: Self) -> Self { + Self(self.0 | rhs.0) + } + } + + impl BitAnd for $name { + type Output = Self; + #[inline] + fn bitand(self, rhs: Self) -> Self { + Self(self.0 & rhs.0) + } + } + + impl BitOrAssign for $name { + #[inline] + fn bitor_assign(&mut self, rhs: Self) { + self.0 |= rhs.0; + } + } + }; +} + +// ── HeaderFlags ───────────────────────────────────────────────────────── + +/// SMB2 packet header flags (32-bit field from MS-SMB2 2.2.1). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct HeaderFlags(pub u32); + +impl HeaderFlags { + /// The message is a response rather than a request. + pub const SERVER_TO_REDIR: u32 = 0x0000_0001; + /// The message is an async SMB2 header. + pub const ASYNC_COMMAND: u32 = 0x0000_0002; + /// The message is part of a compounded chain. + pub const RELATED_OPERATIONS: u32 = 0x0000_0004; + /// The message is signed. + pub const SIGNED: u32 = 0x0000_0008; + /// Priority value mask (SMB 3.1.1). + pub const PRIORITY_MASK: u32 = 0x0000_0070; + /// The command is a DFS operation. + pub const DFS_OPERATIONS: u32 = 0x1000_0000; + /// The command is a replay operation (SMB 3.x). + pub const REPLAY_OPERATION: u32 = 0x2000_0000; + + /// Returns `true` if this is a response (server-to-redirector). + #[inline] + pub fn is_response(&self) -> bool { + self.contains(Self::SERVER_TO_REDIR) + } + + /// Returns `true` if the async flag is set. + #[inline] + pub fn is_async(&self) -> bool { + self.contains(Self::ASYNC_COMMAND) + } + + /// Returns `true` if the related-operations flag is set. + #[inline] + pub fn is_related(&self) -> bool { + self.contains(Self::RELATED_OPERATIONS) + } + + /// Returns `true` if the signed flag is set. + #[inline] + pub fn is_signed(&self) -> bool { + self.contains(Self::SIGNED) + } + + /// Set the response flag. + #[inline] + pub fn set_response(&mut self) { + self.set(Self::SERVER_TO_REDIR); + } + + /// Set the async flag. + #[inline] + pub fn set_async(&mut self) { + self.set(Self::ASYNC_COMMAND); + } + + /// Set the related-operations flag. + #[inline] + pub fn set_related(&mut self) { + self.set(Self::RELATED_OPERATIONS); + } + + /// Set the signed flag. + #[inline] + pub fn set_signed(&mut self) { + self.set(Self::SIGNED); + } +} + +impl_flags!(HeaderFlags, u32); + +// ── SecurityMode ──────────────────────────────────────────────────────── + +/// Security mode flags (16-bit field from MS-SMB2 2.2.3/2.2.4). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct SecurityMode(pub u16); + +impl SecurityMode { + /// Signing is supported (enabled). + pub const SIGNING_ENABLED: u16 = 0x0001; + /// Signing is required. + pub const SIGNING_REQUIRED: u16 = 0x0002; + + /// Returns `true` if signing is enabled. + #[inline] + pub fn signing_enabled(&self) -> bool { + self.contains(Self::SIGNING_ENABLED) + } + + /// Returns `true` if signing is required. + #[inline] + pub fn signing_required(&self) -> bool { + self.contains(Self::SIGNING_REQUIRED) + } +} + +impl_flags!(SecurityMode, u16); + +// ── Capabilities ──────────────────────────────────────────────────────── + +/// Server/client capability flags (32-bit field from MS-SMB2 2.2.3/2.2.4). +/// +/// With the `serde` feature on, this serializes as the underlying `u32` +/// bits, **not** a JSON object of named flags. Decode against the +/// associated constants on this type. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct Capabilities(pub u32); + +#[cfg(feature = "serde")] +impl serde::Serialize for Capabilities { + fn serialize(&self, s: S) -> std::result::Result { + s.serialize_u32(self.0) + } +} + +impl Capabilities { + /// Distributed File System (DFS) support. + pub const DFS: u32 = 0x0000_0001; + /// Leasing support. + pub const LEASING: u32 = 0x0000_0002; + /// Multi-credit (large MTU) support. + pub const LARGE_MTU: u32 = 0x0000_0004; + /// Multi-channel support. + pub const MULTI_CHANNEL: u32 = 0x0000_0008; + /// Persistent handle support. + pub const PERSISTENT_HANDLES: u32 = 0x0000_0010; + /// Directory leasing support. + pub const DIRECTORY_LEASING: u32 = 0x0000_0020; + /// Encryption support. + pub const ENCRYPTION: u32 = 0x0000_0040; +} + +impl_flags!(Capabilities, u32); + +// ── ShareFlags ────────────────────────────────────────────────────────── + +/// Share property flags (32-bit field from MS-SMB2 2.2.10). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct ShareFlags(pub u32); + +impl ShareFlags { + /// The share is in a DFS tree structure. + pub const DFS: u32 = 0x0000_0001; + /// The share is a DFS root. + pub const DFS_ROOT: u32 = 0x0000_0002; + + // Offline caching policies (mutually exclusive, stored in bits 4-5). + + /// The client can cache files explicitly selected by the user. + pub const MANUAL_CACHING: u32 = 0x0000_0000; + /// The client can automatically cache files used by the user. + pub const AUTO_CACHING: u32 = 0x0000_0010; + /// Auto-cache with offline mode even when the share is available. + pub const VDO_CACHING: u32 = 0x0000_0020; + /// Offline caching must not occur. + pub const NO_CACHING: u32 = 0x0000_0030; + + /// Disallows exclusive file opens that deny reads. + pub const RESTRICT_EXCLUSIVE_OPENS: u32 = 0x0000_0100; + /// Disallows exclusive opens that prevent deletion. + pub const FORCE_SHARED_DELETE: u32 = 0x0000_0200; + /// Allow namespace caching (client must ignore). + pub const ALLOW_NAMESPACE_CACHING: u32 = 0x0000_0400; + /// Server filters directory entries based on access permissions. + pub const ACCESS_BASED_DIRECTORY_ENUM: u32 = 0x0000_0800; + /// Server will not issue exclusive caching rights. + pub const FORCE_LEVELII_OPLOCK: u32 = 0x0000_1000; + /// Hash generation v1 for branch cache (not valid for SMB 2.0.2). + pub const ENABLE_HASH_V1: u32 = 0x0000_2000; + /// Hash generation v2 for branch cache. + pub const ENABLE_HASH_V2: u32 = 0x0000_4000; + /// Encryption of remote file access messages required (SMB 3.x). + pub const ENCRYPT_DATA: u32 = 0x0000_8000; + /// The share supports identity remoting. + pub const IDENTITY_REMOTING: u32 = 0x0004_0000; + /// The server supports compression on this share (SMB 3.1.1). + pub const COMPRESS_DATA: u32 = 0x0010_0000; + /// Prefer isolated transport for this share (advisory). + pub const ISOLATED_TRANSPORT: u32 = 0x0020_0000; +} + +impl_flags!(ShareFlags, u32); + +// ── ShareCapabilities ─────────────────────────────────────────────────── + +/// Share capability flags (32-bit field from MS-SMB2 2.2.10). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct ShareCapabilities(pub u32); + +impl ShareCapabilities { + /// The share is part of a DFS tree. + pub const DFS: u32 = 0x0000_0008; + /// The share has continuously available file handles. + pub const CONTINUOUS_AVAILABILITY: u32 = 0x0000_0010; + /// The share is a scale-out share. + pub const SCALEOUT: u32 = 0x0000_0020; + /// The share is a cluster share. + pub const CLUSTER: u32 = 0x0000_0040; + /// The share is an asymmetric share. + pub const ASYMMETRIC: u32 = 0x0000_0080; + /// The share supports redirect to owner. + pub const REDIRECT_TO_OWNER: u32 = 0x0000_0100; +} + +impl_flags!(ShareCapabilities, u32); + +// ── FileAccessMask ────────────────────────────────────────────────────── + +/// File access rights mask (32-bit, from MS-SMB2 2.2.13.1). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct FileAccessMask(pub u32); + +impl FileAccessMask { + /// Read data from the file. + pub const FILE_READ_DATA: u32 = 0x0000_0001; + /// Write data to the file. + pub const FILE_WRITE_DATA: u32 = 0x0000_0002; + /// Append data to the file. + pub const FILE_APPEND_DATA: u32 = 0x0000_0004; + /// Read extended attributes. + pub const FILE_READ_EA: u32 = 0x0000_0008; + /// Write extended attributes. + pub const FILE_WRITE_EA: u32 = 0x0000_0010; + /// Execute the file. + pub const FILE_EXECUTE: u32 = 0x0000_0020; + /// Read file attributes. + pub const FILE_READ_ATTRIBUTES: u32 = 0x0000_0080; + /// Write file attributes. + pub const FILE_WRITE_ATTRIBUTES: u32 = 0x0000_0100; + /// Delete the object. + pub const DELETE: u32 = 0x0001_0000; + /// Read the security descriptor. + pub const READ_CONTROL: u32 = 0x0002_0000; + /// Modify the DACL. + pub const WRITE_DAC: u32 = 0x0004_0000; + /// Change the owner. + pub const WRITE_OWNER: u32 = 0x0008_0000; + /// Synchronize access. + pub const SYNCHRONIZE: u32 = 0x0010_0000; + /// Request maximum allowed access. + pub const MAXIMUM_ALLOWED: u32 = 0x0200_0000; + /// All possible access rights. + pub const GENERIC_ALL: u32 = 0x1000_0000; + /// Execute access. + pub const GENERIC_EXECUTE: u32 = 0x2000_0000; + /// Write access. + pub const GENERIC_WRITE: u32 = 0x4000_0000; + /// Read access. + pub const GENERIC_READ: u32 = 0x8000_0000; +} + +impl_flags!(FileAccessMask, u32); + +#[cfg(test)] +mod tests { + use super::*; + + // ── HeaderFlags ───────────────────────────────────────────────── + + #[test] + fn header_flags_default_is_zero() { + let f = HeaderFlags::default(); + assert_eq!(f.bits(), 0); + assert!(!f.is_response()); + assert!(!f.is_async()); + assert!(!f.is_related()); + assert!(!f.is_signed()); + } + + #[test] + fn header_flags_set_and_check() { + let mut f = HeaderFlags::default(); + f.set_response(); + assert!(f.is_response()); + assert!(!f.is_async()); + + f.set_signed(); + assert!(f.is_signed()); + assert!(f.is_response()); + } + + #[test] + fn header_flags_clear() { + let mut f = HeaderFlags::new(0xFFFF_FFFF); + assert!(f.is_response()); + f.clear(HeaderFlags::SERVER_TO_REDIR); + assert!(!f.is_response()); + assert!(f.is_async()); // other flags untouched + } + + #[test] + fn header_flags_contains() { + let f = HeaderFlags::new(HeaderFlags::SIGNED | HeaderFlags::ASYNC_COMMAND); + assert!(f.contains(HeaderFlags::SIGNED)); + assert!(f.contains(HeaderFlags::ASYNC_COMMAND)); + assert!(!f.contains(HeaderFlags::SERVER_TO_REDIR)); + } + + #[test] + fn header_flags_bitor() { + let a = HeaderFlags::new(HeaderFlags::SERVER_TO_REDIR); + let b = HeaderFlags::new(HeaderFlags::SIGNED); + let c = a | b; + assert!(c.is_response()); + assert!(c.is_signed()); + } + + #[test] + fn header_flags_bitand() { + let a = HeaderFlags::new(HeaderFlags::SERVER_TO_REDIR | HeaderFlags::SIGNED); + let b = HeaderFlags::new(HeaderFlags::SIGNED); + let c = a & b; + assert!(!c.is_response()); + assert!(c.is_signed()); + } + + #[test] + fn header_flags_bitor_assign() { + let mut a = HeaderFlags::new(HeaderFlags::SERVER_TO_REDIR); + a |= HeaderFlags::new(HeaderFlags::ASYNC_COMMAND); + assert!(a.is_response()); + assert!(a.is_async()); + } + + // ── SecurityMode ──────────────────────────────────────────────── + + #[test] + fn security_mode_signing_enabled() { + let m = SecurityMode::new(SecurityMode::SIGNING_ENABLED); + assert!(m.signing_enabled()); + assert!(!m.signing_required()); + } + + #[test] + fn security_mode_signing_required() { + let m = SecurityMode::new(SecurityMode::SIGNING_ENABLED | SecurityMode::SIGNING_REQUIRED); + assert!(m.signing_enabled()); + assert!(m.signing_required()); + } + + // ── Capabilities ──────────────────────────────────────────────── + + #[test] + fn capabilities_combine_with_bitor() { + let a = Capabilities::new(Capabilities::DFS); + let b = Capabilities::new(Capabilities::ENCRYPTION); + let c = a | b; + assert!(c.contains(Capabilities::DFS)); + assert!(c.contains(Capabilities::ENCRYPTION)); + assert!(!c.contains(Capabilities::LEASING)); + } + + #[test] + fn capabilities_set_and_clear() { + let mut c = Capabilities::default(); + c.set(Capabilities::LARGE_MTU); + assert!(c.contains(Capabilities::LARGE_MTU)); + c.clear(Capabilities::LARGE_MTU); + assert!(!c.contains(Capabilities::LARGE_MTU)); + } + + // ── ShareFlags ────────────────────────────────────────────────── + + #[test] + fn share_flags_encrypt_data() { + let f = ShareFlags::new(ShareFlags::ENCRYPT_DATA | ShareFlags::DFS); + assert!(f.contains(ShareFlags::ENCRYPT_DATA)); + assert!(f.contains(ShareFlags::DFS)); + assert!(!f.contains(ShareFlags::COMPRESS_DATA)); + } + + // ── ShareCapabilities ─────────────────────────────────────────── + + #[test] + fn share_capabilities_dfs() { + let c = ShareCapabilities::new(ShareCapabilities::DFS); + assert!(c.contains(ShareCapabilities::DFS)); + assert!(!c.contains(ShareCapabilities::CLUSTER)); + } + + // ── FileAccessMask ────────────────────────────────────────────── + + #[test] + fn file_access_mask_generic_read() { + let m = FileAccessMask::new(FileAccessMask::GENERIC_READ); + assert!(m.contains(FileAccessMask::GENERIC_READ)); + assert!(!m.contains(FileAccessMask::GENERIC_WRITE)); + } + + #[test] + fn file_access_mask_combine() { + let m = + FileAccessMask::new(FileAccessMask::FILE_READ_DATA | FileAccessMask::FILE_WRITE_DATA); + assert!(m.contains(FileAccessMask::FILE_READ_DATA)); + assert!(m.contains(FileAccessMask::FILE_WRITE_DATA)); + assert!(!m.contains(FileAccessMask::DELETE)); + } +} diff --git a/vendor/smb2/src/types/mod.rs b/vendor/smb2/src/types/mod.rs new file mode 100644 index 0000000..c991d5e --- /dev/null +++ b/vendor/smb2/src/types/mod.rs @@ -0,0 +1,364 @@ +//! Newtypes, enums, and common data structures for SMB2/3 protocol fields. +//! +//! Most users don't need to import from this module directly -- the commonly +//! used types are re-exported at the crate root. + +pub mod flags; +pub mod status; + +use std::fmt; + +use crate::Error; + +/// Requested oplock level (MS-SMB2 2.2.13, 2.2.23). +/// +/// Used across CREATE requests/responses and oplock break messages. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum OplockLevel { + /// No oplock is requested. + None = 0x00, + /// Level II oplock is requested. + LevelII = 0x01, + /// Exclusive oplock is requested. + Exclusive = 0x08, + /// Batch oplock is requested. + Batch = 0x09, + /// Lease is requested. + Lease = 0xFF, +} + +impl TryFrom for OplockLevel { + type Error = Error; + + fn try_from(value: u8) -> crate::error::Result { + match value { + 0x00 => Ok(Self::None), + 0x01 => Ok(Self::LevelII), + 0x08 => Ok(Self::Exclusive), + 0x09 => Ok(Self::Batch), + 0xFF => Ok(Self::Lease), + _ => Err(Error::invalid_data(format!( + "invalid OplockLevel: 0x{:02X}", + value + ))), + } + } +} + +/// 64-bit session identifier. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +#[cfg_attr(feature = "serde", derive(serde::Serialize), serde(transparent))] +pub struct SessionId(pub u64); + +impl SessionId { + /// Sentinel value indicating no session. + pub const NONE: Self = Self(0); +} + +impl fmt::Display for SessionId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SessionId(0x{:016X})", self.0) + } +} + +/// 64-bit message identifier for request/response correlation. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +#[cfg_attr(feature = "serde", derive(serde::Serialize), serde(transparent))] +pub struct MessageId(pub u64); + +impl MessageId { + /// Unsolicited message ID used for oplock/lease break notifications. + pub const UNSOLICITED: Self = Self(0xFFFF_FFFF_FFFF_FFFF); +} + +impl fmt::Display for MessageId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "MessageId(0x{:016X})", self.0) + } +} + +/// 32-bit tree connect identifier. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +#[cfg_attr(feature = "serde", derive(serde::Serialize), serde(transparent))] +pub struct TreeId(pub u32); + +impl fmt::Display for TreeId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "TreeId(0x{:08X})", self.0) + } +} + +/// 16-bit credit charge for multi-credit requests. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub struct CreditCharge(pub u16); + +impl fmt::Display for CreditCharge { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "CreditCharge({})", self.0) + } +} + +/// 128-bit file identifier consisting of two 64-bit parts. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub struct FileId { + /// Persistent portion of the file handle. + pub persistent: u64, + /// Volatile portion of the file handle. + pub volatile: u64, +} + +impl FileId { + /// Sentinel value used in related compound requests. + pub const SENTINEL: Self = Self { + persistent: 0xFFFF_FFFF_FFFF_FFFF, + volatile: 0xFFFF_FFFF_FFFF_FFFF, + }; +} + +impl fmt::Display for FileId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "FileId(0x{:016X}:0x{:016X})", + self.persistent, self.volatile + ) + } +} + +/// SMB2 command codes from MS-SMB2 section 2.2.1. +#[derive( + Debug, Clone, Copy, PartialEq, Eq, Hash, num_enum::TryFromPrimitive, num_enum::IntoPrimitive, +)] +#[repr(u16)] +pub enum Command { + /// Negotiate protocol version and capabilities. + Negotiate = 0x0000, + /// Set up an authenticated session. + SessionSetup = 0x0001, + /// Log off a session. + Logoff = 0x0002, + /// Connect to a share. + TreeConnect = 0x0003, + /// Disconnect from a share. + TreeDisconnect = 0x0004, + /// Open or create a file. + Create = 0x0005, + /// Close a file handle. + Close = 0x0006, + /// Flush cached data to stable storage. + Flush = 0x0007, + /// Read data from a file. + Read = 0x0008, + /// Write data to a file. + Write = 0x0009, + /// Lock or unlock byte ranges. + Lock = 0x000A, + /// Issue a device control or file system control command. + Ioctl = 0x000B, + /// Cancel a previously sent request. + Cancel = 0x000C, + /// Check server liveness. + Echo = 0x000D, + /// Enumerate directory contents. + QueryDirectory = 0x000E, + /// Request change notifications on a directory. + ChangeNotify = 0x000F, + /// Query file or filesystem information. + QueryInfo = 0x0010, + /// Set file or filesystem information. + SetInfo = 0x0011, + /// Oplock or lease break notification/acknowledgment. + OplockBreak = 0x0012, +} + +impl fmt::Display for Command { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +/// SMB2 dialect revision identifiers from MS-SMB2 section 2.2.3. +#[derive( + Debug, + Clone, + Copy, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + num_enum::TryFromPrimitive, + num_enum::IntoPrimitive, +)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +#[repr(u16)] +pub enum Dialect { + /// SMB 2.0.2 dialect. + Smb2_0_2 = 0x0202, + /// SMB 2.1 dialect. + Smb2_1 = 0x0210, + /// SMB 3.0 dialect. + Smb3_0 = 0x0300, + /// SMB 3.0.2 dialect. + Smb3_0_2 = 0x0302, + /// SMB 3.1.1 dialect. + Smb3_1_1 = 0x0311, +} + +impl Dialect { + /// All supported dialect revisions, in ascending order. + pub const ALL: &[Dialect] = &[ + Dialect::Smb2_0_2, + Dialect::Smb2_1, + Dialect::Smb3_0, + Dialect::Smb3_0_2, + Dialect::Smb3_1_1, + ]; +} + +impl fmt::Display for Dialect { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Dialect::Smb2_0_2 => f.write_str("SMB 2.0.2"), + Dialect::Smb2_1 => f.write_str("SMB 2.1"), + Dialect::Smb3_0 => f.write_str("SMB 3.0"), + Dialect::Smb3_0_2 => f.write_str("SMB 3.0.2"), + Dialect::Smb3_1_1 => f.write_str("SMB 3.1.1"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── Newtype tests ─────────────────────────────────────────────── + + #[test] + fn session_id_none_is_zero() { + assert_eq!(SessionId::NONE, SessionId(0)); + assert_eq!(SessionId::NONE.0, 0); + } + + #[test] + fn message_id_unsolicited() { + assert_eq!(MessageId::UNSOLICITED.0, 0xFFFF_FFFF_FFFF_FFFF); + } + + #[test] + fn file_id_sentinel() { + assert_eq!(FileId::SENTINEL.persistent, 0xFFFF_FFFF_FFFF_FFFF); + assert_eq!(FileId::SENTINEL.volatile, 0xFFFF_FFFF_FFFF_FFFF); + } + + #[test] + fn newtype_display_formatting() { + assert_eq!( + SessionId(0x1234).to_string(), + "SessionId(0x0000000000001234)" + ); + assert_eq!( + MessageId(0xABCD).to_string(), + "MessageId(0x000000000000ABCD)" + ); + assert_eq!(TreeId(0x42).to_string(), "TreeId(0x00000042)"); + assert_eq!(CreditCharge(5).to_string(), "CreditCharge(5)"); + assert_eq!( + FileId { + persistent: 0x11, + volatile: 0x22 + } + .to_string(), + "FileId(0x0000000000000011:0x0000000000000022)" + ); + } + + // ── Command tests ─────────────────────────────────────────────── + + #[test] + fn command_roundtrip_via_u16() { + assert_eq!(Command::try_from(0x0005u16), Ok(Command::Create)); + assert_eq!(u16::from(Command::Create), 0x0005); + } + + #[test] + fn command_all_variants_correct_values() { + assert_eq!(u16::from(Command::Negotiate), 0x0000); + assert_eq!(u16::from(Command::SessionSetup), 0x0001); + assert_eq!(u16::from(Command::Logoff), 0x0002); + assert_eq!(u16::from(Command::TreeConnect), 0x0003); + assert_eq!(u16::from(Command::TreeDisconnect), 0x0004); + assert_eq!(u16::from(Command::Create), 0x0005); + assert_eq!(u16::from(Command::Close), 0x0006); + assert_eq!(u16::from(Command::Flush), 0x0007); + assert_eq!(u16::from(Command::Read), 0x0008); + assert_eq!(u16::from(Command::Write), 0x0009); + assert_eq!(u16::from(Command::Lock), 0x000A); + assert_eq!(u16::from(Command::Ioctl), 0x000B); + assert_eq!(u16::from(Command::Cancel), 0x000C); + assert_eq!(u16::from(Command::Echo), 0x000D); + assert_eq!(u16::from(Command::QueryDirectory), 0x000E); + assert_eq!(u16::from(Command::ChangeNotify), 0x000F); + assert_eq!(u16::from(Command::QueryInfo), 0x0010); + assert_eq!(u16::from(Command::SetInfo), 0x0011); + assert_eq!(u16::from(Command::OplockBreak), 0x0012); + } + + #[test] + fn command_invalid_u16_is_error() { + assert!(Command::try_from(0xFFFFu16).is_err()); + assert!(Command::try_from(0x0013u16).is_err()); + } + + #[test] + fn command_display() { + assert_eq!(Command::Create.to_string(), "Create"); + assert_eq!(Command::OplockBreak.to_string(), "OplockBreak"); + } + + // ── Dialect tests ─────────────────────────────────────────────── + + #[test] + fn dialect_ordering() { + assert!(Dialect::Smb2_0_2 < Dialect::Smb2_1); + assert!(Dialect::Smb2_1 < Dialect::Smb3_0); + assert!(Dialect::Smb3_0 < Dialect::Smb3_0_2); + assert!(Dialect::Smb3_0_2 < Dialect::Smb3_1_1); + } + + #[test] + fn dialect_roundtrip_via_u16() { + assert_eq!(Dialect::try_from(0x0311u16), Ok(Dialect::Smb3_1_1)); + assert_eq!(u16::from(Dialect::Smb3_1_1), 0x0311); + } + + #[test] + fn dialect_invalid_u16_is_error() { + assert!(Dialect::try_from(0x0000u16).is_err()); + assert!(Dialect::try_from(0x0201u16).is_err()); + } + + #[test] + fn dialect_display() { + assert_eq!(Dialect::Smb2_0_2.to_string(), "SMB 2.0.2"); + assert_eq!(Dialect::Smb2_1.to_string(), "SMB 2.1"); + assert_eq!(Dialect::Smb3_0.to_string(), "SMB 3.0"); + assert_eq!(Dialect::Smb3_0_2.to_string(), "SMB 3.0.2"); + assert_eq!(Dialect::Smb3_1_1.to_string(), "SMB 3.1.1"); + } + + #[test] + fn dialect_all_has_five_variants() { + assert_eq!(Dialect::ALL.len(), 5); + assert_eq!(Dialect::ALL[0], Dialect::Smb2_0_2); + assert_eq!(Dialect::ALL[4], Dialect::Smb3_1_1); + } + + #[test] + fn dialect_all_is_sorted() { + for w in Dialect::ALL.windows(2) { + assert!(w[0] < w[1]); + } + } +} diff --git a/vendor/smb2/src/types/status.rs b/vendor/smb2/src/types/status.rs new file mode 100644 index 0000000..e790413 --- /dev/null +++ b/vendor/smb2/src/types/status.rs @@ -0,0 +1,384 @@ +//! NTSTATUS codes used by SMB2/3 (from MS-ERREF). + +use std::fmt; + +/// Defines `NtStatus` associated constants and the `name()` match arms from +/// a single table, so adding a new status code only requires one edit. +macro_rules! nt_status_codes { + ( + $( + $(#[$meta:meta])* + $name:ident = $value:expr, $display:expr; + )* + ) => { + impl NtStatus { + $( + $(#[$meta])* + pub const $name: Self = Self($value); + )* + + /// Returns a human-readable name for known status codes, + /// or `None` for unknown codes. + fn name(&self) -> Option<&'static str> { + match self.0 { + $( $value => Some($display), )* + _ => None, + } + } + } + }; +} + +/// NT status code returned in SMB2 response headers. +/// +/// The top two bits encode severity: +/// - `00` = success +/// - `01` = informational +/// - `10` = warning +/// - `11` = error +#[derive(Clone, Copy, PartialEq, Eq, Hash, Default)] +pub struct NtStatus(pub u32); + +nt_status_codes! { + // -- Success (severity 0b00) -- + + /// The operation completed successfully. + SUCCESS = 0x0000_0000, "STATUS_SUCCESS"; + + /// The operation that was requested is pending completion. + PENDING = 0x0000_0103, "STATUS_PENDING"; + + /// Oplock break notification (informational). + NOTIFY_ENUM_DIR = 0x0000_010C, "STATUS_NOTIFY_ENUM_DIR"; + + // -- Informational (severity 0b00, facility-specific) -- + + /// The authentication exchange is not complete -- send the next + /// SESSION_SETUP with the GSS token from this response. + /// + /// **Important:** The severity bits are 0b11 (error), so `is_error()` + /// returns `true`. But this is NOT a real error -- it's a "keep going" + /// signal during NTLM/SPNEGO auth. Auth code must check + /// `is_more_processing_required()` before checking `is_error()`. + MORE_PROCESSING_REQUIRED = 0xC000_0016, "STATUS_MORE_PROCESSING_REQUIRED"; + + // -- Warnings (severity 0b10) -- + + /// The data was too large to fit into the specified buffer. + /// This is a warning -- the response body contains valid partial data. + BUFFER_OVERFLOW = 0x8000_0005, "STATUS_BUFFER_OVERFLOW"; + + /// No more files were found which match the file specification. + NO_MORE_FILES = 0x8000_0006, "STATUS_NO_MORE_FILES"; + + // -- Errors (severity 0b11) -- + + /// The requested operation was unsuccessful. + UNSUCCESSFUL = 0xC000_0001, "STATUS_UNSUCCESSFUL"; + + /// The requested operation is not implemented. + NOT_IMPLEMENTED = 0xC000_0002, "STATUS_NOT_IMPLEMENTED"; + + /// An invalid parameter was passed to a service or function. + INVALID_PARAMETER = 0xC000_000D, "STATUS_INVALID_PARAMETER"; + + /// A device that does not exist was specified. + NO_SUCH_DEVICE = 0xC000_000E, "STATUS_NO_SUCH_DEVICE"; + + /// The file does not exist. + NO_SUCH_FILE = 0xC000_000F, "STATUS_NO_SUCH_FILE"; + + /// The specified request is not a valid operation for the target device. + INVALID_DEVICE_REQUEST = 0xC000_0010, "STATUS_INVALID_DEVICE_REQUEST"; + + /// The end-of-file marker has been reached. + END_OF_FILE = 0xC000_0011, "STATUS_END_OF_FILE"; + + /// A process has requested access to an object but has not been + /// granted those access rights. + ACCESS_DENIED = 0xC000_0022, "STATUS_ACCESS_DENIED"; + + /// The buffer is too small to contain the entry. + BUFFER_TOO_SMALL = 0xC000_0023, "STATUS_BUFFER_TOO_SMALL"; + + /// The object name is not found. + OBJECT_NAME_NOT_FOUND = 0xC000_0034, "STATUS_OBJECT_NAME_NOT_FOUND"; + + /// The object name already exists. + OBJECT_NAME_COLLISION = 0xC000_0035, "STATUS_OBJECT_NAME_COLLISION"; + + /// The path does not exist. + OBJECT_PATH_NOT_FOUND = 0xC000_003A, "STATUS_OBJECT_PATH_NOT_FOUND"; + + /// A file cannot be opened because the share access flags + /// are incompatible. + SHARING_VIOLATION = 0xC000_0043, "STATUS_SHARING_VIOLATION"; + + /// A requested read/write cannot be granted due to a conflicting + /// file lock. + FILE_LOCK_CONFLICT = 0xC000_0054, "STATUS_FILE_LOCK_CONFLICT"; + + /// A non-close operation has been requested of a file object that + /// has a delete pending. + DELETE_PENDING = 0xC000_0056, "STATUS_DELETE_PENDING"; + + /// The disk is full. + DISK_FULL = 0xC000_007F, "STATUS_DISK_FULL"; + + /// The attempted logon is invalid. + LOGON_FAILURE = 0xC000_006D, "STATUS_LOGON_FAILURE"; + + /// The referenced account is currently disabled. + ACCOUNT_DISABLED = 0xC000_0072, "STATUS_ACCOUNT_DISABLED"; + + /// Insufficient system resources exist to complete the API. + INSUFFICIENT_RESOURCES = 0xC000_009A, "STATUS_INSUFFICIENT_RESOURCES"; + + /// The file that was specified as a target is a directory. + FILE_IS_A_DIRECTORY = 0xC000_00BA, "STATUS_FILE_IS_A_DIRECTORY"; + + /// The network path cannot be located. + BAD_NETWORK_PATH = 0xC000_00BE, "STATUS_BAD_NETWORK_PATH"; + + /// The network name was deleted. + NETWORK_NAME_DELETED = 0xC000_00C9, "STATUS_NETWORK_NAME_DELETED"; + + /// The specified share name cannot be found on the remote server. + BAD_NETWORK_NAME = 0xC000_00CC, "STATUS_BAD_NETWORK_NAME"; + + /// No more connections can be made to this remote computer at this time. + REQUEST_NOT_ACCEPTED = 0xC000_00D0, "STATUS_REQUEST_NOT_ACCEPTED"; + + /// A requested opened file is not a directory. + NOT_A_DIRECTORY = 0xC000_0103, "STATUS_NOT_A_DIRECTORY"; + + /// The I/O request was canceled. + CANCELLED = 0xC000_0120, "STATUS_CANCELLED"; + + /// An I/O request other than close was attempted using a file object + /// that had already been closed. + FILE_CLOSED = 0xC000_0128, "STATUS_FILE_CLOSED"; + + /// The remote user session has been deleted. + USER_SESSION_DELETED = 0xC000_0203, "STATUS_USER_SESSION_DELETED"; + + /// Insufficient server resources exist to complete the request. + INSUFF_SERVER_RESOURCES = 0xC000_0205, "STATUS_INSUFF_SERVER_RESOURCES"; + + /// The object was not found. + NOT_FOUND = 0xC000_0225, "STATUS_NOT_FOUND"; + + /// The contacted server does not support the indicated part + /// of the DFS namespace. + PATH_NOT_COVERED = 0xC000_0257, "STATUS_PATH_NOT_COVERED"; + + /// The client session has expired; the client must re-authenticate. + NETWORK_SESSION_EXPIRED = 0xC000_035C, "STATUS_NETWORK_SESSION_EXPIRED"; +} + +impl NtStatus { + // -- Helper methods -- + + /// Returns the severity bits (top 2 bits): 0 = success, 1 = info, + /// 2 = warning, 3 = error. + #[inline] + pub fn severity(&self) -> u8 { + (self.0 >> 30) as u8 + } + + /// Returns `true` if the status indicates success (severity 0b00). + #[inline] + pub fn is_success(&self) -> bool { + self.severity() == 0 + } + + /// Returns `true` if the status is a warning (severity 0b10). + #[inline] + pub fn is_warning(&self) -> bool { + self.severity() == 2 + } + + /// Returns `true` if the status is an error (severity 0b11). + #[inline] + pub fn is_error(&self) -> bool { + self.severity() == 3 + } + + /// Returns `true` if this is `STATUS_PENDING`. + #[inline] + pub fn is_pending(&self) -> bool { + *self == Self::PENDING + } + + /// Returns `true` if this status indicates the operation produced usable data. + /// + /// This includes `SUCCESS` and warnings like `BUFFER_OVERFLOW` where partial + /// data is valid and should be parsed. + #[inline] + pub fn is_success_or_partial(&self) -> bool { + self.is_success() || *self == Self::BUFFER_OVERFLOW + } + + /// Returns `true` if the server wants another SESSION_SETUP round-trip. + /// + /// Check this BEFORE `is_error()` during authentication -- it has + /// error severity bits but is not a real error. + #[inline] + pub fn is_more_processing_required(&self) -> bool { + *self == Self::MORE_PROCESSING_REQUIRED + } +} + +impl fmt::Debug for NtStatus { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.name() { + Some(name) => write!(f, "NtStatus({name})"), + None => write!(f, "NtStatus(0x{:08X})", self.0), + } + } +} + +impl fmt::Display for NtStatus { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.name() { + Some(name) => f.write_str(name), + None => write!(f, "0x{:08X}", self.0), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn success_is_success() { + assert!(NtStatus::SUCCESS.is_success()); + assert!(!NtStatus::SUCCESS.is_error()); + assert!(!NtStatus::SUCCESS.is_warning()); + assert_eq!(NtStatus::SUCCESS.severity(), 0); + } + + #[test] + fn access_denied_is_error() { + assert!(NtStatus::ACCESS_DENIED.is_error()); + assert!(!NtStatus::ACCESS_DENIED.is_success()); + assert!(!NtStatus::ACCESS_DENIED.is_warning()); + assert_eq!(NtStatus::ACCESS_DENIED.severity(), 3); + } + + #[test] + fn buffer_overflow_is_warning() { + assert!(NtStatus::BUFFER_OVERFLOW.is_warning()); + assert!(!NtStatus::BUFFER_OVERFLOW.is_success()); + assert!(!NtStatus::BUFFER_OVERFLOW.is_error()); + assert_eq!(NtStatus::BUFFER_OVERFLOW.severity(), 2); + } + + #[test] + fn pending_is_pending() { + assert!(NtStatus::PENDING.is_pending()); + assert!(NtStatus::PENDING.is_success()); // severity 0b00 + assert!(!NtStatus::SUCCESS.is_pending()); + } + + #[test] + fn more_processing_required_is_error_severity() { + // 0xC0000016 has severity 0b11 (error), even though semantically + // it means "keep going" during authentication handshakes. + assert!(NtStatus::MORE_PROCESSING_REQUIRED.is_error()); + assert_eq!(NtStatus::MORE_PROCESSING_REQUIRED.severity(), 3); + } + + #[test] + fn display_known_code() { + assert_eq!(NtStatus::ACCESS_DENIED.to_string(), "STATUS_ACCESS_DENIED"); + } + + #[test] + fn display_unknown_code() { + let unknown = NtStatus(0xDEAD_BEEF); + assert_eq!(unknown.to_string(), "0xDEADBEEF"); + } + + #[test] + fn debug_known_code() { + let s = format!("{:?}", NtStatus::SUCCESS); + assert_eq!(s, "NtStatus(STATUS_SUCCESS)"); + } + + #[test] + fn debug_unknown_code() { + let s = format!("{:?}", NtStatus(0x1234_5678)); + assert_eq!(s, "NtStatus(0x12345678)"); + } + + #[test] + fn no_more_files_is_warning() { + assert!(NtStatus::NO_MORE_FILES.is_warning()); + } + + #[test] + fn default_is_success() { + assert_eq!(NtStatus::default(), NtStatus::SUCCESS); + } + + #[test] + fn is_success_or_partial() { + // SUCCESS is usable + assert!(NtStatus::SUCCESS.is_success_or_partial()); + // BUFFER_OVERFLOW is a warning with valid partial data + assert!(NtStatus::BUFFER_OVERFLOW.is_success_or_partial()); + // Errors are not usable + assert!(!NtStatus::ACCESS_DENIED.is_success_or_partial()); + // PENDING has success severity (0b00) so is_success() is true, + // but callers handle PENDING separately before reaching status checks. + assert!(NtStatus::PENDING.is_success_or_partial()); + // Other warnings (not BUFFER_OVERFLOW) are not usable + assert!(!NtStatus::NO_MORE_FILES.is_success_or_partial()); + } + + #[test] + fn all_error_codes_have_error_severity() { + let errors = [ + NtStatus::UNSUCCESSFUL, + NtStatus::NOT_IMPLEMENTED, + NtStatus::INVALID_PARAMETER, + NtStatus::NO_SUCH_DEVICE, + NtStatus::NO_SUCH_FILE, + NtStatus::END_OF_FILE, + NtStatus::ACCESS_DENIED, + NtStatus::BUFFER_TOO_SMALL, + NtStatus::OBJECT_NAME_NOT_FOUND, + NtStatus::OBJECT_NAME_COLLISION, + NtStatus::OBJECT_PATH_NOT_FOUND, + NtStatus::SHARING_VIOLATION, + NtStatus::FILE_LOCK_CONFLICT, + NtStatus::DELETE_PENDING, + NtStatus::LOGON_FAILURE, + NtStatus::ACCOUNT_DISABLED, + NtStatus::INSUFFICIENT_RESOURCES, + NtStatus::FILE_IS_A_DIRECTORY, + NtStatus::BAD_NETWORK_PATH, + NtStatus::NETWORK_NAME_DELETED, + NtStatus::BAD_NETWORK_NAME, + NtStatus::REQUEST_NOT_ACCEPTED, + NtStatus::NOT_A_DIRECTORY, + NtStatus::CANCELLED, + NtStatus::FILE_CLOSED, + NtStatus::USER_SESSION_DELETED, + NtStatus::INSUFF_SERVER_RESOURCES, + NtStatus::NOT_FOUND, + NtStatus::PATH_NOT_COVERED, + NtStatus::NETWORK_SESSION_EXPIRED, + NtStatus::MORE_PROCESSING_REQUIRED, + ]; + for status in &errors { + assert!( + status.is_error(), + "{status} should be error but severity is {}", + status.severity() + ); + } + } +}