diff --git a/vendor/smb-server/src/dispatch.rs b/vendor/smb-server/src/dispatch.rs index 0df1352..8f2ed01 100644 --- a/vendor/smb-server/src/dispatch.rs +++ b/vendor/smb-server/src/dispatch.rs @@ -708,6 +708,10 @@ async fn handle_smb1_multi_protocol( #[cfg(test)] mod tests { use super::*; + use crate::conn::state::{Session, TreeConnect}; + use crate::proto::messages::create::{CreateResponse, FileId}; + use crate::proto::header::SMB2_MAGIC; + use crate::Share; use uuid::Uuid; fn test_conn() -> Arc { @@ -761,4 +765,315 @@ mod tests { assert_eq!(got.snapshot(), expected); assert!(!conn.session_preauth.read().await.contains_key(&7)); } + + // ── Compound request response stitching ───────────────────────────────── + + /// Build a minimal response frame of `body_len` bytes (64 header + body). + fn make_response(body_len: usize) -> Vec { + let mut buf = vec![0u8; SMB2_HEADER_LEN + body_len]; + buf[..4].copy_from_slice(&SMB2_MAGIC); + buf + } + + #[tokio::test] + async fn test_stitch_responses_single() { + let conn = test_conn(); + let r1 = make_response(100); + let responses = vec![r1.clone()]; + let stitched = stitch_responses(&conn, responses).await; + // Single response: no padding, NextCommand=0 + assert_eq!(stitched.len(), 100 + SMB2_HEADER_LEN); + assert_eq!(&stitched[..4], &SMB2_MAGIC); + // NextCommand at offset 20 should be 0 + let next = u32::from_le_bytes(stitched[20..24].try_into().unwrap()); + assert_eq!(next, 0); + } + + #[tokio::test] + async fn test_stitch_responses_aligned() { + let conn = test_conn(); + // Two responses: 100 bytes + 80 bytes body + let r1 = make_response(100); + let r2 = make_response(80); + let responses = vec![r1, r2]; + let stitched = stitch_responses(&conn, responses).await; + + // First response: 64+100 = 164, aligned to 168 (next multiple of 8) + let total1 = SMB2_HEADER_LEN + 100; + let aligned1 = (total1 + 7) & !7; + assert_eq!(aligned1, 168); + + // First response's NextCommand should point to second + let next1 = u32::from_le_bytes(stitched[20..24].try_into().unwrap()); + assert_eq!(next1 as usize, aligned1); + + // Second response body starts at offset `aligned1` + let next2 = u32::from_le_bytes(stitched[aligned1 + 20..aligned1 + 24].try_into().unwrap()); + assert_eq!(next2, 0); // Last response, no NextCommand + + // Total length = aligned1 + 64(body header) + 80(body) + assert_eq!(stitched.len(), aligned1 + SMB2_HEADER_LEN + 80); + } + + #[tokio::test] + async fn test_stitch_responses_three_responses() { + let conn = test_conn(); + let r1 = make_response(100); + let r2 = make_response(80); + let r3 = make_response(60); + let responses = vec![r1, r2, r3]; + let stitched = stitch_responses(&conn, responses).await; + + // First header at 0 + let total1 = SMB2_HEADER_LEN + 100; + let aligned1 = (total1 + 7) & !7; + let next1 = u32::from_le_bytes(stitched[20..24].try_into().unwrap()); + assert_eq!(next1 as usize, aligned1); + + // Second header at aligned1 + let total2 = SMB2_HEADER_LEN + 80; + let aligned2 = aligned1 + ((total2 + 7) & !7); + let next2 = u32::from_le_bytes(stitched[aligned1 + 20..aligned1 + 24].try_into().unwrap()); + assert_eq!(next2 as usize, aligned2 - aligned1); + + // Third header at aligned1 + aligned2_inner + let next3 = u32::from_le_bytes(stitched[aligned2 + 20..aligned2 + 24].try_into().unwrap()); + assert_eq!(next3, 0); // Last + + assert_eq!(stitched.len(), aligned2 + SMB2_HEADER_LEN + 60); + } + + #[tokio::test] + async fn test_stitch_responses_empty_returns_empty() { + let conn = test_conn(); + let stitched = stitch_responses(&conn, vec![]).await; + assert!(stitched.is_empty()); + } + + // ── FileId capture from CREATE response ───────────────────────────────── + + /// Build CREATE response bytes with known FileId. + fn make_create_response(persistent: u64, volatile: u64) -> Vec { + let mut resp = Vec::new(); + let cr = CreateResponse { + structure_size: 89, + oplock_level: 0, + flags: 0, + create_action: 1, + creation_time: 0, + last_access_time: 0, + last_write_time: 0, + change_time: 0, + allocation_size: 100, + end_of_file: 100, + file_attributes: 0, + reserved2: 0, + file_id: FileId::new(persistent, volatile), + create_contexts_offset: 0, + create_contexts_length: 0, + create_contexts: vec![], + }; + cr.write_to(&mut resp).unwrap(); + resp + } + + #[test] + fn test_capture_create_file_id_found() { + // Build a full response: 64-byte header + CREATE body + let body = make_create_response(0xAAAABBBBCCCCDDDD, 0x1111222233334444); + let mut response = vec![0u8; SMB2_HEADER_LEN]; + response[..4].copy_from_slice(&SMB2_MAGIC); + response[12..14].copy_from_slice(&(Command::Create as u16).to_le_bytes()); + response.extend_from_slice(&body); + + let file_id = capture_create_file_id(&response); + // FileId::new(0xAAAABBBBCCCCDDDD, 0x1111222233334444) + // LE: persistent=[0xDD,0xDD,0xCC,0xCC,0xBB,0xBB,0xAA,0xAA] + // volatile=[0x44,0x44,0x33,0x33,0x22,0x22,0x11,0x11] + assert_eq!(file_id, Some([0xDD, 0xDD, 0xCC, 0xCC, 0xBB, 0xBB, 0xAA, 0xAA, + 0x44, 0x44, 0x33, 0x33, 0x22, 0x22, 0x11, 0x11])); + } + + #[test] + fn test_capture_create_file_id_not_create_command() { + let response = vec![0u8; SMB2_HEADER_LEN]; + // Response has no file_id because it's not a CREATE + assert!(capture_create_file_id(&response).is_none()); + } + + #[test] + fn test_capture_create_file_id_too_short() { + let mut response = vec![0u8; SMB2_HEADER_LEN + 10]; + response[12..14].copy_from_slice(&(Command::Create as u16).to_le_bytes()); + assert!(capture_create_file_id(&response).is_none()); + } + + // ── Related context inheritance ───────────────────────────────────────── + + fn build_frame(command: Command, flags: u32, next: u32, session: u64, tree: u32, file_id_bytes: &[u8; 16]) -> Vec { + let mut f = vec![0u8; SMB2_HEADER_LEN]; + f[..4].copy_from_slice(&SMB2_MAGIC); + f[4..6].copy_from_slice(&64u16.to_le_bytes()); // structure_size must be 64 + f[12..14].copy_from_slice(&(command as u16).to_le_bytes()); + f[16..20].copy_from_slice(&flags.to_le_bytes()); + f[20..24].copy_from_slice(&next.to_le_bytes()); + f[24..32].copy_from_slice(&0u64.to_le_bytes()); // message_id + // tail: Sync { reserved: 0, tree_id } + let tail_data = &tree.to_le_bytes(); + f[36..40].copy_from_slice(tail_data); + f[40..48].copy_from_slice(&session.to_le_bytes()); + // Append file_id at body_offset for CLOSE (offset 8) + let mut body = vec![0u8; 24]; // CLOSE body size + body[8..24].copy_from_slice(file_id_bytes); + f.extend_from_slice(&body); + f + } + + #[tokio::test] + async fn test_inherit_related_context_close() { + // Use MAX values so inherit_related_context overrides them + let mut frame = build_frame(Command::Close, SMB2_FLAGS_RELATED_OPERATIONS, 0, u64::MAX, u32::MAX, &[0xFF; 16]); + let mut hdr = Smb2Header::parse(&frame).unwrap().0; + + let prev_file_id = Some([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10]); + + inherit_related_context(&mut frame, &mut hdr, 7, 3, prev_file_id); + + // After inheritance: session_id should be 7 + let session = u64::from_le_bytes(frame[40..48].try_into().unwrap()); + assert_eq!(session, 7, "session_id should be inherited from 0xFFFF to 7"); + + // tree_id should be 3 + let tree = u32::from_le_bytes(frame[36..40].try_into().unwrap()); + assert_eq!(tree, 3, "tree_id should be inherited from 0xFFFFFFFF to 3"); + + // FileId should be inherited + let file_id_field = &frame[SMB2_HEADER_LEN + 8..SMB2_HEADER_LEN + 24]; + assert_eq!(file_id_field, &prev_file_id.unwrap(), "FileId should be inherited from 0xFFFF to prev_file_id"); + } + + // ── Full compound chain dispatch test ─────────────────────────────────── + + #[tokio::test] + async fn test_compound_chain_create_then_close() { + use crate::server::SmbServer; + use crate::tests::memfs::MemFsBackend; + use crate::Access; + + let server = SmbServer::builder() + .listen("127.0.0.1:0".parse().unwrap()) + .user("alice", "password") + .share( + Share::new("home", MemFsBackend::new().with_file("test.txt", b"hello world")) + .user("alice", Access::ReadWrite), + ) + .build() + .expect("build"); + + let state = server.state(); + let conn = Arc::new(Connection::new( + state.config.server_guid, + state.config.max_read_size, + state.config.max_write_size, + )); + state.active_connections.register(&conn).await; + + let identity = Identity::User { + user: "alice".to_string(), + domain: String::new(), + }; + let session = Session::new(1, identity, [0; 16], [0; 16], None, false, false, None); + let session = Arc::new(tokio::sync::RwLock::new(session)); + let share = state.find_share("home").await.expect("share"); + let tree = Arc::new(tokio::sync::RwLock::new(TreeConnect::new( + 1, + share, + Access::ReadWrite, + ))); + { + let sess = session.read().await; + sess.trees.write().await.insert(1, tree); + } + conn.sessions.write().await.insert(1, session); + + // Build a compound CREATE+CLOSE frame + // Sub-frame 1: CREATE "test.txt" + let name_utf16: Vec = "test.txt".encode_utf16().flat_map(|c| c.to_le_bytes()).collect(); + let create_body_len = 56 + name_utf16.len(); + let create_frame_len = SMB2_HEADER_LEN + create_body_len; + + let mut frame = vec![0u8; create_frame_len + SMB2_HEADER_LEN + 24]; // CREATE + CLOSE + // ── CREATE sub-frame ── + // Header + frame[..4].copy_from_slice(&SMB2_MAGIC); + frame[4..6].copy_from_slice(&64u16.to_le_bytes()); // structure_size + frame[12..14].copy_from_slice(&(Command::Create as u16).to_le_bytes()); + frame[16..20].copy_from_slice(&SMB2_FLAGS_RELATED_OPERATIONS.to_le_bytes()); + frame[20..24].copy_from_slice(&(create_frame_len as u32).to_le_bytes()); // NextCommand + frame[24..32].copy_from_slice(&0u64.to_le_bytes()); // message_id + frame[36..40].copy_from_slice(&1u32.to_le_bytes()); // tree_id + frame[40..48].copy_from_slice(&1u64.to_le_bytes()); // session_id + + // CREATE body + let body_start = SMB2_HEADER_LEN; + frame[body_start..body_start+2].copy_from_slice(&57u16.to_le_bytes()); // structure_size + frame[body_start+3] = 0; // security_flags + frame[body_start+4..body_start+8].copy_from_slice(&2u32.to_le_bytes()); // impersonation_level + // desired_access = 0x00120089 (READ_CONTROL|SYNCHRONIZE|READ_ATTR|READ_DATA) + frame[body_start+24..body_start+28].copy_from_slice(&0x00120089u32.to_le_bytes()); + frame[body_start+32..body_start+36].copy_from_slice(&7u32.to_le_bytes()); // share_access + frame[body_start+36..body_start+40].copy_from_slice(&1u32.to_le_bytes()); // create_disposition: OPEN + frame[body_start+40..body_start+44].copy_from_slice(&0x00000040u32.to_le_bytes()); // create_options: FILE_NON_DIRECTORY_FILE + // name_offset relative to header start: 64 + 56 = 120 + frame[body_start+44..body_start+46].copy_from_slice(&120u16.to_le_bytes()); + frame[body_start+46..body_start+48].copy_from_slice(&(name_utf16.len() as u16).to_le_bytes()); + frame[body_start+56..body_start+56+name_utf16.len()].copy_from_slice(&name_utf16); + + // ── CLOSE sub-frame ── + let close_start = create_frame_len; + frame[close_start..close_start+4].copy_from_slice(&SMB2_MAGIC); + frame[close_start+4..close_start+6].copy_from_slice(&64u16.to_le_bytes()); // structure_size + frame[close_start+12..close_start+14].copy_from_slice(&(Command::Close as u16).to_le_bytes()); + frame[close_start+16..close_start+20].copy_from_slice(&(SMB2_FLAGS_RELATED_OPERATIONS).to_le_bytes()); + frame[close_start+20..close_start+24].copy_from_slice(&0u32.to_le_bytes()); // NextCommand = last + frame[close_start+24..close_start+32].copy_from_slice(&0u64.to_le_bytes()); // message_id + frame[close_start+36..close_start+40].copy_from_slice(&0u32.to_le_bytes()); // tree_id = 0 (inherited) + frame[close_start+40..close_start+48].copy_from_slice(&0u64.to_le_bytes()); // session_id = 0 (inherited) + // CLOSE body: FileId = 0xFFFF...FFFF (auto-inherit from CREATE) + let close_body_start = close_start + SMB2_HEADER_LEN; + frame[close_body_start..close_body_start+2].copy_from_slice(&24u16.to_le_bytes()); // structure_size + fill_all_ones(&mut frame[close_body_start+8..close_body_start+24]); // FileId = FFFF... + + let result = dispatch_frame(&state, &conn, &frame).await; + assert!(result.is_some(), "dispatch_frame returned None"); + let response = result.unwrap(); + + // Response should contain two sub-responses + assert!(response.len() >= SMB2_HEADER_LEN * 2 + 64, + "response too short for compound chain: {} bytes", response.len()); + + // First response should be CREATE + let next1 = u32::from_le_bytes(response[20..24].try_into().unwrap()); + assert!(next1 > 0, "first response should have NextCommand for compound chain"); + + // First response header should have SERVER_TO_REDIR + let flags1 = u32::from_le_bytes(response[16..20].try_into().unwrap()); + assert_ne!(flags1 & SMB2_FLAGS_SERVER_TO_REDIR, 0); + + // Second response should be CLOSE (at offset next1) + let hdr2_start = next1 as usize; + let flags2 = u32::from_le_bytes(response[hdr2_start+16..hdr2_start+20].try_into().unwrap()); + assert_ne!(flags2 & SMB2_FLAGS_SERVER_TO_REDIR, 0); + + // Second response should be last (NextCommand = 0) + let next2 = u32::from_le_bytes(response[hdr2_start+20..hdr2_start+24].try_into().unwrap()); + assert_eq!(next2, 0, "second (last) response should have NextCommand = 0"); + } + + fn fill_all_ones(buf: &mut [u8]) { + for b in buf.iter_mut() { + *b = 0xFF; + } + } } diff --git a/vendor/smb-server/src/lib.rs b/vendor/smb-server/src/lib.rs index fa16f4a..7ac46d8 100644 --- a/vendor/smb-server/src/lib.rs +++ b/vendor/smb-server/src/lib.rs @@ -56,5 +56,5 @@ pub mod wire { #[cfg(test)] mod tests { mod dynamic_config; - mod memfs; + pub(crate) mod memfs; }