Compound request integration tests: stitch_responses, capture_file_id, inherit_context, CREATE+CLOSE chain
Some checks are pending
Test / test (push) Waiting to run
Test / build (push) Blocked by required conditions

This commit is contained in:
Warren
2026-06-23 10:46:30 +08:00
parent 637227f4e4
commit 5300b672cb
2 changed files with 316 additions and 1 deletions

View File

@@ -708,6 +708,10 @@ async fn handle_smb1_multi_protocol(
#[cfg(test)]
mod tests {
use super::*;
use crate::conn::state::{Session, TreeConnect};
use crate::proto::messages::create::{CreateResponse, FileId};
use crate::proto::header::SMB2_MAGIC;
use crate::Share;
use uuid::Uuid;
fn test_conn() -> Arc<Connection> {
@@ -761,4 +765,315 @@ mod tests {
assert_eq!(got.snapshot(), expected);
assert!(!conn.session_preauth.read().await.contains_key(&7));
}
// ── Compound request response stitching ─────────────────────────────────
/// Build a minimal response frame of `body_len` bytes (64 header + body).
fn make_response(body_len: usize) -> Vec<u8> {
let mut buf = vec![0u8; SMB2_HEADER_LEN + body_len];
buf[..4].copy_from_slice(&SMB2_MAGIC);
buf
}
#[tokio::test]
async fn test_stitch_responses_single() {
let conn = test_conn();
let r1 = make_response(100);
let responses = vec![r1.clone()];
let stitched = stitch_responses(&conn, responses).await;
// Single response: no padding, NextCommand=0
assert_eq!(stitched.len(), 100 + SMB2_HEADER_LEN);
assert_eq!(&stitched[..4], &SMB2_MAGIC);
// NextCommand at offset 20 should be 0
let next = u32::from_le_bytes(stitched[20..24].try_into().unwrap());
assert_eq!(next, 0);
}
#[tokio::test]
async fn test_stitch_responses_aligned() {
let conn = test_conn();
// Two responses: 100 bytes + 80 bytes body
let r1 = make_response(100);
let r2 = make_response(80);
let responses = vec![r1, r2];
let stitched = stitch_responses(&conn, responses).await;
// First response: 64+100 = 164, aligned to 168 (next multiple of 8)
let total1 = SMB2_HEADER_LEN + 100;
let aligned1 = (total1 + 7) & !7;
assert_eq!(aligned1, 168);
// First response's NextCommand should point to second
let next1 = u32::from_le_bytes(stitched[20..24].try_into().unwrap());
assert_eq!(next1 as usize, aligned1);
// Second response body starts at offset `aligned1`
let next2 = u32::from_le_bytes(stitched[aligned1 + 20..aligned1 + 24].try_into().unwrap());
assert_eq!(next2, 0); // Last response, no NextCommand
// Total length = aligned1 + 64(body header) + 80(body)
assert_eq!(stitched.len(), aligned1 + SMB2_HEADER_LEN + 80);
}
#[tokio::test]
async fn test_stitch_responses_three_responses() {
let conn = test_conn();
let r1 = make_response(100);
let r2 = make_response(80);
let r3 = make_response(60);
let responses = vec![r1, r2, r3];
let stitched = stitch_responses(&conn, responses).await;
// First header at 0
let total1 = SMB2_HEADER_LEN + 100;
let aligned1 = (total1 + 7) & !7;
let next1 = u32::from_le_bytes(stitched[20..24].try_into().unwrap());
assert_eq!(next1 as usize, aligned1);
// Second header at aligned1
let total2 = SMB2_HEADER_LEN + 80;
let aligned2 = aligned1 + ((total2 + 7) & !7);
let next2 = u32::from_le_bytes(stitched[aligned1 + 20..aligned1 + 24].try_into().unwrap());
assert_eq!(next2 as usize, aligned2 - aligned1);
// Third header at aligned1 + aligned2_inner
let next3 = u32::from_le_bytes(stitched[aligned2 + 20..aligned2 + 24].try_into().unwrap());
assert_eq!(next3, 0); // Last
assert_eq!(stitched.len(), aligned2 + SMB2_HEADER_LEN + 60);
}
#[tokio::test]
async fn test_stitch_responses_empty_returns_empty() {
let conn = test_conn();
let stitched = stitch_responses(&conn, vec![]).await;
assert!(stitched.is_empty());
}
// ── FileId capture from CREATE response ─────────────────────────────────
/// Build CREATE response bytes with known FileId.
fn make_create_response(persistent: u64, volatile: u64) -> Vec<u8> {
let mut resp = Vec::new();
let cr = CreateResponse {
structure_size: 89,
oplock_level: 0,
flags: 0,
create_action: 1,
creation_time: 0,
last_access_time: 0,
last_write_time: 0,
change_time: 0,
allocation_size: 100,
end_of_file: 100,
file_attributes: 0,
reserved2: 0,
file_id: FileId::new(persistent, volatile),
create_contexts_offset: 0,
create_contexts_length: 0,
create_contexts: vec![],
};
cr.write_to(&mut resp).unwrap();
resp
}
#[test]
fn test_capture_create_file_id_found() {
// Build a full response: 64-byte header + CREATE body
let body = make_create_response(0xAAAABBBBCCCCDDDD, 0x1111222233334444);
let mut response = vec![0u8; SMB2_HEADER_LEN];
response[..4].copy_from_slice(&SMB2_MAGIC);
response[12..14].copy_from_slice(&(Command::Create as u16).to_le_bytes());
response.extend_from_slice(&body);
let file_id = capture_create_file_id(&response);
// FileId::new(0xAAAABBBBCCCCDDDD, 0x1111222233334444)
// LE: persistent=[0xDD,0xDD,0xCC,0xCC,0xBB,0xBB,0xAA,0xAA]
// volatile=[0x44,0x44,0x33,0x33,0x22,0x22,0x11,0x11]
assert_eq!(file_id, Some([0xDD, 0xDD, 0xCC, 0xCC, 0xBB, 0xBB, 0xAA, 0xAA,
0x44, 0x44, 0x33, 0x33, 0x22, 0x22, 0x11, 0x11]));
}
#[test]
fn test_capture_create_file_id_not_create_command() {
let response = vec![0u8; SMB2_HEADER_LEN];
// Response has no file_id because it's not a CREATE
assert!(capture_create_file_id(&response).is_none());
}
#[test]
fn test_capture_create_file_id_too_short() {
let mut response = vec![0u8; SMB2_HEADER_LEN + 10];
response[12..14].copy_from_slice(&(Command::Create as u16).to_le_bytes());
assert!(capture_create_file_id(&response).is_none());
}
// ── Related context inheritance ─────────────────────────────────────────
fn build_frame(command: Command, flags: u32, next: u32, session: u64, tree: u32, file_id_bytes: &[u8; 16]) -> Vec<u8> {
let mut f = vec![0u8; SMB2_HEADER_LEN];
f[..4].copy_from_slice(&SMB2_MAGIC);
f[4..6].copy_from_slice(&64u16.to_le_bytes()); // structure_size must be 64
f[12..14].copy_from_slice(&(command as u16).to_le_bytes());
f[16..20].copy_from_slice(&flags.to_le_bytes());
f[20..24].copy_from_slice(&next.to_le_bytes());
f[24..32].copy_from_slice(&0u64.to_le_bytes()); // message_id
// tail: Sync { reserved: 0, tree_id }
let tail_data = &tree.to_le_bytes();
f[36..40].copy_from_slice(tail_data);
f[40..48].copy_from_slice(&session.to_le_bytes());
// Append file_id at body_offset for CLOSE (offset 8)
let mut body = vec![0u8; 24]; // CLOSE body size
body[8..24].copy_from_slice(file_id_bytes);
f.extend_from_slice(&body);
f
}
#[tokio::test]
async fn test_inherit_related_context_close() {
// Use MAX values so inherit_related_context overrides them
let mut frame = build_frame(Command::Close, SMB2_FLAGS_RELATED_OPERATIONS, 0, u64::MAX, u32::MAX, &[0xFF; 16]);
let mut hdr = Smb2Header::parse(&frame).unwrap().0;
let prev_file_id = Some([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10]);
inherit_related_context(&mut frame, &mut hdr, 7, 3, prev_file_id);
// After inheritance: session_id should be 7
let session = u64::from_le_bytes(frame[40..48].try_into().unwrap());
assert_eq!(session, 7, "session_id should be inherited from 0xFFFF to 7");
// tree_id should be 3
let tree = u32::from_le_bytes(frame[36..40].try_into().unwrap());
assert_eq!(tree, 3, "tree_id should be inherited from 0xFFFFFFFF to 3");
// FileId should be inherited
let file_id_field = &frame[SMB2_HEADER_LEN + 8..SMB2_HEADER_LEN + 24];
assert_eq!(file_id_field, &prev_file_id.unwrap(), "FileId should be inherited from 0xFFFF to prev_file_id");
}
// ── Full compound chain dispatch test ───────────────────────────────────
#[tokio::test]
async fn test_compound_chain_create_then_close() {
use crate::server::SmbServer;
use crate::tests::memfs::MemFsBackend;
use crate::Access;
let server = SmbServer::builder()
.listen("127.0.0.1:0".parse().unwrap())
.user("alice", "password")
.share(
Share::new("home", MemFsBackend::new().with_file("test.txt", b"hello world"))
.user("alice", Access::ReadWrite),
)
.build()
.expect("build");
let state = server.state();
let conn = Arc::new(Connection::new(
state.config.server_guid,
state.config.max_read_size,
state.config.max_write_size,
));
state.active_connections.register(&conn).await;
let identity = Identity::User {
user: "alice".to_string(),
domain: String::new(),
};
let session = Session::new(1, identity, [0; 16], [0; 16], None, false, false, None);
let session = Arc::new(tokio::sync::RwLock::new(session));
let share = state.find_share("home").await.expect("share");
let tree = Arc::new(tokio::sync::RwLock::new(TreeConnect::new(
1,
share,
Access::ReadWrite,
)));
{
let sess = session.read().await;
sess.trees.write().await.insert(1, tree);
}
conn.sessions.write().await.insert(1, session);
// Build a compound CREATE+CLOSE frame
// Sub-frame 1: CREATE "test.txt"
let name_utf16: Vec<u8> = "test.txt".encode_utf16().flat_map(|c| c.to_le_bytes()).collect();
let create_body_len = 56 + name_utf16.len();
let create_frame_len = SMB2_HEADER_LEN + create_body_len;
let mut frame = vec![0u8; create_frame_len + SMB2_HEADER_LEN + 24]; // CREATE + CLOSE
// ── CREATE sub-frame ──
// Header
frame[..4].copy_from_slice(&SMB2_MAGIC);
frame[4..6].copy_from_slice(&64u16.to_le_bytes()); // structure_size
frame[12..14].copy_from_slice(&(Command::Create as u16).to_le_bytes());
frame[16..20].copy_from_slice(&SMB2_FLAGS_RELATED_OPERATIONS.to_le_bytes());
frame[20..24].copy_from_slice(&(create_frame_len as u32).to_le_bytes()); // NextCommand
frame[24..32].copy_from_slice(&0u64.to_le_bytes()); // message_id
frame[36..40].copy_from_slice(&1u32.to_le_bytes()); // tree_id
frame[40..48].copy_from_slice(&1u64.to_le_bytes()); // session_id
// CREATE body
let body_start = SMB2_HEADER_LEN;
frame[body_start..body_start+2].copy_from_slice(&57u16.to_le_bytes()); // structure_size
frame[body_start+3] = 0; // security_flags
frame[body_start+4..body_start+8].copy_from_slice(&2u32.to_le_bytes()); // impersonation_level
// desired_access = 0x00120089 (READ_CONTROL|SYNCHRONIZE|READ_ATTR|READ_DATA)
frame[body_start+24..body_start+28].copy_from_slice(&0x00120089u32.to_le_bytes());
frame[body_start+32..body_start+36].copy_from_slice(&7u32.to_le_bytes()); // share_access
frame[body_start+36..body_start+40].copy_from_slice(&1u32.to_le_bytes()); // create_disposition: OPEN
frame[body_start+40..body_start+44].copy_from_slice(&0x00000040u32.to_le_bytes()); // create_options: FILE_NON_DIRECTORY_FILE
// name_offset relative to header start: 64 + 56 = 120
frame[body_start+44..body_start+46].copy_from_slice(&120u16.to_le_bytes());
frame[body_start+46..body_start+48].copy_from_slice(&(name_utf16.len() as u16).to_le_bytes());
frame[body_start+56..body_start+56+name_utf16.len()].copy_from_slice(&name_utf16);
// ── CLOSE sub-frame ──
let close_start = create_frame_len;
frame[close_start..close_start+4].copy_from_slice(&SMB2_MAGIC);
frame[close_start+4..close_start+6].copy_from_slice(&64u16.to_le_bytes()); // structure_size
frame[close_start+12..close_start+14].copy_from_slice(&(Command::Close as u16).to_le_bytes());
frame[close_start+16..close_start+20].copy_from_slice(&(SMB2_FLAGS_RELATED_OPERATIONS).to_le_bytes());
frame[close_start+20..close_start+24].copy_from_slice(&0u32.to_le_bytes()); // NextCommand = last
frame[close_start+24..close_start+32].copy_from_slice(&0u64.to_le_bytes()); // message_id
frame[close_start+36..close_start+40].copy_from_slice(&0u32.to_le_bytes()); // tree_id = 0 (inherited)
frame[close_start+40..close_start+48].copy_from_slice(&0u64.to_le_bytes()); // session_id = 0 (inherited)
// CLOSE body: FileId = 0xFFFF...FFFF (auto-inherit from CREATE)
let close_body_start = close_start + SMB2_HEADER_LEN;
frame[close_body_start..close_body_start+2].copy_from_slice(&24u16.to_le_bytes()); // structure_size
fill_all_ones(&mut frame[close_body_start+8..close_body_start+24]); // FileId = FFFF...
let result = dispatch_frame(&state, &conn, &frame).await;
assert!(result.is_some(), "dispatch_frame returned None");
let response = result.unwrap();
// Response should contain two sub-responses
assert!(response.len() >= SMB2_HEADER_LEN * 2 + 64,
"response too short for compound chain: {} bytes", response.len());
// First response should be CREATE
let next1 = u32::from_le_bytes(response[20..24].try_into().unwrap());
assert!(next1 > 0, "first response should have NextCommand for compound chain");
// First response header should have SERVER_TO_REDIR
let flags1 = u32::from_le_bytes(response[16..20].try_into().unwrap());
assert_ne!(flags1 & SMB2_FLAGS_SERVER_TO_REDIR, 0);
// Second response should be CLOSE (at offset next1)
let hdr2_start = next1 as usize;
let flags2 = u32::from_le_bytes(response[hdr2_start+16..hdr2_start+20].try_into().unwrap());
assert_ne!(flags2 & SMB2_FLAGS_SERVER_TO_REDIR, 0);
// Second response should be last (NextCommand = 0)
let next2 = u32::from_le_bytes(response[hdr2_start+20..hdr2_start+24].try_into().unwrap());
assert_eq!(next2, 0, "second (last) response should have NextCommand = 0");
}
fn fill_all_ones(buf: &mut [u8]) {
for b in buf.iter_mut() {
*b = 0xFF;
}
}
}

View File

@@ -56,5 +56,5 @@ pub mod wire {
#[cfg(test)]
mod tests {
mod dynamic_config;
mod memfs;
pub(crate) mod memfs;
}