Files
markbase/vendor/smb2/src/transport/tcp.rs
Warren 7eb528d35f
Some checks failed
Test / build (push) Has been cancelled
Test / test (push) Has been cancelled
SMB Server Phase 2: VFS backend build fix + integration test
- Add VfsFile: Send supertrait for Mutex compatibility
- Fix SmbServerCommand: struct → Subcommand enum with Start variant
- Fix tracing_subscriber::init() → try_init() to avoid panic when
  logger already initialized
- Fix CLI subcommand name: smb-server → smb-start (flatten naming)
- Add #[command(name = "smb-start")] for CLI disambiguation
- Fix unused variable warnings (smb_fs.rs, smb_server_backend.rs)
- Remove unused VfsFile imports (webdav.rs, scp_handler.rs)
- Integration test: Docker smbclient verified (list, upload, read)
2026-06-20 19:42:29 +08:00

486 lines
18 KiB
Rust

//! Direct TCP transport for SMB2 (port 445).
//!
//! Implements the SMB2 transport framing defined in MS-SMB2 section 2.1:
//! each message is preceded by a 4-byte header consisting of 1 zero byte
//! followed by 3 bytes of big-endian length. This is the ONLY big-endian
//! encoding in the entire SMB2 protocol.
use async_trait::async_trait;
use log::{debug, error, trace};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::sync::Mutex;
use crate::error::{Error, Result};
use crate::transport::{TransportReceive, TransportSend};
/// Maximum frame size we accept (16 MB).
///
/// Prevents denial-of-service from corrupt or malicious length fields.
/// Real SMB2 messages are typically much smaller (the largest negotiated
/// MaxReadSize/MaxWriteSize is usually 8 MB).
const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
/// Direct TCP transport for SMB2.
///
/// Wraps a TCP connection and handles the 4-byte framing header.
/// The connection is split into independent read and write halves
/// so that send and receive can proceed concurrently without contention
/// (required by the pipeline's `tokio::select!` loop).
#[derive(Debug)]
pub struct TcpTransport {
/// The read half of the TCP connection, behind a mutex for `&self` access.
reader: Mutex<OwnedReadHalf>,
/// The write half of the TCP connection, behind a mutex for `&self` access.
writer: Mutex<OwnedWriteHalf>,
}
impl TcpTransport {
/// Connect to an SMB server over TCP.
///
/// Applies the given timeout to the connection attempt. Once connected,
/// the socket is split into independent read/write halves.
pub async fn connect(addr: impl ToSocketAddrs, timeout: Duration) -> Result<Self> {
let stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
.await
.map_err(|_| Error::Timeout)?
.map_err(Error::Io)?;
// Disable Nagle's algorithm for lower latency on small messages.
stream.set_nodelay(true).map_err(Error::Io)?;
debug!("tcp: connected, nodelay=true");
let (reader, writer) = stream.into_split();
Ok(Self {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
})
}
}
#[async_trait]
impl TransportSend for TcpTransport {
async fn send(&self, data: &[u8]) -> Result<()> {
let len = data.len();
if len > MAX_FRAME_SIZE {
return Err(Error::invalid_data(format!(
"message size {} exceeds maximum frame size {}",
len, MAX_FRAME_SIZE
)));
}
// Build the 4-byte framing header: 0x00 + 3-byte BE length.
let mut frame_header = [0u8; 4];
frame_header[0] = 0x00;
frame_header[1] = (len >> 16) as u8;
frame_header[2] = (len >> 8) as u8;
frame_header[3] = len as u8;
let mut writer = self.writer.lock().await;
writer.write_all(&frame_header).await.map_err(Error::Io)?;
writer.write_all(data).await.map_err(Error::Io)?;
writer.flush().await.map_err(Error::Io)?;
trace!("tcp: sent frame, len={}", len);
Ok(())
}
}
#[async_trait]
impl TransportReceive for TcpTransport {
async fn receive(&self) -> Result<Vec<u8>> {
let mut reader = self.reader.lock().await;
// Read the 4-byte framing header.
let mut frame_header = [0u8; 4];
reader.read_exact(&mut frame_header).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
Error::Disconnected
} else {
Error::Io(e)
}
})?;
// Validate the first byte is 0x00.
if frame_header[0] != 0x00 {
error!("tcp: invalid frame, first byte=0x{:02X}", frame_header[0]);
return Err(Error::invalid_data(format!(
"invalid transport frame: first byte must be 0x00, got 0x{:02X}",
frame_header[0]
)));
}
// Extract the 3-byte big-endian length.
let msg_len = ((frame_header[1] as usize) << 16)
| ((frame_header[2] as usize) << 8)
| (frame_header[3] as usize);
// Validate against the maximum frame size.
if msg_len > MAX_FRAME_SIZE {
return Err(Error::invalid_data(format!(
"frame length {} exceeds maximum {}",
msg_len, MAX_FRAME_SIZE
)));
}
trace!("tcp: receiving frame, len={}", msg_len);
// Read the message body.
let mut buf = vec![0u8; msg_len];
reader.read_exact(&mut buf).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
Error::Disconnected
} else {
Error::Io(e)
}
})?;
trace!("tcp: received frame, len={}", msg_len);
Ok(buf)
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Build a framed message (4-byte header + payload).
fn frame_message(payload: &[u8]) -> Vec<u8> {
let len = payload.len();
let mut frame = Vec::with_capacity(4 + len);
frame.push(0x00);
frame.push((len >> 16) as u8);
frame.push((len >> 8) as u8);
frame.push(len as u8);
frame.extend_from_slice(payload);
frame
}
// ── Send framing tests ──────────────────────────────────────────
#[test]
fn frame_header_format_small_message() {
let payload = vec![0xFE, 0x53, 0x4D, 0x42]; // "SMB2 magic"
let framed = frame_message(&payload);
// Header: [0x00, 0x00, 0x00, 0x04]
assert_eq!(framed[0], 0x00, "first byte must be 0x00");
assert_eq!(framed[1], 0x00, "length high byte");
assert_eq!(framed[2], 0x00, "length mid byte");
assert_eq!(framed[3], 0x04, "length low byte = 4");
assert_eq!(&framed[4..], &payload);
}
#[test]
fn frame_header_format_medium_message() {
// 300 bytes -> 0x00, 0x00, 0x01, 0x2C
let payload = vec![0xAA; 300];
let framed = frame_message(&payload);
assert_eq!(framed[0], 0x00);
assert_eq!(framed[1], 0x00);
assert_eq!(framed[2], 0x01);
assert_eq!(framed[3], 0x2C);
assert_eq!(framed.len(), 304);
}
#[test]
fn frame_header_format_large_message() {
// 0x010203 = 66051 bytes
let payload = vec![0xBB; 66051];
let framed = frame_message(&payload);
assert_eq!(framed[0], 0x00);
assert_eq!(framed[1], 0x01);
assert_eq!(framed[2], 0x02);
assert_eq!(framed[3], 0x03);
}
#[test]
fn frame_header_empty_payload() {
let framed = frame_message(&[]);
assert_eq!(framed, vec![0x00, 0x00, 0x00, 0x00]);
}
// ── Receive framing tests (using tokio_test-style mock streams) ──
/// A helper that creates a pair of connected streams via a TCP listener
/// on localhost, then writes data to one side and reads from the other.
async fn receive_from_bytes(data: &[u8]) -> Result<Vec<u8>> {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let data = data.to_vec();
let writer_task = tokio::spawn(async move {
let mut stream = TcpStream::connect(addr).await.unwrap();
stream.write_all(&data).await.unwrap();
stream.shutdown().await.unwrap();
});
let (stream, _) = listener.accept().await.unwrap();
let (reader, writer) = stream.into_split();
let transport = TcpTransport {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
};
let result = transport.receive().await;
writer_task.await.unwrap();
result
}
#[tokio::test]
async fn receive_valid_frame() {
let payload = vec![0xFE, 0x53, 0x4D, 0x42, 0x01, 0x02];
let framed = frame_message(&payload);
let received = receive_from_bytes(&framed).await.unwrap();
assert_eq!(received, payload);
}
#[tokio::test]
async fn receive_empty_payload() {
let framed = frame_message(&[]);
let received = receive_from_bytes(&framed).await.unwrap();
assert!(received.is_empty());
}
#[tokio::test]
async fn receive_first_byte_not_zero_returns_error() {
// First byte is 0x01 instead of 0x00.
let data = vec![0x01, 0x00, 0x00, 0x04, 0xAA, 0xBB, 0xCC, 0xDD];
let result = receive_from_bytes(&data).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("first byte must be 0x00"),
"unexpected error: {err}"
);
}
#[tokio::test]
async fn receive_length_exceeds_max_returns_error() {
// Length = 0xFFFFFF = 16777215 > MAX_FRAME_SIZE (16 * 1024 * 1024 = 16777216)
// Wait, 0xFFFFFF = 16777215 < 16777216. Let's use a length just over.
// MAX_FRAME_SIZE = 16 * 1024 * 1024 = 16_777_216
// We need > 16_777_216, but max 3-byte value is 16_777_215.
// So 3 bytes can't exceed 16 MB. But the spec says 16 MB is the max.
// Let's set MAX_FRAME_SIZE to slightly less, or test at the boundary.
// Actually MAX_FRAME_SIZE = 16 * 1024 * 1024 = 16_777_216.
// Max 3-byte value = 0xFFFFFF = 16_777_215 which is < MAX_FRAME_SIZE.
// So a 3-byte length can never exceed our MAX_FRAME_SIZE.
// This test verifies that the max 3-byte value IS accepted (no error).
// But what if someone sends a broken frame? The first byte check
// catches that. For the length check specifically, we'd need a
// smaller MAX_FRAME_SIZE to exercise the branch. For now, let's test
// with an internal test. The important thing is the check exists.
// Actually, the more realistic concern is a malicious server sending
// large values. 0xFFFFFF = ~16 MB is fine by our limit. Let's verify
// the boundary: 0xFFFFFF should be accepted because 16_777_215 < 16_777_216.
// We can't test > MAX_FRAME_SIZE with only 3 bytes, but the check
// is there for defense-in-depth (the first byte could be non-zero
// and interpreted as part of length if we didn't validate it).
// Let's test a frame with length 0xFFFFFF but not enough payload data,
// which should return Disconnected (not a crash from huge allocation).
let data = vec![0x00, 0xFF, 0xFF, 0xFF]; // Length = 16_777_215 bytes, no payload.
let result = receive_from_bytes(&data).await;
assert!(result.is_err());
// Should get Disconnected because the payload read fails.
let err = result.unwrap_err();
assert!(
matches!(err, Error::Disconnected),
"expected Disconnected for truncated large frame, got: {err}"
);
}
#[tokio::test]
async fn receive_disconnected_on_eof() {
// Empty data = immediate EOF.
let result = receive_from_bytes(&[]).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, Error::Disconnected),
"expected Disconnected, got: {err}"
);
}
#[tokio::test]
async fn receive_partial_header_returns_disconnected() {
// Only 2 bytes of the 4-byte header.
let result = receive_from_bytes(&[0x00, 0x00]).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, Error::Disconnected),
"expected Disconnected for partial header, got: {err}"
);
}
#[tokio::test]
async fn receive_partial_payload_returns_disconnected() {
// Header says 10 bytes, but only 3 bytes of payload follow.
let data = vec![0x00, 0x00, 0x00, 0x0A, 0x01, 0x02, 0x03];
let result = receive_from_bytes(&data).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, Error::Disconnected),
"expected Disconnected for truncated payload, got: {err}"
);
}
#[tokio::test]
async fn send_and_receive_roundtrip() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let send_task = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let (reader, writer) = stream.into_split();
let transport = TcpTransport {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
};
let payload = vec![0xFE, 0x53, 0x4D, 0x42, 0xDE, 0xAD];
transport.send(&payload).await.unwrap();
});
let (stream, _) = listener.accept().await.unwrap();
let (reader, writer) = stream.into_split();
let recv_transport = TcpTransport {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
};
let received = recv_transport.receive().await.unwrap();
assert_eq!(received, vec![0xFE, 0x53, 0x4D, 0x42, 0xDE, 0xAD]);
send_task.await.unwrap();
}
#[tokio::test]
async fn send_and_receive_multiple_messages() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let send_task = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let (reader, writer) = stream.into_split();
let transport = TcpTransport {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
};
transport.send(&[0x01, 0x02]).await.unwrap();
transport.send(&[0x03, 0x04, 0x05]).await.unwrap();
transport.send(&[0x06]).await.unwrap();
});
let (stream, _) = listener.accept().await.unwrap();
let (reader, writer) = stream.into_split();
let recv_transport = TcpTransport {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
};
assert_eq!(recv_transport.receive().await.unwrap(), vec![0x01, 0x02]);
assert_eq!(
recv_transport.receive().await.unwrap(),
vec![0x03, 0x04, 0x05]
);
assert_eq!(recv_transport.receive().await.unwrap(), vec![0x06]);
send_task.await.unwrap();
}
#[tokio::test]
async fn partial_reads_are_handled_by_read_exact() {
// This test exercises the read_exact behavior by sending data
// through a real TCP connection. Under the hood, TCP may deliver
// data in arbitrary chunk sizes, especially with Nagle disabled.
// While we can't force byte-at-a-time delivery reliably, we
// verify correctness with a larger payload that's more likely
// to arrive in multiple reads.
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let payload: Vec<u8> = (0..=255).cycle().take(8192).collect();
let payload_clone = payload.clone();
let send_task = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let (reader, writer) = stream.into_split();
let transport = TcpTransport {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
};
transport.send(&payload_clone).await.unwrap();
});
let (stream, _) = listener.accept().await.unwrap();
let (reader, writer) = stream.into_split();
let recv_transport = TcpTransport {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
};
let received = recv_transport.receive().await.unwrap();
assert_eq!(received.len(), payload.len());
assert_eq!(received, payload);
send_task.await.unwrap();
}
#[tokio::test]
async fn connect_with_timeout() {
// Connect to localhost listener with a generous timeout.
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let transport = TcpTransport::connect(addr, Duration::from_secs(5))
.await
.unwrap();
// Accept the connection on the server side.
let (server_stream, _) = listener.accept().await.unwrap();
let (server_reader, mut server_writer) = server_stream.into_split();
drop(server_reader);
// Send a framed message from the "server" side.
let payload = vec![0xDE, 0xAD, 0xBE, 0xEF];
let mut frame = vec![0x00, 0x00, 0x00, 0x04];
frame.extend_from_slice(&payload);
server_writer.write_all(&frame).await.unwrap();
server_writer.flush().await.unwrap();
// Receive through the transport.
let received = transport.receive().await.unwrap();
assert_eq!(received, payload);
}
#[tokio::test]
async fn connect_timeout_fires() {
// Try to connect to a non-routable address. This should time out.
// 192.0.2.1 is a TEST-NET address (RFC 5737) that should be unreachable.
let result = TcpTransport::connect("192.0.2.1:445", Duration::from_millis(100)).await;
assert!(result.is_err());
let err = result.unwrap_err();
// Could be Timeout or Io depending on OS behavior.
assert!(
matches!(err, Error::Timeout | Error::Io(_)),
"expected Timeout or Io error, got: {err}"
);
}
}