Fix code quality: trailing whitespace, unused imports, clippy warnings

- Fix trailing whitespace in kex.rs and s3.rs
- Add missing KexProposal import in kex_complete.rs
- Auto-fix clippy warnings across all crates
- All 153 tests pass
This commit is contained in:
Warren
2026-06-19 05:21:38 +08:00
parent 4b37e524cf
commit d94cb2df4c
135 changed files with 7256 additions and 4321 deletions

View File

@@ -1,11 +1,11 @@
use crate::ssh_server::packet::{SshPacket, PacketType};
use std::io::Write;
use anyhow::{Result, anyhow};
use crate::ssh_server::packet::{PacketType, SshPacket};
use anyhow::{anyhow, Result};
use base64::{engine::general_purpose, Engine as _};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use log::{info, warn, debug};
use base64::{Engine as _, engine::general_purpose};
use log::{debug, info, warn};
use std::io::Write;
use ed25519_dalek::{VerifyingKey, Signature};
use ed25519_dalek::{Signature, VerifyingKey};
use crate::provider::{DataProvider, ProviderError};
@@ -27,7 +27,11 @@ impl AuthHandler {
}
/// 处理SSH_MSG_USERAUTH_REQUEST参考OpenSSH auth2.c: userauth_request()
pub fn handle_userauth_request(&mut self, packet: &SshPacket, session_id: &[u8]) -> Result<AuthResult> {
pub fn handle_userauth_request(
&mut self,
packet: &SshPacket,
session_id: &[u8],
) -> Result<AuthResult> {
info!("Processing SSH_MSG_USERAUTH_REQUEST");
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
@@ -41,7 +45,10 @@ impl AuthHandler {
let service = read_ssh_string(&mut cursor)?;
let method = read_ssh_string(&mut cursor)?;
info!("Auth request: user={}, service={}, method={}", user, service, method);
info!(
"Auth request: user={}, service={}, method={}",
user, service, method
);
if service != "ssh-connection" {
warn!("Unsupported service: {}", service);
@@ -62,18 +69,28 @@ impl AuthHandler {
}
/// 处理password认证参考OpenSSH auth-passwd.c
fn handle_password_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result<AuthResult> {
fn handle_password_auth(
&mut self,
cursor: &mut std::io::Cursor<&[u8]>,
user: &str,
) -> Result<AuthResult> {
info!("Handling password auth for user: {}", user);
let change_password = cursor.read_u8()? != 0;
if change_password {
warn!("Password change not supported");
return Ok(AuthResult::Failure("Password change not supported".to_string()));
return Ok(AuthResult::Failure(
"Password change not supported".to_string(),
));
}
let password = read_ssh_string(cursor)?;
debug!("Password auth attempt: user={}, password length={}", user, password.len());
debug!(
"Password auth attempt: user={}, password length={}",
user,
password.len()
);
match self.provider.check_password(user, &password) {
Ok(true) => {
@@ -88,9 +105,7 @@ impl AuthHandler {
warn!("User not found: {}", msg);
Ok(AuthResult::Failure("password,publickey".to_string()))
}
Err(e) => {
Err(anyhow!("Password auth error: {}", e))
}
Err(e) => Err(anyhow!("Password auth error: {}", e)),
}
}
@@ -145,7 +160,12 @@ impl AuthHandler {
let algorithm = read_ssh_string(cursor)?;
let public_key_blob = read_ssh_string_bytes(cursor)?;
info!("Publickey auth: algorithm={}, blob_len={}, is_signed={}", algorithm, public_key_blob.len(), is_signed);
info!(
"Publickey auth: algorithm={}, blob_len={}, is_signed={}",
algorithm,
public_key_blob.len(),
is_signed
);
if !self.is_key_authorized(user, &algorithm, &public_key_blob)? {
warn!("Public key not authorized for user: {}", user);
@@ -160,14 +180,26 @@ impl AuthHandler {
let signature_blob = read_ssh_string_bytes(cursor)?;
self.verify_signature(&algorithm, &public_key_blob, &signature_blob, user, service, session_id)?;
self.verify_signature(
&algorithm,
&public_key_blob,
&signature_blob,
user,
service,
session_id,
)?;
info!("Publickey auth successful for user: {}", user);
Ok(AuthResult::Success)
}
/// 检查public key是否在授权列表中数据库优先fallback到filesystem
fn is_key_authorized(&self, user: &str, algorithm: &str, public_key_blob: &[u8]) -> Result<bool> {
fn is_key_authorized(
&self,
user: &str,
algorithm: &str,
public_key_blob: &[u8],
) -> Result<bool> {
// 1. 先检查数据库
match self.provider.get_public_keys(user) {
Ok(keys) => {
@@ -187,10 +219,12 @@ impl AuthHandler {
Err(_) => match std::fs::read_to_string("data/authorized_keys") {
Ok(c) => c,
Err(_) => return Ok(false),
}
},
};
Ok(content.lines().any(|line| public_key_matches_line(line, algorithm, public_key_blob)))
Ok(content
.lines()
.any(|line| public_key_matches_line(line, algorithm, public_key_blob)))
}
/// 验证Ed25519签名RFC 4252 §7
@@ -246,7 +280,8 @@ impl AuthHandler {
signed_data.write_all(public_key_blob)?;
// 验证签名
verifying_key.verify_strict(&signed_data, &signature)
verifying_key
.verify_strict(&signed_data, &signature)
.map_err(|e| anyhow!("Ed25519 signature verification failed: {}", e))
}
}
@@ -270,10 +305,10 @@ fn parse_ed25519_verifying_key(public_key_blob: &[u8]) -> Result<VerifyingKey> {
if key_bytes.len() != 32 {
return Err(anyhow!("Invalid Ed25519 key length: {}", key_bytes.len()));
}
let key_array: [u8; 32] = key_bytes.try_into()
let key_array: [u8; 32] = key_bytes
.try_into()
.map_err(|_| anyhow!("Invalid Ed25519 key data"))?;
VerifyingKey::from_bytes(&key_array)
.map_err(|e| anyhow!("Invalid Ed25519 key: {}", e))
VerifyingKey::from_bytes(&key_array).map_err(|e| anyhow!("Invalid Ed25519 key: {}", e))
}
/// 解析Ed25519签名blobSSH格式 -> Signature
@@ -285,9 +320,13 @@ fn parse_ed25519_signature(signature_blob: &[u8]) -> Result<Signature> {
}
let sig_bytes = read_ssh_string_bytes(&mut cursor)?;
if sig_bytes.len() != 64 {
return Err(anyhow!("Invalid Ed25519 signature length: {}", sig_bytes.len()));
return Err(anyhow!(
"Invalid Ed25519 signature length: {}",
sig_bytes.len()
));
}
let sig_array: [u8; 64] = sig_bytes.try_into()
let sig_array: [u8; 64] = sig_bytes
.try_into()
.map_err(|_| anyhow!("Invalid Ed25519 signature data"))?;
Ok(Signature::from_bytes(&sig_array))
}
@@ -305,7 +344,9 @@ fn public_key_matches_line(line: &str, algorithm: &str, public_key_blob: &[u8])
if parts[0] != algorithm {
return false;
}
base64_decode(parts[1]).map(|decoded| decoded == public_key_blob).unwrap_or(false)
base64_decode(parts[1])
.map(|decoded| decoded == public_key_blob)
.unwrap_or(false)
}
fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
@@ -323,7 +364,8 @@ fn read_ssh_string_bytes<R: std::io::Read>(reader: &mut R) -> Result<Vec<u8>> {
}
fn base64_decode(input: &str) -> Result<Vec<u8>> {
general_purpose::STANDARD.decode(input)
general_purpose::STANDARD
.decode(input)
.map_err(|e| anyhow!("Base64 decode error: {}", e))
}
@@ -335,7 +377,10 @@ mod tests {
#[test]
fn test_userauth_success_packet() {
let packet = AuthHandler::build_userauth_success().unwrap();
assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_SUCCESS as u8);
assert_eq!(
packet.payload[0],
PacketType::SSH_MSG_USERAUTH_SUCCESS as u8
);
}
#[test]
@@ -343,6 +388,9 @@ mod tests {
let methods = vec!["password".to_string(), "publickey".to_string()];
let packet = AuthHandler::build_userauth_failure(&methods, false).unwrap();
assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_FAILURE as u8);
assert_eq!(
packet.payload[0],
PacketType::SSH_MSG_USERAUTH_FAILURE as u8
);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,33 +1,33 @@
// SSH加密通道实现Phase 4
// 参考OpenSSH cipher.c, mac.c
use aes::Aes128; // 改为AES-128协商算法是aes128-ctr
use super::crypto::SessionKeys;
use aes::Aes128; // 改为AES-128协商算法是aes128-ctr
use anyhow::{anyhow, Result};
use byteorder::{BigEndian, WriteBytesExt};
use cipher::{KeyIvInit, StreamCipher};
use ctr::Ctr128BE;
use hmac::{Hmac, Mac};
use log::info;
use sha2::Sha256;
use cipher::{KeyIvInit, StreamCipher};
use std::io::Write;
use anyhow::{Result, anyhow};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use log::{info, debug, warn};
use super::crypto::SessionKeys;
type Aes128Ctr = Ctr128BE<Aes128>; // AES-128-CTR16字节密钥
type Aes128Ctr = Ctr128BE<Aes128>; // AES-128-CTR16字节密钥
type HmacSha256 = Hmac<Sha256>;
/// SSH加密通道管理器参考OpenSSH struct sshcipher_ctx
pub struct EncryptionContext {
pub session_id: Vec<u8>, // session identifier (exchange hash)
pub encryption_key_ctos: Vec<u8>, // 客户端→服务器加密密钥
pub encryption_key_stoc: Vec<u8>, // 服务器→客户端加密密钥
pub mac_key_ctos: Vec<u8>, // 客户端→服务器MAC密钥
pub mac_key_stoc: Vec<u8>, // 服务器→客户端MAC密钥
pub iv_ctos: Vec<u8>, // 客户端→服务器IV
pub iv_stoc: Vec<u8>, // 服务器→客户端IV
pub sequence_number_ctos: u32, // 客户端→服务器序列号
pub sequence_number_stoc: u32, // 服务器→客户端序列号
pub cipher_ctos: Option<Aes128Ctr>, // 客户端→服务器cipher实例持久化
pub cipher_stoc: Option<Aes128Ctr>, // 服务器→客户端cipher实例持久化
pub encryption_key_ctos: Vec<u8>, // 客户端→服务器加密密钥
pub encryption_key_stoc: Vec<u8>, // 服务器→客户端加密密钥
pub mac_key_ctos: Vec<u8>, // 客户端→服务器MAC密钥
pub mac_key_stoc: Vec<u8>, // 服务器→客户端MAC密钥
pub iv_ctos: Vec<u8>, // 客户端→服务器IV
pub iv_stoc: Vec<u8>, // 服务器→客户端IV
pub sequence_number_ctos: u32, // 客户端→服务器序列号
pub sequence_number_stoc: u32, // 服务器→客户端序列号
pub cipher_ctos: Option<Aes128Ctr>, // 客户端→服务器cipher实例持久化
pub cipher_stoc: Option<Aes128Ctr>, // 服务器→客户端cipher实例持久化
}
impl Default for EncryptionContext {
@@ -53,27 +53,33 @@ impl EncryptionContext {
/// OpenSSH cipher.c: cipher初始化后状态持久化counter跨packet递增
pub fn from_session_keys(keys: &SessionKeys) -> Self {
info!("Initializing ciphers with session keys:");
info!(" encryption_key_ctos (16 bytes): {:?}", &keys.encryption_key_ctos[..16]);
info!(
" encryption_key_ctos (16 bytes): {:?}",
&keys.encryption_key_ctos[..16]
);
info!(" iv_ctos (16 bytes): {:?}", &keys.iv_ctos[..16]);
info!(" encryption_key_stoc (16 bytes): {:?}", &keys.encryption_key_stoc[..16]);
info!(
" encryption_key_stoc (16 bytes): {:?}",
&keys.encryption_key_stoc[..16]
);
info!(" iv_stoc (16 bytes): {:?}", &keys.iv_stoc[..16]);
// 初始化客户端→服务器cipher用于解密client packets
let key_ctos_array = <[u8; 16]>::try_from(&keys.encryption_key_ctos[..16])
.expect("encryption_key_ctos must be 16 bytes");
let iv_ctos_array = <[u8; 16]>::try_from(&keys.iv_ctos[..16])
.expect("iv_ctos must be 16 bytes");
let iv_ctos_array =
<[u8; 16]>::try_from(&keys.iv_ctos[..16]).expect("iv_ctos must be 16 bytes");
let cipher_ctos = Aes128Ctr::new(&key_ctos_array.into(), &iv_ctos_array.into());
// 初始化服务器→客户端cipher用于加密server packets
let key_stoc_array = <[u8; 16]>::try_from(&keys.encryption_key_stoc[..16])
.expect("encryption_key_stoc must be 16 bytes");
let iv_stoc_array = <[u8; 16]>::try_from(&keys.iv_stoc[..16])
.expect("iv_stoc must be 16 bytes");
let iv_stoc_array =
<[u8; 16]>::try_from(&keys.iv_stoc[..16]).expect("iv_stoc must be 16 bytes");
let cipher_stoc = Aes128Ctr::new(&key_stoc_array.into(), &iv_stoc_array.into());
info!("Ciphers initialized successfully");
Self {
session_id: keys.session_id.clone(),
encryption_key_ctos: keys.encryption_key_ctos.clone(),
@@ -84,26 +90,26 @@ impl EncryptionContext {
iv_stoc: keys.iv_stoc.clone(),
sequence_number_ctos: 0,
sequence_number_stoc: 0,
cipher_ctos: Some(cipher_ctos), // 持久化cipher实例
cipher_stoc: Some(cipher_stoc), // 持久化cipher实例
cipher_ctos: Some(cipher_ctos), // 持久化cipher实例
cipher_stoc: Some(cipher_stoc), // 持久化cipher实例
}
}
/// RFC 4344: Compute AES-CTR IV for a specific packet
/// IV = nonce(8 bytes from derived IV) + sequence_number(8 bytes)
fn compute_ctr_iv(nonce: &[u8], sequence_number: u32) -> Vec<u8> {
let mut iv = Vec::with_capacity(16);
// Nonce: first 8 bytes of derived IV (constant)
iv.extend_from_slice(&nonce[..8]);
// Counter: sequence number as 8-byte big-endian
iv.extend_from_slice(&sequence_number.to_be_bytes());
iv.extend_from_slice(&[0u8; 4]); // Upper 4 bytes = 0
iv
}
/// 加密packet参考OpenSSH cipher.c: cipher_encrypt()
pub fn encrypt_packet(
&mut self,
@@ -113,17 +119,17 @@ impl EncryptionContext {
) -> Result<Vec<u8>> {
let key_array = <[u8; 16]>::try_from(encryption_key)?;
let iv_array = <[u8; 16]>::try_from(iv)?;
let mut cipher = Aes128Ctr::new(&key_array.into(), &iv_array.into());
let mut ciphertext = plaintext.to_vec();
cipher.apply_keystream(&mut ciphertext);
self.sequence_number_stoc += 1;
Ok(ciphertext)
}
/// 解密packet参考OpenSSH cipher.c: cipher_decrypt()
pub fn decrypt_packet(
&mut self,
@@ -133,17 +139,17 @@ impl EncryptionContext {
) -> Result<Vec<u8>> {
let key_array = <[u8; 16]>::try_from(encryption_key)?;
let iv_array = <[u8; 16]>::try_from(iv)?;
let mut cipher = Aes128Ctr::new(&key_array.into(), &iv_array.into());
let mut plaintext = ciphertext.to_vec();
cipher.apply_keystream(&mut plaintext);
self.sequence_number_ctos += 1;
Ok(plaintext)
}
/// 计算MAC参考OpenSSH mac.c: mac_compute()
pub fn compute_mac(
&self,
@@ -152,17 +158,17 @@ impl EncryptionContext {
mac_key: &[u8],
) -> Result<Vec<u8>> {
// HMAC-SHA256 MAC计算参考OpenSSH mac.c
let mut mac = HmacSha256::new_from_slice(mac_key)?;
// OpenSSH MAC格式sequence_number + data
mac.update(&sequence_number.to_be_bytes());
mac.update(data);
let result = mac.finalize();
Ok(result.into_bytes().to_vec())
}
/// 验证MAC参考OpenSSH mac.c: mac_check()
pub fn verify_mac(
&self,
@@ -172,14 +178,14 @@ impl EncryptionContext {
mac_key: &[u8],
) -> Result<bool> {
// HMAC验证参考OpenSSH mac.c
let computed_mac = self.compute_mac(sequence_number, data, mac_key)?;
// 防止时间攻击(使用常量时间比较)
if computed_mac.len() != expected_mac.len() {
return Ok(false);
}
// 简化实现:直接比较(实际应使用常量时间比较)
Ok(computed_mac == expected_mac)
}
@@ -187,11 +193,11 @@ impl EncryptionContext {
/// SSH加密packet封装参考OpenSSH packet.c: ssh_packet_write_poll()
pub struct EncryptedPacket {
pub packet_length: u32, // 加密后packet长度
pub padding_length: u8, // padding长度加密后
pub payload: Vec<u8>, // payload加密后
pub padding: Vec<u8>, // padding加密后
pub mac: Vec<u8>, // MAC32字节HMAC-SHA256
pub packet_length: u32, // 加密后packet长度
pub padding_length: u8, // padding长度加密后
pub payload: Vec<u8>, // payload加密后
pub padding: Vec<u8>, // padding加密后
pub mac: Vec<u8>, // MAC32字节HMAC-SHA256
}
impl EncryptedPacket {
@@ -204,82 +210,88 @@ impl EncryptedPacket {
) -> Result<Self> {
let block_size = 16;
let min_padding = 4;
let payload_length = plaintext_payload.len();
// RFC 4253: entire plaintext packet (including 4-byte packet_length field) must be multiple of block_size
// plaintext_packet = packet_length_field(4) + padding_length(1) + payload + padding
// So: (4 + 1 + payload_length + padding_length) % 16 == 0
let base_size = 4 + 1 + payload_length; // without padding
let base_size = 4 + 1 + payload_length; // without padding
let padding_needed = (block_size - (base_size % block_size)) % block_size;
// Ensure padding >= min_padding (RFC 4253 requirement)
let padding_length: u8 = if padding_needed < min_padding {
(padding_needed + block_size) as u8 // Add one more block to meet minimum
(padding_needed + block_size) as u8 // Add one more block to meet minimum
} else {
padding_needed as u8
};
// packet_length = padding_length(1) + payload + padding
let packet_length = 1 + payload_length + padding_length as usize;
info!("Creating AES-CTR encrypted packet: payload_len={}, padding_len={}, packet_len={}",
payload_length, padding_length, packet_length);
info!(
"Creating AES-CTR encrypted packet: payload_len={}, padding_len={}, packet_len={}",
payload_length, padding_length, packet_length
);
// 构建plaintext packetpacket_length + padding_length + payload + padding
let mut plaintext_packet = Vec::new();
plaintext_packet.write_u32::<BigEndian>(packet_length as u32)?; // plaintext packet_length
plaintext_packet.write_u8(padding_length)?; // plaintext padding_length
plaintext_packet.write_all(plaintext_payload)?; // plaintext payload
plaintext_packet.write_u32::<BigEndian>(packet_length as u32)?; // plaintext packet_length
plaintext_packet.write_u8(padding_length)?; // plaintext padding_length
plaintext_packet.write_all(plaintext_payload)?; // plaintext payload
let mut random_padding = vec![0u8; padding_length as usize];
use rand::RngCore;
rand::thread_rng().fill_bytes(&mut random_padding);
plaintext_packet.write_all(&random_padding)?; // plaintext padding
plaintext_packet.write_all(&random_padding)?; // plaintext padding
info!("Plaintext packet size: {} bytes", plaintext_packet.len());
// MtE模式先計算MAC over plaintext再加密
let sequence_number = if is_server_to_client {
encryption_ctx.sequence_number_stoc
} else {
encryption_ctx.sequence_number_ctos
};
let mac_key = if is_server_to_client {
&encryption_ctx.mac_key_stoc
} else {
&encryption_ctx.mac_key_ctos
};
info!("MAC calculation (MtE mode) over plaintext packet:");
info!(" sequence_number: {}", sequence_number);
info!(" mac_key length: {}", mac_key.len());
info!(" plaintext_packet length: {}", plaintext_packet.len());
// MAC計算HMAC(sequence_number || plaintext_packet)
let mac = encryption_ctx.compute_mac(sequence_number, &plaintext_packet, mac_key)?;
// 然後加密plaintext packetAES-CTR加密整個packet
let cipher = if is_server_to_client {
encryption_ctx.cipher_stoc.as_mut()
encryption_ctx
.cipher_stoc
.as_mut()
.ok_or_else(|| anyhow!("cipher_stoc not initialized"))?
} else {
encryption_ctx.cipher_ctos.as_mut()
encryption_ctx
.cipher_ctos
.as_mut()
.ok_or_else(|| anyhow!("cipher_ctos not initialized"))?
};
let mut encrypted_packet = plaintext_packet;
cipher.apply_keystream(&mut encrypted_packet);
// 更新sequence number
if is_server_to_client {
encryption_ctx.sequence_number_stoc += 1;
} else {
encryption_ctx.sequence_number_ctos += 1;
}
Ok(Self {
packet_length: packet_length as u32,
padding_length,
@@ -288,24 +300,27 @@ impl EncryptedPacket {
mac,
})
}
/// 写入加密packet参考OpenSSH cipher.c
/// AES-CTR模式写入完整加密packet + MAC
pub fn write<W: std::io::Write>(&self, stream: &mut W) -> Result<()> {
info!("Writing AES-CTR encrypted packet: total_encrypted_len={}, mac_len={}",
self.payload.len(), self.mac.len());
info!(
"Writing AES-CTR encrypted packet: total_encrypted_len={}, mac_len={}",
self.payload.len(),
self.mac.len()
);
// AES-CTR: 整个packet已加密包括packet_length直接写入
stream.write_all(&self.payload)?;
info!("Wrote encrypted packet ({} bytes)", self.payload.len());
// 写入MAC
stream.write_all(&self.mac)?;
info!("Wrote MAC ({} bytes)", self.mac.len());
Ok(())
}
/// 读取加密packet参考OpenSSH packet.c ssh_packet_read_poll2
/// OpenSSH packet.c: AES-CTR先解密第一个块再提取packet_length
/// aadlen = 0 (没有EtM或authenticated encryption), packet_length被加密
@@ -315,32 +330,42 @@ impl EncryptedPacket {
is_client_to_server: bool,
) -> Result<Self> {
use std::io::Read;
info!("Reading AES-CTR encrypted packet (packet_length encrypted)");
// 1. 读取第一个加密块16字节包含加密的packet_length
let mut first_block_encrypted = [0u8; 16];
stream.read_exact(&mut first_block_encrypted)?;
info!("Read first encrypted block (16 bytes): {:?}", &first_block_encrypted);
info!(
"Read first encrypted block (16 bytes): {:?}",
&first_block_encrypted
);
// 2. 获取持久化cipher实例counter已递增
let cipher = if is_client_to_server {
encryption_ctx.cipher_ctos.as_mut()
encryption_ctx
.cipher_ctos
.as_mut()
.ok_or_else(|| anyhow!("cipher_ctos not initialized"))?
} else {
encryption_ctx.cipher_stoc.as_mut()
encryption_ctx
.cipher_stoc
.as_mut()
.ok_or_else(|| anyhow!("cipher_stoc not initialized"))?
};
info!("Using cipher for decryption (is_client_to_server={})", is_client_to_server);
info!(
"Using cipher for decryption (is_client_to_server={})",
is_client_to_server
);
// 3. 解密第一个块counter自动递增
let mut first_block_decrypted = first_block_encrypted;
cipher.apply_keystream(&mut first_block_decrypted);
info!("Decrypted first block: {:?}", &first_block_decrypted);
// 3. 从解密后的数据中提取packet_length前4字节和padding_length第5字节
let packet_length = u32::from_be_bytes([
first_block_decrypted[0],
@@ -349,67 +374,73 @@ impl EncryptedPacket {
first_block_decrypted[3],
]);
let padding_length = first_block_decrypted[4];
info!("Decrypted packet_length={}, padding_length={}", packet_length, padding_length);
info!(
"Decrypted packet_length={}, padding_length={}",
packet_length, padding_length
);
// 4. 合理性检查
if packet_length > 35000 {
info!("packet_length raw bytes: {:?}", &first_block_decrypted[..4]);
return Err(anyhow!("Invalid packet_length: {}", packet_length));
}
// 3. 计算剩余加密数据长度
// packet_length = padding_length(1) + payload + padding
// 总加密数据 = packet_length(4) + packet_length = packet_length + 4
// 已读取16字节剩余 = packet_length + 4 - 16
let total_encrypted_size = packet_length as usize + 4; // packet_length field + content
let total_encrypted_size = packet_length as usize + 4; // packet_length field + content
let remaining_encrypted_size = total_encrypted_size - 16;
info!("Total encrypted size: {}, remaining: {}", total_encrypted_size, remaining_encrypted_size);
info!(
"Total encrypted size: {}, remaining: {}",
total_encrypted_size, remaining_encrypted_size
);
// 4. 读取剩余加密数据
let mut remaining_encrypted = vec![0u8; remaining_encrypted_size];
stream.read_exact(&mut remaining_encrypted)?;
// 5. 继续解密使用同一个cipher
cipher.apply_keystream(&mut remaining_encrypted);
info!("Remaining decrypted data: {:?}", &remaining_encrypted);
// 6. 提取payload和padding
// payload长度 = packet_length - padding_length - 1
let payload_length = packet_length as usize - padding_length as usize - 1;
info!("Calculated payload_length: {}", payload_length);
// 从第一块提取payload_part15-16字节11字节
let payload_part1_len = std::cmp::min(payload_length, 11);
let payload_part1 = &first_block_decrypted[5..5 + payload_part1_len];
// 从剩余数据提取payload_part2
let payload_part2_len = payload_length - payload_part1_len;
let payload_part2 = &remaining_encrypted[..payload_part2_len];
// 合并payload
let mut payload = Vec::new();
payload.extend_from_slice(payload_part1);
payload.extend_from_slice(payload_part2);
// 提取padding从remaining_encrypted的末尾
let padding = remaining_encrypted[payload_part2_len..].to_vec();
// 9. 读取MAC
info!("Reading MAC (32 bytes)...");
let mut mac = vec![0u8; 32];
stream.read_exact(&mut mac)?;
info!("MAC read successfully");
// 10. 更新sequence number
if is_client_to_server {
encryption_ctx.sequence_number_ctos += 1;
} else {
encryption_ctx.sequence_number_stoc += 1;
}
Ok(Self {
packet_length,
padding_length,
@@ -418,7 +449,7 @@ impl EncryptedPacket {
mac,
})
}
/// 获取payload内容
pub fn payload(&self) -> &[u8] {
&self.payload
@@ -428,13 +459,13 @@ impl EncryptedPacket {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aes256_ctr_encryption() {
let key = vec![0u8; 16]; // AES-128 key (16 bytes)
let key = vec![0u8; 16]; // AES-128 key (16 bytes)
let iv = vec![0u8; 16];
let plaintext = b"Hello World";
let mut ctx = EncryptionContext::from_session_keys(&SessionKeys {
session_id: vec![0u8; 32],
encryption_key_ctos: key.clone(),
@@ -444,18 +475,18 @@ mod tests {
iv_ctos: iv.clone(),
iv_stoc: iv.clone(),
});
let ciphertext = ctx.encrypt_packet(plaintext, &key, &iv).unwrap();
let decrypted = ctx.decrypt_packet(&ciphertext, &key, &iv).unwrap();
assert_eq!(plaintext.to_vec(), decrypted);
}
#[test]
fn test_hmac_sha256() {
let key = vec![0u8; 32];
let data = b"test data";
let ctx = EncryptionContext::from_session_keys(&SessionKeys {
session_id: vec![0u8; 32],
encryption_key_ctos: vec![0u8; 32],
@@ -465,10 +496,10 @@ mod tests {
iv_ctos: vec![0u8; 16],
iv_stoc: vec![0u8; 16],
});
let mac = ctx.compute_mac(1, data, &key).unwrap();
assert_eq!(mac.len(), 32); // HMAC-SHA256 = 32字节
assert_eq!(mac.len(), 32); // HMAC-SHA256 = 32字节
// 验证MAC
assert!(ctx.verify_mac(1, data, &mac, &key).unwrap());
}

View File

@@ -1,16 +1,16 @@
// SSH加密模块Phase 3密钥交换
// 参考OpenSSH curve25519.c, kex.c
use anyhow::{Result, anyhow};
use x25519_dalek::{EphemeralSecret, PublicKey, SharedSecret};
use ed25519_dalek::{SigningKey, VerifyingKey, Signature, Signer};
use sha2::{Sha256, Digest};
use log::{info, debug};
use anyhow::{anyhow, Result};
use ed25519_dalek::{Signer, SigningKey};
use log::info;
use rand::rngs::OsRng;
use sha2::{Digest, Sha256};
use x25519_dalek::{EphemeralSecret, PublicKey};
/// Curve25519密钥交换处理器参考OpenSSH curve25519.c
pub struct Curve25519Kex {
secret: Option<EphemeralSecret>, // 使用Option包装一次性使用类型
secret: Option<EphemeralSecret>, // 使用Option包装一次性使用类型
public: PublicKey,
}
@@ -21,34 +21,37 @@ impl Curve25519Kex {
// x25519-dalek 2.0标准API使用random_from_rng
let secret = EphemeralSecret::random_from_rng(OsRng);
let public = PublicKey::from(&secret);
Self { secret: Some(secret), public } // Some包装
Self {
secret: Some(secret),
public,
} // Some包装
}
/// 获取公钥用于SSH_MSG_KEX_ECDH_INIT
pub fn public_key(&self) -> &[u8] {
self.public.as_bytes()
}
/// 计算共享密钥参考OpenSSH curve25519_shared_secret()
/// 使用&mut self消耗模式符合OpenSSH设计
pub fn compute_shared_secret(&mut self, client_public: &[u8]) -> Result<[u8; 32]> {
if client_public.len() != 32 {
return Err(anyhow!("Invalid client public key length"));
}
info!("=== X25519 Shared Secret Calculation ===");
info!("Client public key input: {:?}", client_public);
info!("Server public key: {:?}", self.public.as_bytes());
// 参考OpenSSHcurve25519共享密钥计算
let client_public_key = PublicKey::from(<[u8; 32]>::try_from(client_public)?);
// 使用take()取出secretRust标准模式
if let Some(secret) = self.secret.take() {
let shared_secret = secret.diffie_hellman(&client_public_key);
info!("Computed shared secret: {:?}", shared_secret.as_bytes());
Ok(shared_secret.as_bytes().clone())
Ok(*shared_secret.as_bytes())
} else {
Err(anyhow!("Secret already used"))
}
@@ -71,47 +74,85 @@ impl SessionKeys {
/// RFC 4253 Section 7.2: Key = HASH(K || H || X || session_id)
pub fn derive(
shared_secret: &[u8],
exchange_hash: &[u8], // H参数exchange hash
server_public_key: &[u8],
client_public_key: &[u8],
server_host_key: &[u8],
exchange_hash: &[u8], // H参数exchange hash
_server_public_key: &[u8],
_client_public_key: &[u8],
_server_host_key: &[u8],
) -> Result<Self> {
// RFC 4253: session_id = H (第一次exchange hash)
let session_id = exchange_hash.to_vec();
info!("SessionKeys::derive() starting");
info!(" shared_secret full (32 bytes): {:?}", shared_secret);
// RFC 8731 Section 3.1: X25519 output is little-endian
// OpenSSH sshbuf_put_bignum2_bytes() uses bytes DIRECTLY (no reversal)
// Treats little-endian bytes as big-endian mpint (logical reinterpret)
info!(" Using shared_secret directly (little-endian bytes as big-endian mpint)");
info!(" shared_secret[0] = {} (>=0x80? {})", shared_secret[0], shared_secret[0] >= 0x80);
info!(
" shared_secret[0] = {} (>=0x80? {})",
shared_secret[0],
shared_secret[0] >= 0x80
);
info!(" exchange_hash full (32 bytes): {:?}", exchange_hash);
info!(" session_id full (32 bytes): {:?}", session_id);
// RFC 4253密钥派生公式HASH(K || H || X || session_id)
// K is shared_secret encoded as mpint (using little-endian bytes directly)
let shared_secret_mpint = Self::encode_mpint(shared_secret);
info!(" shared_secret_mpint ({} bytes): {:?}", shared_secret_mpint.len(), &shared_secret_mpint[..std::cmp::min(12, shared_secret_mpint.len())]);
let encryption_key_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'C', &session_id)?;
let encryption_key_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'D', &session_id)?;
let mac_key_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'E', &session_id)?;
let mac_key_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'F', &session_id)?;
let iv_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'A', &session_id)?;
let iv_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'B', &session_id)?;
info!(
" shared_secret_mpint ({} bytes): {:?}",
shared_secret_mpint.len(),
&shared_secret_mpint[..std::cmp::min(12, shared_secret_mpint.len())]
);
let encryption_key_ctos =
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'C', &session_id)?;
let encryption_key_stoc =
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'D', &session_id)?;
let mac_key_ctos =
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'E', &session_id)?;
let mac_key_stoc =
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'F', &session_id)?;
let iv_ctos =
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'A', &session_id)?;
let iv_stoc =
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'B', &session_id)?;
info!("Derived keys summary:");
info!(" encryption_key_ctos ({} bytes): {:?}", encryption_key_ctos.len(), &encryption_key_ctos[..std::cmp::min(16, encryption_key_ctos.len())]);
info!(" encryption_key_stoc ({} bytes): {:?}", encryption_key_stoc.len(), &encryption_key_stoc[..std::cmp::min(16, encryption_key_stoc.len())]);
info!(" iv_ctos ({} bytes): {:?}", iv_ctos.len(), &iv_ctos[..std::cmp::min(16, iv_ctos.len())]);
info!(" iv_stoc ({} bytes): {:?}", iv_stoc.len(), &iv_stoc[..std::cmp::min(16, iv_stoc.len())]);
info!(" mac_key_ctos ({} bytes): {:?}", mac_key_ctos.len(), &mac_key_ctos[..std::cmp::min(16, mac_key_ctos.len())]);
info!(" mac_key_stoc ({} bytes): {:?}", mac_key_stoc.len(), &mac_key_stoc[..std::cmp::min(16, mac_key_stoc.len())]);
info!(
" encryption_key_ctos ({} bytes): {:?}",
encryption_key_ctos.len(),
&encryption_key_ctos[..std::cmp::min(16, encryption_key_ctos.len())]
);
info!(
" encryption_key_stoc ({} bytes): {:?}",
encryption_key_stoc.len(),
&encryption_key_stoc[..std::cmp::min(16, encryption_key_stoc.len())]
);
info!(
" iv_ctos ({} bytes): {:?}",
iv_ctos.len(),
&iv_ctos[..std::cmp::min(16, iv_ctos.len())]
);
info!(
" iv_stoc ({} bytes): {:?}",
iv_stoc.len(),
&iv_stoc[..std::cmp::min(16, iv_stoc.len())]
);
info!(
" mac_key_ctos ({} bytes): {:?}",
mac_key_ctos.len(),
&mac_key_ctos[..std::cmp::min(16, mac_key_ctos.len())]
);
info!(
" mac_key_stoc ({} bytes): {:?}",
mac_key_stoc.len(),
&mac_key_stoc[..std::cmp::min(16, mac_key_stoc.len())]
);
Ok(Self {
session_id,
encryption_key_ctos,
@@ -122,65 +163,73 @@ impl SessionKeys {
iv_stoc,
})
}
/// RFC 4253密钥派生函数
/// 公式Key = HASH(K || H || X || session_id)
fn derive_key_rfc4253(K_mpint: &[u8], H: &[u8], X: char, session_id: &[u8]) -> Result<Vec<u8>> {
let mut hasher = Sha256::new();
info!("Deriving key for X='{}'", X);
info!(" K_mpint ({} bytes): {:?}", K_mpint.len(), &K_mpint[..std::cmp::min(8, K_mpint.len())]);
info!(
" K_mpint ({} bytes): {:?}",
K_mpint.len(),
&K_mpint[..std::cmp::min(8, K_mpint.len())]
);
info!(" H ({} bytes): {:?}", H.len(), &H[..8]);
info!(" session_id ({} bytes): {:?}", session_id.len(), &session_id[..8]);
info!(
" session_id ({} bytes): {:?}",
session_id.len(),
&session_id[..8]
);
// RFC 4253: HASH(K || H || X || session_id)
hasher.update(K_mpint); // K (shared secret in mpint format)
hasher.update(H); // H (exchange hash)
hasher.update(&[X as u8]); // X (single character)
hasher.update(K_mpint); // K (shared secret in mpint format)
hasher.update(H); // H (exchange hash)
hasher.update([X as u8]); // X (single character)
hasher.update(session_id); // session_id
let full_hash = hasher.finalize();
info!(" Derived key (first 8 bytes): {:?}", &full_hash[..8]);
// 根據key類型返回不同長度
// AES-128-CTR key/IV: 16 bytes
// HMAC-SHA256 key: 32 bytes
match X {
'A' | 'B' | 'C' | 'D' => Ok(full_hash[..16].to_vec()), // IV or encryption key
'A' | 'B' | 'C' | 'D' => Ok(full_hash[..16].to_vec()), // IV or encryption key
'E' | 'F' => Ok(full_hash.to_vec()), // MAC key (full 32 bytes)
_ => Ok(full_hash[..16].to_vec()), // default
}
}
/// SSH mpint编码参考RFC 4253 Section 5
/// Curve25519 shared secret特殊处理
fn encode_mpint(bytes: &[u8]) -> Vec<u8> {
// RFC 4253: mpint = uint32(length) + data
// 去掉前导零,如果最高位>=0x80前面加0
// 去掉前导零字节但不去掉最后一个字节即使它是0
let mut start = 0;
while start < bytes.len() - 1 && bytes[start] == 0 {
start += 1;
}
let data_without_leading_zeros = &bytes[start..];
// 构建mpint数据
let mut mpint_data = Vec::new();
// 如果最高位>=0x80前面加0字节避免负数
if data_without_leading_zeros[0] >= 0x80 {
mpint_data.push(0);
}
mpint_data.extend_from_slice(data_without_leading_zeros);
// 最终格式uint32长度 + mpint数据
let mut result = Vec::new();
result.extend_from_slice(&(mpint_data.len() as u32).to_be_bytes());
result.extend_from_slice(&mpint_data);
result
}
}
@@ -192,45 +241,45 @@ pub struct Ed25519HostKey {
impl Ed25519HostKey {
/// 加载或生成主机密钥参考OpenSSH hostfile.c
pub fn load_or_generate(key_path: &str) -> Result<Self> {
pub fn load_or_generate(_key_path: &str) -> Result<Self> {
// 简化实现:生成临时密钥(实际应从文件加载)
// 参考OpenSSH ssh-keygen
let signing_key = SigningKey::generate(&mut OsRng);
Ok(Self { signing_key })
}
/// 获取公钥用于SSH_MSG_KEX_ECDH_REPLY
pub fn public_key_bytes(&self) -> Vec<u8> {
// SSH Ed25519公钥格式参考OpenSSH sshkey.c
let verifying_key = self.signing_key.verifying_key();
// SSH格式ssh-ed25519 + 公钥bytes
// 简化仅返回公钥bytes32字节
verifying_key.as_bytes().to_vec()
}
/// 签名参考OpenSSH sshkey.c: sshkey_sign()
pub fn sign(&self, data: &[u8]) -> Result<Vec<u8>> {
// OpenSSH Ed25519签名
let signature = self.signing_key.sign(data);
// SSH签名格式参考OpenSSH ssh-sign.c
// 简化仅返回签名bytes64字节
Ok(signature.to_bytes().to_vec())
}
/// 获取完整SSH公钥格式参考OpenSSH sshkey.c
pub fn ssh_public_key(&self) -> String {
let public_bytes = self.public_key_bytes();
// SSH公钥格式ssh-ed25519 <base64-encoded-public-key>
// 参考OpenSSH ssh-keygen -y
use base64::{Engine as _, engine::general_purpose};
use base64::{engine::general_purpose, Engine as _};
let encoded = general_purpose::STANDARD.encode(&public_bytes);
format!("ssh-ed25519 {}", encoded)
}
}
@@ -238,40 +287,44 @@ impl Ed25519HostKey {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_curve25519_key_generation() {
let kex = Curve25519Kex::new();
assert_eq!(kex.public_key().len(), 32);
}
#[test]
fn test_curve25519_shared_secret() {
let mut client_kex = Curve25519Kex::new();
let mut server_kex = Curve25519Kex::new();
// 客户端计算共享密钥
let client_secret = client_kex.compute_shared_secret(server_kex.public_key()).unwrap();
let client_secret = client_kex
.compute_shared_secret(server_kex.public_key())
.unwrap();
// 服务器计算共享密钥
let server_secret = server_kex.compute_shared_secret(client_kex.public_key()).unwrap();
let server_secret = server_kex
.compute_shared_secret(client_kex.public_key())
.unwrap();
// 应该相同Curve25519特性
assert_eq!(client_secret, server_secret);
}
#[test]
fn test_ed25519_host_key() {
let host_key = Ed25519HostKey::load_or_generate("test_key").unwrap();
assert_eq!(host_key.public_key_bytes().len(), 32);
}
#[test]
fn test_ed25519_signature() {
let host_key = Ed25519HostKey::load_or_generate("test_key").unwrap();
let data = b"test data";
let signature = host_key.sign(data).unwrap();
assert_eq!(signature.len(), 64); // Ed25519签名64字节
assert_eq!(signature.len(), 64); // Ed25519签名64字节
}
}

View File

@@ -1,13 +1,13 @@
// SSH端口转发数据传输Phase 13.5
// 参考OpenSSH channels.c: channel_handle_data()
use anyhow::{Result, anyhow};
use log::{info, warn, debug};
use std::net::{TcpStream};
use std::io::{Read, Write};
use std::thread;
use std::sync::{Arc, Mutex};
use anyhow::{anyhow, Result};
use byteorder::{BigEndian, WriteBytesExt};
use log::{debug, info, warn};
use std::io::{Read, Write};
use std::net::TcpStream;
use std::sync::{Arc, Mutex};
use std::thread;
/// 数据转发器Phase 13.5:双向数据传输)
pub struct DataForwarder {
@@ -25,29 +25,40 @@ impl DataForwarder {
max_packet_size,
}
}
/// 启动双向数据转发Phase 13.5SSH channel ↔ TCP socket
pub fn start_bidirectional_forwarding(
&mut self,
ssh_stream: TcpStream, // SSH client连接加密通道
target_stream: TcpStream, // 目标服务连接TCP socket
ssh_stream: TcpStream, // SSH client连接加密通道
target_stream: TcpStream, // 目标服务连接TCP socket
) -> Result<()> {
info!("Starting bidirectional data forwarding for channel {}", self.channel_id);
info!(
"Starting bidirectional data forwarding for channel {}",
self.channel_id
);
// Phase 13.5: SSH channel → Target socketSSH client数据 → 本地服务)
let ssh_to_target = self.start_ssh_to_target_forwarding(ssh_stream.try_clone()?, target_stream.try_clone()?);
let ssh_to_target = self
.start_ssh_to_target_forwarding(ssh_stream.try_clone()?, target_stream.try_clone()?);
// Phase 13.5: Target socket → SSH channel本地服务数据 → SSH client
let target_to_ssh = self.start_target_to_ssh_forwarding(target_stream, ssh_stream);
// Phase 13.5: 等待两个转发线程完成
ssh_to_target.join().map_err(|e| anyhow!("SSH to target thread error: {:?}", e))?;
target_to_ssh.join().map_err(|e| anyhow!("Target to SSH thread error: {:?}", e))?;
info!("Bidirectional data forwarding completed for channel {}", self.channel_id);
ssh_to_target
.join()
.map_err(|e| anyhow!("SSH to target thread error: {:?}", e))?;
target_to_ssh
.join()
.map_err(|e| anyhow!("Target to SSH thread error: {:?}", e))?;
info!(
"Bidirectional data forwarding completed for channel {}",
self.channel_id
);
Ok(())
}
/// SSH channel → Target socket转发Phase 13.5
fn start_ssh_to_target_forwarding(
&self,
@@ -57,18 +68,21 @@ impl DataForwarder {
let channel_id = self.channel_id;
let window_size = self.window_size.clone();
let max_packet_size = self.max_packet_size;
thread::spawn(move || {
info!("SSH to target forwarding thread started for channel {}", channel_id);
info!(
"SSH to target forwarding thread started for channel {}",
channel_id
);
let mut buffer = vec![0u8; max_packet_size as usize];
loop {
// Phase 13.5: 从SSH channel读取数据
let n = match ssh_stream.read(&mut buffer) {
Ok(0) => {
info!("SSH channel EOF for channel {}", channel_id);
break; // EOF
break; // EOF
}
Ok(n) => n,
Err(e) => {
@@ -76,45 +90,61 @@ impl DataForwarder {
break;
}
};
// Phase 13.5: 检查window size
{
let window = window_size.lock().unwrap();
if *window < n as u32 {
warn!("Window size insufficient for channel {}: need {}, have {}",
channel_id, n, *window);
warn!(
"Window size insufficient for channel {}: need {}, have {}",
channel_id, n, *window
);
// Phase 13.5: 理论上应该等待SSH_MSG_CHANNEL_WINDOW_ADJUST
// 简化实现继续发送可能会违反RFC 4254
}
}
// Phase 13.5: 写入目标socket
if let Err(e) = target_stream.write_all(&buffer[..n]) {
warn!("Target socket write error for channel {}: {}", channel_id, e);
warn!(
"Target socket write error for channel {}: {}",
channel_id, e
);
break;
}
// Phase 13.5: Flush确保数据发送
if let Err(e) = target_stream.flush() {
warn!("Target socket flush error for channel {}: {}", channel_id, e);
warn!(
"Target socket flush error for channel {}: {}",
channel_id, e
);
break;
}
// Phase 13.5: 消耗window size
{
let mut window = window_size.lock().unwrap();
*window -= n as u32;
debug!("Window size consumed for channel {}: {} bytes, remaining {}",
channel_id, n, *window);
debug!(
"Window size consumed for channel {}: {} bytes, remaining {}",
channel_id, n, *window
);
}
info!("Forwarded {} bytes from SSH to target for channel {}", n, channel_id);
info!(
"Forwarded {} bytes from SSH to target for channel {}",
n, channel_id
);
}
info!("SSH to target forwarding thread stopped for channel {}", channel_id);
info!(
"SSH to target forwarding thread stopped for channel {}",
channel_id
);
})
}
/// Target socket → SSH channel转发Phase 13.5
fn start_target_to_ssh_forwarding(
&self,
@@ -122,18 +152,21 @@ impl DataForwarder {
mut ssh_stream: TcpStream,
) -> thread::JoinHandle<()> {
let channel_id = self.channel_id;
thread::spawn(move || {
info!("Target to SSH forwarding thread started for channel {}", channel_id);
let mut buffer = vec![0u8; 8192]; // 8KB buffer
info!(
"Target to SSH forwarding thread started for channel {}",
channel_id
);
let mut buffer = vec![0u8; 8192]; // 8KB buffer
loop {
// Phase 13.5: 从目标socket读取数据
let n = match target_stream.read(&mut buffer) {
Ok(0) => {
info!("Target socket EOF for channel {}", channel_id);
break; // EOF
break; // EOF
}
Ok(n) => n,
Err(e) => {
@@ -141,43 +174,51 @@ impl DataForwarder {
break;
}
};
// Phase 13.5: 构建SSH_MSG_CHANNEL_DATA packet
// 注意实际实现需要通过EncryptedPacket加密
// 这里简化实现直接写入SSH stream测试用
// Phase 13.5: 写入SSH channel
if let Err(e) = ssh_stream.write_all(&buffer[..n]) {
warn!("SSH channel write error for channel {}: {}", channel_id, e);
break;
}
// Phase 13.5: Flush确保数据发送
if let Err(e) = ssh_stream.flush() {
warn!("SSH channel flush error for channel {}: {}", channel_id, e);
break;
}
info!("Forwarded {} bytes from target to SSH for channel {}", n, channel_id);
info!(
"Forwarded {} bytes from target to SSH for channel {}",
n, channel_id
);
}
info!("Target to SSH forwarding thread stopped for channel {}", channel_id);
info!(
"Target to SSH forwarding thread stopped for channel {}",
channel_id
);
})
}
/// 获取当前window sizePhase 13.5
pub fn get_window_size(&self) -> u32 {
*self.window_size.lock().unwrap()
}
/// 增加window sizePhase 13.5SSH_MSG_CHANNEL_WINDOW_ADJUST
pub fn adjust_window_size(&self, bytes_to_add: u32) {
let mut window = self.window_size.lock().unwrap();
*window += bytes_to_add;
info!("Window size adjusted for channel {}: added {} bytes, total {}",
self.channel_id, bytes_to_add, *window);
info!(
"Window size adjusted for channel {}: added {} bytes, total {}",
self.channel_id, bytes_to_add, *window
);
}
/// 检查window size是否足够Phase 13.5
pub fn check_window_available(&self, data_size: u32) -> bool {
let window = self.window_size.lock().unwrap();
@@ -188,64 +229,64 @@ impl DataForwarder {
/// SSH_MSG_CHANNEL_DATA构建Phase 13.5
pub fn build_channel_data_packet(channel_id: u32, data: &[u8]) -> Result<Vec<u8>> {
let mut packet = Vec::new();
// Packet type: SSH_MSG_CHANNEL_DATA (type 94)
packet.write_u8(94)?;
// Recipient channel ID
packet.write_u32::<BigEndian>(channel_id)?;
// Data length (SSH string)
packet.write_u32::<BigEndian>(data.len() as u32)?;
// Data content
packet.write_all(data)?;
Ok(packet)
}
/// SSH_MSG_CHANNEL_WINDOW_ADJUST构建Phase 13.5
pub fn build_window_adjust_packet(channel_id: u32, bytes_to_add: u32) -> Result<Vec<u8>> {
let mut packet = Vec::new();
// Packet type: SSH_MSG_CHANNEL_WINDOW_ADJUST (type 93)
packet.write_u8(93)?;
// Recipient channel ID
packet.write_u32::<BigEndian>(channel_id)?;
// Bytes to add
packet.write_u32::<BigEndian>(bytes_to_add)?;
Ok(packet)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_data_forwarder_creation() {
let forwarder = DataForwarder::new(1, 2097152, 32768);
assert_eq!(forwarder.channel_id, 1);
assert_eq!(forwarder.get_window_size(), 2097152);
}
#[test]
fn test_window_size_adjustment() {
let forwarder = DataForwarder::new(1, 2097152, 32768);
// 消耗window size
forwarder.adjust_window_size(1000);
assert_eq!(forwarder.get_window_size(), 2097152 + 1000);
}
#[test]
fn test_build_channel_data_packet() {
let data = b"Hello, SSH!";
let packet = build_channel_data_packet(1, data).unwrap();
assert_eq!(packet[0], 94); // SSH_MSG_CHANNEL_DATA
// 验证packet结构
assert_eq!(packet[0], 94); // SSH_MSG_CHANNEL_DATA
// 验证packet结构
}
}

View File

@@ -1,42 +1,42 @@
// SSH密钥交换算法协商实现Phase 2
// 参考OpenSSH kex.c: kex_send_kexinit(), kex_choose_conf()
use crate::ssh_server::packet::{SshPacket, PacketType};
use anyhow::{Result, anyhow};
use crate::ssh_server::packet::{PacketType, SshPacket};
use anyhow::{anyhow, Result};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use log::{info, debug};
use log::{debug, info};
use std::io::{Read, Write};
/// SSH算法类型参考OpenSSH PROTOCOL定义
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AlgorithmType {
KEX_ALGS = 0, // 密钥交换算法
KEX_ALGS = 0, // 密钥交换算法
SERVER_HOST_KEY_ALGS = 1, // 服务器主机密钥算法
ENC_ALGS_CTOS = 2, // 客户端到服务器加密算法
ENC_ALGS_STOC = 3, // 服务器到客户端加密算法
MAC_ALGS_CTOS = 4, // 客户端到服务器MAC算法
MAC_ALGS_STOC = 5, // 服务器到客户端MAC算法
COMP_ALGS_CTOS = 6, // 客户端到服务器压缩算法
COMP_ALGS_STOC = 7, // 服务器到客户端压缩算法
LANGS_CTOS = 8, // 客户端到服务器语言
LANGS_STOC = 9, // 服务器到客户端语言
ENC_ALGS_CTOS = 2, // 客户端到服务器加密算法
ENC_ALGS_STOC = 3, // 服务器到客户端加密算法
MAC_ALGS_CTOS = 4, // 客户端到服务器MAC算法
MAC_ALGS_STOC = 5, // 服务器到客户端MAC算法
COMP_ALGS_CTOS = 6, // 客户端到服务器压缩算法
COMP_ALGS_STOC = 7, // 服务器到客户端压缩算法
LANGS_CTOS = 8, // 客户端到服务器语言
LANGS_STOC = 9, // 服务器到客户端语言
}
/// SSH算法提议参考OpenSSH kex.h: struct kex
#[derive(Debug, Clone)]
pub struct KexProposal {
pub kex_algorithms: String, // 密钥交换算法列表
pub server_host_key_algorithms: String, // 主机密钥算法列表
pub encryption_algorithms_ctos: String, // 加密算法(客户端→服务器)
pub encryption_algorithms_stoc: String, // 加密算法(服务器→客户端)
pub mac_algorithms_ctos: String, // MAC算法客户端→服务器
pub mac_algorithms_stoc: String, // MAC算法服务器→客户端
pub kex_algorithms: String, // 密钥交换算法列表
pub server_host_key_algorithms: String, // 主机密钥算法列表
pub encryption_algorithms_ctos: String, // 加密算法(客户端→服务器)
pub encryption_algorithms_stoc: String, // 加密算法(服务器→客户端)
pub mac_algorithms_ctos: String, // MAC算法客户端→服务器
pub mac_algorithms_stoc: String, // MAC算法服务器→客户端
pub compression_algorithms_ctos: String, // 压缩算法(客户端→服务器)
pub compression_algorithms_stoc: String, // 压缩算法(服务器→客户端)
pub languages_ctos: String, // 语言(客户端→服务器)
pub languages_stoc: String, // 语言(服务器→客户端)
pub first_kex_packet_follows: bool, // 是否立即发送第一个KEX packet
pub reserved: u32, // 保留字段0
pub languages_ctos: String, // 语言(客户端→服务器)
pub languages_stoc: String, // 语言(服务器→客户端)
pub first_kex_packet_follows: bool, // 是否立即发送第一个KEX packet
pub reserved: u32, // 保留字段0
}
impl KexProposal {
@@ -46,31 +46,31 @@ impl KexProposal {
Self {
// 密钥交换算法优先Curve25519推荐 + strict KEX extension
kex_algorithms: "curve25519-sha256,curve25519-sha256@libssh.org,diffie-hellman-group14-sha256,ext-info-s,kex-strict-s-v00@openssh.com".to_string(),
// 主机密钥算法优先Ed25519
server_host_key_algorithms: "ssh-ed25519,rsa-sha2-256,rsa-sha2-512".to_string(),
// 加密算法AES-256-CTR推荐
encryption_algorithms_ctos: "aes256-ctr,aes128-ctr".to_string(),
encryption_algorithms_stoc: "aes256-ctr,aes128-ctr".to_string(),
// MAC算法HMAC-SHA256
mac_algorithms_ctos: "hmac-sha2-256,hmac-sha2-512".to_string(),
mac_algorithms_stoc: "hmac-sha2-256,hmac-sha2-512".to_string(),
// 压缩算法none优先
compression_algorithms_ctos: "none,zlib".to_string(),
compression_algorithms_stoc: "none,zlib".to_string(),
// 语言:空
languages_ctos: "".to_string(),
languages_stoc: "".to_string(),
first_kex_packet_follows: false,
reserved: 0,
}
}
/// 创建客户端默认提议(用于测试)
pub fn client_default() -> Self {
Self {
@@ -88,20 +88,20 @@ impl KexProposal {
reserved: 0,
}
}
/// 序列化到SSH_MSG_KEXINIT packet参考OpenSSH kex_send_kexinit()
pub fn to_kexinit_packet(&self) -> Result<SshPacket> {
let mut payload = Vec::new();
// Packet type
payload.write_u8(PacketType::SSH_MSG_KEXINIT as u8)?;
// Cookie16字节随机数OpenSSH要求
let mut cookie = [0u8; 16];
use rand::Rng;
rand::thread_rng().fill(&mut cookie);
payload.write_all(&cookie)?;
// 10个算法列表SSH string格式length + data
write_ssh_string(&mut payload, &self.kex_algorithms)?;
write_ssh_string(&mut payload, &self.server_host_key_algorithms)?;
@@ -113,29 +113,29 @@ impl KexProposal {
write_ssh_string(&mut payload, &self.compression_algorithms_stoc)?;
write_ssh_string(&mut payload, &self.languages_ctos)?;
write_ssh_string(&mut payload, &self.languages_stoc)?;
// first_kex_packet_followsboolean
payload.write_u8(if self.first_kex_packet_follows { 1 } else { 0 })?;
// reservedu32
payload.write_u32::<BigEndian>(self.reserved)?;
Ok(SshPacket::new(payload))
}
/// 从SSH_MSG_KEXINIT packet解析参考OpenSSH kex_input_kexinit()
pub fn from_kexinit_packet(packet: &SshPacket) -> Result<Self> {
let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()Rust标准
let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()Rust标准
// Packet type
let packet_type = cursor.read_u8()?;
if packet_type != PacketType::SSH_MSG_KEXINIT as u8 {
return Err(anyhow!("Invalid packet type for KEXINIT"));
}
// Cookie16字节忽略
cursor.read_exact(&mut [0u8; 16])?;
// 10个算法列表
let kex_algorithms = read_ssh_string(&mut cursor)?;
let server_host_key_algorithms = read_ssh_string(&mut cursor)?;
@@ -147,13 +147,13 @@ impl KexProposal {
let compression_algorithms_stoc = read_ssh_string(&mut cursor)?;
let languages_ctos = read_ssh_string(&mut cursor)?;
let languages_stoc = read_ssh_string(&mut cursor)?;
// first_kex_packet_follows
let first_kex_packet_follows = cursor.read_u8()? != 0;
// reserved
let reserved = cursor.read_u32::<BigEndian>()?;
Ok(Self {
kex_algorithms,
server_host_key_algorithms,
@@ -174,14 +174,14 @@ impl KexProposal {
/// SSH算法协商结果参考OpenSSH struct kex
#[derive(Debug, Clone)]
pub struct KexResult {
pub kex_algorithm: String, // 选定的密钥交换算法
pub host_key_algorithm: String, // 选定的主机密钥算法
pub encryption_ctos: String, // 选定的加密算法(客户端→服务器)
pub encryption_stoc: String, // 选定的加密算法(服务器→客户端)
pub mac_ctos: String, // 选定的MAC算法客户端→服务器
pub mac_stoc: String, // 选定的MAC算法服务器→客户端
pub compression_ctos: String, // 选定的压缩算法(客户端→服务器)
pub compression_stoc: String, // 选定的压缩算法(服务器→客户端)
pub kex_algorithm: String, // 选定的密钥交换算法
pub host_key_algorithm: String, // 选定的主机密钥算法
pub encryption_ctos: String, // 选定的加密算法(客户端→服务器)
pub encryption_stoc: String, // 选定的加密算法(服务器→客户端)
pub mac_ctos: String, // 选定的MAC算法客户端→服务器
pub mac_stoc: String, // 选定的MAC算法服务器→客户端
pub compression_ctos: String, // 选定的压缩算法(客户端→服务器)
pub compression_stoc: String, // 选定的压缩算法(服务器→客户端)
}
/// 算法匹配逻辑参考OpenSSH kex_choose_conf()
@@ -189,28 +189,43 @@ impl KexResult {
/// 从服务器和客户端提议中选择算法参考OpenSSH kex_choose_conf()
pub fn choose_algorithms(server: &KexProposal, client: &KexProposal) -> Result<Self> {
info!("Starting algorithm negotiation");
// 算法匹配优先客户端偏好OpenSSH逻辑
// 参考OpenSSH客户端列出的算法顺序为偏好顺序
// 密钥交换算法匹配
let kex_algorithm = match_algorithm(&client.kex_algorithms, &server.kex_algorithms)?;
// 主机密钥算法匹配
let host_key_algorithm = match_algorithm(&client.server_host_key_algorithms, &server.server_host_key_algorithms)?;
let host_key_algorithm = match_algorithm(
&client.server_host_key_algorithms,
&server.server_host_key_algorithms,
)?;
// 加密算法匹配
let encryption_ctos = match_algorithm(&client.encryption_algorithms_ctos, &server.encryption_algorithms_ctos)?;
let encryption_stoc = match_algorithm(&client.encryption_algorithms_stoc, &server.encryption_algorithms_stoc)?;
let encryption_ctos = match_algorithm(
&client.encryption_algorithms_ctos,
&server.encryption_algorithms_ctos,
)?;
let encryption_stoc = match_algorithm(
&client.encryption_algorithms_stoc,
&server.encryption_algorithms_stoc,
)?;
// MAC算法匹配
let mac_ctos = match_algorithm(&client.mac_algorithms_ctos, &server.mac_algorithms_ctos)?;
let mac_stoc = match_algorithm(&client.mac_algorithms_stoc, &server.mac_algorithms_stoc)?;
// 压缩算法匹配
let compression_ctos = match_algorithm(&client.compression_algorithms_ctos, &server.compression_algorithms_ctos)?;
let compression_stoc = match_algorithm(&client.compression_algorithms_stoc, &server.compression_algorithms_stoc)?;
let compression_ctos = match_algorithm(
&client.compression_algorithms_ctos,
&server.compression_algorithms_ctos,
)?;
let compression_stoc = match_algorithm(
&client.compression_algorithms_stoc,
&server.compression_algorithms_stoc,
)?;
info!("Algorithm negotiation completed:");
debug!(" KEX: {}", kex_algorithm);
debug!(" Host key: {}", host_key_algorithm);
@@ -218,7 +233,7 @@ impl KexResult {
debug!(" Encryption (S->C): {}", encryption_stoc);
debug!(" MAC (C->S): {}", mac_ctos);
debug!(" MAC (S->C): {}", mac_stoc);
Ok(Self {
kex_algorithm,
host_key_algorithm,
@@ -237,15 +252,19 @@ fn match_algorithm(client_algs: &str, server_algs: &str) -> Result<String> {
// 算法列表格式name1,name2,name3,...
let client_list: Vec<&str> = client_algs.split(',').collect();
let server_list: Vec<&str> = server_algs.split(',').collect();
// OpenSSH逻辑按客户端偏好顺序匹配
for client_alg in &client_list {
if server_list.contains(client_alg) {
return Ok(client_alg.to_string());
}
}
Err(anyhow!("No matching algorithm found: client={}, server={}", client_algs, server_algs))
Err(anyhow!(
"No matching algorithm found: client={}, server={}",
client_algs,
server_algs
))
}
/// SSH string写入辅助函数length + data
@@ -266,36 +285,36 @@ fn read_ssh_string<R: Read>(reader: &mut R) -> Result<String> {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kex_proposal_creation() {
let proposal = KexProposal::server_default();
assert!(proposal.kex_algorithms.contains("curve25519-sha256"));
}
#[test]
fn test_kex_proposal_serialization() {
let proposal = KexProposal::server_default();
let packet = proposal.to_kexinit_packet().unwrap();
assert!(packet.payload.len() > 0);
}
#[test]
fn test_algorithm_matching() {
let client = "curve25519-sha256,aes256-ctr";
let server = "aes256-ctr,diffie-hellman-group14-sha256";
let matched = match_algorithm(client, server).unwrap();
assert_eq!(matched, "aes256-ctr"); // 按客户端顺序匹配
assert_eq!(matched, "aes256-ctr"); // 按客户端顺序匹配
}
#[test]
fn test_kex_negotiation() {
let server = KexProposal::server_default();
let client = KexProposal::client_default();
let result = KexResult::choose_algorithms(&server, &client).unwrap();
assert_eq!(result.kex_algorithm, "curve25519-sha256"); // 优先Curve25519
assert_eq!(result.encryption_ctos, "aes256-ctr"); // AES-256-CTR
assert_eq!(result.kex_algorithm, "curve25519-sha256"); // 优先Curve25519
assert_eq!(result.encryption_ctos, "aes256-ctr"); // AES-256-CTR
}
}

View File

@@ -1,14 +1,13 @@
// SSH密钥交换完整流程Phase 3剩余
// 参考OpenSSH kex.c: complete implementation
use crate::ssh_server::packet::{SshPacket, PacketType};
use crate::ssh_server::crypto::SessionKeys;
use crate::ssh_server::kex::{KexProposal, KexResult};
use crate::ssh_server::crypto::{SessionKeys};
use crate::ssh_server::kex_exchange::KexExchangeHandler;
use anyhow::{Result, anyhow};
use sha2::{Sha256, Digest};
use byteorder::{BigEndian, WriteBytesExt};
use log::{info, debug};
use crate::ssh_server::packet::{PacketType, SshPacket};
use anyhow::{anyhow, Result};
use log::info;
use sha2::{Digest, Sha256};
/// SSH密钥交换完整状态管理参考OpenSSH struct kex
pub struct KexState {
@@ -30,7 +29,7 @@ impl KexState {
kex_result: KexResult,
) -> Result<Self> {
let exchange_handler = KexExchangeHandler::new(kex_result)?;
Ok(Self {
client_version,
server_version,
@@ -42,18 +41,18 @@ impl KexState {
newkeys_sent: false,
})
}
/// 保存KEXINIT payloads用于Exchange Hash计算
///
///
/// 分析OpenSSH源码后的结论
/// - kex->peer存储的是incoming_packet剩余内容payload fields + padding
/// - kex->my存储的是prop2buf()结果payload fields不包括padding
///
///
/// **但exchange hash必须使用相同的I_C/I_S**
///
///
/// 疑问OpenSSH如何确保client和server使用相同的padding
/// 可能答案OpenSSH在计算exchange hash时不包括padding
///
///
/// 暂时保持不包括padding因为签名验证之前成功
pub fn save_kexinit_payloads(
&mut self,
@@ -63,12 +62,18 @@ impl KexState {
// Only save payload (without padding) for now
self.client_kexinit_payload = client_kexinit.payload.clone();
self.server_kexinit_payload = server_kexinit.payload.clone();
info!("Saved KEXINIT payloads (payload only, no padding)");
info!(" client payload: {} bytes", self.client_kexinit_payload.len());
info!(" server payload: {} bytes", self.server_kexinit_payload.len());
info!(
" client payload: {} bytes",
self.client_kexinit_payload.len()
);
info!(
" server payload: {} bytes",
self.server_kexinit_payload.len()
);
}
/// 计算Exchange Hash参考OpenSSH kex.c: kex_hash()
/// H = SHA256(V_C || V_S || I_C || I_S || K_S || K_C || K_S || shared_secret)
pub fn compute_exchange_hash(
@@ -80,74 +85,74 @@ impl KexState {
) -> Result<Vec<u8>> {
// 参考OpenSSH kex.c: kex_hash()
let mut hasher = Sha256::new();
// V_C: 客户端版本字符串SSH string格式
write_ssh_string_to_hash(&mut hasher, &self.client_version)?;
// V_S: 服务器版本字符串SSH string格式
write_ssh_string_to_hash(&mut hasher, &self.server_version)?;
// OpenSSH kexgex.c: "kexinit messages: fake header: len+SSH2_MSG_KEXINIT"
// Remove SSH_MSG_KEXINIT type byte from payloads and prepend it in exchange hash
let client_kexinit_without_type = &self.client_kexinit_payload[1..];
let server_kexinit_without_type = &self.server_kexinit_payload[1..];
hasher.update(&((client_kexinit_without_type.len() + 1) as u32).to_be_bytes());
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
hasher.update(((client_kexinit_without_type.len() + 1) as u32).to_be_bytes());
hasher.update([20]); // SSH_MSG_KEXINIT type byte
hasher.update(client_kexinit_without_type);
hasher.update(&((server_kexinit_without_type.len() + 1) as u32).to_be_bytes());
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
hasher.update(((server_kexinit_without_type.len() + 1) as u32).to_be_bytes());
hasher.update([20]); // SSH_MSG_KEXINIT type byte
hasher.update(server_kexinit_without_type);
// K_S: 服务器主机密钥blobSSH string格式
hasher.update(server_host_key_blob);
// K_C: 客户端Curve25519公钥SSH string格式
write_ssh_bytes_to_hash(&mut hasher, client_public_key)?;
// K_S: 服务器Curve25519公钥SSH string格式
write_ssh_bytes_to_hash(&mut hasher, server_public_key)?;
// K: 共享密钥SSH mpint格式
// OpenSSH要求去掉前导零
write_ssh_mpint_to_hash(&mut hasher, shared_secret)?;
Ok(hasher.finalize().to_vec())
}
/// 处理SSH_MSG_NEWKEYS参考OpenSSH kex.c: kex_input_newkeys()
pub fn handle_newkeys(&mut self, packet: &SshPacket) -> Result<()> {
info!("Processing SSH_MSG_NEWKEYS");
// 验证packet类型
if packet.payload.len() < 1 {
if packet.payload.is_empty() {
return Err(anyhow!("Invalid NEWKEYS packet"));
}
let packet_type = packet.payload[0];
if packet_type != PacketType::SSH_MSG_NEWKEYS as u8 {
return Err(anyhow!("Invalid packet type for NEWKEYS"));
}
// 标记NEWKEYS接收完成参考OpenSSH
self.newkeys_received = true;
info!("SSH_MSG_NEWKEYS received, encryption channel ready");
Ok(())
}
/// 发送SSH_MSG_NEWKEYS参考OpenSSH kex.c: kex_send_newkeys()
pub fn send_newkeys() -> Result<SshPacket> {
info!("Sending SSH_MSG_NEWKEYS");
let payload = vec![PacketType::SSH_MSG_NEWKEYS as u8];
Ok(SshPacket::new(payload))
}
/// 检查NEWKEYS完成状态加密通道建立
pub fn is_encryption_ready(&self) -> bool {
self.newkeys_received && self.newkeys_sent
@@ -156,14 +161,14 @@ impl KexState {
/// SSH string写入到hash辅助函数
fn write_ssh_string_to_hash(hasher: &mut Sha256, s: &str) -> Result<()> {
hasher.update(&(s.len() as u32).to_be_bytes());
hasher.update((s.len() as u32).to_be_bytes());
hasher.update(s.as_bytes());
Ok(())
}
/// SSH bytes写入到hash辅助函数
fn write_ssh_bytes_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
hasher.update(&(bytes.len() as u32).to_be_bytes());
hasher.update((bytes.len() as u32).to_be_bytes());
hasher.update(bytes);
Ok(())
}
@@ -171,7 +176,7 @@ fn write_ssh_bytes_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
/// SSH mpint写入到hash参考OpenSSH sshbuf_put_mpint()
fn write_ssh_mpint_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
// OpenSSH要求去掉前导零如果最高位为1
let mpint_bytes = if bytes.len() > 0 && bytes[0] >= 0x80 {
let mpint_bytes = if !bytes.is_empty() && bytes[0] >= 0x80 {
// 需要添加前导零(避免负数)
let mut mpint = vec![0u8];
mpint.extend_from_slice(bytes);
@@ -179,61 +184,67 @@ fn write_ssh_mpint_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
} else {
bytes.to_vec()
};
hasher.update(&(mpint_bytes.len() as u32).to_be_bytes());
hasher.update((mpint_bytes.len() as u32).to_be_bytes());
hasher.update(&mpint_bytes);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exchange_hash_computation() {
let kex_result = KexResult::choose_algorithms(
&KexProposal::server_default(),
&KexProposal::client_default(),
).unwrap();
)
.unwrap();
let mut state = KexState::new(
"SSH-2.0-OpenSSH_10.2".to_string(),
"SSH-2.0-MarkBaseSSH_1.0".to_string(),
kex_result,
).unwrap();
)
.unwrap();
// Set minimal KEXINIT payloads (need at least 1 byte for packet type)
state.client_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
state.server_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
state.client_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
state.server_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
let shared_secret = vec![0u8; 32];
let host_key = vec![0u8; 32];
let client_pub = vec![0u8; 32];
let server_pub = vec![0u8; 32];
let hash = state.compute_exchange_hash(&shared_secret, &host_key, &client_pub, &server_pub).unwrap();
assert_eq!(hash.len(), 32); // SHA256输出32字节
let hash = state
.compute_exchange_hash(&shared_secret, &host_key, &client_pub, &server_pub)
.unwrap();
assert_eq!(hash.len(), 32); // SHA256输出32字节
}
#[test]
fn test_newkeys_handling() {
let kex_result = KexResult::choose_algorithms(
&KexProposal::server_default(),
&KexProposal::client_default(),
).unwrap();
)
.unwrap();
let mut state = KexState::new(
"SSH-2.0-OpenSSH_10.2".to_string(),
"SSH-2.0-MarkBaseSSH_1.0".to_string(),
kex_result,
).unwrap();
)
.unwrap();
let newkeys_packet = SshPacket::new(vec![PacketType::SSH_MSG_NEWKEYS as u8]);
state.handle_newkeys(&newkeys_packet).unwrap();
assert!(state.newkeys_received);
}
}

View File

@@ -1,14 +1,14 @@
// SSH密钥交换流程实现Phase 3
// 参考OpenSSH kex.c: kex_input_kex_init(), kex_send_kex_reply()
use crate::ssh_server::packet::{SshPacket, PacketType};
use crate::ssh_server::kex::{KexResult};
use crate::ssh_server::crypto::{Curve25519Kex, SessionKeys, Ed25519HostKey};
use anyhow::{Result, anyhow};
use crate::ssh_server::crypto::{Curve25519Kex, Ed25519HostKey, SessionKeys};
use crate::ssh_server::kex::KexResult;
use crate::ssh_server::packet::{PacketType, SshPacket};
use anyhow::{anyhow, Result};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use log::{info, debug};
use log::info;
use sha2::Digest;
use std::io::{Read, Write};
use sha2::{Sha256, Digest};
/// SSH密钥交换流程处理器参考OpenSSH kex.c
pub struct KexExchangeHandler {
@@ -18,7 +18,7 @@ pub struct KexExchangeHandler {
shared_secret: Option<Vec<u8>>,
client_public_key: Option<Vec<u8>>,
server_public_key: Option<Vec<u8>>,
exchange_hash: Option<Vec<u8>>, // 保存exchange hashH参数
exchange_hash: Option<Vec<u8>>, // 保存exchange hashH参数
client_version: Option<String>,
server_version: Option<String>,
client_kexinit_payload: Option<Vec<u8>>,
@@ -30,7 +30,7 @@ impl KexExchangeHandler {
pub fn new(kex_result: KexResult) -> Result<Self> {
// 加载或生成服务器主机密钥
let host_key = Ed25519HostKey::load_or_generate("config/ssh_host_ed25519_key")?;
Ok(Self {
kex_algorithm: kex_result.kex_algorithm,
server_kex: None,
@@ -45,10 +45,10 @@ impl KexExchangeHandler {
server_kexinit_payload: None,
})
}
/// 处理SSH_MSG_KEXDH_INITCurve25519密钥交换参考OpenSSH kex.c: kex_input_kex_init()
/// 处理SSH_MSG_KEXDH_INITCurve25519密钥交换参考OpenSSH kex.c: kex_input_kex_init()
pub fn handle_kexdh_init(
&mut self,
&mut self,
packet: &SshPacket,
client_version: &str,
server_version: &str,
@@ -56,41 +56,44 @@ impl KexExchangeHandler {
server_kexinit_payload: &[u8],
) -> Result<SshPacket> {
info!("Processing SSH_MSG_KEXDH_INIT (Curve25519)");
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
let packet_type = cursor.read_u8()?;
if packet_type != PacketType::SSH_MSG_KEXDH_INIT as u8 {
return Err(anyhow!("Invalid packet type for KEXDH_INIT"));
}
let key_length = cursor.read_u32::<BigEndian>()?;
if key_length != 32 {
return Err(anyhow!("Invalid Curve25519 public key length: {}", key_length));
return Err(anyhow!(
"Invalid Curve25519 public key length: {}",
key_length
));
}
let mut client_public_key = vec![0u8; 32];
cursor.read_exact(&mut client_public_key)?;
self.server_kex = Some(Curve25519Kex::new());
let server_kex = self.server_kex.as_mut().unwrap();
let shared_secret = server_kex.compute_shared_secret(&client_public_key)?;
let server_public_key = server_kex.public_key().to_vec();
// Save for later session key computation
self.shared_secret = Some(shared_secret.to_vec());
self.client_public_key = Some(client_public_key.clone());
self.server_public_key = Some(server_public_key.clone());
// Save client_version, server_version, kexinit payloads for exchange hash
self.client_version = Some(client_version.to_string());
self.server_version = Some(server_version.to_string());
self.client_kexinit_payload = Some(client_kexinit_payload.to_vec());
self.server_kexinit_payload = Some(server_kexinit_payload.to_vec());
info!("Curve25519 shared secret computed and saved");
// Compute exchange hash ONCE and reuse it
let host_key_blob = self.build_ssh_host_key()?;
let exchange_hash = self.compute_exchange_hash(
@@ -103,69 +106,69 @@ impl KexExchangeHandler {
client_kexinit_payload,
server_kexinit_payload,
)?;
info!("Exchange hash computed:");
info!(" shared_secret[0] = {} (>=0x80? {})", shared_secret[0], shared_secret[0] >= 0x80);
info!(
" shared_secret[0] = {} (>=0x80? {})",
shared_secret[0],
shared_secret[0] >= 0x80
);
info!(" exchange_hash full (32 bytes): {:?}", exchange_hash);
self.exchange_hash = Some(exchange_hash.clone());
info!("Exchange hash saved for key derivation");
self.build_kexdh_reply(
&exchange_hash,
&host_key_blob,
&server_public_key,
)
self.build_kexdh_reply(&exchange_hash, &host_key_blob, &server_public_key)
}
/// 构建SSH_MSG_KEXDH_REPLY packet参考OpenSSH kex.c
fn build_kexdh_reply(
&self,
exchange_hash: &[u8],
&self,
exchange_hash: &[u8],
host_key_blob: &[u8],
server_public_key: &[u8],
) -> Result<SshPacket> {
info!("=== Building SSH_MSG_KEXDH_REPLY ===");
info!("Input server_public_key: {:?}", server_public_key);
let mut payload = Vec::new();
payload.write_u8(PacketType::SSH_MSG_KEXDH_REPLY as u8)?;
payload.write_u32::<BigEndian>(host_key_blob.len() as u32)?;
payload.write_all(host_key_blob)?;
info!("Writing server_public_key to payload (32 bytes)");
payload.write_u32::<BigEndian>(32)?;
payload.write_all(server_public_key)?;
let signature = self.build_exchange_signature(exchange_hash)?;
payload.write_u32::<BigEndian>(signature.len() as u32)?;
payload.write_all(&signature)?;
info!("SSH_MSG_KEXDH_REPLY payload built successfully");
Ok(SshPacket::new(payload))
}
/// 构建SSH主机密钥blob参考OpenSSH sshkey.c: sshkey_to_blob()
fn build_ssh_host_key(&self) -> Result<Vec<u8>> {
let mut blob = Vec::new();
// SSH key format: key-type + public-key
// 参考OpenSSH sshkey.c
// Key type: ssh-ed25519
blob.write_u32::<BigEndian>(11)?; // "ssh-ed25519".len()
blob.write_u32::<BigEndian>(11)?; // "ssh-ed25519".len()
blob.write_all("ssh-ed25519".as_bytes())?;
// Ed25519公钥32字节
let public_key = self.host_key.public_key_bytes();
blob.write_u32::<BigEndian>(32)?;
blob.write_all(&public_key)?;
Ok(blob)
}
/// 计算Exchange Hash参考OpenSSH kex.c: kex_hash() RFC 4253 Section 7.2
fn compute_exchange_hash(
&self,
@@ -178,94 +181,147 @@ impl KexExchangeHandler {
client_kexinit_payload: &[u8],
server_kexinit_payload: &[u8],
) -> Result<Vec<u8>> {
use sha2::{Sha256, Digest};
use sha2::{Digest, Sha256};
info!("=== EXCHANGE HASH COMPUTATION ===");
info!("V_C (client version): {:?}", client_version.as_bytes());
info!("V_C length: {}", client_version.len());
info!("V_S (server version): {:?}", server_version.as_bytes());
info!("V_S length: {}", server_version.len());
info!("I_C (client KEXINIT payload): {:?}", &client_kexinit_payload[..std::cmp::min(50, client_kexinit_payload.len())]);
info!(
"I_C (client KEXINIT payload): {:?}",
&client_kexinit_payload[..std::cmp::min(50, client_kexinit_payload.len())]
);
info!("I_C length: {}", client_kexinit_payload.len());
info!("I_C[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)", client_kexinit_payload[0]);
info!("I_S (server KEXINIT payload): {:?}", &server_kexinit_payload[..std::cmp::min(50, server_kexinit_payload.len())]);
info!(
"I_C[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)",
client_kexinit_payload[0]
);
info!(
"I_S (server KEXINIT payload): {:?}",
&server_kexinit_payload[..std::cmp::min(50, server_kexinit_payload.len())]
);
info!("I_S length: {}", server_kexinit_payload.len());
info!("I_S[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)", server_kexinit_payload[0]);
info!("K_S (host key blob): {:?}", &host_key_blob[..std::cmp::min(30, host_key_blob.len())]);
info!(
"I_S[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)",
server_kexinit_payload[0]
);
info!(
"K_S (host key blob): {:?}",
&host_key_blob[..std::cmp::min(30, host_key_blob.len())]
);
info!("K_S length: {}", host_key_blob.len());
info!("Q_C (client ECDH public key): {:?}", &client_public_key[..std::cmp::min(16, client_public_key.len())]);
info!(
"Q_C (client ECDH public key): {:?}",
&client_public_key[..std::cmp::min(16, client_public_key.len())]
);
info!("Q_C full (32 bytes): {:?}", client_public_key);
info!("Q_C length: {}", client_public_key.len());
info!("Q_S (server ECDH public key): {:?}", &server_public_key[..std::cmp::min(16, server_public_key.len())]);
info!(
"Q_S (server ECDH public key): {:?}",
&server_public_key[..std::cmp::min(16, server_public_key.len())]
);
info!("Q_S full (32 bytes): {:?}", server_public_key);
info!("Q_S length: {}", server_public_key.len());
let mut hasher = Sha256::new();
// RFC 4253 Section 7: V_C and V_S are version strings (without \r\n based on testing)
let vc_ssh_string = &(client_version.len() as u32).to_be_bytes();
hasher.update(vc_ssh_string);
hasher.update(client_version.as_bytes());
info!(" Exchange hash component V_C: len={} bytes=[{:?}] data=[{:?}]", 4+client_version.len(), vc_ssh_string, client_version.as_bytes());
info!(
" Exchange hash component V_C: len={} bytes=[{:?}] data=[{:?}]",
4 + client_version.len(),
vc_ssh_string,
client_version.as_bytes()
);
let vs_ssh_string = &(server_version.len() as u32).to_be_bytes();
hasher.update(vs_ssh_string);
hasher.update(server_version.as_bytes());
info!(" Exchange hash component V_S: len={} bytes=[{:?}] data=[{:?}]", 4+server_version.len(), vs_ssh_string, server_version.as_bytes());
info!(
" Exchange hash component V_S: len={} bytes=[{:?}] data=[{:?}]",
4 + server_version.len(),
vs_ssh_string,
server_version.as_bytes()
);
// OpenSSH kexgex.c: "kexinit messages: fake header: len+SSH2_MSG_KEXINIT"
// KEXINIT payload should NOT include SSH_MSG_KEXINIT type byte
// OpenSSH stores payload starting from cookie, prepends SSH_MSG_KEXINIT in exchange hash
// Remove SSH_MSG_KEXINIT type byte from payloads (our payload includes it)
let client_kexinit_without_type = &client_kexinit_payload[1..];
let server_kexinit_without_type = &server_kexinit_payload[1..];
info!("I_C (client KEXINIT without type byte): {} bytes (first byte should be cookie)", client_kexinit_without_type.len());
info!("I_S (server KEXINIT without type byte): {} bytes", server_kexinit_without_type.len());
info!(
"I_C (client KEXINIT without type byte): {} bytes (first byte should be cookie)",
client_kexinit_without_type.len()
);
info!(
"I_S (server KEXINIT without type byte): {} bytes",
server_kexinit_without_type.len()
);
// Exchange hash: uint32(len+1) + uint8(SSH_MSG_KEXINIT) + payload_without_type
let ic_len_bytes = &((client_kexinit_without_type.len() + 1) as u32).to_be_bytes();
hasher.update(ic_len_bytes);
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
hasher.update([20]); // SSH_MSG_KEXINIT type byte
hasher.update(client_kexinit_without_type);
info!(" Exchange hash component I_C: len={} bytes=[{:?}] type=[20] payload_len={} (first 8 bytes=[{:?}])", 4+1+client_kexinit_without_type.len(), ic_len_bytes, client_kexinit_without_type.len(), &client_kexinit_without_type[..std::cmp::min(8, client_kexinit_without_type.len())]);
let is_len_bytes = &((server_kexinit_without_type.len() + 1) as u32).to_be_bytes();
hasher.update(is_len_bytes);
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
hasher.update([20]); // SSH_MSG_KEXINIT type byte
hasher.update(server_kexinit_without_type);
info!(" Exchange hash component I_S: len={} bytes=[{:?}] type=[20] payload_len={} (first 8 bytes=[{:?}])", 4+1+server_kexinit_without_type.len(), is_len_bytes, server_kexinit_without_type.len(), &server_kexinit_without_type[..std::cmp::min(8, server_kexinit_without_type.len())]);
let ks_len_bytes = &(host_key_blob.len() as u32).to_be_bytes();
hasher.update(ks_len_bytes);
hasher.update(host_key_blob);
info!(" Exchange hash component K_S: len={} bytes=[{:?}] blob_len={} (full=[{:?}])", 4+host_key_blob.len(), ks_len_bytes, host_key_blob.len(), host_key_blob);
info!(
" Exchange hash component K_S: len={} bytes=[{:?}] blob_len={} (full=[{:?}])",
4 + host_key_blob.len(),
ks_len_bytes,
host_key_blob.len(),
host_key_blob
);
let qc_len_bytes = &(client_public_key.len() as u32).to_be_bytes();
hasher.update(qc_len_bytes);
hasher.update(client_public_key);
info!(" Exchange hash component Q_C: len={} bytes=[{:?}] key=[{:?}]", 4+client_public_key.len(), qc_len_bytes, client_public_key);
info!(
" Exchange hash component Q_C: len={} bytes=[{:?}] key=[{:?}]",
4 + client_public_key.len(),
qc_len_bytes,
client_public_key
);
let qs_len_bytes = &(server_public_key.len() as u32).to_be_bytes();
hasher.update(qs_len_bytes);
hasher.update(server_public_key);
info!(" Exchange hash component Q_S: len={} bytes=[{:?}] key=[{:?}]", 4+server_public_key.len(), qs_len_bytes, server_public_key);
info!(
" Exchange hash component Q_S: len={} bytes=[{:?}] key=[{:?}]",
4 + server_public_key.len(),
qs_len_bytes,
server_public_key
);
info!("Exchange hash components:");
info!(" shared_secret raw full (32 bytes): {:?}", shared_secret);
// RFC 8731 Section 3.1: X25519 output is little-endian
// OpenSSH sshbuf_put_bignum2_bytes() uses bytes DIRECTLY (no reversal)
// Treats little-endian bytes as big-endian mpint (logical reinterpret)
info!(" Using shared_secret directly (little-endian bytes as big-endian mpint)");
// RFC 4253: mpint格式 = 去掉前导零 + 最高位>=0x80时前面加0
// 参考OpenSSH sshbuf_put_bignum2_bytes()
let mut start = 0;
@@ -273,64 +329,73 @@ impl KexExchangeHandler {
start += 1;
}
let trimmed_shared_secret = &shared_secret[start..];
info!(" shared_secret after removing leading zeros ({} bytes): {:?}", trimmed_shared_secret.len(), trimmed_shared_secret);
let mpint_shared_secret_data = if trimmed_shared_secret.len() > 0 && trimmed_shared_secret[0] >= 0x80 {
let mut mpint = vec![0u8];
mpint.extend_from_slice(trimmed_shared_secret);
info!(" trimmed_shared_secret[0] >= 0x80, prepending 0 byte");
mpint
} else {
trimmed_shared_secret.to_vec()
};
info!(" mpint_shared_secret_data ({} bytes): {:?}", mpint_shared_secret_data.len(), &mpint_shared_secret_data[..std::cmp::min(8, mpint_shared_secret_data.len())]);
info!(
" shared_secret after removing leading zeros ({} bytes): {:?}",
trimmed_shared_secret.len(),
trimmed_shared_secret
);
let mpint_shared_secret_data =
if !trimmed_shared_secret.is_empty() && trimmed_shared_secret[0] >= 0x80 {
let mut mpint = vec![0u8];
mpint.extend_from_slice(trimmed_shared_secret);
info!(" trimmed_shared_secret[0] >= 0x80, prepending 0 byte");
mpint
} else {
trimmed_shared_secret.to_vec()
};
info!(
" mpint_shared_secret_data ({} bytes): {:?}",
mpint_shared_secret_data.len(),
&mpint_shared_secret_data[..std::cmp::min(8, mpint_shared_secret_data.len())]
);
// mpint格式 = uint32(length) + mpint_data
let mpint_len_bytes = &(mpint_shared_secret_data.len() as u32).to_be_bytes();
hasher.update(mpint_len_bytes);
hasher.update(&mpint_shared_secret_data);
info!(" Exchange hash component K (shared secret mpint): len={} bytes=[{:?}] data_len={} (first 8 bytes=[{:?}])", 4+mpint_shared_secret_data.len(), mpint_len_bytes, mpint_shared_secret_data.len(), &mpint_shared_secret_data[..std::cmp::min(8, mpint_shared_secret_data.len())]);
Ok(hasher.finalize().to_vec())
}
/// 构建交换签名参考OpenSSH ssh-sign.c
fn build_exchange_signature(&self, exchange_hash: &[u8]) -> Result<Vec<u8>> {
let signature_bytes = self.host_key.sign(exchange_hash)?;
let mut ssh_signature = Vec::new();
ssh_signature.write_u32::<BigEndian>(11)?;
ssh_signature.write_all("ssh-ed25519".as_bytes())?;
ssh_signature.write_u32::<BigEndian>(64)?;
ssh_signature.write_all(&signature_bytes)?;
Ok(ssh_signature)
}
/// 计算会话密钥参考OpenSSH kex.c: derive_keys()
/// 使用保存的exchange_hashH参数
pub fn compute_session_keys(&self) -> Result<SessionKeys> {
if self.shared_secret.is_none() {
return Err(anyhow!("No shared secret available"));
}
if self.exchange_hash.is_none() {
return Err(anyhow!("No exchange hash available"));
}
let shared_secret = self.shared_secret.as_ref().unwrap();
let exchange_hash = self.exchange_hash.as_ref().unwrap();
let server_public_key = self.server_public_key.as_ref().unwrap();
let client_public_key = self.client_public_key.as_ref().unwrap();
let host_key_blob = self.build_ssh_host_key()?;
SessionKeys::derive(
shared_secret,
exchange_hash, // 使用保存的exchange hashH参数
exchange_hash, // 使用保存的exchange hashH参数
server_public_key,
client_public_key,
&host_key_blob,
@@ -342,13 +407,13 @@ impl KexExchangeHandler {
mod tests {
use super::*;
use crate::ssh_server::kex::KexProposal;
#[test]
fn test_kex_exchange_handler_creation() {
let server_proposal = KexProposal::server_default();
let client_proposal = KexProposal::client_default();
let kex_result = KexResult::choose_algorithms(&server_proposal, &client_proposal).unwrap();
let handler = KexExchangeHandler::new(kex_result).unwrap();
assert!(handler.host_key.public_key_bytes().len() == 32);
}

View File

@@ -1,28 +1,28 @@
// SSH服务器模块手动实现SSH协议
// 参考OpenSSH源码实现完整的SSH/SFTP/SCP/rsync协议
pub mod server;
pub mod packet;
pub mod version;
pub mod crypto;
pub mod kex;
pub mod kex_exchange;
pub mod kex_complete;
pub mod cipher;
pub mod auth;
pub mod channel;
pub mod sftp_handler;
pub mod scp_handler;
pub mod cipher;
pub mod crypto;
pub mod data_forwarder; // Phase 13.5: 数据传输模块
pub mod kex;
pub mod kex_complete;
pub mod kex_exchange;
pub mod packet;
pub mod port_forward; // Phase 13: 端口转发模块
pub mod port_forward_listener; // Phase 13.4: 监听线程模块
pub mod rsync_handler;
pub mod sshbuf; // Phase 15: SSH Buffer 零拷贝管理参考OpenSSH sshbuf.c
pub mod port_forward; // Phase 13: 端口转发模块
pub mod ssh_security_config; // Phase 13.1: 企业级安全配置
pub mod port_forward_listener; // Phase 13.4: 监听线程模块
pub mod data_forwarder; // Phase 13.5: 数据传输模块
pub mod window_manager; // Phase 13.6-13.7: Window size + Channel生命周期
pub mod scp_handler;
pub mod server;
pub mod sftp_handler;
pub mod ssh_security_config; // Phase 13.1: 企业级安全配置
pub mod sshbuf; // Phase 15: SSH Buffer 零拷贝管理参考OpenSSH sshbuf.c
pub mod version;
pub mod window_manager; // Phase 13.6-13.7: Window size + Channel生命周期
pub use packet::{PacketType, SshPacket};
pub use server::SshServer;
pub use packet::{SshPacket, PacketType};
pub use version::VersionExchange;
pub use ssh_security_config::SshSecurityConfig; // Phase 13.1: 导出安全配置
pub use sshbuf::SshBuf; // Phase 15: 导出 SSH Buffer
pub use ssh_security_config::SshSecurityConfig; // Phase 13.1: 导出安全配置
pub use sshbuf::SshBuf;
pub use version::VersionExchange; // Phase 15: 导出 SSH Buffer

View File

@@ -1,7 +1,7 @@
// SSH Packet基础结构定义
// 参考OpenSSH packet.c: ssh_packet_read(), ssh_packet_write()
use anyhow::{Result, anyhow};
use anyhow::{anyhow, Result};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use std::io::{Read, Write};
@@ -18,21 +18,21 @@ pub enum PacketType {
SSH_MSG_EXT_INFO = 7,
SSH_MSG_KEXINIT = 20,
SSH_MSG_NEWKEYS = 21,
// 密钥交换相关
SSH_MSG_KEXDH_INIT = 30,
SSH_MSG_KEXDH_REPLY = 31,
// 注意Curve25519和DH使用相同的消息类型30/31
// SSH_MSG_KEX_ECDH_INIT和SSH_MSG_KEX_ECDH_REPLY已在代码中注释
// 使用SSH_MSG_KEXDH_INIT和SSH_MSG_KEXDH_REPLY代替
// 认证相关
SSH_MSG_USERAUTH_REQUEST = 50,
SSH_MSG_USERAUTH_FAILURE = 51,
SSH_MSG_USERAUTH_SUCCESS = 52,
SSH_MSG_USERAUTH_BANNER = 53,
SSH_MSG_USERAUTH_PK_OK = 60,
// Channel相关
SSH_MSG_GLOBAL_REQUEST = 80,
SSH_MSG_REQUEST_SUCCESS = 81,
@@ -70,38 +70,38 @@ impl SshPacket {
pub fn new(payload: Vec<u8>) -> Self {
// 计算paddingSSH协议RFC 4253规范
// 参考OpenSSH packet.c: construct_packet()
let block_size = 8; // 未加密阶段block_size=8
let block_size = 8; // 未加密阶段block_size=8
let payload_length = payload.len();
let min_padding = 4; // OpenSSH要求最少4字节padding
let min_padding = 4; // OpenSSH要求最少4字节padding
// SSH协议约束
// packet_length = padding_length + payload_length + 1
// (packet_length + 4) 必须是block_size的倍数
//
//
// 计算:
// (1 + payload_length + padding_length + 4) % 8 == 0
// => (5 + payload_length + padding_length) % 8 == 0
// 先尝试padding=4最小
let mut padding_length = min_padding as u8;
// 计算packet总长度包括4字节的packet_length字段
let packet_length = 1 + payload_length + padding_length as usize;
let total_length = packet_length + 4; // 加上packet_length字段本身的4字节
let total_length = packet_length + 4; // 加上packet_length字段本身的4字节
// 如果总长度不是block_size的倍数增加padding
if total_length % block_size != 0 {
if !total_length.is_multiple_of(block_size) {
let remainder = total_length % block_size;
padding_length += (block_size - remainder) as u8;
}
// 重新计算packet_length
let packet_length = (1 + payload_length + padding_length as usize) as u32;
// 生成随机padding简化使用0
let padding = vec![0u8; padding_length as usize];
Self {
packet_length,
padding_length,
@@ -109,49 +109,49 @@ impl SshPacket {
padding,
}
}
/// 写入packet到stream未加密阶段
/// 参考OpenSSH packet_write_poll()
pub fn write<T: Write>(&self, stream: &mut T) -> Result<()> {
// 写入packet_lengthBigEndian
stream.write_u32::<BigEndian>(self.packet_length)?;
// 写入padding_length
stream.write_u8(self.padding_length)?;
// 写入payload
stream.write_all(&self.payload)?;
// 写入padding
stream.write_all(&self.padding)?;
stream.flush()?;
Ok(())
}
/// 从stream读取packet未加密阶段
/// 参考OpenSSH packet_read_poll()
pub fn read<T: Read>(stream: &mut T) -> Result<Self> {
// 读取packet_lengthBigEndian
let packet_length = stream.read_u32::<BigEndian>()?;
// 检查packet长度限制OpenSSH限制256KB
if packet_length > 256 * 1024 {
return Err(anyhow!("Packet too large: {}", packet_length));
}
// 读取padding_length
let padding_length = stream.read_u8()?;
// 读取payloadpacket_length - padding_length - 1
let payload_length = packet_length - padding_length as u32 - 1;
let mut payload = vec![0u8; payload_length as usize];
stream.read_exact(&mut payload)?;
// 读取padding
let mut padding = vec![0u8; padding_length as usize];
stream.read_exact(&mut padding)?;
Ok(Self {
packet_length,
padding_length,
@@ -159,15 +159,15 @@ impl SshPacket {
padding,
})
}
/// 获取payload中的packet type
pub fn get_type(&self) -> Result<PacketType> {
if self.payload.is_empty() {
return Err(anyhow!("Empty payload"));
}
let type_byte = self.payload[0];
// 转换为PacketType enum
match type_byte {
1 => Ok(PacketType::SSH_MSG_DISCONNECT),
@@ -208,27 +208,27 @@ impl SshPacket {
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_packet_creation() {
let payload = vec![PacketType::SSH_MSG_KEXINIT as u8];
let packet = SshPacket::new(payload);
assert!(packet.packet_length > 0);
assert!(packet.padding_length >= 4);
}
#[test]
fn test_packet_write_read() {
let payload = vec![PacketType::SSH_MSG_KEXINIT as u8];
let packet = SshPacket::new(payload);
let mut buffer = Vec::new();
packet.write(&mut buffer).unwrap();
let mut cursor = Cursor::new(buffer);
let read_packet = SshPacket::read(&mut cursor).unwrap();
assert_eq!(packet.packet_length, read_packet.packet_length);
assert_eq!(packet.payload, read_packet.payload);
}

View File

@@ -1,21 +1,21 @@
// SSH端口转发协议实现Phase 13
// 参考OpenSSH channels.c和RFC 4254
use anyhow::{Result, anyhow};
use log::{info, warn, debug};
use std::net::{TcpListener, TcpStream, SocketAddr};
use std::io::{Read, Write};
use std::sync::{Arc, Mutex};
use std::thread;
use crate::ssh_server::ssh_security_config::SshSecurityConfig;
use anyhow::Result;
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.2: 安全配置
use log::{info, warn};
use std::io::Read;
use std::net::{TcpListener, TcpStream};
use std::sync::{Arc, Mutex};
// Phase 13.2: 安全配置
/// 端口转发类型参考RFC 4254
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PortForwardType {
Local, // Local port forwarding (-L)
Remote, // Remote port forwarding (-R)
Dynamic, // Dynamic port forwarding (-D, SOCKS)
Local, // Local port forwarding (-L)
Remote, // Remote port forwarding (-R)
Dynamic, // Dynamic port forwarding (-D, SOCKS)
}
/// 端口转发请求参考RFC 4254 Section 7
@@ -36,6 +36,12 @@ pub struct PortForwardManager {
active_forwards: Arc<Mutex<Vec<(u32, PortForwardType)>>>,
}
impl Default for PortForwardManager {
fn default() -> Self {
Self::new()
}
}
impl PortForwardManager {
pub fn new() -> Self {
Self {
@@ -46,24 +52,29 @@ impl PortForwardManager {
/// 处理SSH_MSG_GLOBAL_REQUEST端口转发请求
/// 参考RFC 4254 Section 4
/// Phase 13.2: 添加安全配置验证
pub fn handle_global_request(&mut self, data: &[u8], security_config: &SshSecurityConfig) -> Result<(bool, Option<Vec<u8>>)> {
pub fn handle_global_request(
&mut self,
data: &[u8],
security_config: &SshSecurityConfig,
) -> Result<(bool, Option<Vec<u8>>)> {
info!("Processing SSH_MSG_GLOBAL_REQUEST for port forwarding");
let mut cursor = std::io::Cursor::new(data);
cursor.set_position(1); // Skip packet type
cursor.set_position(1); // Skip packet type
// 读取请求名称SSH string
let request_name = read_ssh_string(&mut cursor)?;
info!("Global request: {}", request_name);
// 读取want-reply标志
let want_reply = cursor.read_u8()? != 0;
match request_name.as_str() {
"tcpip-forward" => {
// Local port forwarding (-L)
self.handle_tcpip_forward(&mut cursor, want_reply, security_config) // Phase 13.2
self.handle_tcpip_forward(&mut cursor, want_reply, security_config)
// Phase 13.2
}
"cancel-tcpip-forward" => {
// Cancel port forwarding
@@ -84,29 +95,37 @@ impl PortForwardManager {
/// 处理tcpip-forward请求Local port forwarding
/// 参考RFC 4254 Section 7.1
/// Phase 13.2: 添加安全配置验证
fn handle_tcpip_forward(&mut self, cursor: &mut std::io::Cursor<&[u8]>, want_reply: bool, security_config: &SshSecurityConfig) -> Result<(bool, Option<Vec<u8>>)> {
fn handle_tcpip_forward(
&mut self,
cursor: &mut std::io::Cursor<&[u8]>,
want_reply: bool,
security_config: &SshSecurityConfig,
) -> Result<(bool, Option<Vec<u8>>)> {
// 读取bind addressSSH string
let bind_address = read_ssh_string(cursor)?;
// 读取bind port
let bind_port = cursor.read_u32::<BigEndian>()?;
info!("tcpip-forward request: bind_address={}, bind_port={}", bind_address, bind_port);
info!(
"tcpip-forward request: bind_address={}, bind_port={}",
bind_address, bind_port
);
// Phase 13.2: 安全配置验证
if let Err(e) = security_config.validate_tcpip_forward_request(&bind_address, bind_port) {
warn!("tcpip-forward security validation failed: {}", e);
return Ok((false, None)); // 拒绝请求
return Ok((false, None)); // 拒绝请求
}
info!("tcpip-forward security validation passed");
// 添加到active forwards
let mut forwards = self.active_forwards.lock().unwrap();
forwards.push((bind_port, PortForwardType::Local));
info!("tcpip-forward registered: bind_port={}", bind_port);
// 返回成功响应包含bind_port
if want_reply {
let response = self.build_global_request_response(true, Some(bind_port))?;
@@ -117,16 +136,23 @@ impl PortForwardManager {
}
/// 处理cancel-tcpip-forward请求
fn handle_cancel_tcpip_forward(&mut self, cursor: &mut std::io::Cursor<&[u8]>, want_reply: bool) -> Result<(bool, Option<Vec<u8>>)> {
fn handle_cancel_tcpip_forward(
&mut self,
cursor: &mut std::io::Cursor<&[u8]>,
want_reply: bool,
) -> Result<(bool, Option<Vec<u8>>)> {
let bind_address = read_ssh_string(cursor)?;
let bind_port = cursor.read_u32::<BigEndian>()?;
info!("cancel-tcpip-forward: bind_address={}, bind_port={}", bind_address, bind_port);
info!(
"cancel-tcpip-forward: bind_address={}, bind_port={}",
bind_address, bind_port
);
// 移除active forward
let mut forwards = self.active_forwards.lock().unwrap();
forwards.retain(|(port, _)| *port != bind_port);
if want_reply {
let response = self.build_global_request_response(true, None)?;
Ok((true, Some(response)))
@@ -136,14 +162,18 @@ impl PortForwardManager {
}
/// 构建SSH_MSG_REQUEST_SUCCESS/FAILURE响应
fn build_global_request_response(&self, success: bool, bound_port: Option<u32>) -> Result<Vec<u8>> {
fn build_global_request_response(
&self,
success: bool,
bound_port: Option<u32>,
) -> Result<Vec<u8>> {
use crate::ssh_server::packet::PacketType;
let mut response = Vec::new();
if success {
response.write_u8(PacketType::SSH_MSG_REQUEST_SUCCESS as u8)?;
// 如果有bound_port写入用于tcpip-forward响应
if let Some(port) = bound_port {
response.write_u32::<BigEndian>(port)?;
@@ -151,7 +181,7 @@ impl PortForwardManager {
} else {
response.write_u8(PacketType::SSH_MSG_REQUEST_FAILURE as u8)?;
}
Ok(response)
}
@@ -159,37 +189,39 @@ impl PortForwardManager {
/// 参考RFC 4254 Section 7.2
pub fn handle_direct_tcpip_channel(&mut self, data: &[u8]) -> Result<DirectTcpipChannel> {
info!("Processing direct-tcpip channel open");
let mut cursor = std::io::Cursor::new(data);
cursor.set_position(1); // Skip packet type
cursor.set_position(1); // Skip packet type
// 读取channel type已知道是"direct-tcpip",跳过)
let _channel_type = read_ssh_string(&mut cursor)?;
// 读取sender_channel
let sender_channel = cursor.read_u32::<BigEndian>()?;
// 读取initial window size
let initial_window_size = cursor.read_u32::<BigEndian>()?;
// 读取maximum packet size
let max_packet_size = cursor.read_u32::<BigEndian>()?;
// 读取host to connectSSH string
let host_to_connect = read_ssh_string(&mut cursor)?;
// 读取port to connect
let port_to_connect = cursor.read_u32::<BigEndian>()?;
// 读取originator addressSSH string
let originator_address = read_ssh_string(&mut cursor)?;
// 读取originator port
let originator_port = cursor.read_u32::<BigEndian>()?;
info!("direct-tcpip: host={}, port={}, originator={}:{}",
host_to_connect, port_to_connect, originator_address, originator_port);
info!(
"direct-tcpip: host={}, port={}, originator={}:{}",
host_to_connect, port_to_connect, originator_address, originator_port
);
Ok(DirectTcpipChannel {
sender_channel,
initial_window_size,
@@ -205,30 +237,32 @@ impl PortForwardManager {
/// 参考RFC 4254 Section 7.1
pub fn handle_forwarded_tcpip_channel(&mut self, data: &[u8]) -> Result<ForwardedTcpipChannel> {
info!("Processing forwarded-tcpip channel open");
let mut cursor = std::io::Cursor::new(data);
cursor.set_position(1);
let _channel_type = read_ssh_string(&mut cursor)?;
let sender_channel = cursor.read_u32::<BigEndian>()?;
let initial_window_size = cursor.read_u32::<BigEndian>()?;
let max_packet_size = cursor.read_u32::<BigEndian>()?;
// 读取bind addressSSH string
let bind_address = read_ssh_string(&mut cursor)?;
// 读取bind port
let bind_port = cursor.read_u32::<BigEndian>()?;
// 读取originator addressSSH string
let originator_address = read_ssh_string(&mut cursor)?;
// 读取originator port
let originator_port = cursor.read_u32::<BigEndian>()?;
info!("forwarded-tcpip: bind={}:{}, originator={}:{}",
bind_address, bind_port, originator_address, originator_port);
info!(
"forwarded-tcpip: bind={}:{}, originator={}:{}",
bind_address, bind_port, originator_address, originator_port
);
Ok(ForwardedTcpipChannel {
sender_channel,
initial_window_size,
@@ -244,10 +278,10 @@ impl PortForwardManager {
pub fn connect_to_target(host: &str, port: u32) -> Result<TcpStream> {
let addr = format!("{}:{}", host, port);
info!("Connecting to target: {}", addr);
let stream = TcpStream::connect(&addr)?;
info!("Connected to target successfully");
Ok(stream)
}
@@ -258,12 +292,12 @@ impl PortForwardManager {
} else {
format!("{}:{}", bind_address, bind_port)
};
info!("Creating listener on {}", addr);
let listener = TcpListener::bind(&addr)?;
info!("Listener created successfully");
Ok(listener)
}
}
@@ -303,10 +337,10 @@ fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_port_forward_manager_creation() {
let manager = PortForwardManager::new();
assert_eq!(manager.active_forwards.lock().unwrap().len(), 0);
}
}
}

View File

@@ -1,15 +1,12 @@
// SSH端口转发监听线程Phase 13.4
// 参考OpenSSH channels.c: channel_forward_listener
use anyhow::{Result, anyhow};
use log::{info, warn, debug, error};
use std::net::{TcpListener, TcpStream};
use std::thread;
use std::sync::{Arc, Mutex, mpsc};
use std::io::{Read, Write};
use byteorder::{BigEndian, WriteBytesExt};
use crate::ssh_server::packet::PacketType;
use crate::ssh_server::ssh_security_config::SshSecurityConfig;
use anyhow::Result;
use log::{error, info, warn};
use std::net::{TcpListener, TcpStream};
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
/// 监听器状态Phase 13.4
#[derive(Debug, Clone)]
@@ -30,28 +27,18 @@ pub enum ListenerRequest {
stream: TcpStream,
},
/// 停止监听
StopListener {
bind_port: u32,
},
StopListener { bind_port: u32 },
}
/// 监听器响应Phase 13.4:线程通信)
#[derive(Debug)]
pub enum ListenerResponse {
/// Channel创建成功
ChannelCreated {
bind_port: u32,
channel_id: u32,
},
ChannelCreated { bind_port: u32, channel_id: u32 },
/// 监听器停止
ListenerStopped {
bind_port: u32,
},
ListenerStopped { bind_port: u32 },
/// 错误
Error {
bind_port: u32,
message: String,
},
Error { bind_port: u32, message: String },
}
/// 端口转发监听器Phase 13.4
@@ -73,26 +60,29 @@ impl PortForwardListener {
security_config: SshSecurityConfig,
) -> Result<Self> {
info!("Creating port forward listener on port {}", bind_port);
// Phase 13.4: 根据GatewayPorts决定绑定地址
let bind_addr = if security_config.gateway_ports {
format!("0.0.0.0:{}", bind_port) // 允许外部访问
format!("0.0.0.0:{}", bind_port) // 允许外部访问
} else {
format!("127.0.0.1:{}", bind_port) // 只允许本地访问
format!("127.0.0.1:{}", bind_port) // 只允许本地访问
};
info!("Binding to address: {} (GatewayPorts={})", bind_addr, security_config.gateway_ports);
info!(
"Binding to address: {} (GatewayPorts={})",
bind_addr, security_config.gateway_ports
);
let listener = TcpListener::bind(&bind_addr)?;
info!("Listener created successfully on {}", bind_addr);
// Phase 13.4: 创建线程通信channel
let (request_tx, request_rx) = mpsc::channel();
let (response_tx, response_rx) = mpsc::channel();
let (request_tx, _request_rx) = mpsc::channel();
let (_response_tx, response_rx) = mpsc::channel();
// Phase 13.4: 活动状态标记
let active = Arc::new(Mutex::new(true));
Ok(Self {
bind_port,
bind_address,
@@ -103,38 +93,38 @@ impl PortForwardListener {
active,
})
}
/// 启动监听线程Phase 13.4
pub fn start_listener_thread(&mut self) -> Result<()> {
info!("Starting listener thread for port {}", self.bind_port);
let listener = self.listener.try_clone()?;
let bind_port = self.bind_port;
let request_sender = self.request_sender.clone();
let active = self.active.clone();
// Phase 13.4: 创建独立监听线程
thread::spawn(move || {
info!("Listener thread started for port {}", bind_port);
while *active.lock().unwrap() {
match listener.accept() {
Ok((stream, addr)) => {
info!("New connection on port {}: {}", bind_port, addr);
// Phase 13.4: 发送新连接请求给主线程
let request = ListenerRequest::NewConnection {
bind_port,
originator_address: addr.ip().to_string(),
originator_port: addr.port() as u32, // Phase 13.4: u16转u32
originator_port: addr.port() as u32, // Phase 13.4: u16转u32
stream,
};
if let Err(e) = request_sender.send(request) {
error!("Failed to send listener request: {}", e);
break;
}
info!("Listener request sent to main thread");
}
Err(e) => {
@@ -145,32 +135,32 @@ impl PortForwardListener {
}
}
}
info!("Listener thread stopped for port {}", bind_port);
});
info!("Listener thread started successfully");
Ok(())
}
/// 停止监听器Phase 13.4
pub fn stop_listener(&mut self) -> Result<()> {
info!("Stopping listener for port {}", self.bind_port);
// Phase 13.4: 设置active=false线程会自动退出
*self.active.lock().unwrap() = false;
info!("Listener stopped for port {}", self.bind_port);
Ok(())
}
/// 获取请求接收器Phase 13.4
pub fn get_request_receiver(&self) -> mpsc::Receiver<ListenerRequest> {
// 注意这里需要返回一个新的receiver因为mpsc::Sender可以clone但Receiver不能
// 实际应用中应该使用更复杂的channel设计
unimplemented!("Use Arc<Mutex<mpsc::Receiver>> instead")
}
/// 获取活动状态Phase 13.4
pub fn is_active(&self) -> bool {
*self.active.lock().unwrap()
@@ -182,13 +172,19 @@ pub struct ListenerManager {
listeners: HashMap<u32, Arc<Mutex<PortForwardListener>>>,
}
impl Default for ListenerManager {
fn default() -> Self {
Self::new()
}
}
impl ListenerManager {
pub fn new() -> Self {
Self {
listeners: HashMap::new(),
}
}
/// 创建并启动监听器Phase 13.4
pub fn create_listener(
&mut self,
@@ -197,21 +193,21 @@ impl ListenerManager {
security_config: SshSecurityConfig,
) -> Result<()> {
info!("Creating listener for port {}", bind_port);
let mut listener = PortForwardListener::new(bind_port, bind_address, security_config)?;
listener.start_listener_thread()?;
let listener_arc = Arc::new(Mutex::new(listener));
self.listeners.insert(bind_port, listener_arc);
info!("Listener created and started for port {}", bind_port);
Ok(())
}
/// 停止监听器Phase 13.4
pub fn stop_listener(&mut self, bind_port: u32) -> Result<()> {
info!("Stopping listener for port {}", bind_port);
if let Some(listener_arc) = self.listeners.remove(&bind_port) {
let mut listener = listener_arc.lock().unwrap();
listener.stop_listener()?;
@@ -219,28 +215,31 @@ impl ListenerManager {
} else {
warn!("No listener found for port {}", bind_port);
}
Ok(())
}
/// 获取活动监听器数量Phase 13.4
pub fn active_count(&self) -> usize {
self.listeners.values().filter(|l| l.lock().unwrap().is_active()).count()
self.listeners
.values()
.filter(|l| l.lock().unwrap().is_active())
.count()
}
}
use std::collections::HashMap; // Phase 13.4: HashMap for listener management
use std::collections::HashMap; // Phase 13.4: HashMap for listener management
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_listener_creation() {
let security_config = SshSecurityConfig::enterprise_default();
let listener = PortForwardListener::new(8080, "127.0.0.1".to_string(), security_config);
// 注意:实际测试需要处理端口占用问题
assert!(listener.is_ok() || true); // 暂时跳过测试
assert!(listener.is_ok() || true); // 暂时跳过测试
}
}

View File

@@ -1,8 +1,8 @@
use std::path::PathBuf;
use anyhow::{Result, anyhow};
use log::{info, debug, warn};
use crate::vfs::{VfsBackend, VfsFile, VfsError};
use crate::vfs::open_flags::OpenFlags;
use crate::vfs::{VfsBackend, VfsFile};
use anyhow::{anyhow, Result};
use log::{debug, info, warn};
use std::path::PathBuf;
/// MPLEX_BASE from rsync io.h
const MPLEX_BASE: u32 = 7;
@@ -18,7 +18,9 @@ pub(crate) enum RsyncState {
WaitVersion,
ReadFileList,
/// Sum head (4 × write_int = 16 bytes) + checksum seed (4 bytes) = 20 bytes
ReadSumHead { need: usize },
ReadSumHead {
need: usize,
},
SendSumCount,
/// Raw file data from MSG_DATA packets
ReadFileData,
@@ -51,9 +53,16 @@ impl RsyncHandler {
let mut dest = String::new();
for p in &parts[1..] {
if *p == "--server" { is_server = true; continue; }
if *p == "--sender" || p.starts_with('-') { continue; }
if *p == "." { continue; }
if *p == "--server" {
is_server = true;
continue;
}
if *p == "--sender" || p.starts_with('-') {
continue;
}
if *p == "." {
continue;
}
dest = p.to_string();
}
@@ -107,8 +116,10 @@ impl RsyncHandler {
break;
}
let header = u32::from_le_bytes([
self.raw_input[0], self.raw_input[1],
self.raw_input[2], self.raw_input[3],
self.raw_input[0],
self.raw_input[1],
self.raw_input[2],
self.raw_input[3],
]);
let raw_tag = ((header >> 24) & 0xFF) as u8;
let tag = raw_tag.wrapping_sub(MPLEX_BASE as u8);
@@ -182,12 +193,17 @@ impl RsyncHandler {
RsyncState::WaitVersion => {
if self.rsync_input.len() >= 4 {
let version = u32::from_le_bytes([
self.rsync_input[0], self.rsync_input[1],
self.rsync_input[2], self.rsync_input[3],
self.rsync_input[0],
self.rsync_input[1],
self.rsync_input[2],
self.rsync_input[3],
]);
self.rsync_input.drain(..4);
self.protocol_version = std::cmp::min(self.protocol_version, version);
info!("rsync: negotiated protocol version {}", self.protocol_version);
info!(
"rsync: negotiated protocol version {}",
self.protocol_version
);
self.multiplex = self.protocol_version >= 30;
self.transition(RsyncState::ReadFileList);
} else {
@@ -197,7 +213,9 @@ impl RsyncHandler {
RsyncState::ReadFileList => {
loop {
if self.rsync_input.is_empty() { break; }
if self.rsync_input.is_empty() {
break;
}
let flags = self.rsync_input[0];
if flags == 0 {
@@ -215,17 +233,25 @@ impl RsyncHandler {
let mut pos = 1;
let _more_flags = if flags & 0x80 != 0 {
if self.rsync_input.len() <= pos { break; }
if self.rsync_input.len() <= pos {
break;
}
let ef = self.rsync_input[pos];
pos += 1;
ef
} else { 0 };
} else {
0
};
let has_name = !(flags & 0x02 != 0 && self.current_file > 0);
if has_name {
if let Some(nul_pos) = self.rsync_input[pos..].iter().position(|&b| b == 0) {
let name = String::from_utf8_lossy(&self.rsync_input[pos..pos + nul_pos]).to_string();
if let Some(nul_pos) =
self.rsync_input[pos..].iter().position(|&b| b == 0)
{
let name =
String::from_utf8_lossy(&self.rsync_input[pos..pos + nul_pos])
.to_string();
pos += nul_pos + 1;
self.file_entries.push(name.clone());
debug!("rsync: file entry: {}", name);
@@ -269,24 +295,34 @@ impl RsyncHandler {
RsyncState::ReadSumHead { need } => {
if self.rsync_input.len() >= need {
let sum_count = i32::from_le_bytes([
self.rsync_input[0], self.rsync_input[1],
self.rsync_input[2], self.rsync_input[3],
self.rsync_input[0],
self.rsync_input[1],
self.rsync_input[2],
self.rsync_input[3],
]);
let _sum_blength = i32::from_le_bytes([
self.rsync_input[4], self.rsync_input[5],
self.rsync_input[6], self.rsync_input[7],
self.rsync_input[4],
self.rsync_input[5],
self.rsync_input[6],
self.rsync_input[7],
]);
let _sum_s2length = i32::from_le_bytes([
self.rsync_input[8], self.rsync_input[9],
self.rsync_input[10], self.rsync_input[11],
self.rsync_input[8],
self.rsync_input[9],
self.rsync_input[10],
self.rsync_input[11],
]);
let _sum_remainder = i32::from_le_bytes([
self.rsync_input[12], self.rsync_input[13],
self.rsync_input[14], self.rsync_input[15],
self.rsync_input[12],
self.rsync_input[13],
self.rsync_input[14],
self.rsync_input[15],
]);
let checksum_seed = i32::from_le_bytes([
self.rsync_input[16], self.rsync_input[17],
self.rsync_input[18], self.rsync_input[19],
self.rsync_input[16],
self.rsync_input[17],
self.rsync_input[18],
self.rsync_input[19],
]);
self.rsync_input.drain(..20);
@@ -308,7 +344,9 @@ impl RsyncHandler {
RsyncState::ReadFileData => {
let done_marker = b"RSYNCDONE";
if let Some(pos) = self.rsync_input.windows(done_marker.len())
if let Some(pos) = self
.rsync_input
.windows(done_marker.len())
.position(|w| w == done_marker)
{
if pos > 0 {
@@ -323,8 +361,11 @@ impl RsyncHandler {
warn!("rsync flush error: {}", e);
}
}
info!("rsync: file {} complete ({} bytes written to {})",
self.file_entries.get(self.current_file).unwrap_or(&"?".to_string()),
info!(
"rsync: file {} complete ({} bytes written to {})",
self.file_entries
.get(self.current_file)
.unwrap_or(&"?".to_string()),
self.total_written,
self.dest_path.display(),
);
@@ -332,8 +373,11 @@ impl RsyncHandler {
self.current_file += 1;
if self.current_file >= self.file_entries.len() {
self.transition(RsyncState::Done);
info!("rsync ALL DONE: {} bytes written to {}",
self.total_written, self.dest_path.display());
info!(
"rsync ALL DONE: {} bytes written to {}",
self.total_written,
self.dest_path.display()
);
} else {
self.transition(RsyncState::ReadSumHead { need: 20 });
}
@@ -360,7 +404,9 @@ impl RsyncHandler {
self.vfs.create_dir_all(parent, 0o755).ok();
}
let flags = OpenFlags::new().write().create().truncate();
let file = self.vfs.open_file(&self.dest_path, &flags)
let file = self
.vfs
.open_file(&self.dest_path, &flags)
.map_err(|e| anyhow!("open error: {}", e))?;
self.output_file = Some(file);
info!("rsync: opened {} for writing", self.dest_path.display());
@@ -379,31 +425,43 @@ impl RsyncHandler {
/// Read rsync varint (LSB-first 7-bit groups, 0xFF prefix for negative)
fn read_varint(buf: &[u8]) -> Option<(i32, usize)> {
if buf.is_empty() { return None; }
if buf.is_empty() {
return None;
}
let mut pos = 0;
let mut b = buf[pos];
pos += 1;
let neg = if b == 0xFF {
if pos >= buf.len() { return None; }
if pos >= buf.len() {
return None;
}
b = buf[pos];
pos += 1;
true
} else { false };
} else {
false
};
let mut x = (b & 0x7F) as i32;
let mut shift = 7;
while b & 0x80 != 0 {
if pos >= buf.len() { return None; }
if pos >= buf.len() {
return None;
}
b = buf[pos];
pos += 1;
x |= ((b & 0x7F) as i32) << shift;
shift += 7;
}
if neg { Some((-x, pos)) } else { Some((x, pos)) }
if neg {
Some((-x, pos))
} else {
Some((x, pos))
}
}
#[cfg(test)]
@@ -419,8 +477,9 @@ mod tests {
fn test_parse_command() {
let h = RsyncHandler::parse_rsync_command(
"rsync --server -g -l -o -p -D -r -t -v --dirs . /tmp/upload.bin",
make_vfs()
).unwrap();
make_vfs(),
)
.unwrap();
assert_eq!(h.dest_path, PathBuf::from("/tmp/upload.bin"));
}
@@ -428,14 +487,16 @@ mod tests {
fn test_parse_command_sender() {
let h = RsyncHandler::parse_rsync_command(
"rsync --server --sender -vlogDtprz . /home/user/file.txt",
make_vfs()
).unwrap();
make_vfs(),
)
.unwrap();
assert_eq!(h.dest_path, PathBuf::from("/home/user/file.txt"));
}
#[test]
fn test_version_exchange() {
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap();
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs())
.unwrap();
let output = h.drain_output();
assert_eq!(output, b"\x1e\x00\x00\x00");
assert_eq!(h.state, RsyncState::WaitVersion);
@@ -447,7 +508,8 @@ mod tests {
#[test]
fn test_version_negotiate_down() {
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap();
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs())
.unwrap();
let _ = h.drain_output();
h.feed(b"\x1d\x00\x00\x00").unwrap();
assert_eq!(h.protocol_version, 29);
@@ -464,26 +526,33 @@ mod tests {
#[test]
fn test_file_list_multiplex() {
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap();
let mut h =
RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs())
.unwrap();
let _ = h.drain_output();
h.feed(b"\x1e\x00\x00\x00").unwrap();
assert!(h.multiplex);
let mut flist = Vec::new();
// File list: flags=1 (has name), then name with NUL terminator
flist.push(1); // flags: has name
flist.push(1); // flags: has name
flist.extend_from_slice(b"test.txt");
flist.push(0); // name terminator
flist.push(0); // name terminator
fn write_varint(buf: &mut Vec<u8>, val: i32) {
if val == 0 { buf.push(0); return; }
if val == 0 {
buf.push(0);
return;
}
if val < 0 {
buf.push(0xFF);
let mut v = (-val) as u32;
while v > 0 {
let mut byte = (v & 0x7F) as u8;
v >>= 7;
if v > 0 { byte |= 0x80; }
if v > 0 {
byte |= 0x80;
}
buf.push(byte);
}
} else {
@@ -491,7 +560,9 @@ mod tests {
while v > 0 {
let mut byte = (v & 0x7F) as u8;
v >>= 7;
if v > 0 { byte |= 0x80; }
if v > 0 {
byte |= 0x80;
}
buf.push(byte);
}
}
@@ -502,7 +573,7 @@ mod tests {
write_varint(&mut flist, 1700000000);
write_varint(&mut flist, 100);
write_varint(&mut flist, 0);
flist.push(0); // file list end marker
flist.push(0); // file list end marker
let mut sum_head = Vec::new();
sum_head.extend_from_slice(&0i32.to_le_bytes());
@@ -527,22 +598,51 @@ mod tests {
#[test]
fn test_file_data_multiplex() {
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap();
let mut h =
RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs())
.unwrap();
let _ = h.drain_output();
h.feed(b"\x1e\x00\x00\x00").unwrap();
let mut flist = Vec::new();
flist.push(1); // flags: has name
flist.push(1); // flags: has name
flist.extend_from_slice(b"test.bin");
flist.push(0);
fn wv(buf: &mut Vec<u8>, val: i32) {
if val == 0 { buf.push(0); return; }
if val < 0 { buf.push(0xFF); let mut v = (-val) as u32; while v > 0 { let mut byte = (v & 0x7F) as u8; v >>= 7; if v > 0 { byte |= 0x80; } buf.push(byte); } }
else { let mut v = val as u32; while v > 0 { let mut byte = (v & 0x7F) as u8; v >>= 7; if v > 0 { byte |= 0x80; } buf.push(byte); } }
if val == 0 {
buf.push(0);
return;
}
if val < 0 {
buf.push(0xFF);
let mut v = (-val) as u32;
while v > 0 {
let mut byte = (v & 0x7F) as u8;
v >>= 7;
if v > 0 {
byte |= 0x80;
}
buf.push(byte);
}
} else {
let mut v = val as u32;
while v > 0 {
let mut byte = (v & 0x7F) as u8;
v >>= 7;
if v > 0 {
byte |= 0x80;
}
buf.push(byte);
}
}
}
wv(&mut flist, 33188); wv(&mut flist, 501); wv(&mut flist, 20);
wv(&mut flist, 1700000000); wv(&mut flist, 100); wv(&mut flist, 0);
flist.push(0); // file list end
wv(&mut flist, 33188);
wv(&mut flist, 501);
wv(&mut flist, 20);
wv(&mut flist, 1700000000);
wv(&mut flist, 100);
wv(&mut flist, 0);
flist.push(0); // file list end
h.feed(&build_multiplex(&flist)).unwrap();
let mut sh = Vec::new();

View File

@@ -1,13 +1,12 @@
// SCP协议实现Phase 8
// 参考OpenSSH scp.c源码
use crate::vfs::{VfsBackend, VfsFile, VfsError, VfsStat};
use crate::vfs::open_flags::OpenFlags;
use anyhow::{Result, anyhow};
use log::{info, warn, debug};
use crate::vfs::{VfsBackend, VfsFile, VfsStat};
use anyhow::{anyhow, Result};
use log::{debug, info, warn};
use std::io::{BufRead, Read, Write};
use std::path::{Path, PathBuf};
use std::io::{Read, Write, BufRead};
use std::time::SystemTime;
/// SCP Handler参考OpenSSH scp.c
pub struct ScpHandler {
@@ -38,13 +37,13 @@ impl ScpHandler {
/// 解析SCP命令参考OpenSSH scp.c: parse_command()
pub fn parse_scp_command(command: &str, vfs: Box<dyn VfsBackend>) -> Result<Self> {
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.len() < 2 || parts[0] != "scp" {
return Err(anyhow!("Invalid SCP command: {}", command));
}
let mut handler = ScpHandler::new(PathBuf::from("/tmp"), vfs);
for part in &parts[1..] {
match part {
&"-f" => handler.mode = ScpMode::Source,
@@ -71,10 +70,15 @@ impl ScpHandler {
/// SCP Source Modescp -f发送文件
fn handle_source_mode(&self, channel: &mut dyn ReadWrite) -> Result<()> {
info!("SCP source mode: sending files from {}", self.root_dir.display());
info!(
"SCP source mode: sending files from {}",
self.root_dir.display()
);
let full_path = self.resolve_path(&self.root_dir.to_string_lossy())?;
let stat = self.vfs.stat(&full_path)
let stat = self
.vfs
.stat(&full_path)
.map_err(|e| anyhow!("stat error: {}", e))?;
if stat.is_dir {
@@ -91,16 +95,19 @@ impl ScpHandler {
/// SCP Destination Modescp -t接收文件
fn handle_destination_mode(&mut self, channel: &mut dyn ReadWrite) -> Result<()> {
info!("SCP destination mode: receiving files to {}", self.root_dir.display());
info!(
"SCP destination mode: receiving files to {}",
self.root_dir.display()
);
channel.write_all(&[0])?;
channel.flush()?;
let mut buffer = String::new();
loop {
buffer.clear();
let mut reader = std::io::BufReader::new(&mut *channel);
match reader.read_line(&mut buffer)? {
0 => break,
@@ -130,7 +137,9 @@ impl ScpHandler {
/// 发送文件参考OpenSSH scp.c: source()
fn send_file(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
let stat = self.vfs.stat(path)
let stat = self
.vfs
.stat(path)
.map_err(|e| anyhow!("stat error: {}", e))?;
let size = stat.size;
let filename = path.file_name().unwrap().to_string_lossy();
@@ -146,13 +155,16 @@ impl ScpHandler {
}
let flags = OpenFlags::new().read();
let mut file = self.vfs.open_file(path, &flags)
let mut file = self
.vfs
.open_file(path, &flags)
.map_err(|e| anyhow!("open error: {}", e))?;
let mut buffer = vec![0u8; 8192];
loop {
let n = file.read(&mut buffer)
let n = file
.read(&mut buffer)
.map_err(|e| anyhow!("read error: {}", e))?;
if n == 0 {
break;
@@ -188,7 +200,9 @@ impl ScpHandler {
return Err(anyhow!("SCP directory command rejected"));
}
let entries = self.vfs.read_dir(path)
let entries = self
.vfs
.read_dir(path)
.map_err(|e| anyhow!("read_dir error: {}", e))?;
for entry in &entries {
@@ -218,7 +232,7 @@ impl ScpHandler {
/// 处理文件命令C0644 size filename
fn handle_file_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.len() != 3 {
return self.send_error(channel, "Invalid file command format");
}
@@ -227,7 +241,10 @@ impl ScpHandler {
let size: u64 = parts[1].parse()?;
let filename = parts[2];
debug!("SCP receive file: mode={}, size={}, name={}", mode_str, size, filename);
debug!(
"SCP receive file: mode={}, size={}, name={}",
mode_str, size, filename
);
if size > 1024 * 1024 * 1024 {
return self.send_error(channel, "File too large (max 1GB)");
@@ -236,7 +253,9 @@ impl ScpHandler {
let full_path = self.resolve_path(filename)?;
let flags = OpenFlags::new().write().create().truncate();
let mut file = self.vfs.open_file(&full_path, &flags)
let mut file = self
.vfs
.open_file(&full_path, &flags)
.map_err(|e| anyhow!("open error: {}", e))?;
channel.write_all(&[0])?;
@@ -263,7 +282,8 @@ impl ScpHandler {
if mode_int != 0 {
let mut set_stat = VfsStat::new();
set_stat.mode = mode_int;
self.vfs.set_stat(&full_path, &set_stat)
self.vfs
.set_stat(&full_path, &set_stat)
.map_err(|e| anyhow!("set_stat error: {}", e))?;
}
@@ -280,7 +300,7 @@ impl ScpHandler {
/// 处理目录命令D0755 0 dirname
fn handle_directory_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.len() != 3 {
return self.send_error(channel, "Invalid directory command format");
}
@@ -297,7 +317,8 @@ impl ScpHandler {
let full_path = self.resolve_path(dirname)?;
let mode_int: u32 = mode_str.parse()?;
self.vfs.create_dir_all(&full_path, mode_int)
self.vfs
.create_dir_all(&full_path, mode_int)
.map_err(|e| anyhow!("create_dir_all error: {}", e))?;
channel.write_all(&[0])?;
@@ -326,7 +347,7 @@ impl ScpHandler {
}
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.len() != 3 {
return self.send_error(channel, "Invalid time command format");
}
@@ -353,11 +374,15 @@ impl ScpHandler {
/// 路径解析(安全性检查)
fn resolve_path(&self, path: &str) -> Result<PathBuf> {
let full_path = self.root_dir.join(path);
let canonical_path = self.vfs.real_path(&full_path)
let canonical_path = self
.vfs
.real_path(&full_path)
.map_err(|e| anyhow!("Path resolution error: {}", e))?;
let root_canonical = self.vfs.real_path(&self.root_dir)
let root_canonical = self
.vfs
.real_path(&self.root_dir)
.map_err(|e| anyhow!("Root path resolution error: {}", e))?;
if !canonical_path.starts_with(&root_canonical) {
@@ -383,20 +408,23 @@ mod tests {
#[test]
fn test_scp_command_parse() {
let handler = ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap();
let handler =
ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap();
assert_eq!(handler.mode, ScpMode::Destination);
assert_eq!(handler.root_dir, PathBuf::from("/tmp"));
}
#[test]
fn test_scp_recursive_parse() {
let handler = ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap();
let handler =
ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap();
assert!(handler.recursive);
}
#[test]
fn test_scp_source_parse() {
let handler = ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap();
let handler =
ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap();
assert_eq!(handler.mode, ScpMode::Source);
}
}

View File

@@ -1,32 +1,32 @@
// SSH服务器完整实现Phase 1-7集成版 + Phase 13端口转发
// 参考OpenSSH sshd.c: complete SSH/SFTP flow + port forwarding
use crate::ssh_server::version::VersionExchange;
use crate::ssh_server::packet::{SshPacket, PacketType};
use crate::ssh_server::kex::{KexResult, KexProposal};
use crate::ssh_server::kex_complete::{KexState};
use crate::ssh_server::auth::{AuthHandler, AuthResult};
use crate::provider::sqlite::SqliteProvider;
use crate::provider::pg::PgProvider;
use crate::provider::sqlite::SqliteProvider;
use crate::provider::DataProvider;
use crate::ssh_server::channel::{ChannelManager};
use crate::ssh_server::cipher::{EncryptionContext, EncryptedPacket};
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1
use crate::ssh_server::port_forward::PortForwardManager; // Phase 13
use anyhow::{Result, anyhow};
use log::{info, warn, error, debug};
use crate::ssh_server::auth::{AuthHandler, AuthResult};
use crate::ssh_server::channel::ChannelManager;
use crate::ssh_server::cipher::{EncryptedPacket, EncryptionContext};
use crate::ssh_server::kex::{KexProposal, KexResult};
use crate::ssh_server::kex_complete::KexState;
use crate::ssh_server::packet::{PacketType, SshPacket};
use crate::ssh_server::port_forward::PortForwardManager; // Phase 13
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1
use crate::ssh_server::version::VersionExchange;
use anyhow::{anyhow, Result};
use log::{error, info, warn};
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::path::PathBuf;
use std::thread;
use std::io::{Read, Write};
use std::sync::{Arc, Mutex}; // Phase 13: 端口转发线程同步
use std::sync::{Arc, Mutex};
use std::thread; // Phase 13: 端口转发线程同步
/// SSH服务器配置Phase 13.1企业级安全配置)
pub struct SshServerConfig {
pub port: u16,
pub bind_address: String,
pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置
pub pg_conn: Option<String>, // PostgreSQL连接字符串SFTPGo兼容认证
pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置
pub pg_conn: Option<String>, // PostgreSQL连接字符串SFTPGo兼容认证
}
impl Default for SshServerConfig {
@@ -34,7 +34,7 @@ impl Default for SshServerConfig {
Self {
port: 2024,
bind_address: "127.0.0.1".to_string(),
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1
pg_conn: None,
}
}
@@ -56,43 +56,48 @@ impl SshServerConfig {
/// SSH服务器主结构Phase 1-13完整版
pub struct SshServer {
config: SshServerConfig,
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1: 共享安全配置
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1: 共享安全配置
}
impl SshServer {
pub fn new(config: SshServerConfig) -> Self {
let security_config = Arc::new(Mutex::new(config.security_config.clone())); // Phase 13.1: 先clone
let security_config = Arc::new(Mutex::new(config.security_config.clone())); // Phase 13.1: 先clone
Self {
config,
security_config, // Phase 13.1
security_config, // Phase 13.1
}
}
pub fn run(&self) -> Result<()> {
let bind_addr = format!("{}:{}", self.config.bind_address, self.config.port);
let listener = TcpListener::bind(&bind_addr)?;
info!("MarkBaseSSH server listening on {}", bind_addr);
info!("Implementation: Complete SSH/SFTP + Port Forwarding (Phase 1-13)");
info!("Security config: GatewayPorts={}, PermitOpen={:?}, MaxSessions={}",
info!(
"Security config: GatewayPorts={}, PermitOpen={:?}, MaxSessions={}",
self.config.security_config.gateway_ports,
self.config.security_config.permit_open,
self.config.security_config.max_sessions);
let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置
self.config.security_config.max_sessions
);
let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置
let pg_conn = self.config.pg_conn.clone();
for stream in listener.incoming() {
match stream {
Ok(stream) => {
let client_addr = stream.peer_addr()?;
info!("New SSH connection from {}", client_addr);
let security_config_clone = security_config.clone(); // Phase 13.1
let security_config_clone = security_config.clone(); // Phase 13.1
let pg_conn_clone = pg_conn.clone();
thread::spawn(move || {
if let Err(e) = handle_connection_complete(stream, security_config_clone, pg_conn_clone) { // Phase 13.1
if let Err(e) =
handle_connection_complete(stream, security_config_clone, pg_conn_clone)
{
// Phase 13.1
error!("Connection error: {}", e);
}
});
@@ -102,90 +107,127 @@ impl SshServer {
}
}
}
Ok(())
}
}
/// 处理完整SSH连接Phase 1-13完整流程
fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshSecurityConfig>>, pg_conn: Option<String>) -> Result<()> {
fn handle_connection_complete(
stream: TcpStream,
security_config: Arc<Mutex<SshSecurityConfig>>,
pg_conn: Option<String>,
) -> Result<()> {
info!("Handling client connection (Phase 1-13 complete flow with port forwarding)");
// Phase 13.1: 增加活动会话数
{
let mut security = security_config.lock().unwrap();
security.increment_sessions()?;
}
let mut stream = stream;
// Phase 1: 版本交换
let client_version = VersionExchange::exchange(&mut stream)?;
info!("Version exchange: client={}, server=SSH-2.0-MarkBaseSSH_1.0", client_version);
info!(
"Version exchange: client={}, server=SSH-2.0-MarkBaseSSH_1.0",
client_version
);
// Phase 2: 箋法协商
let (kex_result, server_kexinit, client_kexinit) = perform_kex_negotiation_complete(&mut stream)?;
info!("KEX negotiation: KEX={}, Cipher={}", kex_result.kex_algorithm, kex_result.encryption_ctos);
let (kex_result, server_kexinit, client_kexinit) =
perform_kex_negotiation_complete(&mut stream)?;
info!(
"KEX negotiation: KEX={}, Cipher={}",
kex_result.kex_algorithm, kex_result.encryption_ctos
);
// Phase 3: 密钥交换完整流程
let mut encryption_ctx = perform_complete_kex_exchange(&mut stream, client_version.clone(), kex_result, server_kexinit, client_kexinit)?;
let mut encryption_ctx = perform_complete_kex_exchange(
&mut stream,
client_version.clone(),
kex_result,
server_kexinit,
client_kexinit,
)?;
info!("Key exchange completed, encryption channel ready");
// Phase 5: SSH认证SFTPGo兼容 — PostgreSQL或SQLite
let provider: Box<dyn DataProvider> = if let Some(ref conn_str) = pg_conn {
info!("Using PostgreSQL auth provider (SFTPGo-compatible): {}", conn_str);
Box::new(PgProvider::new(conn_str)
.map_err(|e| anyhow!("Failed to init PgProvider: {}", e))?)
info!(
"Using PostgreSQL auth provider (SFTPGo-compatible): {}",
conn_str
);
Box::new(
PgProvider::new(conn_str).map_err(|e| anyhow!("Failed to init PgProvider: {}", e))?,
)
} else {
info!("Using SQLite auth provider");
Box::new(SqliteProvider::new("data/auth.sqlite")
.map_err(|e| anyhow!("Failed to init SqliteProvider: {}", e))?)
Box::new(
SqliteProvider::new("data/auth.sqlite")
.map_err(|e| anyhow!("Failed to init SqliteProvider: {}", e))?,
)
};
let mut auth_handler = AuthHandler::new(provider);
let auth_user = perform_ssh_auth(&mut stream, &mut auth_handler, &mut encryption_ctx)?;
info!("SSH authentication succeeded: user={}", auth_user.username);
// Phase 6: SSH Channel管理参考OpenSSH channel.c
let mut channel_manager = ChannelManager::new(auth_user.home_dir.clone());
// Phase 13: PortForwardManager初始化
let mut port_forward_manager = PortForwardManager::new();
// Phase 6-13: SSH服务循环处理channel请求 + 端口转发)
let security_config_clone = security_config.clone(); // Phase 13.1: clone for service loop
handle_ssh_service_loop(&mut stream, &mut channel_manager, &mut encryption_ctx, &mut port_forward_manager, security_config_clone)?;
let security_config_clone = security_config.clone(); // Phase 13.1: clone for service loop
handle_ssh_service_loop(
&mut stream,
&mut channel_manager,
&mut encryption_ctx,
&mut port_forward_manager,
security_config_clone,
)?;
info!("SSH session completed successfully");
// Phase 13.1: 减少活动会话数
{
let mut security = security_config.lock().unwrap();
security.decrement_sessions();
}
Ok(())
}
/// 完整算法协商返回KEXINIT payloads
fn perform_kex_negotiation_complete(stream: &mut TcpStream) -> Result<(KexResult, SshPacket, SshPacket)> {
fn perform_kex_negotiation_complete(
stream: &mut TcpStream,
) -> Result<(KexResult, SshPacket, SshPacket)> {
info!("Starting complete KEX negotiation");
// 1. 发送服务器KEXINIT
let server_proposal = KexProposal::server_default();
let server_kexinit = server_proposal.to_kexinit_packet()?;
server_kexinit.write(stream)?;
info!("Sent server KEXINIT (payload size: {} bytes)", server_kexinit.payload.len());
info!(
"Sent server KEXINIT (payload size: {} bytes)",
server_kexinit.payload.len()
);
// 2. 接收客户端KEXINIT
let client_kexinit = SshPacket::read(stream)?;
let client_proposal = KexProposal::from_kexinit_packet(&client_kexinit)?;
info!("Received client KEXINIT (payload size: {} bytes)", client_kexinit.payload.len());
info!(
"Received client KEXINIT (payload size: {} bytes)",
client_kexinit.payload.len()
);
// 3. 算法匹配
let kex_result = KexResult::choose_algorithms(&server_proposal, &client_proposal)?;
Ok((kex_result, server_kexinit, client_kexinit))
}
@@ -198,18 +240,18 @@ fn perform_complete_kex_exchange(
client_kexinit: SshPacket,
) -> Result<EncryptionContext> {
info!("Starting complete key exchange flow");
let mut kex_state = KexState::new(
client_version,
"SSH-2.0-MarkBaseSSH_1.0".to_string(),
kex_result,
)?;
kex_state.save_kexinit_payloads(&client_kexinit, &server_kexinit);
let kexdh_init = SshPacket::read(stream)?;
info!("Received SSH_MSG_KEX_ECDH_INIT");
let kexdh_reply = kex_state.exchange_handler.handle_kexdh_init(
&kexdh_init,
&kex_state.client_version,
@@ -219,27 +261,27 @@ fn perform_complete_kex_exchange(
)?;
kexdh_reply.write(stream)?;
info!("Sent SSH_MSG_KEX_ECDH_REPLY");
// Strict KEX: Wait for client NEWKEYS first (OpenSSH 10.2 requirement)
let client_newkeys = SshPacket::read(stream)?;
kex_state.handle_newkeys(&client_newkeys)?;
info!("Received SSH_MSG_NEWKEYS from client");
// Now send server NEWKEYS
let newkeys_packet = KexState::send_newkeys()?;
newkeys_packet.write(stream)?;
kex_state.newkeys_sent = true;
info!("Sent SSH_MSG_NEWKEYS from server");
if kex_state.is_encryption_ready() {
info!("Encryption channel established successfully");
} else {
return Err(anyhow::anyhow!("Encryption channel not ready"));
}
let session_keys = kex_state.exchange_handler.compute_session_keys()?;
let encryption_ctx = EncryptionContext::from_session_keys(&session_keys);
Ok(encryption_ctx)
}
@@ -250,102 +292,100 @@ pub struct AuthUser {
}
fn perform_ssh_auth(
stream: &mut TcpStream,
stream: &mut TcpStream,
auth_handler: &mut AuthHandler,
encryption_ctx: &mut EncryptionContext,
) -> Result<AuthUser> {
info!("Starting SSH authentication");
info!("Encryption context: key_ctos_len={}, key_stoc_len={}, iv_ctos_len={}, iv_stoc_len={}",
info!(
"Encryption context: key_ctos_len={}, key_stoc_len={}, iv_ctos_len={}, iv_stoc_len={}",
encryption_ctx.encryption_key_ctos.len(),
encryption_ctx.encryption_key_stoc.len(),
encryption_ctx.iv_ctos.len(),
encryption_ctx.iv_stoc.len()
);
// OpenSSH strict KEX: SSH_MSG_EXT_INFO may be sent before SSH_MSG_SERVICE_REQUEST
let mut encrypted_request = EncryptedPacket::read(stream, encryption_ctx, true)?;
let payload = encrypted_request.payload();
if payload[0] == PacketType::SSH_MSG_EXT_INFO as u8 {
info!("Received SSH_MSG_EXT_INFO, reading next packet");
encrypted_request = EncryptedPacket::read(stream, encryption_ctx, true)?;
}
let payload = encrypted_request.payload();
info!("Received packet type: {}", payload[0]);
if payload[0] != PacketType::SSH_MSG_SERVICE_REQUEST as u8 {
return Err(anyhow!("Expected SSH_MSG_SERVICE_REQUEST, got type {}", payload[0]));
return Err(anyhow!(
"Expected SSH_MSG_SERVICE_REQUEST, got type {}",
payload[0]
));
}
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
let mut cursor = std::io::Cursor::new(&payload[1..]);
let service_name_len = cursor.read_u32::<BigEndian>()?;
let mut service_name = vec![0u8; service_name_len as usize];
cursor.read_exact(&mut service_name)?;
let service_name_str = String::from_utf8_lossy(&service_name);
if service_name_str != "ssh-userauth" {
return Err(anyhow!("Unsupported service: {}", service_name_str));
}
let mut service_accept_payload = Vec::new();
service_accept_payload.write_u8(PacketType::SSH_MSG_SERVICE_ACCEPT as u8)?;
service_accept_payload.write_u32::<BigEndian>(12)?; // "ssh-userauth" length is 12, not 14!
service_accept_payload.write_u32::<BigEndian>(12)?; // "ssh-userauth" length is 12, not 14!
service_accept_payload.write_all("ssh-userauth".as_bytes())?;
let encrypted_accept = EncryptedPacket::new(
&service_accept_payload,
encryption_ctx,
true,
)?;
let encrypted_accept = EncryptedPacket::new(&service_accept_payload, encryption_ctx, true)?;
encrypted_accept.write(stream)?;
info!("Sent encrypted SSH_MSG_SERVICE_ACCEPT");
let session_id = encryption_ctx.session_id.clone();
loop {
let auth_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; // Reading from client, use cipher_ctos
let auth_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; // Reading from client, use cipher_ctos
let auth_payload = auth_packet.payload();
info!("Received encrypted SSH_MSG_USERAUTH_REQUEST");
let auth_request = SshPacket::new(auth_payload.to_vec());
match auth_handler.handle_userauth_request(&auth_request, &session_id)? {
AuthResult::Success => {
let success_payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8];
let encrypted_success = EncryptedPacket::new(
&success_payload,
encryption_ctx,
true,
)?;
let encrypted_success =
EncryptedPacket::new(&success_payload, encryption_ctx, true)?;
encrypted_success.write(stream)?;
info!("Sent encrypted SSH_MSG_USERAUTH_SUCCESS");
// Extract username from auth request
let user = extract_username_from_auth_request(&auth_request)
.unwrap_or_else(|_| "unknown".to_string());
let home_dir = auth_handler.get_home_dir(&user)
let home_dir = auth_handler
.get_home_dir(&user)
.ok()
.flatten()
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from("/Users/accusys/markbase"));
info!("Auth success: user={}, home_dir={:?}", user, home_dir);
return Ok(AuthUser { username: user, home_dir });
return Ok(AuthUser {
username: user,
home_dir,
});
}
AuthResult::Failure(message) => {
AuthResult::Failure(message) => {
// message包含可用的认证方法列表如"password,publickey"
let mut failure_payload = Vec::new();
failure_payload.write_u8(PacketType::SSH_MSG_USERAUTH_FAILURE as u8)?;
failure_payload.write_u32::<BigEndian>(message.len() as u32)?;
failure_payload.write_all(message.as_bytes())?;
failure_payload.write_u8(0)?; // partial_success = false
let encrypted_failure = EncryptedPacket::new(
&failure_payload,
encryption_ctx,
true,
)?;
failure_payload.write_u8(0)?; // partial_success = false
let encrypted_failure =
EncryptedPacket::new(&failure_payload, encryption_ctx, true)?;
encrypted_failure.write(stream)?;
warn!("Sent encrypted SSH_MSG_USERAUTH_FAILURE: {}", message);
}
@@ -356,27 +396,23 @@ AuthResult::Failure(message) => {
AuthResult::PublicKeyOk(algorithm, public_key_blob) => {
// SSH_MSG_USERAUTH_PK_OKpublic key acceptable
info!("Public key acceptable, sending USERAUTH_PK_OK");
let mut pk_ok_payload = Vec::new();
pk_ok_payload.write_u8(PacketType::SSH_MSG_USERAUTH_PK_OK as u8)?;
// algorithm (SSH string)
pk_ok_payload.write_u32::<BigEndian>(algorithm.len() as u32)?;
pk_ok_payload.write_all(algorithm.as_bytes())?;
// public key blob (SSH string)
pk_ok_payload.write_u32::<BigEndian>(public_key_blob.len() as u32)?;
pk_ok_payload.write_all(&public_key_blob)?;
let encrypted_pk_ok = EncryptedPacket::new(
&pk_ok_payload,
encryption_ctx,
true,
)?;
let encrypted_pk_ok = EncryptedPacket::new(&pk_ok_payload, encryption_ctx, true)?;
encrypted_pk_ok.write(stream)?;
info!("Sent SSH_MSG_USERAUTH_PK_OK");
continue; // Wait for signed request
continue; // Wait for signed request
}
}
}
@@ -389,16 +425,17 @@ fn handle_ssh_service_loop(
stream: &mut TcpStream,
channel_manager: &mut ChannelManager,
encryption_ctx: &mut EncryptionContext,
port_forward_manager: &mut PortForwardManager, // Phase 13
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1
port_forward_manager: &mut PortForwardManager, // Phase 13
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1
) -> Result<()> {
info!("Starting SSH service loop (Phase 14.2: unified poll + child status)");
loop {
// ⭐⭐⭐⭐⭐ Phase 14.2: 统一poll + child状态检测
// 返回三元组:(stdout_packets, client_has_data, child_exited)
let (stdout_packets, client_has_data, child_exited) = channel_manager.poll_exec_stdout_and_client(stream)?;
let (stdout_packets, client_has_data, child_exited) =
channel_manager.poll_exec_stdout_and_client(stream)?;
// 1. 发送stdout/stderr数据如果有
if let Some(packets) = stdout_packets {
for packet in packets {
@@ -407,93 +444,100 @@ fn handle_ssh_service_loop(
info!("Sent stdout/stderr data (Phase 14.2)");
}
}
// 2. 处理child exited发送EOF + CLOSE
if child_exited {
info!("Child process exited, sending SSH_MSG_CHANNEL_EOF + CLOSE");
// ⭐⭐⭐⭐⭐ Phase 14.2: 使用ChannelManager.handle_child_exited()
let exit_packets = channel_manager.handle_child_exited()?;
for packet in exit_packets {
let encrypted_packet = EncryptedPacket::new(&packet.payload, encryption_ctx, true)?;
encrypted_packet.write(stream)?;
}
// 继续处理client数据可能还有其他请求
}
// 3. 处理client数据如果有
if !client_has_data {
// client没有数据继续下一轮循环
continue;
}
// client有数据读取并处理
let encrypted_packet = EncryptedPacket::read(stream, encryption_ctx, true)?;
let packet = SshPacket::new(encrypted_packet.payload().to_vec());
match packet.payload.first() {
// Phase 13: SSH_MSG_GLOBAL_REQUEST处理端口转发
Some(&pt) if pt == PacketType::SSH_MSG_GLOBAL_REQUEST as u8 => {
info!("Received SSH_MSG_GLOBAL_REQUEST (port forwarding)");
// Phase 13.1: 安全配置验证
let security = security_config.lock().unwrap();
if !security.allow_tcp_forwarding {
warn!("TCP forwarding disabled by security config");
let failure_packet = vec![PacketType::SSH_MSG_REQUEST_FAILURE as u8];
let encrypted_failure = EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
let encrypted_failure =
EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
encrypted_failure.write(stream)?;
info!("Sent SSH_MSG_REQUEST_FAILURE (TCP forwarding disabled)");
continue;
}
// Phase 13.2: 调用PortForwardManager处理传递security_config
let (success, response) = port_forward_manager.handle_global_request(&packet.payload, &security)?;
drop(security); // 释放锁
let (success, response) =
port_forward_manager.handle_global_request(&packet.payload, &security)?;
drop(security); // 释放锁
if success {
if let Some(response_data) = response {
let encrypted_response = EncryptedPacket::new(&response_data, encryption_ctx, true)?;
let encrypted_response =
EncryptedPacket::new(&response_data, encryption_ctx, true)?;
encrypted_response.write(stream)?;
info!("Sent SSH_MSG_REQUEST_SUCCESS (tcpip-forward accepted)");
} else {
// 无响应数据时发送简单的SUCCESS
let success_packet = vec![PacketType::SSH_MSG_REQUEST_SUCCESS as u8];
let encrypted_success = EncryptedPacket::new(&success_packet, encryption_ctx, true)?;
let encrypted_success =
EncryptedPacket::new(&success_packet, encryption_ctx, true)?;
encrypted_success.write(stream)?;
info!("Sent SSH_MSG_REQUEST_SUCCESS");
}
} else {
let failure_packet = vec![PacketType::SSH_MSG_REQUEST_FAILURE as u8];
let encrypted_failure = EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
let encrypted_failure =
EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
encrypted_failure.write(stream)?;
info!("Sent SSH_MSG_REQUEST_FAILURE (tcpip-forward rejected)");
}
}
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_OPEN as u8 => {
info!("Received SSH_MSG_CHANNEL_OPEN");
// Phase 13.3: 获取security_config并传递给handle_channel_open
let security = security_config.lock().unwrap();
let response = channel_manager.handle_channel_open(&packet, Some(&security))?;
drop(security); // 释放锁
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
drop(security); // 释放锁
let encrypted_response =
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
encrypted_response.write(stream)?;
info!("Sent SSH_MSG_CHANNEL_OPEN_CONFIRMATION");
}
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_REQUEST as u8 => {
info!("Received SSH_MSG_CHANNEL_REQUEST");
if let Some(response) = channel_manager.handle_channel_request(&packet)? {
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
let encrypted_response =
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
encrypted_response.write(stream)?;
// ⭐⭐⭐⭐⭐ Phase 14.5修复:区分普通命令和交互式进程
// 检查是否有 exec_process交互式进程如 rsync
let has_exec_process = channel_manager.has_exec_process();
if has_exec_process {
info!("⭐⭐⭐⭐⭐ [INTERACTIVE_PROCESS] Detected exec_process (rsync/SCP), skipping immediate EOF");
// 对于交互式进程,只发送 SUCCESS等待 poll 循环处理数据流
@@ -503,23 +547,37 @@ fn handle_ssh_service_loop(
if let Some(channel_id) = channel_manager.get_channel_with_output() {
if let Some(output) = channel_manager.get_channel_output(channel_id) {
// 发送命令输出SSH_MSG_CHANNEL_DATA
let data_packet = channel_manager.build_channel_data(channel_id, &output)?;
let encrypted_data = EncryptedPacket::new(&data_packet.payload, encryption_ctx, true)?;
let data_packet =
channel_manager.build_channel_data(channel_id, &output)?;
let encrypted_data = EncryptedPacket::new(
&data_packet.payload,
encryption_ctx,
true,
)?;
encrypted_data.write(stream)?;
info!("Sent command output ({} bytes)", output.len());
// 发送SSH_MSG_CHANNEL_EOF
let eof_packet = channel_manager.build_channel_eof(channel_id)?;
let encrypted_eof = EncryptedPacket::new(&eof_packet.payload, encryption_ctx, true)?;
let encrypted_eof = EncryptedPacket::new(
&eof_packet.payload,
encryption_ctx,
true,
)?;
encrypted_eof.write(stream)?;
info!("Sent SSH_MSG_CHANNEL_EOF");
// 发送SSH_MSG_CHANNEL_CLOSE
let close_packet = channel_manager.build_channel_close(channel_id)?;
let encrypted_close = EncryptedPacket::new(&close_packet.payload, encryption_ctx, true)?;
let close_packet =
channel_manager.build_channel_close(channel_id)?;
let encrypted_close = EncryptedPacket::new(
&close_packet.payload,
encryption_ctx,
true,
)?;
encrypted_close.write(stream)?;
info!("Sent SSH_MSG_CHANNEL_CLOSE");
// 移除channel
channel_manager.remove_channel(channel_id);
}
@@ -531,22 +589,28 @@ fn handle_ssh_service_loop(
info!("Received SSH_MSG_CHANNEL_DATA");
if let Some(response) = channel_manager.handle_channel_data(&packet)? {
// Phase 7: SFTP响应通过CHANNEL_DATA返回
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
let encrypted_response =
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
encrypted_response.write(stream)?;
info!("Sent SSH_MSG_CHANNEL_DATA (SFTP response)");
}
// ⭐⭐⭐⭐⭐ Phase 15.1: Drain pending packets (e.g. WINDOW_ADJUST + delayed SFTP response)
while let Some(pending) = channel_manager.pending_packets.pop_front() {
let encrypted_pending = EncryptedPacket::new(&pending.payload, encryption_ctx, true)?;
let encrypted_pending =
EncryptedPacket::new(&pending.payload, encryption_ctx, true)?;
encrypted_pending.write(stream)?;
info!("Sent pending packet (type {})", pending.payload.first().unwrap_or(&0));
info!(
"Sent pending packet (type {})",
pending.payload.first().unwrap_or(&0)
);
}
}
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_CLOSE as u8 => {
info!("Received SSH_MSG_CHANNEL_CLOSE");
if let Some(response) = channel_manager.handle_channel_close(&packet)? {
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
let encrypted_response =
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
encrypted_response.write(stream)?;
}
break;
@@ -565,8 +629,10 @@ fn handle_ssh_service_loop(
let payload = &packet.payload;
if payload.len() >= 9 {
// Format: uint32 recipient_channel || uint32 bytes_to_add
let recipient_channel = u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]);
let bytes_to_add = u32::from_be_bytes([payload[5], payload[6], payload[7], payload[8]]);
let recipient_channel =
u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]);
let bytes_to_add =
u32::from_be_bytes([payload[5], payload[6], payload[7], payload[8]]);
channel_manager.adjust_remote_window(recipient_channel, bytes_to_add);
}
}
@@ -575,12 +641,14 @@ fn handle_ssh_service_loop(
}
}
}
Ok(())
}
/// 从SSH_MSG_USERAUTH_REQUEST payload中提取用户名
fn extract_username_from_auth_request(packet: &crate::ssh_server::packet::SshPacket) -> Result<String> {
fn extract_username_from_auth_request(
packet: &crate::ssh_server::packet::SshPacket,
) -> Result<String> {
let payload = &packet.payload;
if payload.len() < 5 {
return Err(anyhow!("Auth request too short"));
@@ -598,10 +666,10 @@ pub fn run_ssh_server(port: Option<u16>, pg_conn: Option<&str>) -> Result<()> {
let config = SshServerConfig {
port: port.unwrap_or(2024),
bind_address: "127.0.0.1".to_string(),
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置
pg_conn: pg_conn.map(|s| s.to_string()),
};
let server = SshServer::new(config);
server.run()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,7 @@
// SSH企业级安全配置Phase 13.1
// 参考OpenSSH sshd_config安全配置
use anyhow::{Result, anyhow};
use anyhow::{anyhow, Result};
use log::{info, warn};
use std::fs;
use std::path::Path;
@@ -14,25 +14,25 @@ pub struct SshSecurityConfig {
/// false: 只绑定127.0.0.1(安全)
/// true: 允许绑定0.0.0.0(危险)
pub gateway_ports: bool,
/// PermitOpen白名单
/// ["localhost:3000", "localhost:4000", "localhost:*"]
/// 空数组表示允许所有目标(不安全)
pub permit_open: Vec<String>,
/// AllowTcpForwarding配置
/// true: 允许TCP转发
/// false: 禁止所有TCP转发
pub allow_tcp_forwarding: bool,
/// MaxSessions限制
/// 最大会话数,防止资源耗尽
pub max_sessions: u32,
/// ConnectTimeout超时
/// 连接超时设置,防止悬挂连接
pub connect_timeout: u64,
/// 活动会话数(运行时状态)
pub active_sessions: u32,
}
@@ -42,110 +42,125 @@ impl SshSecurityConfig {
/// 参考OpenSSH企业级生产环境配置
pub fn enterprise_default() -> Self {
Self {
gateway_ports: false, // 安全只绑定127.0.0.1
permit_open: vec!["localhost:*".to_string()], // 限制转发目标(白名单)
allow_tcp_forwarding: true, // 允许TCP转发
max_sessions: 10, // 最多10个会话
connect_timeout: 30, // 30秒超时
active_sessions: 0, // 运行时状态
gateway_ports: false, // 安全只绑定127.0.0.1
permit_open: vec!["localhost:*".to_string()], // 限制转发目标(白名单)
allow_tcp_forwarding: true, // 允许TCP转发
max_sessions: 10, // 最多10个会话
connect_timeout: 30, // 30秒超时
active_sessions: 0, // 运行时状态
}
}
/// 开发环境默认配置(宽松)
pub fn development_default() -> Self {
Self {
gateway_ports: true, // 开发允许0.0.0.0
permit_open: vec![], // 开发:允许所有目标
gateway_ports: true, // 开发允许0.0.0.0
permit_open: vec![], // 开发:允许所有目标
allow_tcp_forwarding: true,
max_sessions: 20, // 开发:更多会话
connect_timeout: 60, // 开发:更长超时
max_sessions: 20, // 开发:更多会话
connect_timeout: 60, // 开发:更长超时
active_sessions: 0,
}
}
/// 从JSON配置文件加载
pub fn load_from_file(path: &str) -> Result<Self> {
if !Path::new(path).exists() {
info!("SSH security config file not found, using enterprise default");
return Ok(Self::enterprise_default());
}
let config_str = fs::read_to_string(path)?;
let config: serde_json::Value = serde_json::from_str(&config_str)?;
let security = config.get("ssh_server")
let security = config
.get("ssh_server")
.and_then(|s| s.get("security"))
.ok_or_else(|| anyhow!("Invalid config structure"))?;
Ok(Self {
gateway_ports: security.get("gateway_ports")
gateway_ports: security
.get("gateway_ports")
.and_then(|v| v.as_bool())
.unwrap_or(false),
permit_open: security.get("permit_open")
permit_open: security
.get("permit_open")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_else(|| vec!["localhost:*".to_string()]),
allow_tcp_forwarding: security.get("allow_tcp_forwarding")
allow_tcp_forwarding: security
.get("allow_tcp_forwarding")
.and_then(|v| v.as_bool())
.unwrap_or(true),
max_sessions: security.get("max_sessions")
max_sessions: security
.get("max_sessions")
.and_then(|v| v.as_u64())
.map(|v| v as u32)
.unwrap_or(10),
connect_timeout: security.get("connect_timeout")
connect_timeout: security
.get("connect_timeout")
.and_then(|v| v.as_u64())
.unwrap_or(30),
active_sessions: 0,
})
}
/// 验证tcpip-forward请求安全检查
/// 参考OpenSSH auth2.c: ssh_forwarding_check()
pub fn validate_tcpip_forward_request(
&self,
bind_address: &str,
bind_port: u32,
) -> Result<()> {
info!("Validating tcpip-forward request: bind_address={}, bind_port={}", bind_address, bind_port);
pub fn validate_tcpip_forward_request(&self, bind_address: &str, bind_port: u32) -> Result<()> {
info!(
"Validating tcpip-forward request: bind_address={}, bind_port={}",
bind_address, bind_port
);
// 1. AllowTcpForwarding检查
if !self.allow_tcp_forwarding {
warn!("TCP forwarding disabled by security config");
return Err(anyhow!("TCP forwarding disabled by AllowTcpForwarding=no"));
}
// 2. GatewayPorts检查
if !self.gateway_ports {
// 只允许绑定到127.0.0.1或localhost
if bind_address != "127.0.0.1" && bind_address != "localhost" && bind_address != "" {
warn!("GatewayPorts disabled, bind_address {} not allowed", bind_address);
if bind_address != "127.0.0.1" && bind_address != "localhost" && !bind_address.is_empty() {
warn!(
"GatewayPorts disabled, bind_address {} not allowed",
bind_address
);
return Err(anyhow!("GatewayPorts=no, only 127.0.0.1 allowed"));
}
info!("GatewayPorts check passed: bind_address={}", bind_address);
}
// 3. MaxSessions检查
if self.active_sessions >= self.max_sessions {
warn!("Max sessions limit reached: {} >= {}", self.active_sessions, self.max_sessions);
warn!(
"Max sessions limit reached: {} >= {}",
self.active_sessions, self.max_sessions
);
return Err(anyhow!("Max sessions limit reached: {}", self.max_sessions));
}
// 4. 特权端口检查(防止<1024
if bind_port < 1024 {
warn!("Cannot bind to privileged port: {}", bind_port);
return Err(anyhow!("Cannot bind to privileged port < 1024"));
}
// 5. 端口范围检查(防止过大端口)
if bind_port > 65535 {
warn!("Invalid port number: {}", bind_port);
return Err(anyhow!("Invalid port number > 65535"));
}
info!("tcpip-forward request validated successfully");
Ok(())
}
/// 验证direct-tcpip channel请求安全检查
/// 参考OpenSSH channels.c: channel_connect_direct_tcpip()
pub fn validate_direct_tcpip_channel(
@@ -153,14 +168,17 @@ impl SshSecurityConfig {
host_to_connect: &str,
port_to_connect: u32,
) -> Result<()> {
info!("Validating direct-tcpip channel: host={}, port={}", host_to_connect, port_to_connect);
info!(
"Validating direct-tcpip channel: host={}, port={}",
host_to_connect, port_to_connect
);
// 1. AllowTcpForwarding检查
if !self.allow_tcp_forwarding {
warn!("TCP forwarding disabled by security config");
return Err(anyhow!("TCP forwarding disabled by AllowTcpForwarding=no"));
}
// 2. PermitOpen白名单检查
if !self.permit_open.is_empty() {
let target = format!("{}:{}", host_to_connect, port_to_connect);
@@ -173,28 +191,34 @@ impl SshSecurityConfig {
target == *pattern
}
});
if !allowed {
warn!("Target {}:{} not in PermitOpen whitelist", host_to_connect, port_to_connect);
return Err(anyhow!("Target {}:{} not in PermitOpen whitelist",
host_to_connect, port_to_connect));
warn!(
"Target {}:{} not in PermitOpen whitelist",
host_to_connect, port_to_connect
);
return Err(anyhow!(
"Target {}:{} not in PermitOpen whitelist",
host_to_connect,
port_to_connect
));
}
info!("PermitOpen check passed: target={}", target);
} else {
// permit_open为空允许所有目标不安全仅用于开发
info!("PermitOpen whitelist empty, allowing all targets (development mode)");
}
// 3. 端口范围检查
if port_to_connect < 1 || port_to_connect > 65535 {
if !(1..=65535).contains(&port_to_connect) {
warn!("Invalid port number: {}", port_to_connect);
return Err(anyhow!("Invalid port number: {}", port_to_connect));
}
info!("direct-tcpip channel validated successfully");
Ok(())
}
/// 增加活动会话数
pub fn increment_sessions(&mut self) -> Result<()> {
if self.active_sessions >= self.max_sessions {
@@ -204,7 +228,7 @@ impl SshSecurityConfig {
info!("Active sessions: {}", self.active_sessions);
Ok(())
}
/// 减少活动会话数
pub fn decrement_sessions(&mut self) {
if self.active_sessions > 0 {
@@ -217,56 +241,76 @@ impl SshSecurityConfig {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enterprise_default_config() {
let config = SshSecurityConfig::enterprise_default();
assert_eq!(config.gateway_ports, false);
assert_eq!(config.permit_open, vec!["localhost:*".to_string()]);
assert_eq!(config.allow_tcp_forwarding, true);
assert_eq!(config.max_sessions, 10);
assert_eq!(config.connect_timeout, 30);
}
#[test]
fn test_validate_tcpip_forward_request() {
let config = SshSecurityConfig::enterprise_default();
// 正常请求应该通过
assert!(config.validate_tcpip_forward_request("127.0.0.1", 8080).is_ok());
assert!(config.validate_tcpip_forward_request("localhost", 8080).is_ok());
assert!(config
.validate_tcpip_forward_request("127.0.0.1", 8080)
.is_ok());
assert!(config
.validate_tcpip_forward_request("localhost", 8080)
.is_ok());
// GatewayPorts=false时0.0.0.0应该被拒绝
assert!(config.validate_tcpip_forward_request("0.0.0.0", 8080).is_err());
assert!(config
.validate_tcpip_forward_request("0.0.0.0", 8080)
.is_err());
// 特权端口应该被拒绝
assert!(config.validate_tcpip_forward_request("127.0.0.1", 80).is_err());
assert!(config
.validate_tcpip_forward_request("127.0.0.1", 80)
.is_err());
}
#[test]
fn test_validate_direct_tcpip_channel() {
let config = SshSecurityConfig::enterprise_default();
// localhost:*应该通过(通配符匹配)
assert!(config.validate_direct_tcpip_channel("localhost", 3000).is_ok());
assert!(config.validate_direct_tcpip_channel("localhost", 4000).is_ok());
assert!(config
.validate_direct_tcpip_channel("localhost", 3000)
.is_ok());
assert!(config
.validate_direct_tcpip_channel("localhost", 4000)
.is_ok());
// 其他host应该被拒绝
assert!(config.validate_direct_tcpip_channel("192.168.1.100", 3000).is_err());
assert!(config.validate_direct_tcpip_channel("example.com", 80).is_err());
assert!(config
.validate_direct_tcpip_channel("192.168.1.100", 3000)
.is_err());
assert!(config
.validate_direct_tcpip_channel("example.com", 80)
.is_err());
}
#[test]
fn test_development_default_config() {
let config = SshSecurityConfig::development_default();
assert_eq!(config.gateway_ports, true);
assert_eq!(config.permit_open.len(), 0); // 空数组表示允许所有
assert_eq!(config.max_sessions, 20);
// 开发配置应该允许所有请求
assert!(config.validate_tcpip_forward_request("0.0.0.0", 8080).is_ok());
assert!(config.validate_direct_tcpip_channel("example.com", 80).is_ok());
assert!(config
.validate_tcpip_forward_request("0.0.0.0", 8080)
.is_ok());
assert!(config
.validate_direct_tcpip_channel("example.com", 80)
.is_ok());
}
}

View File

@@ -1,11 +1,11 @@
// SSH Buffer 零拷贝实现(参考 OpenSSH sshbuf.c
// 提供高效的 buffer 管理,消除临时 buffer
use anyhow::{Result, anyhow};
use anyhow::{anyhow, Result};
use std::io::{Read, Write};
/// SSH Buffer参考 OpenSSH struct sshbuf
///
///
/// OpenSSH 实现:
/// ```c
/// struct sshbuf {
@@ -16,10 +16,10 @@ use std::io::{Read, Write};
/// };
/// ```
pub struct SshBuf {
data: Vec<u8>, // Data buffer (对应 OpenSSH buf->d)
off: usize, // Offset (对应 OpenSSH buf->off)
size: usize, // Size (对应 OpenSSH buf->size)
max_size: usize, // Maximum size (对应 OpenSSH buf->max_size)
data: Vec<u8>, // Data buffer (对应 OpenSSH buf->d)
off: usize, // Offset (对应 OpenSSH buf->off)
size: usize, // Size (对应 OpenSSH buf->size)
max_size: usize, // Maximum size (对应 OpenSSH buf->max_size)
}
impl SshBuf {
@@ -32,7 +32,7 @@ impl SshBuf {
max_size: 128 * 1024 * 1024, // 128MB (OpenSSH SSHBUF_SIZE_MAX)
}
}
/// 创建指定大小的 SSH Buffer
pub fn with_capacity(capacity: usize) -> Self {
Self {
@@ -42,7 +42,7 @@ impl SshBuf {
max_size: 128 * 1024 * 1024,
}
}
/// 设置最大大小
pub fn set_max_size(&mut self, max_size: usize) -> Result<()> {
if max_size > 128 * 1024 * 1024 {
@@ -51,47 +51,47 @@ impl SshBuf {
self.max_size = max_size;
Ok(())
}
/// 获取 buffer 长度(对应 OpenSSH sshbuf_len
///
///
/// OpenSSH: `sshbuf_len = buf->size - buf->off`
pub fn len(&self) -> usize {
self.size - self.off
}
/// 检查 buffer 是否为空
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// 获取可用空间(对应 OpenSSH sshbuf_avail
///
///
/// OpenSSH: `sshbuf_avail = buf->max_size - buf->size`
pub fn avail(&self) -> usize {
self.max_size - self.size
}
/// 获取可变指针(对应 OpenSSH sshbuf_mutable_ptr
///
///
/// OpenSSH 实现:
/// ```c
/// u_char *sshbuf_mutable_ptr(const struct sshbuf *buf) {
/// return buf->d + buf->off;
/// }
/// ```
///
///
/// Rust 实现:返回 `&mut [u8]` slice零拷贝
pub fn mutable_ptr(&mut self) -> &mut [u8] {
&mut self.data[self.off..self.size]
}
/// 获取不可变指针(对应 OpenSSH sshbuf_ptr
pub fn ptr(&self) -> &[u8] {
&self.data[self.off..self.size]
}
/// 预分配空间(对应 OpenSSH sshbuf_reserve
///
///
/// OpenSSH 实现:
/// ```c
/// int sshbuf_reserve(struct sshbuf *buf, size_t len, u_char **dpp) {
@@ -104,31 +104,31 @@ impl SshBuf {
/// return 0;
/// }
/// ```
///
///
/// Rust 实现:返回 `&mut [u8]` slice零拷贝可直接 write
pub fn reserve(&mut self, len: usize) -> Result<&mut [u8]> {
if len > self.avail() {
return Err(anyhow!("no buffer space (avail={})", self.avail()));
}
// 预分配空间
let current_size = self.size;
let new_size = current_size + len;
// 确保 Vec 有足够容量
if new_size > self.data.len() {
self.data.resize(new_size, 0);
}
// 更新 size
self.size = new_size;
// 返回新空间的 slice零拷贝
Ok(&mut self.data[current_size..new_size])
}
/// 消费数据(对应 OpenSSH sshbuf_consume
///
///
/// OpenSSH 实现:
/// ```c
/// int sshbuf_consume(struct sshbuf *buf, size_t len) {
@@ -140,29 +140,33 @@ impl SshBuf {
/// return 0;
/// }
/// ```
///
///
/// Rust 实现:移动偏移量(零拷贝,不实际删除数据)
pub fn consume(&mut self, len: usize) -> Result<()> {
if len > self.len() {
return Err(anyhow!("message incomplete (len={}, consume={})", self.len(), len));
return Err(anyhow!(
"message incomplete (len={}, consume={})",
self.len(),
len
));
}
self.off += len;
// 如果 buffer 空,重置
if self.off == self.size {
self.off = 0;
self.size = 0;
// OpenSSH: pack buffer移除已消费的数据
// Rust: 我们保留 Vec但重置指针
}
Ok(())
}
/// 从末尾消费数据(对应 OpenSSH sshbuf_consume_end
///
///
/// OpenSSH 实现:
/// ```c
/// int sshbuf_consume_end(struct sshbuf *buf, size_t len) {
@@ -174,13 +178,13 @@ impl SshBuf {
if len > self.len() {
return Err(anyhow!("message incomplete"));
}
self.size -= len;
Ok(())
}
/// 直接从 fd read 到 buffer对应 OpenSSH sshbuf_read
///
///
/// OpenSSH 实现:
/// ```c
/// int sshbuf_read(int fd, struct sshbuf *buf, size_t maxlen, size_t *rlen) {
@@ -195,71 +199,75 @@ impl SshBuf {
/// return 0;
/// }
/// ```
///
///
/// Rust 实现:零拷贝,直接 read 到 buffer
pub fn read_from<R: Read>(&mut self, reader: &mut R, maxlen: usize) -> Result<usize> {
// 1. reserve 空间
let space = self.reserve(maxlen)?;
// 2. 直接 read 到 buffer零拷贝
let n = reader.read(space)?;
// 3. 调整大小(移除未使用的空间)
if maxlen > n {
self.consume_end(maxlen - n)?;
}
Ok(n)
}
/// 直接从 buffer write 到 fd对应 OpenSSH channel_handle_wfd
///
///
/// OpenSSH 实现:
/// ```c
/// buf = sshbuf_mutable_ptr(c->output); // 获取指针
/// len = write(c->wfd, buf, dlen); // 直接 write
/// sshbuf_consume(c->output, len); // 消费已写入的数据
/// ```
///
///
/// Rust 实现:零拷贝,直接 write 从 buffer
pub fn write_to<W: Write>(&mut self, writer: &mut W) -> Result<usize> {
if self.is_empty() {
return Ok(0);
}
// 1. 获取数据指针(零拷贝)
let data = self.ptr();
// 2. 直接 write零拷贝
let n = writer.write(data)?;
// 3. 消费已写入的数据(零拷贝,只移动偏移)
self.consume(n)?;
Ok(n)
}
/// 添加数据(对应 OpenSSH sshbuf_put
///
///
/// 用于不需要零拷贝的场景
pub fn put(&mut self, data: &[u8]) -> Result<()> {
let space = self.reserve(data.len())?;
space.copy_from_slice(data);
Ok(())
}
/// 清空 buffer
pub fn reset(&mut self) {
self.off = 0;
self.size = 0;
// OpenSSH: 保留 Vec只重置指针
}
/// Debug: 打印 buffer 状态
pub fn debug_info(&self) -> String {
format!(
"SshBuf: off={}, size={}, len={}, alloc={}, max_size={}",
self.off, self.size, self.len(), self.data.len(), self.max_size
self.off,
self.size,
self.len(),
self.data.len(),
self.max_size
)
}
}
@@ -274,11 +282,11 @@ impl Default for SshBuf {
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_sshbuf_basic() {
let mut buf = SshBuf::new();
// Test reserve - write into reserved space
{
let space = buf.reserve(10).unwrap();
@@ -286,57 +294,57 @@ mod tests {
space[0] = 1;
space[1] = 2;
} // space dropped, buf accessible
// Verify buffer length after reserve
assert_eq!(buf.len(), 10);
let ptr = buf.mutable_ptr();
assert_eq!(ptr[0], 1);
assert_eq!(ptr[1], 2);
// Test consume
buf.consume(2).unwrap();
assert_eq!(buf.len(), 8);
assert_eq!(buf.off, 2);
}
#[test]
fn test_sshbuf_zero_copy_read() {
let mut buf = SshBuf::with_capacity(100);
let mut reader = Cursor::new("hello world");
// 零拷贝 read
let n = buf.read_from(&mut reader, 20).unwrap();
assert_eq!(n, 11); // "hello world" length
assert_eq!(buf.len(), 11);
// 检查数据
let data = buf.ptr();
assert_eq!(data, "hello world".as_bytes());
}
#[test]
fn test_sshbuf_zero_copy_write() {
let mut buf = SshBuf::new();
buf.put("hello world".as_bytes()).unwrap();
let mut writer = Vec::new();
// 零拷贝 write
let n = buf.write_to(&mut writer).unwrap();
assert_eq!(n, 11);
assert_eq!(buf.len(), 0); // 已消费
// 检查数据
assert_eq!(writer, "hello world".as_bytes());
}
#[test]
fn test_sshbuf_max_size() {
let mut buf = SshBuf::new();
buf.set_max_size(1000).unwrap();
// 尝试 reserve 超过 max_size
let result = buf.reserve(2000);
assert!(result.is_err());
}
}
}

View File

@@ -2,8 +2,8 @@
// 参考OpenSSH sshd.c: ssh_exchange_identification()
use anyhow::Result;
use log::{debug, info};
use std::io::{Read, Write};
use log::{info, debug};
/// SSH版本字符串
pub const SSH_VERSION: &str = "SSH-2.0-MarkBaseSSH_1.0";
@@ -15,93 +15,96 @@ impl VersionExchange {
/// 执行版本交换(服务器端)
pub fn exchange<T: Read + Write>(stream: &mut T) -> Result<String> {
info!("Starting SSH version exchange");
// 1. 发送服务器版本
Self::send_version(stream)?;
// 2. 接收客户端版本
let client_version = Self::receive_version(stream)?;
info!("Version exchange completed: server={}, client={}", SSH_VERSION, client_version);
info!(
"Version exchange completed: server={}, client={}",
SSH_VERSION, client_version
);
Ok(client_version)
}
/// 发送服务器版本参考OpenSSH ssh_exchange_identification
fn send_version<T: Write>(stream: &mut T) -> Result<()> {
let version_line = format!("{}\r\n", SSH_VERSION);
stream.write_all(version_line.as_bytes())?;
stream.flush()?;
debug!("Sent version: {}", SSH_VERSION);
Ok(())
}
/// 接收客户端版本参考OpenSSH ssh_exchange_identification
fn receive_version<T: Read>(stream: &mut T) -> Result<String> {
let mut buffer = Vec::new();
let mut byte = [0u8; 1];
// 读取直到遇到'\n'参考OpenSSH实现
loop {
stream.read_exact(&mut byte)?;
// OpenSSH兼容性处理跳过前导空行和调试信息
if buffer.is_empty() && byte[0] == '\n' as u8 {
continue; // 跳过空行
if buffer.is_empty() && byte[0] == b'\n' {
continue; // 跳过空行
}
// 调试信息行(以'#'开头),跳过
if buffer.is_empty() && byte[0] == '#' as u8 {
if buffer.is_empty() && byte[0] == b'#' {
// 读取整行调试信息
while byte[0] != '\n' as u8 {
while byte[0] != b'\n' {
stream.read_exact(&mut byte)?;
}
buffer.clear();
continue;
}
buffer.push(byte[0]);
// 遇到'\n'结束
if byte[0] == '\n' as u8 {
if byte[0] == b'\n' {
break;
}
// 缓冲区溢出保护OpenSSH限制255字节
if buffer.len() > 255 {
return Err(anyhow::anyhow!("Version string too long"));
}
}
// 解析版本字符串
let version_line = String::from_utf8(buffer)?;
let version = version_line.trim().trim_matches('\r');
// 验证版本格式SSH-2.0-*
if !version.starts_with("SSH-2.0-") {
return Err(anyhow::anyhow!("Invalid SSH version: {}", version));
}
debug!("Received version: {}", version);
Ok(version.to_string())
}
/// 解析客户端版本信息(兼容性检查)
pub fn parse_client_version(version: &str) -> Result<ClientVersionInfo> {
// 格式SSH-protoversion-softwareversion SP comments
let parts: Vec<&str> = version.split_whitespace().collect();
let main_part = parts.first().map_or(version, |v| v);
let dash_parts: Vec<&str> = main_part.split('-').collect();
if dash_parts.len() < 3 {
return Err(anyhow::anyhow!("Invalid version format: {}", version));
}
let proto_version = dash_parts.get(1).map_or("2.0", |v| v);
let software_version = dash_parts.get(2).map_or("unknown", |v| v);
let comments = parts.get(1).map(|s| s.to_string());
Ok(ClientVersionInfo {
proto_version: proto_version.to_string(),
software_version: software_version.to_string(),
@@ -120,12 +123,12 @@ pub struct ClientVersionInfo {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_version_format() {
assert!(SSH_VERSION.starts_with("SSH-2.0-"));
}
#[test]
fn test_parse_client_version() {
let version = "SSH-2.0-OpenSSH_10.2";

View File

@@ -1,18 +1,18 @@
// SSH Window Size管理Phase 13.6
// 参考RFC 4254 Section 5.2: Window Size Adjustment
use anyhow::{Result, anyhow};
use log::{info, warn, debug};
use std::sync::{Arc, Mutex};
use byteorder::{BigEndian, WriteBytesExt};
use crate::ssh_server::packet::PacketType;
use anyhow::{anyhow, Result};
use byteorder::{BigEndian, WriteBytesExt};
use log::{info, warn};
use std::sync::{Arc, Mutex};
/// Window Size管理器Phase 13.6
pub struct WindowManager {
initial_window_size: u32, // RFC 4254: 2MB默认
initial_window_size: u32, // RFC 4254: 2MB默认
current_window_size: Arc<Mutex<u32>>,
max_packet_size: u32, // RFC 4254: 32KB默认
consumed_bytes: Arc<Mutex<u32>>, // 已消耗bytes统计
max_packet_size: u32, // RFC 4254: 32KB默认
consumed_bytes: Arc<Mutex<u32>>, // 已消耗bytes统计
}
impl WindowManager {
@@ -25,89 +25,103 @@ impl WindowManager {
consumed_bytes: Arc::new(Mutex::new(0)),
}
}
/// RFC 4254默认window size2MB
pub fn rfc_default() -> Self {
Self::new(2097152, 32768) // 2MB window, 32KB packet
Self::new(2097152, 32768) // 2MB window, 32KB packet
}
/// 检查window size是否足够Phase 13.6
pub fn check_window_available(&self, data_size: u32) -> bool {
let window = self.current_window_size.lock().unwrap();
let available = *window >= data_size;
if !available {
warn!("Window size insufficient: need {}, have {}", data_size, *window);
warn!(
"Window size insufficient: need {}, have {}",
data_size, *window
);
}
available
}
/// 消耗window sizePhase 13.6:发送数据后)
pub fn consume_window(&self, data_size: u32) -> Result<()> {
let mut window = self.current_window_size.lock().unwrap();
if *window < data_size {
return Err(anyhow!("Window size insufficient: need {}, have {}", data_size, *window));
return Err(anyhow!(
"Window size insufficient: need {}, have {}",
data_size,
*window
));
}
*window -= data_size;
// 统计已消耗bytes
let mut consumed = self.consumed_bytes.lock().unwrap();
*consumed += data_size;
info!("Window size consumed: {} bytes, remaining {}, total consumed {}",
data_size, *window, *consumed);
info!(
"Window size consumed: {} bytes, remaining {}, total consumed {}",
data_size, *window, *consumed
);
Ok(())
}
/// 调整window sizePhase 13.6收到SSH_MSG_CHANNEL_WINDOW_ADJUST
pub fn adjust_window(&self, bytes_to_add: u32) {
let mut window = self.current_window_size.lock().unwrap();
*window += bytes_to_add;
info!("Window size adjusted: added {} bytes, total {}", bytes_to_add, *window);
info!(
"Window size adjusted: added {} bytes, total {}",
bytes_to_add, *window
);
}
/// 构建SSH_MSG_CHANNEL_WINDOW_ADJUST packetPhase 13.6
pub fn build_window_adjust_packet(channel_id: u32, bytes_to_add: u32) -> Result<Vec<u8>> {
let mut packet = Vec::new();
// Packet type: SSH_MSG_CHANNEL_WINDOW_ADJUST (type 93)
packet.write_u8(PacketType::SSH_MSG_CHANNEL_WINDOW_ADJUST as u8)?;
// Recipient channel ID
packet.write_u32::<BigEndian>(channel_id)?;
// Bytes to add
packet.write_u32::<BigEndian>(bytes_to_add)?;
info!("Built SSH_MSG_CHANNEL_WINDOW_ADJUST for channel {}: +{} bytes",
channel_id, bytes_to_add);
info!(
"Built SSH_MSG_CHANNEL_WINDOW_ADJUST for channel {}: +{} bytes",
channel_id, bytes_to_add
);
Ok(packet)
}
/// 获取当前window sizePhase 13.6
pub fn get_current_window(&self) -> u32 {
*self.current_window_size.lock().unwrap()
}
/// 获取已消耗bytesPhase 13.6
pub fn get_consumed_bytes(&self) -> u32 {
*self.consumed_bytes.lock().unwrap()
}
/// 重置window sizePhase 13.6channel重置
pub fn reset_window(&self) {
let mut window = self.current_window_size.lock().unwrap();
*window = self.initial_window_size;
let mut consumed = self.consumed_bytes.lock().unwrap();
*consumed = 0;
info!("Window size reset to initial: {}", self.initial_window_size);
}
}
@@ -128,63 +142,63 @@ impl ChannelLifecycle {
close_received: false,
}
}
/// 构建SSH_MSG_CHANNEL_EOF packetPhase 13.7
pub fn build_eof_packet(channel_id: u32) -> Result<Vec<u8>> {
let mut packet = Vec::new();
// Packet type: SSH_MSG_CHANNEL_EOF (type 96)
packet.write_u8(PacketType::SSH_MSG_CHANNEL_EOF as u8)?;
// Recipient channel ID
packet.write_u32::<BigEndian>(channel_id)?;
info!("Built SSH_MSG_CHANNEL_EOF for channel {}", channel_id);
Ok(packet)
}
/// 构建SSH_MSG_CHANNEL_CLOSE packetPhase 13.7
pub fn build_close_packet(channel_id: u32) -> Result<Vec<u8>> {
let mut packet = Vec::new();
// Packet type: SSH_MSG_CHANNEL_CLOSE (type 97)
packet.write_u8(PacketType::SSH_MSG_CHANNEL_CLOSE as u8)?;
// Recipient channel ID
packet.write_u32::<BigEndian>(channel_id)?;
info!("Built SSH_MSG_CHANNEL_CLOSE for channel {}", channel_id);
Ok(packet)
}
/// 标记EOF已发送Phase 13.7
pub fn mark_eof_sent(&mut self) {
self.eof_sent = true;
info!("Channel {} EOF marked as sent", self.channel_id);
}
/// 标记CLOSE已接收Phase 13.7
pub fn mark_close_received(&mut self) {
self.close_received = true;
info!("Channel {} CLOSE marked as received", self.channel_id);
}
/// 检查是否可以清理channelPhase 13.7
pub fn can_cleanup(&self) -> bool {
self.eof_sent && self.close_received
}
/// 清理channel资源Phase 13.7
pub fn cleanup_channel(&self) -> Result<()> {
info!("Cleaning up channel {} resources", self.channel_id);
// Phase 13.7: 实际清理逻辑需要在ChannelManager中实现
// - 移除channel记录
// - 关闭TCP连接
// - 清理监听器如果是forwarded-tcpip
info!("Channel {} cleanup completed", self.channel_id);
Ok(())
}
@@ -193,42 +207,42 @@ impl ChannelLifecycle {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_window_manager_creation() {
let manager = WindowManager::rfc_default();
assert_eq!(manager.get_current_window(), 2097152);
assert_eq!(manager.max_packet_size, 32768);
}
#[test]
fn test_window_consumption() {
let manager = WindowManager::rfc_default();
// 消耗1000 bytes
manager.consume_window(1000).unwrap();
assert_eq!(manager.get_current_window(), 2097152 - 1000);
assert_eq!(manager.get_consumed_bytes(), 1000);
}
#[test]
fn test_window_adjustment() {
let manager = WindowManager::rfc_default();
// 消耗1000 bytes
manager.consume_window(1000).unwrap();
// 调整500 bytes
manager.adjust_window(500);
assert_eq!(manager.get_current_window(), 2097152 - 1000 + 500);
}
#[test]
fn test_build_eof_packet() {
let packet = ChannelLifecycle::build_eof_packet(1).unwrap();
assert_eq!(packet[0], PacketType::SSH_MSG_CHANNEL_EOF as u8);
}
#[test]
fn test_build_close_packet() {
let packet = ChannelLifecycle::build_close_packet(1).unwrap();