Fix code quality: trailing whitespace, unused imports, clippy warnings
- Fix trailing whitespace in kex.rs and s3.rs - Add missing KexProposal import in kex_complete.rs - Auto-fix clippy warnings across all crates - All 153 tests pass
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
||||
use std::io::Write;
|
||||
use anyhow::{Result, anyhow};
|
||||
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||
use anyhow::{anyhow, Result};
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use log::{info, warn, debug};
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
use log::{debug, info, warn};
|
||||
use std::io::Write;
|
||||
|
||||
use ed25519_dalek::{VerifyingKey, Signature};
|
||||
use ed25519_dalek::{Signature, VerifyingKey};
|
||||
|
||||
use crate::provider::{DataProvider, ProviderError};
|
||||
|
||||
@@ -27,7 +27,11 @@ impl AuthHandler {
|
||||
}
|
||||
|
||||
/// 处理SSH_MSG_USERAUTH_REQUEST(参考OpenSSH auth2.c: userauth_request())
|
||||
pub fn handle_userauth_request(&mut self, packet: &SshPacket, session_id: &[u8]) -> Result<AuthResult> {
|
||||
pub fn handle_userauth_request(
|
||||
&mut self,
|
||||
packet: &SshPacket,
|
||||
session_id: &[u8],
|
||||
) -> Result<AuthResult> {
|
||||
info!("Processing SSH_MSG_USERAUTH_REQUEST");
|
||||
|
||||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
|
||||
@@ -41,7 +45,10 @@ impl AuthHandler {
|
||||
let service = read_ssh_string(&mut cursor)?;
|
||||
let method = read_ssh_string(&mut cursor)?;
|
||||
|
||||
info!("Auth request: user={}, service={}, method={}", user, service, method);
|
||||
info!(
|
||||
"Auth request: user={}, service={}, method={}",
|
||||
user, service, method
|
||||
);
|
||||
|
||||
if service != "ssh-connection" {
|
||||
warn!("Unsupported service: {}", service);
|
||||
@@ -62,18 +69,28 @@ impl AuthHandler {
|
||||
}
|
||||
|
||||
/// 处理password认证(参考OpenSSH auth-passwd.c)
|
||||
fn handle_password_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result<AuthResult> {
|
||||
fn handle_password_auth(
|
||||
&mut self,
|
||||
cursor: &mut std::io::Cursor<&[u8]>,
|
||||
user: &str,
|
||||
) -> Result<AuthResult> {
|
||||
info!("Handling password auth for user: {}", user);
|
||||
|
||||
let change_password = cursor.read_u8()? != 0;
|
||||
if change_password {
|
||||
warn!("Password change not supported");
|
||||
return Ok(AuthResult::Failure("Password change not supported".to_string()));
|
||||
return Ok(AuthResult::Failure(
|
||||
"Password change not supported".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let password = read_ssh_string(cursor)?;
|
||||
|
||||
debug!("Password auth attempt: user={}, password length={}", user, password.len());
|
||||
debug!(
|
||||
"Password auth attempt: user={}, password length={}",
|
||||
user,
|
||||
password.len()
|
||||
);
|
||||
|
||||
match self.provider.check_password(user, &password) {
|
||||
Ok(true) => {
|
||||
@@ -88,9 +105,7 @@ impl AuthHandler {
|
||||
warn!("User not found: {}", msg);
|
||||
Ok(AuthResult::Failure("password,publickey".to_string()))
|
||||
}
|
||||
Err(e) => {
|
||||
Err(anyhow!("Password auth error: {}", e))
|
||||
}
|
||||
Err(e) => Err(anyhow!("Password auth error: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,7 +160,12 @@ impl AuthHandler {
|
||||
let algorithm = read_ssh_string(cursor)?;
|
||||
let public_key_blob = read_ssh_string_bytes(cursor)?;
|
||||
|
||||
info!("Publickey auth: algorithm={}, blob_len={}, is_signed={}", algorithm, public_key_blob.len(), is_signed);
|
||||
info!(
|
||||
"Publickey auth: algorithm={}, blob_len={}, is_signed={}",
|
||||
algorithm,
|
||||
public_key_blob.len(),
|
||||
is_signed
|
||||
);
|
||||
|
||||
if !self.is_key_authorized(user, &algorithm, &public_key_blob)? {
|
||||
warn!("Public key not authorized for user: {}", user);
|
||||
@@ -160,14 +180,26 @@ impl AuthHandler {
|
||||
|
||||
let signature_blob = read_ssh_string_bytes(cursor)?;
|
||||
|
||||
self.verify_signature(&algorithm, &public_key_blob, &signature_blob, user, service, session_id)?;
|
||||
self.verify_signature(
|
||||
&algorithm,
|
||||
&public_key_blob,
|
||||
&signature_blob,
|
||||
user,
|
||||
service,
|
||||
session_id,
|
||||
)?;
|
||||
|
||||
info!("Publickey auth successful for user: {}", user);
|
||||
Ok(AuthResult::Success)
|
||||
}
|
||||
|
||||
/// 检查public key是否在授权列表中(数据库优先,fallback到filesystem)
|
||||
fn is_key_authorized(&self, user: &str, algorithm: &str, public_key_blob: &[u8]) -> Result<bool> {
|
||||
fn is_key_authorized(
|
||||
&self,
|
||||
user: &str,
|
||||
algorithm: &str,
|
||||
public_key_blob: &[u8],
|
||||
) -> Result<bool> {
|
||||
// 1. 先检查数据库
|
||||
match self.provider.get_public_keys(user) {
|
||||
Ok(keys) => {
|
||||
@@ -187,10 +219,12 @@ impl AuthHandler {
|
||||
Err(_) => match std::fs::read_to_string("data/authorized_keys") {
|
||||
Ok(c) => c,
|
||||
Err(_) => return Ok(false),
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
Ok(content.lines().any(|line| public_key_matches_line(line, algorithm, public_key_blob)))
|
||||
Ok(content
|
||||
.lines()
|
||||
.any(|line| public_key_matches_line(line, algorithm, public_key_blob)))
|
||||
}
|
||||
|
||||
/// 验证Ed25519签名(RFC 4252 §7)
|
||||
@@ -246,7 +280,8 @@ impl AuthHandler {
|
||||
signed_data.write_all(public_key_blob)?;
|
||||
|
||||
// 验证签名
|
||||
verifying_key.verify_strict(&signed_data, &signature)
|
||||
verifying_key
|
||||
.verify_strict(&signed_data, &signature)
|
||||
.map_err(|e| anyhow!("Ed25519 signature verification failed: {}", e))
|
||||
}
|
||||
}
|
||||
@@ -270,10 +305,10 @@ fn parse_ed25519_verifying_key(public_key_blob: &[u8]) -> Result<VerifyingKey> {
|
||||
if key_bytes.len() != 32 {
|
||||
return Err(anyhow!("Invalid Ed25519 key length: {}", key_bytes.len()));
|
||||
}
|
||||
let key_array: [u8; 32] = key_bytes.try_into()
|
||||
let key_array: [u8; 32] = key_bytes
|
||||
.try_into()
|
||||
.map_err(|_| anyhow!("Invalid Ed25519 key data"))?;
|
||||
VerifyingKey::from_bytes(&key_array)
|
||||
.map_err(|e| anyhow!("Invalid Ed25519 key: {}", e))
|
||||
VerifyingKey::from_bytes(&key_array).map_err(|e| anyhow!("Invalid Ed25519 key: {}", e))
|
||||
}
|
||||
|
||||
/// 解析Ed25519签名blob(SSH格式 -> Signature)
|
||||
@@ -285,9 +320,13 @@ fn parse_ed25519_signature(signature_blob: &[u8]) -> Result<Signature> {
|
||||
}
|
||||
let sig_bytes = read_ssh_string_bytes(&mut cursor)?;
|
||||
if sig_bytes.len() != 64 {
|
||||
return Err(anyhow!("Invalid Ed25519 signature length: {}", sig_bytes.len()));
|
||||
return Err(anyhow!(
|
||||
"Invalid Ed25519 signature length: {}",
|
||||
sig_bytes.len()
|
||||
));
|
||||
}
|
||||
let sig_array: [u8; 64] = sig_bytes.try_into()
|
||||
let sig_array: [u8; 64] = sig_bytes
|
||||
.try_into()
|
||||
.map_err(|_| anyhow!("Invalid Ed25519 signature data"))?;
|
||||
Ok(Signature::from_bytes(&sig_array))
|
||||
}
|
||||
@@ -305,7 +344,9 @@ fn public_key_matches_line(line: &str, algorithm: &str, public_key_blob: &[u8])
|
||||
if parts[0] != algorithm {
|
||||
return false;
|
||||
}
|
||||
base64_decode(parts[1]).map(|decoded| decoded == public_key_blob).unwrap_or(false)
|
||||
base64_decode(parts[1])
|
||||
.map(|decoded| decoded == public_key_blob)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
|
||||
@@ -323,7 +364,8 @@ fn read_ssh_string_bytes<R: std::io::Read>(reader: &mut R) -> Result<Vec<u8>> {
|
||||
}
|
||||
|
||||
fn base64_decode(input: &str) -> Result<Vec<u8>> {
|
||||
general_purpose::STANDARD.decode(input)
|
||||
general_purpose::STANDARD
|
||||
.decode(input)
|
||||
.map_err(|e| anyhow!("Base64 decode error: {}", e))
|
||||
}
|
||||
|
||||
@@ -335,7 +377,10 @@ mod tests {
|
||||
#[test]
|
||||
fn test_userauth_success_packet() {
|
||||
let packet = AuthHandler::build_userauth_success().unwrap();
|
||||
assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_SUCCESS as u8);
|
||||
assert_eq!(
|
||||
packet.payload[0],
|
||||
PacketType::SSH_MSG_USERAUTH_SUCCESS as u8
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -343,6 +388,9 @@ mod tests {
|
||||
let methods = vec!["password".to_string(), "publickey".to_string()];
|
||||
let packet = AuthHandler::build_userauth_failure(&methods, false).unwrap();
|
||||
|
||||
assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_FAILURE as u8);
|
||||
assert_eq!(
|
||||
packet.payload[0],
|
||||
PacketType::SSH_MSG_USERAUTH_FAILURE as u8
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,33 +1,33 @@
|
||||
// SSH加密通道实现(Phase 4)
|
||||
// 参考OpenSSH cipher.c, mac.c
|
||||
|
||||
use aes::Aes128; // 改为AES-128(协商算法是aes128-ctr)
|
||||
use super::crypto::SessionKeys;
|
||||
use aes::Aes128; // 改为AES-128(协商算法是aes128-ctr)
|
||||
use anyhow::{anyhow, Result};
|
||||
use byteorder::{BigEndian, WriteBytesExt};
|
||||
use cipher::{KeyIvInit, StreamCipher};
|
||||
use ctr::Ctr128BE;
|
||||
use hmac::{Hmac, Mac};
|
||||
use log::info;
|
||||
use sha2::Sha256;
|
||||
use cipher::{KeyIvInit, StreamCipher};
|
||||
use std::io::Write;
|
||||
use anyhow::{Result, anyhow};
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use log::{info, debug, warn};
|
||||
use super::crypto::SessionKeys;
|
||||
|
||||
type Aes128Ctr = Ctr128BE<Aes128>; // AES-128-CTR(16字节密钥)
|
||||
type Aes128Ctr = Ctr128BE<Aes128>; // AES-128-CTR(16字节密钥)
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// SSH加密通道管理器(参考OpenSSH struct sshcipher_ctx)
|
||||
pub struct EncryptionContext {
|
||||
pub session_id: Vec<u8>, // session identifier (exchange hash)
|
||||
pub encryption_key_ctos: Vec<u8>, // 客户端→服务器加密密钥
|
||||
pub encryption_key_stoc: Vec<u8>, // 服务器→客户端加密密钥
|
||||
pub mac_key_ctos: Vec<u8>, // 客户端→服务器MAC密钥
|
||||
pub mac_key_stoc: Vec<u8>, // 服务器→客户端MAC密钥
|
||||
pub iv_ctos: Vec<u8>, // 客户端→服务器IV
|
||||
pub iv_stoc: Vec<u8>, // 服务器→客户端IV
|
||||
pub sequence_number_ctos: u32, // 客户端→服务器序列号
|
||||
pub sequence_number_stoc: u32, // 服务器→客户端序列号
|
||||
pub cipher_ctos: Option<Aes128Ctr>, // 客户端→服务器cipher实例(持久化)
|
||||
pub cipher_stoc: Option<Aes128Ctr>, // 服务器→客户端cipher实例(持久化)
|
||||
pub encryption_key_ctos: Vec<u8>, // 客户端→服务器加密密钥
|
||||
pub encryption_key_stoc: Vec<u8>, // 服务器→客户端加密密钥
|
||||
pub mac_key_ctos: Vec<u8>, // 客户端→服务器MAC密钥
|
||||
pub mac_key_stoc: Vec<u8>, // 服务器→客户端MAC密钥
|
||||
pub iv_ctos: Vec<u8>, // 客户端→服务器IV
|
||||
pub iv_stoc: Vec<u8>, // 服务器→客户端IV
|
||||
pub sequence_number_ctos: u32, // 客户端→服务器序列号
|
||||
pub sequence_number_stoc: u32, // 服务器→客户端序列号
|
||||
pub cipher_ctos: Option<Aes128Ctr>, // 客户端→服务器cipher实例(持久化)
|
||||
pub cipher_stoc: Option<Aes128Ctr>, // 服务器→客户端cipher实例(持久化)
|
||||
}
|
||||
|
||||
impl Default for EncryptionContext {
|
||||
@@ -53,27 +53,33 @@ impl EncryptionContext {
|
||||
/// OpenSSH cipher.c: cipher初始化后状态持久化,counter跨packet递增
|
||||
pub fn from_session_keys(keys: &SessionKeys) -> Self {
|
||||
info!("Initializing ciphers with session keys:");
|
||||
info!(" encryption_key_ctos (16 bytes): {:?}", &keys.encryption_key_ctos[..16]);
|
||||
info!(
|
||||
" encryption_key_ctos (16 bytes): {:?}",
|
||||
&keys.encryption_key_ctos[..16]
|
||||
);
|
||||
info!(" iv_ctos (16 bytes): {:?}", &keys.iv_ctos[..16]);
|
||||
info!(" encryption_key_stoc (16 bytes): {:?}", &keys.encryption_key_stoc[..16]);
|
||||
info!(
|
||||
" encryption_key_stoc (16 bytes): {:?}",
|
||||
&keys.encryption_key_stoc[..16]
|
||||
);
|
||||
info!(" iv_stoc (16 bytes): {:?}", &keys.iv_stoc[..16]);
|
||||
|
||||
|
||||
// 初始化客户端→服务器cipher(用于解密client packets)
|
||||
let key_ctos_array = <[u8; 16]>::try_from(&keys.encryption_key_ctos[..16])
|
||||
.expect("encryption_key_ctos must be 16 bytes");
|
||||
let iv_ctos_array = <[u8; 16]>::try_from(&keys.iv_ctos[..16])
|
||||
.expect("iv_ctos must be 16 bytes");
|
||||
let iv_ctos_array =
|
||||
<[u8; 16]>::try_from(&keys.iv_ctos[..16]).expect("iv_ctos must be 16 bytes");
|
||||
let cipher_ctos = Aes128Ctr::new(&key_ctos_array.into(), &iv_ctos_array.into());
|
||||
|
||||
|
||||
// 初始化服务器→客户端cipher(用于加密server packets)
|
||||
let key_stoc_array = <[u8; 16]>::try_from(&keys.encryption_key_stoc[..16])
|
||||
.expect("encryption_key_stoc must be 16 bytes");
|
||||
let iv_stoc_array = <[u8; 16]>::try_from(&keys.iv_stoc[..16])
|
||||
.expect("iv_stoc must be 16 bytes");
|
||||
let iv_stoc_array =
|
||||
<[u8; 16]>::try_from(&keys.iv_stoc[..16]).expect("iv_stoc must be 16 bytes");
|
||||
let cipher_stoc = Aes128Ctr::new(&key_stoc_array.into(), &iv_stoc_array.into());
|
||||
|
||||
|
||||
info!("Ciphers initialized successfully");
|
||||
|
||||
|
||||
Self {
|
||||
session_id: keys.session_id.clone(),
|
||||
encryption_key_ctos: keys.encryption_key_ctos.clone(),
|
||||
@@ -84,26 +90,26 @@ impl EncryptionContext {
|
||||
iv_stoc: keys.iv_stoc.clone(),
|
||||
sequence_number_ctos: 0,
|
||||
sequence_number_stoc: 0,
|
||||
cipher_ctos: Some(cipher_ctos), // 持久化cipher实例
|
||||
cipher_stoc: Some(cipher_stoc), // 持久化cipher实例
|
||||
cipher_ctos: Some(cipher_ctos), // 持久化cipher实例
|
||||
cipher_stoc: Some(cipher_stoc), // 持久化cipher实例
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// RFC 4344: Compute AES-CTR IV for a specific packet
|
||||
/// IV = nonce(8 bytes from derived IV) + sequence_number(8 bytes)
|
||||
fn compute_ctr_iv(nonce: &[u8], sequence_number: u32) -> Vec<u8> {
|
||||
let mut iv = Vec::with_capacity(16);
|
||||
|
||||
|
||||
// Nonce: first 8 bytes of derived IV (constant)
|
||||
iv.extend_from_slice(&nonce[..8]);
|
||||
|
||||
|
||||
// Counter: sequence number as 8-byte big-endian
|
||||
iv.extend_from_slice(&sequence_number.to_be_bytes());
|
||||
iv.extend_from_slice(&[0u8; 4]); // Upper 4 bytes = 0
|
||||
|
||||
|
||||
iv
|
||||
}
|
||||
|
||||
|
||||
/// 加密packet(参考OpenSSH cipher.c: cipher_encrypt())
|
||||
pub fn encrypt_packet(
|
||||
&mut self,
|
||||
@@ -113,17 +119,17 @@ impl EncryptionContext {
|
||||
) -> Result<Vec<u8>> {
|
||||
let key_array = <[u8; 16]>::try_from(encryption_key)?;
|
||||
let iv_array = <[u8; 16]>::try_from(iv)?;
|
||||
|
||||
|
||||
let mut cipher = Aes128Ctr::new(&key_array.into(), &iv_array.into());
|
||||
|
||||
|
||||
let mut ciphertext = plaintext.to_vec();
|
||||
cipher.apply_keystream(&mut ciphertext);
|
||||
|
||||
|
||||
self.sequence_number_stoc += 1;
|
||||
|
||||
|
||||
Ok(ciphertext)
|
||||
}
|
||||
|
||||
|
||||
/// 解密packet(参考OpenSSH cipher.c: cipher_decrypt())
|
||||
pub fn decrypt_packet(
|
||||
&mut self,
|
||||
@@ -133,17 +139,17 @@ impl EncryptionContext {
|
||||
) -> Result<Vec<u8>> {
|
||||
let key_array = <[u8; 16]>::try_from(encryption_key)?;
|
||||
let iv_array = <[u8; 16]>::try_from(iv)?;
|
||||
|
||||
|
||||
let mut cipher = Aes128Ctr::new(&key_array.into(), &iv_array.into());
|
||||
|
||||
|
||||
let mut plaintext = ciphertext.to_vec();
|
||||
cipher.apply_keystream(&mut plaintext);
|
||||
|
||||
|
||||
self.sequence_number_ctos += 1;
|
||||
|
||||
|
||||
Ok(plaintext)
|
||||
}
|
||||
|
||||
|
||||
/// 计算MAC(参考OpenSSH mac.c: mac_compute())
|
||||
pub fn compute_mac(
|
||||
&self,
|
||||
@@ -152,17 +158,17 @@ impl EncryptionContext {
|
||||
mac_key: &[u8],
|
||||
) -> Result<Vec<u8>> {
|
||||
// HMAC-SHA256 MAC计算(参考OpenSSH mac.c)
|
||||
|
||||
|
||||
let mut mac = HmacSha256::new_from_slice(mac_key)?;
|
||||
|
||||
|
||||
// OpenSSH MAC格式:sequence_number + data
|
||||
mac.update(&sequence_number.to_be_bytes());
|
||||
mac.update(data);
|
||||
|
||||
|
||||
let result = mac.finalize();
|
||||
Ok(result.into_bytes().to_vec())
|
||||
}
|
||||
|
||||
|
||||
/// 验证MAC(参考OpenSSH mac.c: mac_check())
|
||||
pub fn verify_mac(
|
||||
&self,
|
||||
@@ -172,14 +178,14 @@ impl EncryptionContext {
|
||||
mac_key: &[u8],
|
||||
) -> Result<bool> {
|
||||
// HMAC验证(参考OpenSSH mac.c)
|
||||
|
||||
|
||||
let computed_mac = self.compute_mac(sequence_number, data, mac_key)?;
|
||||
|
||||
|
||||
// 防止时间攻击(使用常量时间比较)
|
||||
if computed_mac.len() != expected_mac.len() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
|
||||
// 简化实现:直接比较(实际应使用常量时间比较)
|
||||
Ok(computed_mac == expected_mac)
|
||||
}
|
||||
@@ -187,11 +193,11 @@ impl EncryptionContext {
|
||||
|
||||
/// SSH加密packet封装(参考OpenSSH packet.c: ssh_packet_write_poll())
|
||||
pub struct EncryptedPacket {
|
||||
pub packet_length: u32, // 加密后packet长度
|
||||
pub padding_length: u8, // padding长度(加密后)
|
||||
pub payload: Vec<u8>, // payload(加密后)
|
||||
pub padding: Vec<u8>, // padding(加密后)
|
||||
pub mac: Vec<u8>, // MAC(32字节,HMAC-SHA256)
|
||||
pub packet_length: u32, // 加密后packet长度
|
||||
pub padding_length: u8, // padding长度(加密后)
|
||||
pub payload: Vec<u8>, // payload(加密后)
|
||||
pub padding: Vec<u8>, // padding(加密后)
|
||||
pub mac: Vec<u8>, // MAC(32字节,HMAC-SHA256)
|
||||
}
|
||||
|
||||
impl EncryptedPacket {
|
||||
@@ -204,82 +210,88 @@ impl EncryptedPacket {
|
||||
) -> Result<Self> {
|
||||
let block_size = 16;
|
||||
let min_padding = 4;
|
||||
|
||||
|
||||
let payload_length = plaintext_payload.len();
|
||||
|
||||
|
||||
// RFC 4253: entire plaintext packet (including 4-byte packet_length field) must be multiple of block_size
|
||||
// plaintext_packet = packet_length_field(4) + padding_length(1) + payload + padding
|
||||
// So: (4 + 1 + payload_length + padding_length) % 16 == 0
|
||||
|
||||
let base_size = 4 + 1 + payload_length; // without padding
|
||||
|
||||
let base_size = 4 + 1 + payload_length; // without padding
|
||||
let padding_needed = (block_size - (base_size % block_size)) % block_size;
|
||||
|
||||
|
||||
// Ensure padding >= min_padding (RFC 4253 requirement)
|
||||
let padding_length: u8 = if padding_needed < min_padding {
|
||||
(padding_needed + block_size) as u8 // Add one more block to meet minimum
|
||||
(padding_needed + block_size) as u8 // Add one more block to meet minimum
|
||||
} else {
|
||||
padding_needed as u8
|
||||
};
|
||||
|
||||
|
||||
// packet_length = padding_length(1) + payload + padding
|
||||
let packet_length = 1 + payload_length + padding_length as usize;
|
||||
|
||||
info!("Creating AES-CTR encrypted packet: payload_len={}, padding_len={}, packet_len={}",
|
||||
payload_length, padding_length, packet_length);
|
||||
|
||||
|
||||
info!(
|
||||
"Creating AES-CTR encrypted packet: payload_len={}, padding_len={}, packet_len={}",
|
||||
payload_length, padding_length, packet_length
|
||||
);
|
||||
|
||||
// 构建plaintext packet(packet_length + padding_length + payload + padding)
|
||||
let mut plaintext_packet = Vec::new();
|
||||
plaintext_packet.write_u32::<BigEndian>(packet_length as u32)?; // plaintext packet_length
|
||||
plaintext_packet.write_u8(padding_length)?; // plaintext padding_length
|
||||
plaintext_packet.write_all(plaintext_payload)?; // plaintext payload
|
||||
|
||||
plaintext_packet.write_u32::<BigEndian>(packet_length as u32)?; // plaintext packet_length
|
||||
plaintext_packet.write_u8(padding_length)?; // plaintext padding_length
|
||||
plaintext_packet.write_all(plaintext_payload)?; // plaintext payload
|
||||
|
||||
let mut random_padding = vec![0u8; padding_length as usize];
|
||||
use rand::RngCore;
|
||||
rand::thread_rng().fill_bytes(&mut random_padding);
|
||||
plaintext_packet.write_all(&random_padding)?; // plaintext padding
|
||||
|
||||
plaintext_packet.write_all(&random_padding)?; // plaintext padding
|
||||
|
||||
info!("Plaintext packet size: {} bytes", plaintext_packet.len());
|
||||
|
||||
|
||||
// MtE模式:先計算MAC over plaintext,再加密
|
||||
let sequence_number = if is_server_to_client {
|
||||
encryption_ctx.sequence_number_stoc
|
||||
} else {
|
||||
encryption_ctx.sequence_number_ctos
|
||||
};
|
||||
|
||||
|
||||
let mac_key = if is_server_to_client {
|
||||
&encryption_ctx.mac_key_stoc
|
||||
} else {
|
||||
&encryption_ctx.mac_key_ctos
|
||||
};
|
||||
|
||||
|
||||
info!("MAC calculation (MtE mode) over plaintext packet:");
|
||||
info!(" sequence_number: {}", sequence_number);
|
||||
info!(" mac_key length: {}", mac_key.len());
|
||||
info!(" plaintext_packet length: {}", plaintext_packet.len());
|
||||
|
||||
|
||||
// MAC計算:HMAC(sequence_number || plaintext_packet)
|
||||
let mac = encryption_ctx.compute_mac(sequence_number, &plaintext_packet, mac_key)?;
|
||||
|
||||
|
||||
// 然後加密plaintext packet(AES-CTR加密整個packet)
|
||||
let cipher = if is_server_to_client {
|
||||
encryption_ctx.cipher_stoc.as_mut()
|
||||
encryption_ctx
|
||||
.cipher_stoc
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("cipher_stoc not initialized"))?
|
||||
} else {
|
||||
encryption_ctx.cipher_ctos.as_mut()
|
||||
encryption_ctx
|
||||
.cipher_ctos
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("cipher_ctos not initialized"))?
|
||||
};
|
||||
|
||||
|
||||
let mut encrypted_packet = plaintext_packet;
|
||||
cipher.apply_keystream(&mut encrypted_packet);
|
||||
|
||||
|
||||
// 更新sequence number
|
||||
if is_server_to_client {
|
||||
encryption_ctx.sequence_number_stoc += 1;
|
||||
} else {
|
||||
encryption_ctx.sequence_number_ctos += 1;
|
||||
}
|
||||
|
||||
|
||||
Ok(Self {
|
||||
packet_length: packet_length as u32,
|
||||
padding_length,
|
||||
@@ -288,24 +300,27 @@ impl EncryptedPacket {
|
||||
mac,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// 写入加密packet(参考OpenSSH cipher.c)
|
||||
/// AES-CTR模式:写入完整加密packet + MAC
|
||||
pub fn write<W: std::io::Write>(&self, stream: &mut W) -> Result<()> {
|
||||
info!("Writing AES-CTR encrypted packet: total_encrypted_len={}, mac_len={}",
|
||||
self.payload.len(), self.mac.len());
|
||||
|
||||
info!(
|
||||
"Writing AES-CTR encrypted packet: total_encrypted_len={}, mac_len={}",
|
||||
self.payload.len(),
|
||||
self.mac.len()
|
||||
);
|
||||
|
||||
// AES-CTR: 整个packet已加密(包括packet_length),直接写入
|
||||
stream.write_all(&self.payload)?;
|
||||
info!("Wrote encrypted packet ({} bytes)", self.payload.len());
|
||||
|
||||
|
||||
// 写入MAC
|
||||
stream.write_all(&self.mac)?;
|
||||
info!("Wrote MAC ({} bytes)", self.mac.len());
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 读取加密packet(参考OpenSSH packet.c ssh_packet_read_poll2)
|
||||
/// OpenSSH packet.c: AES-CTR先解密第一个块,再提取packet_length
|
||||
/// aadlen = 0 (没有EtM或authenticated encryption), packet_length被加密
|
||||
@@ -315,32 +330,42 @@ impl EncryptedPacket {
|
||||
is_client_to_server: bool,
|
||||
) -> Result<Self> {
|
||||
use std::io::Read;
|
||||
|
||||
|
||||
info!("Reading AES-CTR encrypted packet (packet_length encrypted)");
|
||||
|
||||
|
||||
// 1. 读取第一个加密块(16字节,包含加密的packet_length)
|
||||
let mut first_block_encrypted = [0u8; 16];
|
||||
stream.read_exact(&mut first_block_encrypted)?;
|
||||
|
||||
info!("Read first encrypted block (16 bytes): {:?}", &first_block_encrypted);
|
||||
|
||||
|
||||
info!(
|
||||
"Read first encrypted block (16 bytes): {:?}",
|
||||
&first_block_encrypted
|
||||
);
|
||||
|
||||
// 2. 获取持久化cipher实例(counter已递增)
|
||||
let cipher = if is_client_to_server {
|
||||
encryption_ctx.cipher_ctos.as_mut()
|
||||
encryption_ctx
|
||||
.cipher_ctos
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("cipher_ctos not initialized"))?
|
||||
} else {
|
||||
encryption_ctx.cipher_stoc.as_mut()
|
||||
encryption_ctx
|
||||
.cipher_stoc
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("cipher_stoc not initialized"))?
|
||||
};
|
||||
|
||||
info!("Using cipher for decryption (is_client_to_server={})", is_client_to_server);
|
||||
|
||||
|
||||
info!(
|
||||
"Using cipher for decryption (is_client_to_server={})",
|
||||
is_client_to_server
|
||||
);
|
||||
|
||||
// 3. 解密第一个块(counter自动递增)
|
||||
let mut first_block_decrypted = first_block_encrypted;
|
||||
cipher.apply_keystream(&mut first_block_decrypted);
|
||||
|
||||
|
||||
info!("Decrypted first block: {:?}", &first_block_decrypted);
|
||||
|
||||
|
||||
// 3. 从解密后的数据中提取packet_length(前4字节)和padding_length(第5字节)
|
||||
let packet_length = u32::from_be_bytes([
|
||||
first_block_decrypted[0],
|
||||
@@ -349,67 +374,73 @@ impl EncryptedPacket {
|
||||
first_block_decrypted[3],
|
||||
]);
|
||||
let padding_length = first_block_decrypted[4];
|
||||
|
||||
info!("Decrypted packet_length={}, padding_length={}", packet_length, padding_length);
|
||||
|
||||
|
||||
info!(
|
||||
"Decrypted packet_length={}, padding_length={}",
|
||||
packet_length, padding_length
|
||||
);
|
||||
|
||||
// 4. 合理性检查
|
||||
if packet_length > 35000 {
|
||||
info!("packet_length raw bytes: {:?}", &first_block_decrypted[..4]);
|
||||
return Err(anyhow!("Invalid packet_length: {}", packet_length));
|
||||
}
|
||||
|
||||
|
||||
// 3. 计算剩余加密数据长度
|
||||
// packet_length = padding_length(1) + payload + padding
|
||||
// 总加密数据 = packet_length(4) + packet_length = packet_length + 4
|
||||
// 已读取16字节,剩余 = packet_length + 4 - 16
|
||||
let total_encrypted_size = packet_length as usize + 4; // packet_length field + content
|
||||
let total_encrypted_size = packet_length as usize + 4; // packet_length field + content
|
||||
let remaining_encrypted_size = total_encrypted_size - 16;
|
||||
|
||||
info!("Total encrypted size: {}, remaining: {}", total_encrypted_size, remaining_encrypted_size);
|
||||
|
||||
|
||||
info!(
|
||||
"Total encrypted size: {}, remaining: {}",
|
||||
total_encrypted_size, remaining_encrypted_size
|
||||
);
|
||||
|
||||
// 4. 读取剩余加密数据
|
||||
let mut remaining_encrypted = vec![0u8; remaining_encrypted_size];
|
||||
stream.read_exact(&mut remaining_encrypted)?;
|
||||
|
||||
|
||||
// 5. 继续解密(使用同一个cipher)
|
||||
cipher.apply_keystream(&mut remaining_encrypted);
|
||||
|
||||
|
||||
info!("Remaining decrypted data: {:?}", &remaining_encrypted);
|
||||
|
||||
|
||||
// 6. 提取payload和padding
|
||||
// payload长度 = packet_length - padding_length - 1
|
||||
let payload_length = packet_length as usize - padding_length as usize - 1;
|
||||
info!("Calculated payload_length: {}", payload_length);
|
||||
|
||||
|
||||
// 从第一块提取payload_part1(5-16字节,11字节)
|
||||
let payload_part1_len = std::cmp::min(payload_length, 11);
|
||||
let payload_part1 = &first_block_decrypted[5..5 + payload_part1_len];
|
||||
|
||||
|
||||
// 从剩余数据提取payload_part2
|
||||
let payload_part2_len = payload_length - payload_part1_len;
|
||||
let payload_part2 = &remaining_encrypted[..payload_part2_len];
|
||||
|
||||
|
||||
// 合并payload
|
||||
let mut payload = Vec::new();
|
||||
payload.extend_from_slice(payload_part1);
|
||||
payload.extend_from_slice(payload_part2);
|
||||
|
||||
|
||||
// 提取padding(从remaining_encrypted的末尾)
|
||||
let padding = remaining_encrypted[payload_part2_len..].to_vec();
|
||||
|
||||
|
||||
// 9. 读取MAC
|
||||
info!("Reading MAC (32 bytes)...");
|
||||
let mut mac = vec![0u8; 32];
|
||||
stream.read_exact(&mut mac)?;
|
||||
info!("MAC read successfully");
|
||||
|
||||
|
||||
// 10. 更新sequence number
|
||||
if is_client_to_server {
|
||||
encryption_ctx.sequence_number_ctos += 1;
|
||||
} else {
|
||||
encryption_ctx.sequence_number_stoc += 1;
|
||||
}
|
||||
|
||||
|
||||
Ok(Self {
|
||||
packet_length,
|
||||
padding_length,
|
||||
@@ -418,7 +449,7 @@ impl EncryptedPacket {
|
||||
mac,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// 获取payload内容
|
||||
pub fn payload(&self) -> &[u8] {
|
||||
&self.payload
|
||||
@@ -428,13 +459,13 @@ impl EncryptedPacket {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_aes256_ctr_encryption() {
|
||||
let key = vec![0u8; 16]; // AES-128 key (16 bytes)
|
||||
let key = vec![0u8; 16]; // AES-128 key (16 bytes)
|
||||
let iv = vec![0u8; 16];
|
||||
let plaintext = b"Hello World";
|
||||
|
||||
|
||||
let mut ctx = EncryptionContext::from_session_keys(&SessionKeys {
|
||||
session_id: vec![0u8; 32],
|
||||
encryption_key_ctos: key.clone(),
|
||||
@@ -444,18 +475,18 @@ mod tests {
|
||||
iv_ctos: iv.clone(),
|
||||
iv_stoc: iv.clone(),
|
||||
});
|
||||
|
||||
|
||||
let ciphertext = ctx.encrypt_packet(plaintext, &key, &iv).unwrap();
|
||||
let decrypted = ctx.decrypt_packet(&ciphertext, &key, &iv).unwrap();
|
||||
|
||||
|
||||
assert_eq!(plaintext.to_vec(), decrypted);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_hmac_sha256() {
|
||||
let key = vec![0u8; 32];
|
||||
let data = b"test data";
|
||||
|
||||
|
||||
let ctx = EncryptionContext::from_session_keys(&SessionKeys {
|
||||
session_id: vec![0u8; 32],
|
||||
encryption_key_ctos: vec![0u8; 32],
|
||||
@@ -465,10 +496,10 @@ mod tests {
|
||||
iv_ctos: vec![0u8; 16],
|
||||
iv_stoc: vec![0u8; 16],
|
||||
});
|
||||
|
||||
|
||||
let mac = ctx.compute_mac(1, data, &key).unwrap();
|
||||
assert_eq!(mac.len(), 32); // HMAC-SHA256 = 32字节
|
||||
|
||||
assert_eq!(mac.len(), 32); // HMAC-SHA256 = 32字节
|
||||
|
||||
// 验证MAC
|
||||
assert!(ctx.verify_mac(1, data, &mac, &key).unwrap());
|
||||
}
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
// SSH加密模块(Phase 3:密钥交换)
|
||||
// 参考OpenSSH curve25519.c, kex.c
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use x25519_dalek::{EphemeralSecret, PublicKey, SharedSecret};
|
||||
use ed25519_dalek::{SigningKey, VerifyingKey, Signature, Signer};
|
||||
use sha2::{Sha256, Digest};
|
||||
use log::{info, debug};
|
||||
use anyhow::{anyhow, Result};
|
||||
use ed25519_dalek::{Signer, SigningKey};
|
||||
use log::info;
|
||||
use rand::rngs::OsRng;
|
||||
use sha2::{Digest, Sha256};
|
||||
use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||
|
||||
/// Curve25519密钥交换处理器(参考OpenSSH curve25519.c)
|
||||
pub struct Curve25519Kex {
|
||||
secret: Option<EphemeralSecret>, // 使用Option包装(一次性使用类型)
|
||||
secret: Option<EphemeralSecret>, // 使用Option包装(一次性使用类型)
|
||||
public: PublicKey,
|
||||
}
|
||||
|
||||
@@ -21,34 +21,37 @@ impl Curve25519Kex {
|
||||
// x25519-dalek 2.0标准API:使用random_from_rng
|
||||
let secret = EphemeralSecret::random_from_rng(OsRng);
|
||||
let public = PublicKey::from(&secret);
|
||||
|
||||
Self { secret: Some(secret), public } // Some包装
|
||||
|
||||
Self {
|
||||
secret: Some(secret),
|
||||
public,
|
||||
} // Some包装
|
||||
}
|
||||
|
||||
|
||||
/// 获取公钥(用于SSH_MSG_KEX_ECDH_INIT)
|
||||
pub fn public_key(&self) -> &[u8] {
|
||||
self.public.as_bytes()
|
||||
}
|
||||
|
||||
|
||||
/// 计算共享密钥(参考OpenSSH curve25519_shared_secret())
|
||||
/// 使用&mut self(消耗模式,符合OpenSSH设计)
|
||||
pub fn compute_shared_secret(&mut self, client_public: &[u8]) -> Result<[u8; 32]> {
|
||||
if client_public.len() != 32 {
|
||||
return Err(anyhow!("Invalid client public key length"));
|
||||
}
|
||||
|
||||
|
||||
info!("=== X25519 Shared Secret Calculation ===");
|
||||
info!("Client public key input: {:?}", client_public);
|
||||
info!("Server public key: {:?}", self.public.as_bytes());
|
||||
|
||||
|
||||
// 参考OpenSSH:curve25519共享密钥计算
|
||||
let client_public_key = PublicKey::from(<[u8; 32]>::try_from(client_public)?);
|
||||
|
||||
|
||||
// 使用take()取出secret(Rust标准模式)
|
||||
if let Some(secret) = self.secret.take() {
|
||||
let shared_secret = secret.diffie_hellman(&client_public_key);
|
||||
info!("Computed shared secret: {:?}", shared_secret.as_bytes());
|
||||
Ok(shared_secret.as_bytes().clone())
|
||||
Ok(*shared_secret.as_bytes())
|
||||
} else {
|
||||
Err(anyhow!("Secret already used"))
|
||||
}
|
||||
@@ -71,47 +74,85 @@ impl SessionKeys {
|
||||
/// RFC 4253 Section 7.2: Key = HASH(K || H || X || session_id)
|
||||
pub fn derive(
|
||||
shared_secret: &[u8],
|
||||
exchange_hash: &[u8], // H参数(exchange hash)
|
||||
server_public_key: &[u8],
|
||||
client_public_key: &[u8],
|
||||
server_host_key: &[u8],
|
||||
exchange_hash: &[u8], // H参数(exchange hash)
|
||||
_server_public_key: &[u8],
|
||||
_client_public_key: &[u8],
|
||||
_server_host_key: &[u8],
|
||||
) -> Result<Self> {
|
||||
// RFC 4253: session_id = H (第一次exchange hash)
|
||||
let session_id = exchange_hash.to_vec();
|
||||
|
||||
|
||||
info!("SessionKeys::derive() starting");
|
||||
info!(" shared_secret full (32 bytes): {:?}", shared_secret);
|
||||
|
||||
|
||||
// RFC 8731 Section 3.1: X25519 output is little-endian
|
||||
// OpenSSH sshbuf_put_bignum2_bytes() uses bytes DIRECTLY (no reversal)
|
||||
// Treats little-endian bytes as big-endian mpint (logical reinterpret)
|
||||
info!(" Using shared_secret directly (little-endian bytes as big-endian mpint)");
|
||||
info!(" shared_secret[0] = {} (>=0x80? {})", shared_secret[0], shared_secret[0] >= 0x80);
|
||||
info!(
|
||||
" shared_secret[0] = {} (>=0x80? {})",
|
||||
shared_secret[0],
|
||||
shared_secret[0] >= 0x80
|
||||
);
|
||||
info!(" exchange_hash full (32 bytes): {:?}", exchange_hash);
|
||||
info!(" session_id full (32 bytes): {:?}", session_id);
|
||||
|
||||
|
||||
// RFC 4253密钥派生公式:HASH(K || H || X || session_id)
|
||||
// K is shared_secret encoded as mpint (using little-endian bytes directly)
|
||||
let shared_secret_mpint = Self::encode_mpint(shared_secret);
|
||||
|
||||
info!(" shared_secret_mpint ({} bytes): {:?}", shared_secret_mpint.len(), &shared_secret_mpint[..std::cmp::min(12, shared_secret_mpint.len())]);
|
||||
|
||||
let encryption_key_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'C', &session_id)?;
|
||||
let encryption_key_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'D', &session_id)?;
|
||||
let mac_key_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'E', &session_id)?;
|
||||
let mac_key_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'F', &session_id)?;
|
||||
|
||||
let iv_ctos = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'A', &session_id)?;
|
||||
let iv_stoc = Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'B', &session_id)?;
|
||||
|
||||
|
||||
info!(
|
||||
" shared_secret_mpint ({} bytes): {:?}",
|
||||
shared_secret_mpint.len(),
|
||||
&shared_secret_mpint[..std::cmp::min(12, shared_secret_mpint.len())]
|
||||
);
|
||||
|
||||
let encryption_key_ctos =
|
||||
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'C', &session_id)?;
|
||||
let encryption_key_stoc =
|
||||
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'D', &session_id)?;
|
||||
let mac_key_ctos =
|
||||
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'E', &session_id)?;
|
||||
let mac_key_stoc =
|
||||
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'F', &session_id)?;
|
||||
|
||||
let iv_ctos =
|
||||
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'A', &session_id)?;
|
||||
let iv_stoc =
|
||||
Self::derive_key_rfc4253(&shared_secret_mpint, exchange_hash, 'B', &session_id)?;
|
||||
|
||||
info!("Derived keys summary:");
|
||||
info!(" encryption_key_ctos ({} bytes): {:?}", encryption_key_ctos.len(), &encryption_key_ctos[..std::cmp::min(16, encryption_key_ctos.len())]);
|
||||
info!(" encryption_key_stoc ({} bytes): {:?}", encryption_key_stoc.len(), &encryption_key_stoc[..std::cmp::min(16, encryption_key_stoc.len())]);
|
||||
info!(" iv_ctos ({} bytes): {:?}", iv_ctos.len(), &iv_ctos[..std::cmp::min(16, iv_ctos.len())]);
|
||||
info!(" iv_stoc ({} bytes): {:?}", iv_stoc.len(), &iv_stoc[..std::cmp::min(16, iv_stoc.len())]);
|
||||
info!(" mac_key_ctos ({} bytes): {:?}", mac_key_ctos.len(), &mac_key_ctos[..std::cmp::min(16, mac_key_ctos.len())]);
|
||||
info!(" mac_key_stoc ({} bytes): {:?}", mac_key_stoc.len(), &mac_key_stoc[..std::cmp::min(16, mac_key_stoc.len())]);
|
||||
|
||||
info!(
|
||||
" encryption_key_ctos ({} bytes): {:?}",
|
||||
encryption_key_ctos.len(),
|
||||
&encryption_key_ctos[..std::cmp::min(16, encryption_key_ctos.len())]
|
||||
);
|
||||
info!(
|
||||
" encryption_key_stoc ({} bytes): {:?}",
|
||||
encryption_key_stoc.len(),
|
||||
&encryption_key_stoc[..std::cmp::min(16, encryption_key_stoc.len())]
|
||||
);
|
||||
info!(
|
||||
" iv_ctos ({} bytes): {:?}",
|
||||
iv_ctos.len(),
|
||||
&iv_ctos[..std::cmp::min(16, iv_ctos.len())]
|
||||
);
|
||||
info!(
|
||||
" iv_stoc ({} bytes): {:?}",
|
||||
iv_stoc.len(),
|
||||
&iv_stoc[..std::cmp::min(16, iv_stoc.len())]
|
||||
);
|
||||
info!(
|
||||
" mac_key_ctos ({} bytes): {:?}",
|
||||
mac_key_ctos.len(),
|
||||
&mac_key_ctos[..std::cmp::min(16, mac_key_ctos.len())]
|
||||
);
|
||||
info!(
|
||||
" mac_key_stoc ({} bytes): {:?}",
|
||||
mac_key_stoc.len(),
|
||||
&mac_key_stoc[..std::cmp::min(16, mac_key_stoc.len())]
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
session_id,
|
||||
encryption_key_ctos,
|
||||
@@ -122,65 +163,73 @@ impl SessionKeys {
|
||||
iv_stoc,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// RFC 4253密钥派生函数
|
||||
/// 公式:Key = HASH(K || H || X || session_id)
|
||||
fn derive_key_rfc4253(K_mpint: &[u8], H: &[u8], X: char, session_id: &[u8]) -> Result<Vec<u8>> {
|
||||
let mut hasher = Sha256::new();
|
||||
|
||||
|
||||
info!("Deriving key for X='{}'", X);
|
||||
info!(" K_mpint ({} bytes): {:?}", K_mpint.len(), &K_mpint[..std::cmp::min(8, K_mpint.len())]);
|
||||
info!(
|
||||
" K_mpint ({} bytes): {:?}",
|
||||
K_mpint.len(),
|
||||
&K_mpint[..std::cmp::min(8, K_mpint.len())]
|
||||
);
|
||||
info!(" H ({} bytes): {:?}", H.len(), &H[..8]);
|
||||
info!(" session_id ({} bytes): {:?}", session_id.len(), &session_id[..8]);
|
||||
|
||||
info!(
|
||||
" session_id ({} bytes): {:?}",
|
||||
session_id.len(),
|
||||
&session_id[..8]
|
||||
);
|
||||
|
||||
// RFC 4253: HASH(K || H || X || session_id)
|
||||
hasher.update(K_mpint); // K (shared secret in mpint format)
|
||||
hasher.update(H); // H (exchange hash)
|
||||
hasher.update(&[X as u8]); // X (single character)
|
||||
hasher.update(K_mpint); // K (shared secret in mpint format)
|
||||
hasher.update(H); // H (exchange hash)
|
||||
hasher.update([X as u8]); // X (single character)
|
||||
hasher.update(session_id); // session_id
|
||||
|
||||
|
||||
let full_hash = hasher.finalize();
|
||||
|
||||
|
||||
info!(" Derived key (first 8 bytes): {:?}", &full_hash[..8]);
|
||||
|
||||
|
||||
// 根據key類型返回不同長度:
|
||||
// AES-128-CTR key/IV: 16 bytes
|
||||
// HMAC-SHA256 key: 32 bytes
|
||||
match X {
|
||||
'A' | 'B' | 'C' | 'D' => Ok(full_hash[..16].to_vec()), // IV or encryption key
|
||||
'A' | 'B' | 'C' | 'D' => Ok(full_hash[..16].to_vec()), // IV or encryption key
|
||||
'E' | 'F' => Ok(full_hash.to_vec()), // MAC key (full 32 bytes)
|
||||
_ => Ok(full_hash[..16].to_vec()), // default
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// SSH mpint编码(参考RFC 4253 Section 5)
|
||||
/// Curve25519 shared secret特殊处理
|
||||
fn encode_mpint(bytes: &[u8]) -> Vec<u8> {
|
||||
// RFC 4253: mpint = uint32(length) + data
|
||||
// 去掉前导零,如果最高位>=0x80前面加0
|
||||
|
||||
|
||||
// 去掉前导零字节(但不去掉最后一个字节即使它是0)
|
||||
let mut start = 0;
|
||||
while start < bytes.len() - 1 && bytes[start] == 0 {
|
||||
start += 1;
|
||||
}
|
||||
|
||||
|
||||
let data_without_leading_zeros = &bytes[start..];
|
||||
|
||||
|
||||
// 构建mpint数据
|
||||
let mut mpint_data = Vec::new();
|
||||
|
||||
|
||||
// 如果最高位>=0x80,前面加0字节(避免负数)
|
||||
if data_without_leading_zeros[0] >= 0x80 {
|
||||
mpint_data.push(0);
|
||||
}
|
||||
mpint_data.extend_from_slice(data_without_leading_zeros);
|
||||
|
||||
|
||||
// 最终格式:uint32长度 + mpint数据
|
||||
let mut result = Vec::new();
|
||||
result.extend_from_slice(&(mpint_data.len() as u32).to_be_bytes());
|
||||
result.extend_from_slice(&mpint_data);
|
||||
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
@@ -192,45 +241,45 @@ pub struct Ed25519HostKey {
|
||||
|
||||
impl Ed25519HostKey {
|
||||
/// 加载或生成主机密钥(参考OpenSSH hostfile.c)
|
||||
pub fn load_or_generate(key_path: &str) -> Result<Self> {
|
||||
pub fn load_or_generate(_key_path: &str) -> Result<Self> {
|
||||
// 简化实现:生成临时密钥(实际应从文件加载)
|
||||
// 参考OpenSSH ssh-keygen
|
||||
|
||||
|
||||
let signing_key = SigningKey::generate(&mut OsRng);
|
||||
|
||||
|
||||
Ok(Self { signing_key })
|
||||
}
|
||||
|
||||
|
||||
/// 获取公钥(用于SSH_MSG_KEX_ECDH_REPLY)
|
||||
pub fn public_key_bytes(&self) -> Vec<u8> {
|
||||
// SSH Ed25519公钥格式(参考OpenSSH sshkey.c)
|
||||
let verifying_key = self.signing_key.verifying_key();
|
||||
|
||||
|
||||
// SSH格式:ssh-ed25519 + 公钥bytes
|
||||
// 简化:仅返回公钥bytes(32字节)
|
||||
verifying_key.as_bytes().to_vec()
|
||||
}
|
||||
|
||||
|
||||
/// 签名(参考OpenSSH sshkey.c: sshkey_sign())
|
||||
pub fn sign(&self, data: &[u8]) -> Result<Vec<u8>> {
|
||||
// OpenSSH Ed25519签名
|
||||
let signature = self.signing_key.sign(data);
|
||||
|
||||
|
||||
// SSH签名格式(参考OpenSSH ssh-sign.c)
|
||||
// 简化:仅返回签名bytes(64字节)
|
||||
Ok(signature.to_bytes().to_vec())
|
||||
}
|
||||
|
||||
|
||||
/// 获取完整SSH公钥格式(参考OpenSSH sshkey.c)
|
||||
pub fn ssh_public_key(&self) -> String {
|
||||
let public_bytes = self.public_key_bytes();
|
||||
|
||||
|
||||
// SSH公钥格式:ssh-ed25519 <base64-encoded-public-key>
|
||||
// 参考OpenSSH ssh-keygen -y
|
||||
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
let encoded = general_purpose::STANDARD.encode(&public_bytes);
|
||||
|
||||
|
||||
format!("ssh-ed25519 {}", encoded)
|
||||
}
|
||||
}
|
||||
@@ -238,40 +287,44 @@ impl Ed25519HostKey {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_curve25519_key_generation() {
|
||||
let kex = Curve25519Kex::new();
|
||||
assert_eq!(kex.public_key().len(), 32);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_curve25519_shared_secret() {
|
||||
let mut client_kex = Curve25519Kex::new();
|
||||
let mut server_kex = Curve25519Kex::new();
|
||||
|
||||
|
||||
// 客户端计算共享密钥
|
||||
let client_secret = client_kex.compute_shared_secret(server_kex.public_key()).unwrap();
|
||||
|
||||
let client_secret = client_kex
|
||||
.compute_shared_secret(server_kex.public_key())
|
||||
.unwrap();
|
||||
|
||||
// 服务器计算共享密钥
|
||||
let server_secret = server_kex.compute_shared_secret(client_kex.public_key()).unwrap();
|
||||
|
||||
let server_secret = server_kex
|
||||
.compute_shared_secret(client_kex.public_key())
|
||||
.unwrap();
|
||||
|
||||
// 应该相同(Curve25519特性)
|
||||
assert_eq!(client_secret, server_secret);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_ed25519_host_key() {
|
||||
let host_key = Ed25519HostKey::load_or_generate("test_key").unwrap();
|
||||
assert_eq!(host_key.public_key_bytes().len(), 32);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_ed25519_signature() {
|
||||
let host_key = Ed25519HostKey::load_or_generate("test_key").unwrap();
|
||||
let data = b"test data";
|
||||
|
||||
|
||||
let signature = host_key.sign(data).unwrap();
|
||||
assert_eq!(signature.len(), 64); // Ed25519签名64字节
|
||||
assert_eq!(signature.len(), 64); // Ed25519签名64字节
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
// SSH端口转发数据传输(Phase 13.5)
|
||||
// 参考OpenSSH channels.c: channel_handle_data()
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, warn, debug};
|
||||
use std::net::{TcpStream};
|
||||
use std::io::{Read, Write};
|
||||
use std::thread;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use anyhow::{anyhow, Result};
|
||||
use byteorder::{BigEndian, WriteBytesExt};
|
||||
use log::{debug, info, warn};
|
||||
use std::io::{Read, Write};
|
||||
use std::net::TcpStream;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
|
||||
/// 数据转发器(Phase 13.5:双向数据传输)
|
||||
pub struct DataForwarder {
|
||||
@@ -25,29 +25,40 @@ impl DataForwarder {
|
||||
max_packet_size,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 启动双向数据转发(Phase 13.5:SSH channel ↔ TCP socket)
|
||||
pub fn start_bidirectional_forwarding(
|
||||
&mut self,
|
||||
ssh_stream: TcpStream, // SSH client连接(加密通道)
|
||||
target_stream: TcpStream, // 目标服务连接(TCP socket)
|
||||
ssh_stream: TcpStream, // SSH client连接(加密通道)
|
||||
target_stream: TcpStream, // 目标服务连接(TCP socket)
|
||||
) -> Result<()> {
|
||||
info!("Starting bidirectional data forwarding for channel {}", self.channel_id);
|
||||
|
||||
info!(
|
||||
"Starting bidirectional data forwarding for channel {}",
|
||||
self.channel_id
|
||||
);
|
||||
|
||||
// Phase 13.5: SSH channel → Target socket(SSH client数据 → 本地服务)
|
||||
let ssh_to_target = self.start_ssh_to_target_forwarding(ssh_stream.try_clone()?, target_stream.try_clone()?);
|
||||
|
||||
let ssh_to_target = self
|
||||
.start_ssh_to_target_forwarding(ssh_stream.try_clone()?, target_stream.try_clone()?);
|
||||
|
||||
// Phase 13.5: Target socket → SSH channel(本地服务数据 → SSH client)
|
||||
let target_to_ssh = self.start_target_to_ssh_forwarding(target_stream, ssh_stream);
|
||||
|
||||
|
||||
// Phase 13.5: 等待两个转发线程完成
|
||||
ssh_to_target.join().map_err(|e| anyhow!("SSH to target thread error: {:?}", e))?;
|
||||
target_to_ssh.join().map_err(|e| anyhow!("Target to SSH thread error: {:?}", e))?;
|
||||
|
||||
info!("Bidirectional data forwarding completed for channel {}", self.channel_id);
|
||||
ssh_to_target
|
||||
.join()
|
||||
.map_err(|e| anyhow!("SSH to target thread error: {:?}", e))?;
|
||||
target_to_ssh
|
||||
.join()
|
||||
.map_err(|e| anyhow!("Target to SSH thread error: {:?}", e))?;
|
||||
|
||||
info!(
|
||||
"Bidirectional data forwarding completed for channel {}",
|
||||
self.channel_id
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// SSH channel → Target socket转发(Phase 13.5)
|
||||
fn start_ssh_to_target_forwarding(
|
||||
&self,
|
||||
@@ -57,18 +68,21 @@ impl DataForwarder {
|
||||
let channel_id = self.channel_id;
|
||||
let window_size = self.window_size.clone();
|
||||
let max_packet_size = self.max_packet_size;
|
||||
|
||||
|
||||
thread::spawn(move || {
|
||||
info!("SSH to target forwarding thread started for channel {}", channel_id);
|
||||
|
||||
info!(
|
||||
"SSH to target forwarding thread started for channel {}",
|
||||
channel_id
|
||||
);
|
||||
|
||||
let mut buffer = vec![0u8; max_packet_size as usize];
|
||||
|
||||
|
||||
loop {
|
||||
// Phase 13.5: 从SSH channel读取数据
|
||||
let n = match ssh_stream.read(&mut buffer) {
|
||||
Ok(0) => {
|
||||
info!("SSH channel EOF for channel {}", channel_id);
|
||||
break; // EOF
|
||||
break; // EOF
|
||||
}
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
@@ -76,45 +90,61 @@ impl DataForwarder {
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Phase 13.5: 检查window size
|
||||
{
|
||||
let window = window_size.lock().unwrap();
|
||||
if *window < n as u32 {
|
||||
warn!("Window size insufficient for channel {}: need {}, have {}",
|
||||
channel_id, n, *window);
|
||||
warn!(
|
||||
"Window size insufficient for channel {}: need {}, have {}",
|
||||
channel_id, n, *window
|
||||
);
|
||||
// Phase 13.5: 理论上应该等待SSH_MSG_CHANNEL_WINDOW_ADJUST
|
||||
// 简化实现:继续发送(可能会违反RFC 4254)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Phase 13.5: 写入目标socket
|
||||
if let Err(e) = target_stream.write_all(&buffer[..n]) {
|
||||
warn!("Target socket write error for channel {}: {}", channel_id, e);
|
||||
warn!(
|
||||
"Target socket write error for channel {}: {}",
|
||||
channel_id, e
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Phase 13.5: Flush确保数据发送
|
||||
if let Err(e) = target_stream.flush() {
|
||||
warn!("Target socket flush error for channel {}: {}", channel_id, e);
|
||||
warn!(
|
||||
"Target socket flush error for channel {}: {}",
|
||||
channel_id, e
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Phase 13.5: 消耗window size
|
||||
{
|
||||
let mut window = window_size.lock().unwrap();
|
||||
*window -= n as u32;
|
||||
debug!("Window size consumed for channel {}: {} bytes, remaining {}",
|
||||
channel_id, n, *window);
|
||||
debug!(
|
||||
"Window size consumed for channel {}: {} bytes, remaining {}",
|
||||
channel_id, n, *window
|
||||
);
|
||||
}
|
||||
|
||||
info!("Forwarded {} bytes from SSH to target for channel {}", n, channel_id);
|
||||
|
||||
info!(
|
||||
"Forwarded {} bytes from SSH to target for channel {}",
|
||||
n, channel_id
|
||||
);
|
||||
}
|
||||
|
||||
info!("SSH to target forwarding thread stopped for channel {}", channel_id);
|
||||
|
||||
info!(
|
||||
"SSH to target forwarding thread stopped for channel {}",
|
||||
channel_id
|
||||
);
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// Target socket → SSH channel转发(Phase 13.5)
|
||||
fn start_target_to_ssh_forwarding(
|
||||
&self,
|
||||
@@ -122,18 +152,21 @@ impl DataForwarder {
|
||||
mut ssh_stream: TcpStream,
|
||||
) -> thread::JoinHandle<()> {
|
||||
let channel_id = self.channel_id;
|
||||
|
||||
|
||||
thread::spawn(move || {
|
||||
info!("Target to SSH forwarding thread started for channel {}", channel_id);
|
||||
|
||||
let mut buffer = vec![0u8; 8192]; // 8KB buffer
|
||||
|
||||
info!(
|
||||
"Target to SSH forwarding thread started for channel {}",
|
||||
channel_id
|
||||
);
|
||||
|
||||
let mut buffer = vec![0u8; 8192]; // 8KB buffer
|
||||
|
||||
loop {
|
||||
// Phase 13.5: 从目标socket读取数据
|
||||
let n = match target_stream.read(&mut buffer) {
|
||||
Ok(0) => {
|
||||
info!("Target socket EOF for channel {}", channel_id);
|
||||
break; // EOF
|
||||
break; // EOF
|
||||
}
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
@@ -141,43 +174,51 @@ impl DataForwarder {
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Phase 13.5: 构建SSH_MSG_CHANNEL_DATA packet
|
||||
// 注意:实际实现需要通过EncryptedPacket加密
|
||||
// 这里简化实现,直接写入SSH stream(测试用)
|
||||
|
||||
|
||||
// Phase 13.5: 写入SSH channel
|
||||
if let Err(e) = ssh_stream.write_all(&buffer[..n]) {
|
||||
warn!("SSH channel write error for channel {}: {}", channel_id, e);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Phase 13.5: Flush确保数据发送
|
||||
if let Err(e) = ssh_stream.flush() {
|
||||
warn!("SSH channel flush error for channel {}: {}", channel_id, e);
|
||||
break;
|
||||
}
|
||||
|
||||
info!("Forwarded {} bytes from target to SSH for channel {}", n, channel_id);
|
||||
|
||||
info!(
|
||||
"Forwarded {} bytes from target to SSH for channel {}",
|
||||
n, channel_id
|
||||
);
|
||||
}
|
||||
|
||||
info!("Target to SSH forwarding thread stopped for channel {}", channel_id);
|
||||
|
||||
info!(
|
||||
"Target to SSH forwarding thread stopped for channel {}",
|
||||
channel_id
|
||||
);
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// 获取当前window size(Phase 13.5)
|
||||
pub fn get_window_size(&self) -> u32 {
|
||||
*self.window_size.lock().unwrap()
|
||||
}
|
||||
|
||||
|
||||
/// 增加window size(Phase 13.5:SSH_MSG_CHANNEL_WINDOW_ADJUST)
|
||||
pub fn adjust_window_size(&self, bytes_to_add: u32) {
|
||||
let mut window = self.window_size.lock().unwrap();
|
||||
*window += bytes_to_add;
|
||||
info!("Window size adjusted for channel {}: added {} bytes, total {}",
|
||||
self.channel_id, bytes_to_add, *window);
|
||||
info!(
|
||||
"Window size adjusted for channel {}: added {} bytes, total {}",
|
||||
self.channel_id, bytes_to_add, *window
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
/// 检查window size是否足够(Phase 13.5)
|
||||
pub fn check_window_available(&self, data_size: u32) -> bool {
|
||||
let window = self.window_size.lock().unwrap();
|
||||
@@ -188,64 +229,64 @@ impl DataForwarder {
|
||||
/// SSH_MSG_CHANNEL_DATA构建(Phase 13.5)
|
||||
pub fn build_channel_data_packet(channel_id: u32, data: &[u8]) -> Result<Vec<u8>> {
|
||||
let mut packet = Vec::new();
|
||||
|
||||
|
||||
// Packet type: SSH_MSG_CHANNEL_DATA (type 94)
|
||||
packet.write_u8(94)?;
|
||||
|
||||
|
||||
// Recipient channel ID
|
||||
packet.write_u32::<BigEndian>(channel_id)?;
|
||||
|
||||
|
||||
// Data length (SSH string)
|
||||
packet.write_u32::<BigEndian>(data.len() as u32)?;
|
||||
|
||||
|
||||
// Data content
|
||||
packet.write_all(data)?;
|
||||
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
|
||||
/// SSH_MSG_CHANNEL_WINDOW_ADJUST构建(Phase 13.5)
|
||||
pub fn build_window_adjust_packet(channel_id: u32, bytes_to_add: u32) -> Result<Vec<u8>> {
|
||||
let mut packet = Vec::new();
|
||||
|
||||
|
||||
// Packet type: SSH_MSG_CHANNEL_WINDOW_ADJUST (type 93)
|
||||
packet.write_u8(93)?;
|
||||
|
||||
|
||||
// Recipient channel ID
|
||||
packet.write_u32::<BigEndian>(channel_id)?;
|
||||
|
||||
|
||||
// Bytes to add
|
||||
packet.write_u32::<BigEndian>(bytes_to_add)?;
|
||||
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_data_forwarder_creation() {
|
||||
let forwarder = DataForwarder::new(1, 2097152, 32768);
|
||||
assert_eq!(forwarder.channel_id, 1);
|
||||
assert_eq!(forwarder.get_window_size(), 2097152);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_window_size_adjustment() {
|
||||
let forwarder = DataForwarder::new(1, 2097152, 32768);
|
||||
|
||||
|
||||
// 消耗window size
|
||||
forwarder.adjust_window_size(1000);
|
||||
assert_eq!(forwarder.get_window_size(), 2097152 + 1000);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_build_channel_data_packet() {
|
||||
let data = b"Hello, SSH!";
|
||||
let packet = build_channel_data_packet(1, data).unwrap();
|
||||
|
||||
assert_eq!(packet[0], 94); // SSH_MSG_CHANNEL_DATA
|
||||
// 验证packet结构
|
||||
|
||||
assert_eq!(packet[0], 94); // SSH_MSG_CHANNEL_DATA
|
||||
// 验证packet结构
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,42 +1,42 @@
|
||||
// SSH密钥交换算法协商实现(Phase 2)
|
||||
// 参考OpenSSH kex.c: kex_send_kexinit(), kex_choose_conf()
|
||||
|
||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
||||
use anyhow::{Result, anyhow};
|
||||
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||
use anyhow::{anyhow, Result};
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use log::{info, debug};
|
||||
use log::{debug, info};
|
||||
use std::io::{Read, Write};
|
||||
|
||||
/// SSH算法类型(参考OpenSSH PROTOCOL定义)
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum AlgorithmType {
|
||||
KEX_ALGS = 0, // 密钥交换算法
|
||||
KEX_ALGS = 0, // 密钥交换算法
|
||||
SERVER_HOST_KEY_ALGS = 1, // 服务器主机密钥算法
|
||||
ENC_ALGS_CTOS = 2, // 客户端到服务器加密算法
|
||||
ENC_ALGS_STOC = 3, // 服务器到客户端加密算法
|
||||
MAC_ALGS_CTOS = 4, // 客户端到服务器MAC算法
|
||||
MAC_ALGS_STOC = 5, // 服务器到客户端MAC算法
|
||||
COMP_ALGS_CTOS = 6, // 客户端到服务器压缩算法
|
||||
COMP_ALGS_STOC = 7, // 服务器到客户端压缩算法
|
||||
LANGS_CTOS = 8, // 客户端到服务器语言
|
||||
LANGS_STOC = 9, // 服务器到客户端语言
|
||||
ENC_ALGS_CTOS = 2, // 客户端到服务器加密算法
|
||||
ENC_ALGS_STOC = 3, // 服务器到客户端加密算法
|
||||
MAC_ALGS_CTOS = 4, // 客户端到服务器MAC算法
|
||||
MAC_ALGS_STOC = 5, // 服务器到客户端MAC算法
|
||||
COMP_ALGS_CTOS = 6, // 客户端到服务器压缩算法
|
||||
COMP_ALGS_STOC = 7, // 服务器到客户端压缩算法
|
||||
LANGS_CTOS = 8, // 客户端到服务器语言
|
||||
LANGS_STOC = 9, // 服务器到客户端语言
|
||||
}
|
||||
|
||||
/// SSH算法提议(参考OpenSSH kex.h: struct kex)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KexProposal {
|
||||
pub kex_algorithms: String, // 密钥交换算法列表
|
||||
pub server_host_key_algorithms: String, // 主机密钥算法列表
|
||||
pub encryption_algorithms_ctos: String, // 加密算法(客户端→服务器)
|
||||
pub encryption_algorithms_stoc: String, // 加密算法(服务器→客户端)
|
||||
pub mac_algorithms_ctos: String, // MAC算法(客户端→服务器)
|
||||
pub mac_algorithms_stoc: String, // MAC算法(服务器→客户端)
|
||||
pub kex_algorithms: String, // 密钥交换算法列表
|
||||
pub server_host_key_algorithms: String, // 主机密钥算法列表
|
||||
pub encryption_algorithms_ctos: String, // 加密算法(客户端→服务器)
|
||||
pub encryption_algorithms_stoc: String, // 加密算法(服务器→客户端)
|
||||
pub mac_algorithms_ctos: String, // MAC算法(客户端→服务器)
|
||||
pub mac_algorithms_stoc: String, // MAC算法(服务器→客户端)
|
||||
pub compression_algorithms_ctos: String, // 压缩算法(客户端→服务器)
|
||||
pub compression_algorithms_stoc: String, // 压缩算法(服务器→客户端)
|
||||
pub languages_ctos: String, // 语言(客户端→服务器)
|
||||
pub languages_stoc: String, // 语言(服务器→客户端)
|
||||
pub first_kex_packet_follows: bool, // 是否立即发送第一个KEX packet
|
||||
pub reserved: u32, // 保留字段(0)
|
||||
pub languages_ctos: String, // 语言(客户端→服务器)
|
||||
pub languages_stoc: String, // 语言(服务器→客户端)
|
||||
pub first_kex_packet_follows: bool, // 是否立即发送第一个KEX packet
|
||||
pub reserved: u32, // 保留字段(0)
|
||||
}
|
||||
|
||||
impl KexProposal {
|
||||
@@ -46,31 +46,31 @@ impl KexProposal {
|
||||
Self {
|
||||
// 密钥交换算法:优先Curve25519(推荐) + strict KEX extension
|
||||
kex_algorithms: "curve25519-sha256,curve25519-sha256@libssh.org,diffie-hellman-group14-sha256,ext-info-s,kex-strict-s-v00@openssh.com".to_string(),
|
||||
|
||||
|
||||
// 主机密钥算法:优先Ed25519
|
||||
server_host_key_algorithms: "ssh-ed25519,rsa-sha2-256,rsa-sha2-512".to_string(),
|
||||
|
||||
|
||||
// 加密算法:AES-256-CTR(推荐)
|
||||
encryption_algorithms_ctos: "aes256-ctr,aes128-ctr".to_string(),
|
||||
encryption_algorithms_stoc: "aes256-ctr,aes128-ctr".to_string(),
|
||||
|
||||
|
||||
// MAC算法:HMAC-SHA256
|
||||
mac_algorithms_ctos: "hmac-sha2-256,hmac-sha2-512".to_string(),
|
||||
mac_algorithms_stoc: "hmac-sha2-256,hmac-sha2-512".to_string(),
|
||||
|
||||
|
||||
// 压缩算法:none优先
|
||||
compression_algorithms_ctos: "none,zlib".to_string(),
|
||||
compression_algorithms_stoc: "none,zlib".to_string(),
|
||||
|
||||
|
||||
// 语言:空
|
||||
languages_ctos: "".to_string(),
|
||||
languages_stoc: "".to_string(),
|
||||
|
||||
|
||||
first_kex_packet_follows: false,
|
||||
reserved: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 创建客户端默认提议(用于测试)
|
||||
pub fn client_default() -> Self {
|
||||
Self {
|
||||
@@ -88,20 +88,20 @@ impl KexProposal {
|
||||
reserved: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 序列化到SSH_MSG_KEXINIT packet(参考OpenSSH kex_send_kexinit())
|
||||
pub fn to_kexinit_packet(&self) -> Result<SshPacket> {
|
||||
let mut payload = Vec::new();
|
||||
|
||||
|
||||
// Packet type
|
||||
payload.write_u8(PacketType::SSH_MSG_KEXINIT as u8)?;
|
||||
|
||||
|
||||
// Cookie(16字节随机数,OpenSSH要求)
|
||||
let mut cookie = [0u8; 16];
|
||||
use rand::Rng;
|
||||
rand::thread_rng().fill(&mut cookie);
|
||||
payload.write_all(&cookie)?;
|
||||
|
||||
|
||||
// 10个算法列表(SSH string格式:length + data)
|
||||
write_ssh_string(&mut payload, &self.kex_algorithms)?;
|
||||
write_ssh_string(&mut payload, &self.server_host_key_algorithms)?;
|
||||
@@ -113,29 +113,29 @@ impl KexProposal {
|
||||
write_ssh_string(&mut payload, &self.compression_algorithms_stoc)?;
|
||||
write_ssh_string(&mut payload, &self.languages_ctos)?;
|
||||
write_ssh_string(&mut payload, &self.languages_stoc)?;
|
||||
|
||||
|
||||
// first_kex_packet_follows(boolean)
|
||||
payload.write_u8(if self.first_kex_packet_follows { 1 } else { 0 })?;
|
||||
|
||||
|
||||
// reserved(u32)
|
||||
payload.write_u32::<BigEndian>(self.reserved)?;
|
||||
|
||||
|
||||
Ok(SshPacket::new(payload))
|
||||
}
|
||||
|
||||
|
||||
/// 从SSH_MSG_KEXINIT packet解析(参考OpenSSH kex_input_kexinit())
|
||||
pub fn from_kexinit_packet(packet: &SshPacket) -> Result<Self> {
|
||||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准)
|
||||
|
||||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // 使用as_slice()(Rust标准)
|
||||
|
||||
// Packet type
|
||||
let packet_type = cursor.read_u8()?;
|
||||
if packet_type != PacketType::SSH_MSG_KEXINIT as u8 {
|
||||
return Err(anyhow!("Invalid packet type for KEXINIT"));
|
||||
}
|
||||
|
||||
|
||||
// Cookie(16字节,忽略)
|
||||
cursor.read_exact(&mut [0u8; 16])?;
|
||||
|
||||
|
||||
// 10个算法列表
|
||||
let kex_algorithms = read_ssh_string(&mut cursor)?;
|
||||
let server_host_key_algorithms = read_ssh_string(&mut cursor)?;
|
||||
@@ -147,13 +147,13 @@ impl KexProposal {
|
||||
let compression_algorithms_stoc = read_ssh_string(&mut cursor)?;
|
||||
let languages_ctos = read_ssh_string(&mut cursor)?;
|
||||
let languages_stoc = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
// first_kex_packet_follows
|
||||
let first_kex_packet_follows = cursor.read_u8()? != 0;
|
||||
|
||||
|
||||
// reserved
|
||||
let reserved = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
kex_algorithms,
|
||||
server_host_key_algorithms,
|
||||
@@ -174,14 +174,14 @@ impl KexProposal {
|
||||
/// SSH算法协商结果(参考OpenSSH struct kex)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KexResult {
|
||||
pub kex_algorithm: String, // 选定的密钥交换算法
|
||||
pub host_key_algorithm: String, // 选定的主机密钥算法
|
||||
pub encryption_ctos: String, // 选定的加密算法(客户端→服务器)
|
||||
pub encryption_stoc: String, // 选定的加密算法(服务器→客户端)
|
||||
pub mac_ctos: String, // 选定的MAC算法(客户端→服务器)
|
||||
pub mac_stoc: String, // 选定的MAC算法(服务器→客户端)
|
||||
pub compression_ctos: String, // 选定的压缩算法(客户端→服务器)
|
||||
pub compression_stoc: String, // 选定的压缩算法(服务器→客户端)
|
||||
pub kex_algorithm: String, // 选定的密钥交换算法
|
||||
pub host_key_algorithm: String, // 选定的主机密钥算法
|
||||
pub encryption_ctos: String, // 选定的加密算法(客户端→服务器)
|
||||
pub encryption_stoc: String, // 选定的加密算法(服务器→客户端)
|
||||
pub mac_ctos: String, // 选定的MAC算法(客户端→服务器)
|
||||
pub mac_stoc: String, // 选定的MAC算法(服务器→客户端)
|
||||
pub compression_ctos: String, // 选定的压缩算法(客户端→服务器)
|
||||
pub compression_stoc: String, // 选定的压缩算法(服务器→客户端)
|
||||
}
|
||||
|
||||
/// 算法匹配逻辑(参考OpenSSH kex_choose_conf())
|
||||
@@ -189,28 +189,43 @@ impl KexResult {
|
||||
/// 从服务器和客户端提议中选择算法(参考OpenSSH kex_choose_conf())
|
||||
pub fn choose_algorithms(server: &KexProposal, client: &KexProposal) -> Result<Self> {
|
||||
info!("Starting algorithm negotiation");
|
||||
|
||||
|
||||
// 算法匹配:优先客户端偏好(OpenSSH逻辑)
|
||||
// 参考OpenSSH:客户端列出的算法顺序为偏好顺序
|
||||
|
||||
|
||||
// 密钥交换算法匹配
|
||||
let kex_algorithm = match_algorithm(&client.kex_algorithms, &server.kex_algorithms)?;
|
||||
|
||||
|
||||
// 主机密钥算法匹配
|
||||
let host_key_algorithm = match_algorithm(&client.server_host_key_algorithms, &server.server_host_key_algorithms)?;
|
||||
|
||||
let host_key_algorithm = match_algorithm(
|
||||
&client.server_host_key_algorithms,
|
||||
&server.server_host_key_algorithms,
|
||||
)?;
|
||||
|
||||
// 加密算法匹配
|
||||
let encryption_ctos = match_algorithm(&client.encryption_algorithms_ctos, &server.encryption_algorithms_ctos)?;
|
||||
let encryption_stoc = match_algorithm(&client.encryption_algorithms_stoc, &server.encryption_algorithms_stoc)?;
|
||||
|
||||
let encryption_ctos = match_algorithm(
|
||||
&client.encryption_algorithms_ctos,
|
||||
&server.encryption_algorithms_ctos,
|
||||
)?;
|
||||
let encryption_stoc = match_algorithm(
|
||||
&client.encryption_algorithms_stoc,
|
||||
&server.encryption_algorithms_stoc,
|
||||
)?;
|
||||
|
||||
// MAC算法匹配
|
||||
let mac_ctos = match_algorithm(&client.mac_algorithms_ctos, &server.mac_algorithms_ctos)?;
|
||||
let mac_stoc = match_algorithm(&client.mac_algorithms_stoc, &server.mac_algorithms_stoc)?;
|
||||
|
||||
|
||||
// 压缩算法匹配
|
||||
let compression_ctos = match_algorithm(&client.compression_algorithms_ctos, &server.compression_algorithms_ctos)?;
|
||||
let compression_stoc = match_algorithm(&client.compression_algorithms_stoc, &server.compression_algorithms_stoc)?;
|
||||
|
||||
let compression_ctos = match_algorithm(
|
||||
&client.compression_algorithms_ctos,
|
||||
&server.compression_algorithms_ctos,
|
||||
)?;
|
||||
let compression_stoc = match_algorithm(
|
||||
&client.compression_algorithms_stoc,
|
||||
&server.compression_algorithms_stoc,
|
||||
)?;
|
||||
|
||||
info!("Algorithm negotiation completed:");
|
||||
debug!(" KEX: {}", kex_algorithm);
|
||||
debug!(" Host key: {}", host_key_algorithm);
|
||||
@@ -218,7 +233,7 @@ impl KexResult {
|
||||
debug!(" Encryption (S->C): {}", encryption_stoc);
|
||||
debug!(" MAC (C->S): {}", mac_ctos);
|
||||
debug!(" MAC (S->C): {}", mac_stoc);
|
||||
|
||||
|
||||
Ok(Self {
|
||||
kex_algorithm,
|
||||
host_key_algorithm,
|
||||
@@ -237,15 +252,19 @@ fn match_algorithm(client_algs: &str, server_algs: &str) -> Result<String> {
|
||||
// 算法列表格式:name1,name2,name3,...
|
||||
let client_list: Vec<&str> = client_algs.split(',').collect();
|
||||
let server_list: Vec<&str> = server_algs.split(',').collect();
|
||||
|
||||
|
||||
// OpenSSH逻辑:按客户端偏好顺序匹配
|
||||
for client_alg in &client_list {
|
||||
if server_list.contains(client_alg) {
|
||||
return Ok(client_alg.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow!("No matching algorithm found: client={}, server={}", client_algs, server_algs))
|
||||
|
||||
Err(anyhow!(
|
||||
"No matching algorithm found: client={}, server={}",
|
||||
client_algs,
|
||||
server_algs
|
||||
))
|
||||
}
|
||||
|
||||
/// SSH string写入辅助函数(length + data)
|
||||
@@ -266,36 +285,36 @@ fn read_ssh_string<R: Read>(reader: &mut R) -> Result<String> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_kex_proposal_creation() {
|
||||
let proposal = KexProposal::server_default();
|
||||
assert!(proposal.kex_algorithms.contains("curve25519-sha256"));
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_kex_proposal_serialization() {
|
||||
let proposal = KexProposal::server_default();
|
||||
let packet = proposal.to_kexinit_packet().unwrap();
|
||||
assert!(packet.payload.len() > 0);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_algorithm_matching() {
|
||||
let client = "curve25519-sha256,aes256-ctr";
|
||||
let server = "aes256-ctr,diffie-hellman-group14-sha256";
|
||||
|
||||
|
||||
let matched = match_algorithm(client, server).unwrap();
|
||||
assert_eq!(matched, "aes256-ctr"); // 按客户端顺序匹配
|
||||
assert_eq!(matched, "aes256-ctr"); // 按客户端顺序匹配
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_kex_negotiation() {
|
||||
let server = KexProposal::server_default();
|
||||
let client = KexProposal::client_default();
|
||||
|
||||
|
||||
let result = KexResult::choose_algorithms(&server, &client).unwrap();
|
||||
assert_eq!(result.kex_algorithm, "curve25519-sha256"); // 优先Curve25519
|
||||
assert_eq!(result.encryption_ctos, "aes256-ctr"); // AES-256-CTR
|
||||
assert_eq!(result.kex_algorithm, "curve25519-sha256"); // 优先Curve25519
|
||||
assert_eq!(result.encryption_ctos, "aes256-ctr"); // AES-256-CTR
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
// SSH密钥交换完整流程(Phase 3剩余)
|
||||
// 参考OpenSSH kex.c: complete implementation
|
||||
|
||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
||||
use crate::ssh_server::crypto::SessionKeys;
|
||||
use crate::ssh_server::kex::{KexProposal, KexResult};
|
||||
use crate::ssh_server::crypto::{SessionKeys};
|
||||
use crate::ssh_server::kex_exchange::KexExchangeHandler;
|
||||
use anyhow::{Result, anyhow};
|
||||
use sha2::{Sha256, Digest};
|
||||
use byteorder::{BigEndian, WriteBytesExt};
|
||||
use log::{info, debug};
|
||||
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||
use anyhow::{anyhow, Result};
|
||||
use log::info;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
/// SSH密钥交换完整状态管理(参考OpenSSH struct kex)
|
||||
pub struct KexState {
|
||||
@@ -30,7 +29,7 @@ impl KexState {
|
||||
kex_result: KexResult,
|
||||
) -> Result<Self> {
|
||||
let exchange_handler = KexExchangeHandler::new(kex_result)?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
client_version,
|
||||
server_version,
|
||||
@@ -42,18 +41,18 @@ impl KexState {
|
||||
newkeys_sent: false,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// 保存KEXINIT payloads(用于Exchange Hash计算)
|
||||
///
|
||||
///
|
||||
/// 分析OpenSSH源码后的结论:
|
||||
/// - kex->peer存储的是:incoming_packet剩余内容(payload fields + padding)
|
||||
/// - kex->my存储的是:prop2buf()结果(payload fields,不包括padding)
|
||||
///
|
||||
///
|
||||
/// **但exchange hash必须使用相同的I_C/I_S!**
|
||||
///
|
||||
///
|
||||
/// 疑问:OpenSSH如何确保client和server使用相同的padding?
|
||||
/// 可能答案:OpenSSH在计算exchange hash时,不包括padding?
|
||||
///
|
||||
///
|
||||
/// 暂时保持不包括padding(因为签名验证之前成功)
|
||||
pub fn save_kexinit_payloads(
|
||||
&mut self,
|
||||
@@ -63,12 +62,18 @@ impl KexState {
|
||||
// Only save payload (without padding) for now
|
||||
self.client_kexinit_payload = client_kexinit.payload.clone();
|
||||
self.server_kexinit_payload = server_kexinit.payload.clone();
|
||||
|
||||
|
||||
info!("Saved KEXINIT payloads (payload only, no padding)");
|
||||
info!(" client payload: {} bytes", self.client_kexinit_payload.len());
|
||||
info!(" server payload: {} bytes", self.server_kexinit_payload.len());
|
||||
info!(
|
||||
" client payload: {} bytes",
|
||||
self.client_kexinit_payload.len()
|
||||
);
|
||||
info!(
|
||||
" server payload: {} bytes",
|
||||
self.server_kexinit_payload.len()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
/// 计算Exchange Hash(参考OpenSSH kex.c: kex_hash())
|
||||
/// H = SHA256(V_C || V_S || I_C || I_S || K_S || K_C || K_S || shared_secret)
|
||||
pub fn compute_exchange_hash(
|
||||
@@ -80,74 +85,74 @@ impl KexState {
|
||||
) -> Result<Vec<u8>> {
|
||||
// 参考OpenSSH kex.c: kex_hash()
|
||||
let mut hasher = Sha256::new();
|
||||
|
||||
|
||||
// V_C: 客户端版本字符串(SSH string格式)
|
||||
write_ssh_string_to_hash(&mut hasher, &self.client_version)?;
|
||||
|
||||
|
||||
// V_S: 服务器版本字符串(SSH string格式)
|
||||
write_ssh_string_to_hash(&mut hasher, &self.server_version)?;
|
||||
|
||||
|
||||
// OpenSSH kexgex.c: "kexinit messages: fake header: len+SSH2_MSG_KEXINIT"
|
||||
// Remove SSH_MSG_KEXINIT type byte from payloads and prepend it in exchange hash
|
||||
|
||||
|
||||
let client_kexinit_without_type = &self.client_kexinit_payload[1..];
|
||||
let server_kexinit_without_type = &self.server_kexinit_payload[1..];
|
||||
|
||||
hasher.update(&((client_kexinit_without_type.len() + 1) as u32).to_be_bytes());
|
||||
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
|
||||
|
||||
hasher.update(((client_kexinit_without_type.len() + 1) as u32).to_be_bytes());
|
||||
hasher.update([20]); // SSH_MSG_KEXINIT type byte
|
||||
hasher.update(client_kexinit_without_type);
|
||||
|
||||
hasher.update(&((server_kexinit_without_type.len() + 1) as u32).to_be_bytes());
|
||||
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
|
||||
|
||||
hasher.update(((server_kexinit_without_type.len() + 1) as u32).to_be_bytes());
|
||||
hasher.update([20]); // SSH_MSG_KEXINIT type byte
|
||||
hasher.update(server_kexinit_without_type);
|
||||
|
||||
|
||||
// K_S: 服务器主机密钥blob(SSH string格式)
|
||||
hasher.update(server_host_key_blob);
|
||||
|
||||
|
||||
// K_C: 客户端Curve25519公钥(SSH string格式)
|
||||
write_ssh_bytes_to_hash(&mut hasher, client_public_key)?;
|
||||
|
||||
|
||||
// K_S: 服务器Curve25519公钥(SSH string格式)
|
||||
write_ssh_bytes_to_hash(&mut hasher, server_public_key)?;
|
||||
|
||||
|
||||
// K: 共享密钥(SSH mpint格式)
|
||||
// OpenSSH要求:去掉前导零
|
||||
write_ssh_mpint_to_hash(&mut hasher, shared_secret)?;
|
||||
|
||||
|
||||
Ok(hasher.finalize().to_vec())
|
||||
}
|
||||
|
||||
|
||||
/// 处理SSH_MSG_NEWKEYS(参考OpenSSH kex.c: kex_input_newkeys())
|
||||
pub fn handle_newkeys(&mut self, packet: &SshPacket) -> Result<()> {
|
||||
info!("Processing SSH_MSG_NEWKEYS");
|
||||
|
||||
|
||||
// 验证packet类型
|
||||
if packet.payload.len() < 1 {
|
||||
if packet.payload.is_empty() {
|
||||
return Err(anyhow!("Invalid NEWKEYS packet"));
|
||||
}
|
||||
|
||||
|
||||
let packet_type = packet.payload[0];
|
||||
if packet_type != PacketType::SSH_MSG_NEWKEYS as u8 {
|
||||
return Err(anyhow!("Invalid packet type for NEWKEYS"));
|
||||
}
|
||||
|
||||
|
||||
// 标记NEWKEYS接收完成(参考OpenSSH)
|
||||
self.newkeys_received = true;
|
||||
|
||||
|
||||
info!("SSH_MSG_NEWKEYS received, encryption channel ready");
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 发送SSH_MSG_NEWKEYS(参考OpenSSH kex.c: kex_send_newkeys())
|
||||
pub fn send_newkeys() -> Result<SshPacket> {
|
||||
info!("Sending SSH_MSG_NEWKEYS");
|
||||
|
||||
|
||||
let payload = vec![PacketType::SSH_MSG_NEWKEYS as u8];
|
||||
|
||||
|
||||
Ok(SshPacket::new(payload))
|
||||
}
|
||||
|
||||
|
||||
/// 检查NEWKEYS完成状态(加密通道建立)
|
||||
pub fn is_encryption_ready(&self) -> bool {
|
||||
self.newkeys_received && self.newkeys_sent
|
||||
@@ -156,14 +161,14 @@ impl KexState {
|
||||
|
||||
/// SSH string写入到hash(辅助函数)
|
||||
fn write_ssh_string_to_hash(hasher: &mut Sha256, s: &str) -> Result<()> {
|
||||
hasher.update(&(s.len() as u32).to_be_bytes());
|
||||
hasher.update((s.len() as u32).to_be_bytes());
|
||||
hasher.update(s.as_bytes());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// SSH bytes写入到hash(辅助函数)
|
||||
fn write_ssh_bytes_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
|
||||
hasher.update(&(bytes.len() as u32).to_be_bytes());
|
||||
hasher.update((bytes.len() as u32).to_be_bytes());
|
||||
hasher.update(bytes);
|
||||
Ok(())
|
||||
}
|
||||
@@ -171,7 +176,7 @@ fn write_ssh_bytes_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
|
||||
/// SSH mpint写入到hash(参考OpenSSH sshbuf_put_mpint())
|
||||
fn write_ssh_mpint_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
|
||||
// OpenSSH要求:去掉前导零(如果最高位为1)
|
||||
let mpint_bytes = if bytes.len() > 0 && bytes[0] >= 0x80 {
|
||||
let mpint_bytes = if !bytes.is_empty() && bytes[0] >= 0x80 {
|
||||
// 需要添加前导零(避免负数)
|
||||
let mut mpint = vec![0u8];
|
||||
mpint.extend_from_slice(bytes);
|
||||
@@ -179,61 +184,67 @@ fn write_ssh_mpint_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> {
|
||||
} else {
|
||||
bytes.to_vec()
|
||||
};
|
||||
|
||||
hasher.update(&(mpint_bytes.len() as u32).to_be_bytes());
|
||||
|
||||
hasher.update((mpint_bytes.len() as u32).to_be_bytes());
|
||||
hasher.update(&mpint_bytes);
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_exchange_hash_computation() {
|
||||
let kex_result = KexResult::choose_algorithms(
|
||||
&KexProposal::server_default(),
|
||||
&KexProposal::client_default(),
|
||||
).unwrap();
|
||||
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut state = KexState::new(
|
||||
"SSH-2.0-OpenSSH_10.2".to_string(),
|
||||
"SSH-2.0-MarkBaseSSH_1.0".to_string(),
|
||||
kex_result,
|
||||
).unwrap();
|
||||
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Set minimal KEXINIT payloads (need at least 1 byte for packet type)
|
||||
state.client_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
|
||||
state.server_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
|
||||
|
||||
state.client_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
|
||||
state.server_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte
|
||||
|
||||
let shared_secret = vec![0u8; 32];
|
||||
let host_key = vec![0u8; 32];
|
||||
let client_pub = vec![0u8; 32];
|
||||
let server_pub = vec![0u8; 32];
|
||||
|
||||
let hash = state.compute_exchange_hash(&shared_secret, &host_key, &client_pub, &server_pub).unwrap();
|
||||
|
||||
assert_eq!(hash.len(), 32); // SHA256输出32字节
|
||||
|
||||
let hash = state
|
||||
.compute_exchange_hash(&shared_secret, &host_key, &client_pub, &server_pub)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(hash.len(), 32); // SHA256输出32字节
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_newkeys_handling() {
|
||||
let kex_result = KexResult::choose_algorithms(
|
||||
&KexProposal::server_default(),
|
||||
&KexProposal::client_default(),
|
||||
).unwrap();
|
||||
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut state = KexState::new(
|
||||
"SSH-2.0-OpenSSH_10.2".to_string(),
|
||||
"SSH-2.0-MarkBaseSSH_1.0".to_string(),
|
||||
kex_result,
|
||||
).unwrap();
|
||||
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let newkeys_packet = SshPacket::new(vec![PacketType::SSH_MSG_NEWKEYS as u8]);
|
||||
|
||||
|
||||
state.handle_newkeys(&newkeys_packet).unwrap();
|
||||
|
||||
|
||||
assert!(state.newkeys_received);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
// SSH密钥交换流程实现(Phase 3)
|
||||
// 参考OpenSSH kex.c: kex_input_kex_init(), kex_send_kex_reply()
|
||||
|
||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
||||
use crate::ssh_server::kex::{KexResult};
|
||||
use crate::ssh_server::crypto::{Curve25519Kex, SessionKeys, Ed25519HostKey};
|
||||
use anyhow::{Result, anyhow};
|
||||
use crate::ssh_server::crypto::{Curve25519Kex, Ed25519HostKey, SessionKeys};
|
||||
use crate::ssh_server::kex::KexResult;
|
||||
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||
use anyhow::{anyhow, Result};
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use log::{info, debug};
|
||||
use log::info;
|
||||
use sha2::Digest;
|
||||
use std::io::{Read, Write};
|
||||
use sha2::{Sha256, Digest};
|
||||
|
||||
/// SSH密钥交换流程处理器(参考OpenSSH kex.c)
|
||||
pub struct KexExchangeHandler {
|
||||
@@ -18,7 +18,7 @@ pub struct KexExchangeHandler {
|
||||
shared_secret: Option<Vec<u8>>,
|
||||
client_public_key: Option<Vec<u8>>,
|
||||
server_public_key: Option<Vec<u8>>,
|
||||
exchange_hash: Option<Vec<u8>>, // 保存exchange hash(H参数)
|
||||
exchange_hash: Option<Vec<u8>>, // 保存exchange hash(H参数)
|
||||
client_version: Option<String>,
|
||||
server_version: Option<String>,
|
||||
client_kexinit_payload: Option<Vec<u8>>,
|
||||
@@ -30,7 +30,7 @@ impl KexExchangeHandler {
|
||||
pub fn new(kex_result: KexResult) -> Result<Self> {
|
||||
// 加载或生成服务器主机密钥
|
||||
let host_key = Ed25519HostKey::load_or_generate("config/ssh_host_ed25519_key")?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
kex_algorithm: kex_result.kex_algorithm,
|
||||
server_kex: None,
|
||||
@@ -45,10 +45,10 @@ impl KexExchangeHandler {
|
||||
server_kexinit_payload: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// 处理SSH_MSG_KEXDH_INIT(Curve25519密钥交换)(参考OpenSSH kex.c: kex_input_kex_init())
|
||||
|
||||
/// 处理SSH_MSG_KEXDH_INIT(Curve25519密钥交换)(参考OpenSSH kex.c: kex_input_kex_init())
|
||||
pub fn handle_kexdh_init(
|
||||
&mut self,
|
||||
&mut self,
|
||||
packet: &SshPacket,
|
||||
client_version: &str,
|
||||
server_version: &str,
|
||||
@@ -56,41 +56,44 @@ impl KexExchangeHandler {
|
||||
server_kexinit_payload: &[u8],
|
||||
) -> Result<SshPacket> {
|
||||
info!("Processing SSH_MSG_KEXDH_INIT (Curve25519)");
|
||||
|
||||
|
||||
let mut cursor = std::io::Cursor::new(packet.payload.as_slice());
|
||||
|
||||
|
||||
let packet_type = cursor.read_u8()?;
|
||||
if packet_type != PacketType::SSH_MSG_KEXDH_INIT as u8 {
|
||||
return Err(anyhow!("Invalid packet type for KEXDH_INIT"));
|
||||
}
|
||||
|
||||
|
||||
let key_length = cursor.read_u32::<BigEndian>()?;
|
||||
if key_length != 32 {
|
||||
return Err(anyhow!("Invalid Curve25519 public key length: {}", key_length));
|
||||
return Err(anyhow!(
|
||||
"Invalid Curve25519 public key length: {}",
|
||||
key_length
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
let mut client_public_key = vec![0u8; 32];
|
||||
cursor.read_exact(&mut client_public_key)?;
|
||||
|
||||
|
||||
self.server_kex = Some(Curve25519Kex::new());
|
||||
let server_kex = self.server_kex.as_mut().unwrap();
|
||||
|
||||
|
||||
let shared_secret = server_kex.compute_shared_secret(&client_public_key)?;
|
||||
let server_public_key = server_kex.public_key().to_vec();
|
||||
|
||||
|
||||
// Save for later session key computation
|
||||
self.shared_secret = Some(shared_secret.to_vec());
|
||||
self.client_public_key = Some(client_public_key.clone());
|
||||
self.server_public_key = Some(server_public_key.clone());
|
||||
|
||||
|
||||
// Save client_version, server_version, kexinit payloads for exchange hash
|
||||
self.client_version = Some(client_version.to_string());
|
||||
self.server_version = Some(server_version.to_string());
|
||||
self.client_kexinit_payload = Some(client_kexinit_payload.to_vec());
|
||||
self.server_kexinit_payload = Some(server_kexinit_payload.to_vec());
|
||||
|
||||
|
||||
info!("Curve25519 shared secret computed and saved");
|
||||
|
||||
|
||||
// Compute exchange hash ONCE and reuse it
|
||||
let host_key_blob = self.build_ssh_host_key()?;
|
||||
let exchange_hash = self.compute_exchange_hash(
|
||||
@@ -103,69 +106,69 @@ impl KexExchangeHandler {
|
||||
client_kexinit_payload,
|
||||
server_kexinit_payload,
|
||||
)?;
|
||||
|
||||
|
||||
info!("Exchange hash computed:");
|
||||
info!(" shared_secret[0] = {} (>=0x80? {})", shared_secret[0], shared_secret[0] >= 0x80);
|
||||
info!(
|
||||
" shared_secret[0] = {} (>=0x80? {})",
|
||||
shared_secret[0],
|
||||
shared_secret[0] >= 0x80
|
||||
);
|
||||
info!(" exchange_hash full (32 bytes): {:?}", exchange_hash);
|
||||
|
||||
|
||||
self.exchange_hash = Some(exchange_hash.clone());
|
||||
info!("Exchange hash saved for key derivation");
|
||||
|
||||
self.build_kexdh_reply(
|
||||
&exchange_hash,
|
||||
&host_key_blob,
|
||||
&server_public_key,
|
||||
)
|
||||
|
||||
self.build_kexdh_reply(&exchange_hash, &host_key_blob, &server_public_key)
|
||||
}
|
||||
|
||||
|
||||
/// 构建SSH_MSG_KEXDH_REPLY packet(参考OpenSSH kex.c)
|
||||
fn build_kexdh_reply(
|
||||
&self,
|
||||
exchange_hash: &[u8],
|
||||
&self,
|
||||
exchange_hash: &[u8],
|
||||
host_key_blob: &[u8],
|
||||
server_public_key: &[u8],
|
||||
) -> Result<SshPacket> {
|
||||
info!("=== Building SSH_MSG_KEXDH_REPLY ===");
|
||||
info!("Input server_public_key: {:?}", server_public_key);
|
||||
|
||||
|
||||
let mut payload = Vec::new();
|
||||
|
||||
|
||||
payload.write_u8(PacketType::SSH_MSG_KEXDH_REPLY as u8)?;
|
||||
|
||||
|
||||
payload.write_u32::<BigEndian>(host_key_blob.len() as u32)?;
|
||||
payload.write_all(host_key_blob)?;
|
||||
|
||||
|
||||
info!("Writing server_public_key to payload (32 bytes)");
|
||||
payload.write_u32::<BigEndian>(32)?;
|
||||
payload.write_all(server_public_key)?;
|
||||
|
||||
|
||||
let signature = self.build_exchange_signature(exchange_hash)?;
|
||||
payload.write_u32::<BigEndian>(signature.len() as u32)?;
|
||||
payload.write_all(&signature)?;
|
||||
|
||||
|
||||
info!("SSH_MSG_KEXDH_REPLY payload built successfully");
|
||||
Ok(SshPacket::new(payload))
|
||||
}
|
||||
|
||||
|
||||
/// 构建SSH主机密钥blob(参考OpenSSH sshkey.c: sshkey_to_blob())
|
||||
fn build_ssh_host_key(&self) -> Result<Vec<u8>> {
|
||||
let mut blob = Vec::new();
|
||||
|
||||
|
||||
// SSH key format: key-type + public-key
|
||||
// 参考OpenSSH sshkey.c
|
||||
|
||||
|
||||
// Key type: ssh-ed25519
|
||||
blob.write_u32::<BigEndian>(11)?; // "ssh-ed25519".len()
|
||||
blob.write_u32::<BigEndian>(11)?; // "ssh-ed25519".len()
|
||||
blob.write_all("ssh-ed25519".as_bytes())?;
|
||||
|
||||
|
||||
// Ed25519公钥(32字节)
|
||||
let public_key = self.host_key.public_key_bytes();
|
||||
blob.write_u32::<BigEndian>(32)?;
|
||||
blob.write_all(&public_key)?;
|
||||
|
||||
|
||||
Ok(blob)
|
||||
}
|
||||
|
||||
|
||||
/// 计算Exchange Hash(参考OpenSSH kex.c: kex_hash() RFC 4253 Section 7.2)
|
||||
fn compute_exchange_hash(
|
||||
&self,
|
||||
@@ -178,94 +181,147 @@ impl KexExchangeHandler {
|
||||
client_kexinit_payload: &[u8],
|
||||
server_kexinit_payload: &[u8],
|
||||
) -> Result<Vec<u8>> {
|
||||
use sha2::{Sha256, Digest};
|
||||
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
info!("=== EXCHANGE HASH COMPUTATION ===");
|
||||
info!("V_C (client version): {:?}", client_version.as_bytes());
|
||||
info!("V_C length: {}", client_version.len());
|
||||
|
||||
|
||||
info!("V_S (server version): {:?}", server_version.as_bytes());
|
||||
info!("V_S length: {}", server_version.len());
|
||||
|
||||
info!("I_C (client KEXINIT payload): {:?}", &client_kexinit_payload[..std::cmp::min(50, client_kexinit_payload.len())]);
|
||||
|
||||
info!(
|
||||
"I_C (client KEXINIT payload): {:?}",
|
||||
&client_kexinit_payload[..std::cmp::min(50, client_kexinit_payload.len())]
|
||||
);
|
||||
info!("I_C length: {}", client_kexinit_payload.len());
|
||||
info!("I_C[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)", client_kexinit_payload[0]);
|
||||
|
||||
info!("I_S (server KEXINIT payload): {:?}", &server_kexinit_payload[..std::cmp::min(50, server_kexinit_payload.len())]);
|
||||
info!(
|
||||
"I_C[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)",
|
||||
client_kexinit_payload[0]
|
||||
);
|
||||
|
||||
info!(
|
||||
"I_S (server KEXINIT payload): {:?}",
|
||||
&server_kexinit_payload[..std::cmp::min(50, server_kexinit_payload.len())]
|
||||
);
|
||||
info!("I_S length: {}", server_kexinit_payload.len());
|
||||
info!("I_S[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)", server_kexinit_payload[0]);
|
||||
|
||||
info!("K_S (host key blob): {:?}", &host_key_blob[..std::cmp::min(30, host_key_blob.len())]);
|
||||
info!(
|
||||
"I_S[0] (packet type): {} (should be SSH_MSG_KEXINIT=20)",
|
||||
server_kexinit_payload[0]
|
||||
);
|
||||
|
||||
info!(
|
||||
"K_S (host key blob): {:?}",
|
||||
&host_key_blob[..std::cmp::min(30, host_key_blob.len())]
|
||||
);
|
||||
info!("K_S length: {}", host_key_blob.len());
|
||||
|
||||
info!("Q_C (client ECDH public key): {:?}", &client_public_key[..std::cmp::min(16, client_public_key.len())]);
|
||||
|
||||
info!(
|
||||
"Q_C (client ECDH public key): {:?}",
|
||||
&client_public_key[..std::cmp::min(16, client_public_key.len())]
|
||||
);
|
||||
info!("Q_C full (32 bytes): {:?}", client_public_key);
|
||||
info!("Q_C length: {}", client_public_key.len());
|
||||
|
||||
info!("Q_S (server ECDH public key): {:?}", &server_public_key[..std::cmp::min(16, server_public_key.len())]);
|
||||
|
||||
info!(
|
||||
"Q_S (server ECDH public key): {:?}",
|
||||
&server_public_key[..std::cmp::min(16, server_public_key.len())]
|
||||
);
|
||||
info!("Q_S full (32 bytes): {:?}", server_public_key);
|
||||
info!("Q_S length: {}", server_public_key.len());
|
||||
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
|
||||
|
||||
// RFC 4253 Section 7: V_C and V_S are version strings (without \r\n based on testing)
|
||||
let vc_ssh_string = &(client_version.len() as u32).to_be_bytes();
|
||||
hasher.update(vc_ssh_string);
|
||||
hasher.update(client_version.as_bytes());
|
||||
info!(" Exchange hash component V_C: len={} bytes=[{:?}] data=[{:?}]", 4+client_version.len(), vc_ssh_string, client_version.as_bytes());
|
||||
|
||||
info!(
|
||||
" Exchange hash component V_C: len={} bytes=[{:?}] data=[{:?}]",
|
||||
4 + client_version.len(),
|
||||
vc_ssh_string,
|
||||
client_version.as_bytes()
|
||||
);
|
||||
|
||||
let vs_ssh_string = &(server_version.len() as u32).to_be_bytes();
|
||||
hasher.update(vs_ssh_string);
|
||||
hasher.update(server_version.as_bytes());
|
||||
info!(" Exchange hash component V_S: len={} bytes=[{:?}] data=[{:?}]", 4+server_version.len(), vs_ssh_string, server_version.as_bytes());
|
||||
|
||||
info!(
|
||||
" Exchange hash component V_S: len={} bytes=[{:?}] data=[{:?}]",
|
||||
4 + server_version.len(),
|
||||
vs_ssh_string,
|
||||
server_version.as_bytes()
|
||||
);
|
||||
|
||||
// OpenSSH kexgex.c: "kexinit messages: fake header: len+SSH2_MSG_KEXINIT"
|
||||
// KEXINIT payload should NOT include SSH_MSG_KEXINIT type byte
|
||||
// OpenSSH stores payload starting from cookie, prepends SSH_MSG_KEXINIT in exchange hash
|
||||
|
||||
|
||||
// Remove SSH_MSG_KEXINIT type byte from payloads (our payload includes it)
|
||||
let client_kexinit_without_type = &client_kexinit_payload[1..];
|
||||
let server_kexinit_without_type = &server_kexinit_payload[1..];
|
||||
|
||||
info!("I_C (client KEXINIT without type byte): {} bytes (first byte should be cookie)", client_kexinit_without_type.len());
|
||||
info!("I_S (server KEXINIT without type byte): {} bytes", server_kexinit_without_type.len());
|
||||
|
||||
|
||||
info!(
|
||||
"I_C (client KEXINIT without type byte): {} bytes (first byte should be cookie)",
|
||||
client_kexinit_without_type.len()
|
||||
);
|
||||
info!(
|
||||
"I_S (server KEXINIT without type byte): {} bytes",
|
||||
server_kexinit_without_type.len()
|
||||
);
|
||||
|
||||
// Exchange hash: uint32(len+1) + uint8(SSH_MSG_KEXINIT) + payload_without_type
|
||||
let ic_len_bytes = &((client_kexinit_without_type.len() + 1) as u32).to_be_bytes();
|
||||
hasher.update(ic_len_bytes);
|
||||
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
|
||||
hasher.update([20]); // SSH_MSG_KEXINIT type byte
|
||||
hasher.update(client_kexinit_without_type);
|
||||
info!(" Exchange hash component I_C: len={} bytes=[{:?}] type=[20] payload_len={} (first 8 bytes=[{:?}])", 4+1+client_kexinit_without_type.len(), ic_len_bytes, client_kexinit_without_type.len(), &client_kexinit_without_type[..std::cmp::min(8, client_kexinit_without_type.len())]);
|
||||
|
||||
|
||||
let is_len_bytes = &((server_kexinit_without_type.len() + 1) as u32).to_be_bytes();
|
||||
hasher.update(is_len_bytes);
|
||||
hasher.update(&[20]); // SSH_MSG_KEXINIT type byte
|
||||
hasher.update([20]); // SSH_MSG_KEXINIT type byte
|
||||
hasher.update(server_kexinit_without_type);
|
||||
info!(" Exchange hash component I_S: len={} bytes=[{:?}] type=[20] payload_len={} (first 8 bytes=[{:?}])", 4+1+server_kexinit_without_type.len(), is_len_bytes, server_kexinit_without_type.len(), &server_kexinit_without_type[..std::cmp::min(8, server_kexinit_without_type.len())]);
|
||||
|
||||
|
||||
let ks_len_bytes = &(host_key_blob.len() as u32).to_be_bytes();
|
||||
hasher.update(ks_len_bytes);
|
||||
hasher.update(host_key_blob);
|
||||
info!(" Exchange hash component K_S: len={} bytes=[{:?}] blob_len={} (full=[{:?}])", 4+host_key_blob.len(), ks_len_bytes, host_key_blob.len(), host_key_blob);
|
||||
|
||||
info!(
|
||||
" Exchange hash component K_S: len={} bytes=[{:?}] blob_len={} (full=[{:?}])",
|
||||
4 + host_key_blob.len(),
|
||||
ks_len_bytes,
|
||||
host_key_blob.len(),
|
||||
host_key_blob
|
||||
);
|
||||
|
||||
let qc_len_bytes = &(client_public_key.len() as u32).to_be_bytes();
|
||||
hasher.update(qc_len_bytes);
|
||||
hasher.update(client_public_key);
|
||||
info!(" Exchange hash component Q_C: len={} bytes=[{:?}] key=[{:?}]", 4+client_public_key.len(), qc_len_bytes, client_public_key);
|
||||
|
||||
info!(
|
||||
" Exchange hash component Q_C: len={} bytes=[{:?}] key=[{:?}]",
|
||||
4 + client_public_key.len(),
|
||||
qc_len_bytes,
|
||||
client_public_key
|
||||
);
|
||||
|
||||
let qs_len_bytes = &(server_public_key.len() as u32).to_be_bytes();
|
||||
hasher.update(qs_len_bytes);
|
||||
hasher.update(server_public_key);
|
||||
info!(" Exchange hash component Q_S: len={} bytes=[{:?}] key=[{:?}]", 4+server_public_key.len(), qs_len_bytes, server_public_key);
|
||||
|
||||
info!(
|
||||
" Exchange hash component Q_S: len={} bytes=[{:?}] key=[{:?}]",
|
||||
4 + server_public_key.len(),
|
||||
qs_len_bytes,
|
||||
server_public_key
|
||||
);
|
||||
|
||||
info!("Exchange hash components:");
|
||||
info!(" shared_secret raw full (32 bytes): {:?}", shared_secret);
|
||||
|
||||
|
||||
// RFC 8731 Section 3.1: X25519 output is little-endian
|
||||
// OpenSSH sshbuf_put_bignum2_bytes() uses bytes DIRECTLY (no reversal)
|
||||
// Treats little-endian bytes as big-endian mpint (logical reinterpret)
|
||||
info!(" Using shared_secret directly (little-endian bytes as big-endian mpint)");
|
||||
|
||||
|
||||
// RFC 4253: mpint格式 = 去掉前导零 + 最高位>=0x80时前面加0
|
||||
// 参考OpenSSH sshbuf_put_bignum2_bytes()
|
||||
let mut start = 0;
|
||||
@@ -273,64 +329,73 @@ impl KexExchangeHandler {
|
||||
start += 1;
|
||||
}
|
||||
let trimmed_shared_secret = &shared_secret[start..];
|
||||
|
||||
info!(" shared_secret after removing leading zeros ({} bytes): {:?}", trimmed_shared_secret.len(), trimmed_shared_secret);
|
||||
|
||||
let mpint_shared_secret_data = if trimmed_shared_secret.len() > 0 && trimmed_shared_secret[0] >= 0x80 {
|
||||
let mut mpint = vec![0u8];
|
||||
mpint.extend_from_slice(trimmed_shared_secret);
|
||||
info!(" trimmed_shared_secret[0] >= 0x80, prepending 0 byte");
|
||||
mpint
|
||||
} else {
|
||||
trimmed_shared_secret.to_vec()
|
||||
};
|
||||
|
||||
info!(" mpint_shared_secret_data ({} bytes): {:?}", mpint_shared_secret_data.len(), &mpint_shared_secret_data[..std::cmp::min(8, mpint_shared_secret_data.len())]);
|
||||
|
||||
|
||||
info!(
|
||||
" shared_secret after removing leading zeros ({} bytes): {:?}",
|
||||
trimmed_shared_secret.len(),
|
||||
trimmed_shared_secret
|
||||
);
|
||||
|
||||
let mpint_shared_secret_data =
|
||||
if !trimmed_shared_secret.is_empty() && trimmed_shared_secret[0] >= 0x80 {
|
||||
let mut mpint = vec![0u8];
|
||||
mpint.extend_from_slice(trimmed_shared_secret);
|
||||
info!(" trimmed_shared_secret[0] >= 0x80, prepending 0 byte");
|
||||
mpint
|
||||
} else {
|
||||
trimmed_shared_secret.to_vec()
|
||||
};
|
||||
|
||||
info!(
|
||||
" mpint_shared_secret_data ({} bytes): {:?}",
|
||||
mpint_shared_secret_data.len(),
|
||||
&mpint_shared_secret_data[..std::cmp::min(8, mpint_shared_secret_data.len())]
|
||||
);
|
||||
|
||||
// mpint格式 = uint32(length) + mpint_data
|
||||
let mpint_len_bytes = &(mpint_shared_secret_data.len() as u32).to_be_bytes();
|
||||
hasher.update(mpint_len_bytes);
|
||||
hasher.update(&mpint_shared_secret_data);
|
||||
info!(" Exchange hash component K (shared secret mpint): len={} bytes=[{:?}] data_len={} (first 8 bytes=[{:?}])", 4+mpint_shared_secret_data.len(), mpint_len_bytes, mpint_shared_secret_data.len(), &mpint_shared_secret_data[..std::cmp::min(8, mpint_shared_secret_data.len())]);
|
||||
|
||||
|
||||
Ok(hasher.finalize().to_vec())
|
||||
}
|
||||
|
||||
|
||||
/// 构建交换签名(参考OpenSSH ssh-sign.c)
|
||||
fn build_exchange_signature(&self, exchange_hash: &[u8]) -> Result<Vec<u8>> {
|
||||
let signature_bytes = self.host_key.sign(exchange_hash)?;
|
||||
|
||||
|
||||
let mut ssh_signature = Vec::new();
|
||||
|
||||
|
||||
ssh_signature.write_u32::<BigEndian>(11)?;
|
||||
ssh_signature.write_all("ssh-ed25519".as_bytes())?;
|
||||
|
||||
|
||||
ssh_signature.write_u32::<BigEndian>(64)?;
|
||||
ssh_signature.write_all(&signature_bytes)?;
|
||||
|
||||
|
||||
Ok(ssh_signature)
|
||||
}
|
||||
|
||||
|
||||
/// 计算会话密钥(参考OpenSSH kex.c: derive_keys())
|
||||
/// 使用保存的exchange_hash(H参数)
|
||||
pub fn compute_session_keys(&self) -> Result<SessionKeys> {
|
||||
if self.shared_secret.is_none() {
|
||||
return Err(anyhow!("No shared secret available"));
|
||||
}
|
||||
|
||||
|
||||
if self.exchange_hash.is_none() {
|
||||
return Err(anyhow!("No exchange hash available"));
|
||||
}
|
||||
|
||||
|
||||
let shared_secret = self.shared_secret.as_ref().unwrap();
|
||||
let exchange_hash = self.exchange_hash.as_ref().unwrap();
|
||||
let server_public_key = self.server_public_key.as_ref().unwrap();
|
||||
let client_public_key = self.client_public_key.as_ref().unwrap();
|
||||
let host_key_blob = self.build_ssh_host_key()?;
|
||||
|
||||
|
||||
SessionKeys::derive(
|
||||
shared_secret,
|
||||
exchange_hash, // 使用保存的exchange hash(H参数)
|
||||
exchange_hash, // 使用保存的exchange hash(H参数)
|
||||
server_public_key,
|
||||
client_public_key,
|
||||
&host_key_blob,
|
||||
@@ -342,13 +407,13 @@ impl KexExchangeHandler {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ssh_server::kex::KexProposal;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_kex_exchange_handler_creation() {
|
||||
let server_proposal = KexProposal::server_default();
|
||||
let client_proposal = KexProposal::client_default();
|
||||
let kex_result = KexResult::choose_algorithms(&server_proposal, &client_proposal).unwrap();
|
||||
|
||||
|
||||
let handler = KexExchangeHandler::new(kex_result).unwrap();
|
||||
assert!(handler.host_key.public_key_bytes().len() == 32);
|
||||
}
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
// SSH服务器模块(手动实现SSH协议)
|
||||
// 参考OpenSSH源码实现完整的SSH/SFTP/SCP/rsync协议
|
||||
|
||||
pub mod server;
|
||||
pub mod packet;
|
||||
pub mod version;
|
||||
pub mod crypto;
|
||||
pub mod kex;
|
||||
pub mod kex_exchange;
|
||||
pub mod kex_complete;
|
||||
pub mod cipher;
|
||||
pub mod auth;
|
||||
pub mod channel;
|
||||
pub mod sftp_handler;
|
||||
pub mod scp_handler;
|
||||
pub mod cipher;
|
||||
pub mod crypto;
|
||||
pub mod data_forwarder; // Phase 13.5: 数据传输模块
|
||||
pub mod kex;
|
||||
pub mod kex_complete;
|
||||
pub mod kex_exchange;
|
||||
pub mod packet;
|
||||
pub mod port_forward; // Phase 13: 端口转发模块
|
||||
pub mod port_forward_listener; // Phase 13.4: 监听线程模块
|
||||
pub mod rsync_handler;
|
||||
pub mod sshbuf; // Phase 15: SSH Buffer 零拷贝管理(参考OpenSSH sshbuf.c)
|
||||
pub mod port_forward; // Phase 13: 端口转发模块
|
||||
pub mod ssh_security_config; // Phase 13.1: 企业级安全配置
|
||||
pub mod port_forward_listener; // Phase 13.4: 监听线程模块
|
||||
pub mod data_forwarder; // Phase 13.5: 数据传输模块
|
||||
pub mod window_manager; // Phase 13.6-13.7: Window size + Channel生命周期
|
||||
pub mod scp_handler;
|
||||
pub mod server;
|
||||
pub mod sftp_handler;
|
||||
pub mod ssh_security_config; // Phase 13.1: 企业级安全配置
|
||||
pub mod sshbuf; // Phase 15: SSH Buffer 零拷贝管理(参考OpenSSH sshbuf.c)
|
||||
pub mod version;
|
||||
pub mod window_manager; // Phase 13.6-13.7: Window size + Channel生命周期
|
||||
|
||||
pub use packet::{PacketType, SshPacket};
|
||||
pub use server::SshServer;
|
||||
pub use packet::{SshPacket, PacketType};
|
||||
pub use version::VersionExchange;
|
||||
pub use ssh_security_config::SshSecurityConfig; // Phase 13.1: 导出安全配置
|
||||
pub use sshbuf::SshBuf; // Phase 15: 导出 SSH Buffer
|
||||
pub use ssh_security_config::SshSecurityConfig; // Phase 13.1: 导出安全配置
|
||||
pub use sshbuf::SshBuf;
|
||||
pub use version::VersionExchange; // Phase 15: 导出 SSH Buffer
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// SSH Packet基础结构定义
|
||||
// 参考OpenSSH packet.c: ssh_packet_read(), ssh_packet_write()
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{anyhow, Result};
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::io::{Read, Write};
|
||||
|
||||
@@ -18,21 +18,21 @@ pub enum PacketType {
|
||||
SSH_MSG_EXT_INFO = 7,
|
||||
SSH_MSG_KEXINIT = 20,
|
||||
SSH_MSG_NEWKEYS = 21,
|
||||
|
||||
|
||||
// 密钥交换相关
|
||||
SSH_MSG_KEXDH_INIT = 30,
|
||||
SSH_MSG_KEXDH_REPLY = 31,
|
||||
// 注意:Curve25519和DH使用相同的消息类型(30/31)
|
||||
// SSH_MSG_KEX_ECDH_INIT和SSH_MSG_KEX_ECDH_REPLY已在代码中注释
|
||||
// 使用SSH_MSG_KEXDH_INIT和SSH_MSG_KEXDH_REPLY代替
|
||||
|
||||
|
||||
// 认证相关
|
||||
SSH_MSG_USERAUTH_REQUEST = 50,
|
||||
SSH_MSG_USERAUTH_FAILURE = 51,
|
||||
SSH_MSG_USERAUTH_SUCCESS = 52,
|
||||
SSH_MSG_USERAUTH_BANNER = 53,
|
||||
SSH_MSG_USERAUTH_PK_OK = 60,
|
||||
|
||||
|
||||
// Channel相关
|
||||
SSH_MSG_GLOBAL_REQUEST = 80,
|
||||
SSH_MSG_REQUEST_SUCCESS = 81,
|
||||
@@ -70,38 +70,38 @@ impl SshPacket {
|
||||
pub fn new(payload: Vec<u8>) -> Self {
|
||||
// 计算padding(SSH协议RFC 4253规范)
|
||||
// 参考OpenSSH packet.c: construct_packet()
|
||||
let block_size = 8; // 未加密阶段block_size=8
|
||||
|
||||
let block_size = 8; // 未加密阶段block_size=8
|
||||
|
||||
let payload_length = payload.len();
|
||||
let min_padding = 4; // OpenSSH要求最少4字节padding
|
||||
|
||||
let min_padding = 4; // OpenSSH要求最少4字节padding
|
||||
|
||||
// SSH协议约束:
|
||||
// packet_length = padding_length + payload_length + 1
|
||||
// (packet_length + 4) 必须是block_size的倍数
|
||||
//
|
||||
//
|
||||
// 计算:
|
||||
// (1 + payload_length + padding_length + 4) % 8 == 0
|
||||
// => (5 + payload_length + padding_length) % 8 == 0
|
||||
|
||||
|
||||
// 先尝试padding=4(最小)
|
||||
let mut padding_length = min_padding as u8;
|
||||
|
||||
|
||||
// 计算packet总长度(包括4字节的packet_length字段)
|
||||
let packet_length = 1 + payload_length + padding_length as usize;
|
||||
let total_length = packet_length + 4; // 加上packet_length字段本身的4字节
|
||||
|
||||
let total_length = packet_length + 4; // 加上packet_length字段本身的4字节
|
||||
|
||||
// 如果总长度不是block_size的倍数,增加padding
|
||||
if total_length % block_size != 0 {
|
||||
if !total_length.is_multiple_of(block_size) {
|
||||
let remainder = total_length % block_size;
|
||||
padding_length += (block_size - remainder) as u8;
|
||||
}
|
||||
|
||||
|
||||
// 重新计算packet_length
|
||||
let packet_length = (1 + payload_length + padding_length as usize) as u32;
|
||||
|
||||
|
||||
// 生成随机padding(简化:使用0)
|
||||
let padding = vec![0u8; padding_length as usize];
|
||||
|
||||
|
||||
Self {
|
||||
packet_length,
|
||||
padding_length,
|
||||
@@ -109,49 +109,49 @@ impl SshPacket {
|
||||
padding,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 写入packet到stream(未加密阶段)
|
||||
/// 参考OpenSSH packet_write_poll()
|
||||
pub fn write<T: Write>(&self, stream: &mut T) -> Result<()> {
|
||||
// 写入packet_length(BigEndian)
|
||||
stream.write_u32::<BigEndian>(self.packet_length)?;
|
||||
|
||||
|
||||
// 写入padding_length
|
||||
stream.write_u8(self.padding_length)?;
|
||||
|
||||
|
||||
// 写入payload
|
||||
stream.write_all(&self.payload)?;
|
||||
|
||||
|
||||
// 写入padding
|
||||
stream.write_all(&self.padding)?;
|
||||
|
||||
|
||||
stream.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 从stream读取packet(未加密阶段)
|
||||
/// 参考OpenSSH packet_read_poll()
|
||||
pub fn read<T: Read>(stream: &mut T) -> Result<Self> {
|
||||
// 读取packet_length(BigEndian)
|
||||
let packet_length = stream.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
// 检查packet长度限制(OpenSSH限制:256KB)
|
||||
if packet_length > 256 * 1024 {
|
||||
return Err(anyhow!("Packet too large: {}", packet_length));
|
||||
}
|
||||
|
||||
|
||||
// 读取padding_length
|
||||
let padding_length = stream.read_u8()?;
|
||||
|
||||
|
||||
// 读取payload(packet_length - padding_length - 1)
|
||||
let payload_length = packet_length - padding_length as u32 - 1;
|
||||
let mut payload = vec![0u8; payload_length as usize];
|
||||
stream.read_exact(&mut payload)?;
|
||||
|
||||
|
||||
// 读取padding
|
||||
let mut padding = vec![0u8; padding_length as usize];
|
||||
stream.read_exact(&mut padding)?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
packet_length,
|
||||
padding_length,
|
||||
@@ -159,15 +159,15 @@ impl SshPacket {
|
||||
padding,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// 获取payload中的packet type
|
||||
pub fn get_type(&self) -> Result<PacketType> {
|
||||
if self.payload.is_empty() {
|
||||
return Err(anyhow!("Empty payload"));
|
||||
}
|
||||
|
||||
|
||||
let type_byte = self.payload[0];
|
||||
|
||||
|
||||
// 转换为PacketType enum
|
||||
match type_byte {
|
||||
1 => Ok(PacketType::SSH_MSG_DISCONNECT),
|
||||
@@ -208,27 +208,27 @@ impl SshPacket {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Cursor;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_packet_creation() {
|
||||
let payload = vec![PacketType::SSH_MSG_KEXINIT as u8];
|
||||
let packet = SshPacket::new(payload);
|
||||
|
||||
|
||||
assert!(packet.packet_length > 0);
|
||||
assert!(packet.padding_length >= 4);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_packet_write_read() {
|
||||
let payload = vec![PacketType::SSH_MSG_KEXINIT as u8];
|
||||
let packet = SshPacket::new(payload);
|
||||
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
packet.write(&mut buffer).unwrap();
|
||||
|
||||
|
||||
let mut cursor = Cursor::new(buffer);
|
||||
let read_packet = SshPacket::read(&mut cursor).unwrap();
|
||||
|
||||
|
||||
assert_eq!(packet.packet_length, read_packet.packet_length);
|
||||
assert_eq!(packet.payload, read_packet.payload);
|
||||
}
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
// SSH端口转发协议实现(Phase 13)
|
||||
// 参考OpenSSH channels.c和RFC 4254
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, warn, debug};
|
||||
use std::net::{TcpListener, TcpStream, SocketAddr};
|
||||
use std::io::{Read, Write};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
use crate::ssh_server::ssh_security_config::SshSecurityConfig;
|
||||
use anyhow::Result;
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.2: 安全配置
|
||||
use log::{info, warn};
|
||||
use std::io::Read;
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::sync::{Arc, Mutex};
|
||||
// Phase 13.2: 安全配置
|
||||
|
||||
/// 端口转发类型(参考RFC 4254)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PortForwardType {
|
||||
Local, // Local port forwarding (-L)
|
||||
Remote, // Remote port forwarding (-R)
|
||||
Dynamic, // Dynamic port forwarding (-D, SOCKS)
|
||||
Local, // Local port forwarding (-L)
|
||||
Remote, // Remote port forwarding (-R)
|
||||
Dynamic, // Dynamic port forwarding (-D, SOCKS)
|
||||
}
|
||||
|
||||
/// 端口转发请求(参考RFC 4254 Section 7)
|
||||
@@ -36,6 +36,12 @@ pub struct PortForwardManager {
|
||||
active_forwards: Arc<Mutex<Vec<(u32, PortForwardType)>>>,
|
||||
}
|
||||
|
||||
impl Default for PortForwardManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PortForwardManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
@@ -46,24 +52,29 @@ impl PortForwardManager {
|
||||
/// 处理SSH_MSG_GLOBAL_REQUEST(端口转发请求)
|
||||
/// 参考RFC 4254 Section 4
|
||||
/// Phase 13.2: 添加安全配置验证
|
||||
pub fn handle_global_request(&mut self, data: &[u8], security_config: &SshSecurityConfig) -> Result<(bool, Option<Vec<u8>>)> {
|
||||
pub fn handle_global_request(
|
||||
&mut self,
|
||||
data: &[u8],
|
||||
security_config: &SshSecurityConfig,
|
||||
) -> Result<(bool, Option<Vec<u8>>)> {
|
||||
info!("Processing SSH_MSG_GLOBAL_REQUEST for port forwarding");
|
||||
|
||||
|
||||
let mut cursor = std::io::Cursor::new(data);
|
||||
cursor.set_position(1); // Skip packet type
|
||||
|
||||
cursor.set_position(1); // Skip packet type
|
||||
|
||||
// 读取请求名称(SSH string)
|
||||
let request_name = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
info!("Global request: {}", request_name);
|
||||
|
||||
|
||||
// 读取want-reply标志
|
||||
let want_reply = cursor.read_u8()? != 0;
|
||||
|
||||
|
||||
match request_name.as_str() {
|
||||
"tcpip-forward" => {
|
||||
// Local port forwarding (-L)
|
||||
self.handle_tcpip_forward(&mut cursor, want_reply, security_config) // Phase 13.2
|
||||
self.handle_tcpip_forward(&mut cursor, want_reply, security_config)
|
||||
// Phase 13.2
|
||||
}
|
||||
"cancel-tcpip-forward" => {
|
||||
// Cancel port forwarding
|
||||
@@ -84,29 +95,37 @@ impl PortForwardManager {
|
||||
/// 处理tcpip-forward请求(Local port forwarding)
|
||||
/// 参考RFC 4254 Section 7.1
|
||||
/// Phase 13.2: 添加安全配置验证
|
||||
fn handle_tcpip_forward(&mut self, cursor: &mut std::io::Cursor<&[u8]>, want_reply: bool, security_config: &SshSecurityConfig) -> Result<(bool, Option<Vec<u8>>)> {
|
||||
fn handle_tcpip_forward(
|
||||
&mut self,
|
||||
cursor: &mut std::io::Cursor<&[u8]>,
|
||||
want_reply: bool,
|
||||
security_config: &SshSecurityConfig,
|
||||
) -> Result<(bool, Option<Vec<u8>>)> {
|
||||
// 读取bind address(SSH string)
|
||||
let bind_address = read_ssh_string(cursor)?;
|
||||
|
||||
|
||||
// 读取bind port
|
||||
let bind_port = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
info!("tcpip-forward request: bind_address={}, bind_port={}", bind_address, bind_port);
|
||||
|
||||
|
||||
info!(
|
||||
"tcpip-forward request: bind_address={}, bind_port={}",
|
||||
bind_address, bind_port
|
||||
);
|
||||
|
||||
// Phase 13.2: 安全配置验证
|
||||
if let Err(e) = security_config.validate_tcpip_forward_request(&bind_address, bind_port) {
|
||||
warn!("tcpip-forward security validation failed: {}", e);
|
||||
return Ok((false, None)); // 拒绝请求
|
||||
return Ok((false, None)); // 拒绝请求
|
||||
}
|
||||
|
||||
|
||||
info!("tcpip-forward security validation passed");
|
||||
|
||||
|
||||
// 添加到active forwards
|
||||
let mut forwards = self.active_forwards.lock().unwrap();
|
||||
forwards.push((bind_port, PortForwardType::Local));
|
||||
|
||||
|
||||
info!("tcpip-forward registered: bind_port={}", bind_port);
|
||||
|
||||
|
||||
// 返回成功响应(包含bind_port)
|
||||
if want_reply {
|
||||
let response = self.build_global_request_response(true, Some(bind_port))?;
|
||||
@@ -117,16 +136,23 @@ impl PortForwardManager {
|
||||
}
|
||||
|
||||
/// 处理cancel-tcpip-forward请求
|
||||
fn handle_cancel_tcpip_forward(&mut self, cursor: &mut std::io::Cursor<&[u8]>, want_reply: bool) -> Result<(bool, Option<Vec<u8>>)> {
|
||||
fn handle_cancel_tcpip_forward(
|
||||
&mut self,
|
||||
cursor: &mut std::io::Cursor<&[u8]>,
|
||||
want_reply: bool,
|
||||
) -> Result<(bool, Option<Vec<u8>>)> {
|
||||
let bind_address = read_ssh_string(cursor)?;
|
||||
let bind_port = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
info!("cancel-tcpip-forward: bind_address={}, bind_port={}", bind_address, bind_port);
|
||||
|
||||
|
||||
info!(
|
||||
"cancel-tcpip-forward: bind_address={}, bind_port={}",
|
||||
bind_address, bind_port
|
||||
);
|
||||
|
||||
// 移除active forward
|
||||
let mut forwards = self.active_forwards.lock().unwrap();
|
||||
forwards.retain(|(port, _)| *port != bind_port);
|
||||
|
||||
|
||||
if want_reply {
|
||||
let response = self.build_global_request_response(true, None)?;
|
||||
Ok((true, Some(response)))
|
||||
@@ -136,14 +162,18 @@ impl PortForwardManager {
|
||||
}
|
||||
|
||||
/// 构建SSH_MSG_REQUEST_SUCCESS/FAILURE响应
|
||||
fn build_global_request_response(&self, success: bool, bound_port: Option<u32>) -> Result<Vec<u8>> {
|
||||
fn build_global_request_response(
|
||||
&self,
|
||||
success: bool,
|
||||
bound_port: Option<u32>,
|
||||
) -> Result<Vec<u8>> {
|
||||
use crate::ssh_server::packet::PacketType;
|
||||
|
||||
|
||||
let mut response = Vec::new();
|
||||
|
||||
|
||||
if success {
|
||||
response.write_u8(PacketType::SSH_MSG_REQUEST_SUCCESS as u8)?;
|
||||
|
||||
|
||||
// 如果有bound_port,写入(用于tcpip-forward响应)
|
||||
if let Some(port) = bound_port {
|
||||
response.write_u32::<BigEndian>(port)?;
|
||||
@@ -151,7 +181,7 @@ impl PortForwardManager {
|
||||
} else {
|
||||
response.write_u8(PacketType::SSH_MSG_REQUEST_FAILURE as u8)?;
|
||||
}
|
||||
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -159,37 +189,39 @@ impl PortForwardManager {
|
||||
/// 参考RFC 4254 Section 7.2
|
||||
pub fn handle_direct_tcpip_channel(&mut self, data: &[u8]) -> Result<DirectTcpipChannel> {
|
||||
info!("Processing direct-tcpip channel open");
|
||||
|
||||
|
||||
let mut cursor = std::io::Cursor::new(data);
|
||||
cursor.set_position(1); // Skip packet type
|
||||
|
||||
cursor.set_position(1); // Skip packet type
|
||||
|
||||
// 读取channel type(已知道是"direct-tcpip",跳过)
|
||||
let _channel_type = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
// 读取sender_channel
|
||||
let sender_channel = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
// 读取initial window size
|
||||
let initial_window_size = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
// 读取maximum packet size
|
||||
let max_packet_size = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
// 读取host to connect(SSH string)
|
||||
let host_to_connect = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
// 读取port to connect
|
||||
let port_to_connect = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
// 读取originator address(SSH string)
|
||||
let originator_address = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
// 读取originator port
|
||||
let originator_port = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
info!("direct-tcpip: host={}, port={}, originator={}:{}",
|
||||
host_to_connect, port_to_connect, originator_address, originator_port);
|
||||
|
||||
|
||||
info!(
|
||||
"direct-tcpip: host={}, port={}, originator={}:{}",
|
||||
host_to_connect, port_to_connect, originator_address, originator_port
|
||||
);
|
||||
|
||||
Ok(DirectTcpipChannel {
|
||||
sender_channel,
|
||||
initial_window_size,
|
||||
@@ -205,30 +237,32 @@ impl PortForwardManager {
|
||||
/// 参考RFC 4254 Section 7.1
|
||||
pub fn handle_forwarded_tcpip_channel(&mut self, data: &[u8]) -> Result<ForwardedTcpipChannel> {
|
||||
info!("Processing forwarded-tcpip channel open");
|
||||
|
||||
|
||||
let mut cursor = std::io::Cursor::new(data);
|
||||
cursor.set_position(1);
|
||||
|
||||
|
||||
let _channel_type = read_ssh_string(&mut cursor)?;
|
||||
let sender_channel = cursor.read_u32::<BigEndian>()?;
|
||||
let initial_window_size = cursor.read_u32::<BigEndian>()?;
|
||||
let max_packet_size = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
// 读取bind address(SSH string)
|
||||
let bind_address = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
// 读取bind port
|
||||
let bind_port = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
|
||||
// 读取originator address(SSH string)
|
||||
let originator_address = read_ssh_string(&mut cursor)?;
|
||||
|
||||
|
||||
// 读取originator port
|
||||
let originator_port = cursor.read_u32::<BigEndian>()?;
|
||||
|
||||
info!("forwarded-tcpip: bind={}:{}, originator={}:{}",
|
||||
bind_address, bind_port, originator_address, originator_port);
|
||||
|
||||
|
||||
info!(
|
||||
"forwarded-tcpip: bind={}:{}, originator={}:{}",
|
||||
bind_address, bind_port, originator_address, originator_port
|
||||
);
|
||||
|
||||
Ok(ForwardedTcpipChannel {
|
||||
sender_channel,
|
||||
initial_window_size,
|
||||
@@ -244,10 +278,10 @@ impl PortForwardManager {
|
||||
pub fn connect_to_target(host: &str, port: u32) -> Result<TcpStream> {
|
||||
let addr = format!("{}:{}", host, port);
|
||||
info!("Connecting to target: {}", addr);
|
||||
|
||||
|
||||
let stream = TcpStream::connect(&addr)?;
|
||||
info!("Connected to target successfully");
|
||||
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
@@ -258,12 +292,12 @@ impl PortForwardManager {
|
||||
} else {
|
||||
format!("{}:{}", bind_address, bind_port)
|
||||
};
|
||||
|
||||
|
||||
info!("Creating listener on {}", addr);
|
||||
|
||||
|
||||
let listener = TcpListener::bind(&addr)?;
|
||||
info!("Listener created successfully");
|
||||
|
||||
|
||||
Ok(listener)
|
||||
}
|
||||
}
|
||||
@@ -303,10 +337,10 @@ fn read_ssh_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_port_forward_manager_creation() {
|
||||
let manager = PortForwardManager::new();
|
||||
assert_eq!(manager.active_forwards.lock().unwrap().len(), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
// SSH端口转发监听线程(Phase 13.4)
|
||||
// 参考OpenSSH channels.c: channel_forward_listener
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, warn, debug, error};
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::thread;
|
||||
use std::sync::{Arc, Mutex, mpsc};
|
||||
use std::io::{Read, Write};
|
||||
use byteorder::{BigEndian, WriteBytesExt};
|
||||
use crate::ssh_server::packet::PacketType;
|
||||
use crate::ssh_server::ssh_security_config::SshSecurityConfig;
|
||||
use anyhow::Result;
|
||||
use log::{error, info, warn};
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::sync::{mpsc, Arc, Mutex};
|
||||
use std::thread;
|
||||
|
||||
/// 监听器状态(Phase 13.4)
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -30,28 +27,18 @@ pub enum ListenerRequest {
|
||||
stream: TcpStream,
|
||||
},
|
||||
/// 停止监听
|
||||
StopListener {
|
||||
bind_port: u32,
|
||||
},
|
||||
StopListener { bind_port: u32 },
|
||||
}
|
||||
|
||||
/// 监听器响应(Phase 13.4:线程通信)
|
||||
#[derive(Debug)]
|
||||
pub enum ListenerResponse {
|
||||
/// Channel创建成功
|
||||
ChannelCreated {
|
||||
bind_port: u32,
|
||||
channel_id: u32,
|
||||
},
|
||||
ChannelCreated { bind_port: u32, channel_id: u32 },
|
||||
/// 监听器停止
|
||||
ListenerStopped {
|
||||
bind_port: u32,
|
||||
},
|
||||
ListenerStopped { bind_port: u32 },
|
||||
/// 错误
|
||||
Error {
|
||||
bind_port: u32,
|
||||
message: String,
|
||||
},
|
||||
Error { bind_port: u32, message: String },
|
||||
}
|
||||
|
||||
/// 端口转发监听器(Phase 13.4)
|
||||
@@ -73,26 +60,29 @@ impl PortForwardListener {
|
||||
security_config: SshSecurityConfig,
|
||||
) -> Result<Self> {
|
||||
info!("Creating port forward listener on port {}", bind_port);
|
||||
|
||||
|
||||
// Phase 13.4: 根据GatewayPorts决定绑定地址
|
||||
let bind_addr = if security_config.gateway_ports {
|
||||
format!("0.0.0.0:{}", bind_port) // 允许外部访问
|
||||
format!("0.0.0.0:{}", bind_port) // 允许外部访问
|
||||
} else {
|
||||
format!("127.0.0.1:{}", bind_port) // 只允许本地访问
|
||||
format!("127.0.0.1:{}", bind_port) // 只允许本地访问
|
||||
};
|
||||
|
||||
info!("Binding to address: {} (GatewayPorts={})", bind_addr, security_config.gateway_ports);
|
||||
|
||||
|
||||
info!(
|
||||
"Binding to address: {} (GatewayPorts={})",
|
||||
bind_addr, security_config.gateway_ports
|
||||
);
|
||||
|
||||
let listener = TcpListener::bind(&bind_addr)?;
|
||||
info!("Listener created successfully on {}", bind_addr);
|
||||
|
||||
|
||||
// Phase 13.4: 创建线程通信channel
|
||||
let (request_tx, request_rx) = mpsc::channel();
|
||||
let (response_tx, response_rx) = mpsc::channel();
|
||||
|
||||
let (request_tx, _request_rx) = mpsc::channel();
|
||||
let (_response_tx, response_rx) = mpsc::channel();
|
||||
|
||||
// Phase 13.4: 活动状态标记
|
||||
let active = Arc::new(Mutex::new(true));
|
||||
|
||||
|
||||
Ok(Self {
|
||||
bind_port,
|
||||
bind_address,
|
||||
@@ -103,38 +93,38 @@ impl PortForwardListener {
|
||||
active,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// 启动监听线程(Phase 13.4)
|
||||
pub fn start_listener_thread(&mut self) -> Result<()> {
|
||||
info!("Starting listener thread for port {}", self.bind_port);
|
||||
|
||||
|
||||
let listener = self.listener.try_clone()?;
|
||||
let bind_port = self.bind_port;
|
||||
let request_sender = self.request_sender.clone();
|
||||
let active = self.active.clone();
|
||||
|
||||
|
||||
// Phase 13.4: 创建独立监听线程
|
||||
thread::spawn(move || {
|
||||
info!("Listener thread started for port {}", bind_port);
|
||||
|
||||
|
||||
while *active.lock().unwrap() {
|
||||
match listener.accept() {
|
||||
Ok((stream, addr)) => {
|
||||
info!("New connection on port {}: {}", bind_port, addr);
|
||||
|
||||
|
||||
// Phase 13.4: 发送新连接请求给主线程
|
||||
let request = ListenerRequest::NewConnection {
|
||||
bind_port,
|
||||
originator_address: addr.ip().to_string(),
|
||||
originator_port: addr.port() as u32, // Phase 13.4: u16转u32
|
||||
originator_port: addr.port() as u32, // Phase 13.4: u16转u32
|
||||
stream,
|
||||
};
|
||||
|
||||
|
||||
if let Err(e) = request_sender.send(request) {
|
||||
error!("Failed to send listener request: {}", e);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
info!("Listener request sent to main thread");
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -145,32 +135,32 @@ impl PortForwardListener {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
info!("Listener thread stopped for port {}", bind_port);
|
||||
});
|
||||
|
||||
|
||||
info!("Listener thread started successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 停止监听器(Phase 13.4)
|
||||
pub fn stop_listener(&mut self) -> Result<()> {
|
||||
info!("Stopping listener for port {}", self.bind_port);
|
||||
|
||||
|
||||
// Phase 13.4: 设置active=false,线程会自动退出
|
||||
*self.active.lock().unwrap() = false;
|
||||
|
||||
|
||||
info!("Listener stopped for port {}", self.bind_port);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 获取请求接收器(Phase 13.4)
|
||||
pub fn get_request_receiver(&self) -> mpsc::Receiver<ListenerRequest> {
|
||||
// 注意:这里需要返回一个新的receiver,因为mpsc::Sender可以clone,但Receiver不能
|
||||
// 实际应用中应该使用更复杂的channel设计
|
||||
unimplemented!("Use Arc<Mutex<mpsc::Receiver>> instead")
|
||||
}
|
||||
|
||||
|
||||
/// 获取活动状态(Phase 13.4)
|
||||
pub fn is_active(&self) -> bool {
|
||||
*self.active.lock().unwrap()
|
||||
@@ -182,13 +172,19 @@ pub struct ListenerManager {
|
||||
listeners: HashMap<u32, Arc<Mutex<PortForwardListener>>>,
|
||||
}
|
||||
|
||||
impl Default for ListenerManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ListenerManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
listeners: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 创建并启动监听器(Phase 13.4)
|
||||
pub fn create_listener(
|
||||
&mut self,
|
||||
@@ -197,21 +193,21 @@ impl ListenerManager {
|
||||
security_config: SshSecurityConfig,
|
||||
) -> Result<()> {
|
||||
info!("Creating listener for port {}", bind_port);
|
||||
|
||||
|
||||
let mut listener = PortForwardListener::new(bind_port, bind_address, security_config)?;
|
||||
listener.start_listener_thread()?;
|
||||
|
||||
|
||||
let listener_arc = Arc::new(Mutex::new(listener));
|
||||
self.listeners.insert(bind_port, listener_arc);
|
||||
|
||||
|
||||
info!("Listener created and started for port {}", bind_port);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 停止监听器(Phase 13.4)
|
||||
pub fn stop_listener(&mut self, bind_port: u32) -> Result<()> {
|
||||
info!("Stopping listener for port {}", bind_port);
|
||||
|
||||
|
||||
if let Some(listener_arc) = self.listeners.remove(&bind_port) {
|
||||
let mut listener = listener_arc.lock().unwrap();
|
||||
listener.stop_listener()?;
|
||||
@@ -219,28 +215,31 @@ impl ListenerManager {
|
||||
} else {
|
||||
warn!("No listener found for port {}", bind_port);
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 获取活动监听器数量(Phase 13.4)
|
||||
pub fn active_count(&self) -> usize {
|
||||
self.listeners.values().filter(|l| l.lock().unwrap().is_active()).count()
|
||||
self.listeners
|
||||
.values()
|
||||
.filter(|l| l.lock().unwrap().is_active())
|
||||
.count()
|
||||
}
|
||||
}
|
||||
|
||||
use std::collections::HashMap; // Phase 13.4: HashMap for listener management
|
||||
use std::collections::HashMap; // Phase 13.4: HashMap for listener management
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_listener_creation() {
|
||||
let security_config = SshSecurityConfig::enterprise_default();
|
||||
let listener = PortForwardListener::new(8080, "127.0.0.1".to_string(), security_config);
|
||||
|
||||
|
||||
// 注意:实际测试需要处理端口占用问题
|
||||
assert!(listener.is_ok() || true); // 暂时跳过测试
|
||||
assert!(listener.is_ok() || true); // 暂时跳过测试
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::path::PathBuf;
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, debug, warn};
|
||||
use crate::vfs::{VfsBackend, VfsFile, VfsError};
|
||||
use crate::vfs::open_flags::OpenFlags;
|
||||
use crate::vfs::{VfsBackend, VfsFile};
|
||||
use anyhow::{anyhow, Result};
|
||||
use log::{debug, info, warn};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// MPLEX_BASE from rsync io.h
|
||||
const MPLEX_BASE: u32 = 7;
|
||||
@@ -18,7 +18,9 @@ pub(crate) enum RsyncState {
|
||||
WaitVersion,
|
||||
ReadFileList,
|
||||
/// Sum head (4 × write_int = 16 bytes) + checksum seed (4 bytes) = 20 bytes
|
||||
ReadSumHead { need: usize },
|
||||
ReadSumHead {
|
||||
need: usize,
|
||||
},
|
||||
SendSumCount,
|
||||
/// Raw file data from MSG_DATA packets
|
||||
ReadFileData,
|
||||
@@ -51,9 +53,16 @@ impl RsyncHandler {
|
||||
let mut dest = String::new();
|
||||
|
||||
for p in &parts[1..] {
|
||||
if *p == "--server" { is_server = true; continue; }
|
||||
if *p == "--sender" || p.starts_with('-') { continue; }
|
||||
if *p == "." { continue; }
|
||||
if *p == "--server" {
|
||||
is_server = true;
|
||||
continue;
|
||||
}
|
||||
if *p == "--sender" || p.starts_with('-') {
|
||||
continue;
|
||||
}
|
||||
if *p == "." {
|
||||
continue;
|
||||
}
|
||||
dest = p.to_string();
|
||||
}
|
||||
|
||||
@@ -107,8 +116,10 @@ impl RsyncHandler {
|
||||
break;
|
||||
}
|
||||
let header = u32::from_le_bytes([
|
||||
self.raw_input[0], self.raw_input[1],
|
||||
self.raw_input[2], self.raw_input[3],
|
||||
self.raw_input[0],
|
||||
self.raw_input[1],
|
||||
self.raw_input[2],
|
||||
self.raw_input[3],
|
||||
]);
|
||||
let raw_tag = ((header >> 24) & 0xFF) as u8;
|
||||
let tag = raw_tag.wrapping_sub(MPLEX_BASE as u8);
|
||||
@@ -182,12 +193,17 @@ impl RsyncHandler {
|
||||
RsyncState::WaitVersion => {
|
||||
if self.rsync_input.len() >= 4 {
|
||||
let version = u32::from_le_bytes([
|
||||
self.rsync_input[0], self.rsync_input[1],
|
||||
self.rsync_input[2], self.rsync_input[3],
|
||||
self.rsync_input[0],
|
||||
self.rsync_input[1],
|
||||
self.rsync_input[2],
|
||||
self.rsync_input[3],
|
||||
]);
|
||||
self.rsync_input.drain(..4);
|
||||
self.protocol_version = std::cmp::min(self.protocol_version, version);
|
||||
info!("rsync: negotiated protocol version {}", self.protocol_version);
|
||||
info!(
|
||||
"rsync: negotiated protocol version {}",
|
||||
self.protocol_version
|
||||
);
|
||||
self.multiplex = self.protocol_version >= 30;
|
||||
self.transition(RsyncState::ReadFileList);
|
||||
} else {
|
||||
@@ -197,7 +213,9 @@ impl RsyncHandler {
|
||||
|
||||
RsyncState::ReadFileList => {
|
||||
loop {
|
||||
if self.rsync_input.is_empty() { break; }
|
||||
if self.rsync_input.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
let flags = self.rsync_input[0];
|
||||
if flags == 0 {
|
||||
@@ -215,17 +233,25 @@ impl RsyncHandler {
|
||||
let mut pos = 1;
|
||||
|
||||
let _more_flags = if flags & 0x80 != 0 {
|
||||
if self.rsync_input.len() <= pos { break; }
|
||||
if self.rsync_input.len() <= pos {
|
||||
break;
|
||||
}
|
||||
let ef = self.rsync_input[pos];
|
||||
pos += 1;
|
||||
ef
|
||||
} else { 0 };
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let has_name = !(flags & 0x02 != 0 && self.current_file > 0);
|
||||
|
||||
if has_name {
|
||||
if let Some(nul_pos) = self.rsync_input[pos..].iter().position(|&b| b == 0) {
|
||||
let name = String::from_utf8_lossy(&self.rsync_input[pos..pos + nul_pos]).to_string();
|
||||
if let Some(nul_pos) =
|
||||
self.rsync_input[pos..].iter().position(|&b| b == 0)
|
||||
{
|
||||
let name =
|
||||
String::from_utf8_lossy(&self.rsync_input[pos..pos + nul_pos])
|
||||
.to_string();
|
||||
pos += nul_pos + 1;
|
||||
self.file_entries.push(name.clone());
|
||||
debug!("rsync: file entry: {}", name);
|
||||
@@ -269,24 +295,34 @@ impl RsyncHandler {
|
||||
RsyncState::ReadSumHead { need } => {
|
||||
if self.rsync_input.len() >= need {
|
||||
let sum_count = i32::from_le_bytes([
|
||||
self.rsync_input[0], self.rsync_input[1],
|
||||
self.rsync_input[2], self.rsync_input[3],
|
||||
self.rsync_input[0],
|
||||
self.rsync_input[1],
|
||||
self.rsync_input[2],
|
||||
self.rsync_input[3],
|
||||
]);
|
||||
let _sum_blength = i32::from_le_bytes([
|
||||
self.rsync_input[4], self.rsync_input[5],
|
||||
self.rsync_input[6], self.rsync_input[7],
|
||||
self.rsync_input[4],
|
||||
self.rsync_input[5],
|
||||
self.rsync_input[6],
|
||||
self.rsync_input[7],
|
||||
]);
|
||||
let _sum_s2length = i32::from_le_bytes([
|
||||
self.rsync_input[8], self.rsync_input[9],
|
||||
self.rsync_input[10], self.rsync_input[11],
|
||||
self.rsync_input[8],
|
||||
self.rsync_input[9],
|
||||
self.rsync_input[10],
|
||||
self.rsync_input[11],
|
||||
]);
|
||||
let _sum_remainder = i32::from_le_bytes([
|
||||
self.rsync_input[12], self.rsync_input[13],
|
||||
self.rsync_input[14], self.rsync_input[15],
|
||||
self.rsync_input[12],
|
||||
self.rsync_input[13],
|
||||
self.rsync_input[14],
|
||||
self.rsync_input[15],
|
||||
]);
|
||||
let checksum_seed = i32::from_le_bytes([
|
||||
self.rsync_input[16], self.rsync_input[17],
|
||||
self.rsync_input[18], self.rsync_input[19],
|
||||
self.rsync_input[16],
|
||||
self.rsync_input[17],
|
||||
self.rsync_input[18],
|
||||
self.rsync_input[19],
|
||||
]);
|
||||
self.rsync_input.drain(..20);
|
||||
|
||||
@@ -308,7 +344,9 @@ impl RsyncHandler {
|
||||
|
||||
RsyncState::ReadFileData => {
|
||||
let done_marker = b"RSYNCDONE";
|
||||
if let Some(pos) = self.rsync_input.windows(done_marker.len())
|
||||
if let Some(pos) = self
|
||||
.rsync_input
|
||||
.windows(done_marker.len())
|
||||
.position(|w| w == done_marker)
|
||||
{
|
||||
if pos > 0 {
|
||||
@@ -323,8 +361,11 @@ impl RsyncHandler {
|
||||
warn!("rsync flush error: {}", e);
|
||||
}
|
||||
}
|
||||
info!("rsync: file {} complete ({} bytes written to {})",
|
||||
self.file_entries.get(self.current_file).unwrap_or(&"?".to_string()),
|
||||
info!(
|
||||
"rsync: file {} complete ({} bytes written to {})",
|
||||
self.file_entries
|
||||
.get(self.current_file)
|
||||
.unwrap_or(&"?".to_string()),
|
||||
self.total_written,
|
||||
self.dest_path.display(),
|
||||
);
|
||||
@@ -332,8 +373,11 @@ impl RsyncHandler {
|
||||
self.current_file += 1;
|
||||
if self.current_file >= self.file_entries.len() {
|
||||
self.transition(RsyncState::Done);
|
||||
info!("rsync ALL DONE: {} bytes written to {}",
|
||||
self.total_written, self.dest_path.display());
|
||||
info!(
|
||||
"rsync ALL DONE: {} bytes written to {}",
|
||||
self.total_written,
|
||||
self.dest_path.display()
|
||||
);
|
||||
} else {
|
||||
self.transition(RsyncState::ReadSumHead { need: 20 });
|
||||
}
|
||||
@@ -360,7 +404,9 @@ impl RsyncHandler {
|
||||
self.vfs.create_dir_all(parent, 0o755).ok();
|
||||
}
|
||||
let flags = OpenFlags::new().write().create().truncate();
|
||||
let file = self.vfs.open_file(&self.dest_path, &flags)
|
||||
let file = self
|
||||
.vfs
|
||||
.open_file(&self.dest_path, &flags)
|
||||
.map_err(|e| anyhow!("open error: {}", e))?;
|
||||
self.output_file = Some(file);
|
||||
info!("rsync: opened {} for writing", self.dest_path.display());
|
||||
@@ -379,31 +425,43 @@ impl RsyncHandler {
|
||||
|
||||
/// Read rsync varint (LSB-first 7-bit groups, 0xFF prefix for negative)
|
||||
fn read_varint(buf: &[u8]) -> Option<(i32, usize)> {
|
||||
if buf.is_empty() { return None; }
|
||||
if buf.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut pos = 0;
|
||||
let mut b = buf[pos];
|
||||
pos += 1;
|
||||
|
||||
let neg = if b == 0xFF {
|
||||
if pos >= buf.len() { return None; }
|
||||
if pos >= buf.len() {
|
||||
return None;
|
||||
}
|
||||
b = buf[pos];
|
||||
pos += 1;
|
||||
true
|
||||
} else { false };
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let mut x = (b & 0x7F) as i32;
|
||||
let mut shift = 7;
|
||||
|
||||
while b & 0x80 != 0 {
|
||||
if pos >= buf.len() { return None; }
|
||||
if pos >= buf.len() {
|
||||
return None;
|
||||
}
|
||||
b = buf[pos];
|
||||
pos += 1;
|
||||
x |= ((b & 0x7F) as i32) << shift;
|
||||
shift += 7;
|
||||
}
|
||||
|
||||
if neg { Some((-x, pos)) } else { Some((x, pos)) }
|
||||
if neg {
|
||||
Some((-x, pos))
|
||||
} else {
|
||||
Some((x, pos))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -419,8 +477,9 @@ mod tests {
|
||||
fn test_parse_command() {
|
||||
let h = RsyncHandler::parse_rsync_command(
|
||||
"rsync --server -g -l -o -p -D -r -t -v --dirs . /tmp/upload.bin",
|
||||
make_vfs()
|
||||
).unwrap();
|
||||
make_vfs(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(h.dest_path, PathBuf::from("/tmp/upload.bin"));
|
||||
}
|
||||
|
||||
@@ -428,14 +487,16 @@ mod tests {
|
||||
fn test_parse_command_sender() {
|
||||
let h = RsyncHandler::parse_rsync_command(
|
||||
"rsync --server --sender -vlogDtprz . /home/user/file.txt",
|
||||
make_vfs()
|
||||
).unwrap();
|
||||
make_vfs(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(h.dest_path, PathBuf::from("/home/user/file.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_version_exchange() {
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap();
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs())
|
||||
.unwrap();
|
||||
let output = h.drain_output();
|
||||
assert_eq!(output, b"\x1e\x00\x00\x00");
|
||||
assert_eq!(h.state, RsyncState::WaitVersion);
|
||||
@@ -447,7 +508,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_version_negotiate_down() {
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs()).unwrap();
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/test.bin", make_vfs())
|
||||
.unwrap();
|
||||
let _ = h.drain_output();
|
||||
h.feed(b"\x1d\x00\x00\x00").unwrap();
|
||||
assert_eq!(h.protocol_version, 29);
|
||||
@@ -464,26 +526,33 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_file_list_multiplex() {
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap();
|
||||
let mut h =
|
||||
RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs())
|
||||
.unwrap();
|
||||
let _ = h.drain_output();
|
||||
h.feed(b"\x1e\x00\x00\x00").unwrap();
|
||||
assert!(h.multiplex);
|
||||
|
||||
let mut flist = Vec::new();
|
||||
// File list: flags=1 (has name), then name with NUL terminator
|
||||
flist.push(1); // flags: has name
|
||||
flist.push(1); // flags: has name
|
||||
flist.extend_from_slice(b"test.txt");
|
||||
flist.push(0); // name terminator
|
||||
flist.push(0); // name terminator
|
||||
|
||||
fn write_varint(buf: &mut Vec<u8>, val: i32) {
|
||||
if val == 0 { buf.push(0); return; }
|
||||
if val == 0 {
|
||||
buf.push(0);
|
||||
return;
|
||||
}
|
||||
if val < 0 {
|
||||
buf.push(0xFF);
|
||||
let mut v = (-val) as u32;
|
||||
while v > 0 {
|
||||
let mut byte = (v & 0x7F) as u8;
|
||||
v >>= 7;
|
||||
if v > 0 { byte |= 0x80; }
|
||||
if v > 0 {
|
||||
byte |= 0x80;
|
||||
}
|
||||
buf.push(byte);
|
||||
}
|
||||
} else {
|
||||
@@ -491,7 +560,9 @@ mod tests {
|
||||
while v > 0 {
|
||||
let mut byte = (v & 0x7F) as u8;
|
||||
v >>= 7;
|
||||
if v > 0 { byte |= 0x80; }
|
||||
if v > 0 {
|
||||
byte |= 0x80;
|
||||
}
|
||||
buf.push(byte);
|
||||
}
|
||||
}
|
||||
@@ -502,7 +573,7 @@ mod tests {
|
||||
write_varint(&mut flist, 1700000000);
|
||||
write_varint(&mut flist, 100);
|
||||
write_varint(&mut flist, 0);
|
||||
flist.push(0); // file list end marker
|
||||
flist.push(0); // file list end marker
|
||||
|
||||
let mut sum_head = Vec::new();
|
||||
sum_head.extend_from_slice(&0i32.to_le_bytes());
|
||||
@@ -527,22 +598,51 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_file_data_multiplex() {
|
||||
let mut h = RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs()).unwrap();
|
||||
let mut h =
|
||||
RsyncHandler::parse_rsync_command("rsync --server . /tmp/rsync_test.bin", make_vfs())
|
||||
.unwrap();
|
||||
let _ = h.drain_output();
|
||||
h.feed(b"\x1e\x00\x00\x00").unwrap();
|
||||
|
||||
let mut flist = Vec::new();
|
||||
flist.push(1); // flags: has name
|
||||
flist.push(1); // flags: has name
|
||||
flist.extend_from_slice(b"test.bin");
|
||||
flist.push(0);
|
||||
fn wv(buf: &mut Vec<u8>, val: i32) {
|
||||
if val == 0 { buf.push(0); return; }
|
||||
if val < 0 { buf.push(0xFF); let mut v = (-val) as u32; while v > 0 { let mut byte = (v & 0x7F) as u8; v >>= 7; if v > 0 { byte |= 0x80; } buf.push(byte); } }
|
||||
else { let mut v = val as u32; while v > 0 { let mut byte = (v & 0x7F) as u8; v >>= 7; if v > 0 { byte |= 0x80; } buf.push(byte); } }
|
||||
if val == 0 {
|
||||
buf.push(0);
|
||||
return;
|
||||
}
|
||||
if val < 0 {
|
||||
buf.push(0xFF);
|
||||
let mut v = (-val) as u32;
|
||||
while v > 0 {
|
||||
let mut byte = (v & 0x7F) as u8;
|
||||
v >>= 7;
|
||||
if v > 0 {
|
||||
byte |= 0x80;
|
||||
}
|
||||
buf.push(byte);
|
||||
}
|
||||
} else {
|
||||
let mut v = val as u32;
|
||||
while v > 0 {
|
||||
let mut byte = (v & 0x7F) as u8;
|
||||
v >>= 7;
|
||||
if v > 0 {
|
||||
byte |= 0x80;
|
||||
}
|
||||
buf.push(byte);
|
||||
}
|
||||
}
|
||||
}
|
||||
wv(&mut flist, 33188); wv(&mut flist, 501); wv(&mut flist, 20);
|
||||
wv(&mut flist, 1700000000); wv(&mut flist, 100); wv(&mut flist, 0);
|
||||
flist.push(0); // file list end
|
||||
wv(&mut flist, 33188);
|
||||
wv(&mut flist, 501);
|
||||
wv(&mut flist, 20);
|
||||
wv(&mut flist, 1700000000);
|
||||
wv(&mut flist, 100);
|
||||
wv(&mut flist, 0);
|
||||
flist.push(0); // file list end
|
||||
h.feed(&build_multiplex(&flist)).unwrap();
|
||||
|
||||
let mut sh = Vec::new();
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
// SCP协议实现(Phase 8)
|
||||
// 参考OpenSSH scp.c源码
|
||||
|
||||
use crate::vfs::{VfsBackend, VfsFile, VfsError, VfsStat};
|
||||
use crate::vfs::open_flags::OpenFlags;
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, warn, debug};
|
||||
use crate::vfs::{VfsBackend, VfsFile, VfsStat};
|
||||
use anyhow::{anyhow, Result};
|
||||
use log::{debug, info, warn};
|
||||
use std::io::{BufRead, Read, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::io::{Read, Write, BufRead};
|
||||
use std::time::SystemTime;
|
||||
|
||||
/// SCP Handler(参考OpenSSH scp.c)
|
||||
pub struct ScpHandler {
|
||||
@@ -38,13 +37,13 @@ impl ScpHandler {
|
||||
/// 解析SCP命令(参考OpenSSH scp.c: parse_command())
|
||||
pub fn parse_scp_command(command: &str, vfs: Box<dyn VfsBackend>) -> Result<Self> {
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
|
||||
|
||||
if parts.len() < 2 || parts[0] != "scp" {
|
||||
return Err(anyhow!("Invalid SCP command: {}", command));
|
||||
}
|
||||
|
||||
let mut handler = ScpHandler::new(PathBuf::from("/tmp"), vfs);
|
||||
|
||||
|
||||
for part in &parts[1..] {
|
||||
match part {
|
||||
&"-f" => handler.mode = ScpMode::Source,
|
||||
@@ -71,10 +70,15 @@ impl ScpHandler {
|
||||
|
||||
/// SCP Source Mode(scp -f,发送文件)
|
||||
fn handle_source_mode(&self, channel: &mut dyn ReadWrite) -> Result<()> {
|
||||
info!("SCP source mode: sending files from {}", self.root_dir.display());
|
||||
info!(
|
||||
"SCP source mode: sending files from {}",
|
||||
self.root_dir.display()
|
||||
);
|
||||
|
||||
let full_path = self.resolve_path(&self.root_dir.to_string_lossy())?;
|
||||
let stat = self.vfs.stat(&full_path)
|
||||
let stat = self
|
||||
.vfs
|
||||
.stat(&full_path)
|
||||
.map_err(|e| anyhow!("stat error: {}", e))?;
|
||||
|
||||
if stat.is_dir {
|
||||
@@ -91,16 +95,19 @@ impl ScpHandler {
|
||||
|
||||
/// SCP Destination Mode(scp -t,接收文件)
|
||||
fn handle_destination_mode(&mut self, channel: &mut dyn ReadWrite) -> Result<()> {
|
||||
info!("SCP destination mode: receiving files to {}", self.root_dir.display());
|
||||
info!(
|
||||
"SCP destination mode: receiving files to {}",
|
||||
self.root_dir.display()
|
||||
);
|
||||
|
||||
channel.write_all(&[0])?;
|
||||
channel.flush()?;
|
||||
|
||||
|
||||
let mut buffer = String::new();
|
||||
|
||||
|
||||
loop {
|
||||
buffer.clear();
|
||||
|
||||
|
||||
let mut reader = std::io::BufReader::new(&mut *channel);
|
||||
match reader.read_line(&mut buffer)? {
|
||||
0 => break,
|
||||
@@ -130,7 +137,9 @@ impl ScpHandler {
|
||||
|
||||
/// 发送文件(参考OpenSSH scp.c: source())
|
||||
fn send_file(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
|
||||
let stat = self.vfs.stat(path)
|
||||
let stat = self
|
||||
.vfs
|
||||
.stat(path)
|
||||
.map_err(|e| anyhow!("stat error: {}", e))?;
|
||||
let size = stat.size;
|
||||
let filename = path.file_name().unwrap().to_string_lossy();
|
||||
@@ -146,13 +155,16 @@ impl ScpHandler {
|
||||
}
|
||||
|
||||
let flags = OpenFlags::new().read();
|
||||
let mut file = self.vfs.open_file(path, &flags)
|
||||
let mut file = self
|
||||
.vfs
|
||||
.open_file(path, &flags)
|
||||
.map_err(|e| anyhow!("open error: {}", e))?;
|
||||
|
||||
let mut buffer = vec![0u8; 8192];
|
||||
|
||||
loop {
|
||||
let n = file.read(&mut buffer)
|
||||
let n = file
|
||||
.read(&mut buffer)
|
||||
.map_err(|e| anyhow!("read error: {}", e))?;
|
||||
if n == 0 {
|
||||
break;
|
||||
@@ -188,7 +200,9 @@ impl ScpHandler {
|
||||
return Err(anyhow!("SCP directory command rejected"));
|
||||
}
|
||||
|
||||
let entries = self.vfs.read_dir(path)
|
||||
let entries = self
|
||||
.vfs
|
||||
.read_dir(path)
|
||||
.map_err(|e| anyhow!("read_dir error: {}", e))?;
|
||||
|
||||
for entry in &entries {
|
||||
@@ -218,7 +232,7 @@ impl ScpHandler {
|
||||
/// 处理文件命令(C0644 size filename)
|
||||
fn handle_file_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
|
||||
|
||||
if parts.len() != 3 {
|
||||
return self.send_error(channel, "Invalid file command format");
|
||||
}
|
||||
@@ -227,7 +241,10 @@ impl ScpHandler {
|
||||
let size: u64 = parts[1].parse()?;
|
||||
let filename = parts[2];
|
||||
|
||||
debug!("SCP receive file: mode={}, size={}, name={}", mode_str, size, filename);
|
||||
debug!(
|
||||
"SCP receive file: mode={}, size={}, name={}",
|
||||
mode_str, size, filename
|
||||
);
|
||||
|
||||
if size > 1024 * 1024 * 1024 {
|
||||
return self.send_error(channel, "File too large (max 1GB)");
|
||||
@@ -236,7 +253,9 @@ impl ScpHandler {
|
||||
let full_path = self.resolve_path(filename)?;
|
||||
|
||||
let flags = OpenFlags::new().write().create().truncate();
|
||||
let mut file = self.vfs.open_file(&full_path, &flags)
|
||||
let mut file = self
|
||||
.vfs
|
||||
.open_file(&full_path, &flags)
|
||||
.map_err(|e| anyhow!("open error: {}", e))?;
|
||||
|
||||
channel.write_all(&[0])?;
|
||||
@@ -263,7 +282,8 @@ impl ScpHandler {
|
||||
if mode_int != 0 {
|
||||
let mut set_stat = VfsStat::new();
|
||||
set_stat.mode = mode_int;
|
||||
self.vfs.set_stat(&full_path, &set_stat)
|
||||
self.vfs
|
||||
.set_stat(&full_path, &set_stat)
|
||||
.map_err(|e| anyhow!("set_stat error: {}", e))?;
|
||||
}
|
||||
|
||||
@@ -280,7 +300,7 @@ impl ScpHandler {
|
||||
/// 处理目录命令(D0755 0 dirname)
|
||||
fn handle_directory_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
|
||||
|
||||
if parts.len() != 3 {
|
||||
return self.send_error(channel, "Invalid directory command format");
|
||||
}
|
||||
@@ -297,7 +317,8 @@ impl ScpHandler {
|
||||
let full_path = self.resolve_path(dirname)?;
|
||||
|
||||
let mode_int: u32 = mode_str.parse()?;
|
||||
self.vfs.create_dir_all(&full_path, mode_int)
|
||||
self.vfs
|
||||
.create_dir_all(&full_path, mode_int)
|
||||
.map_err(|e| anyhow!("create_dir_all error: {}", e))?;
|
||||
|
||||
channel.write_all(&[0])?;
|
||||
@@ -326,7 +347,7 @@ impl ScpHandler {
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
|
||||
|
||||
if parts.len() != 3 {
|
||||
return self.send_error(channel, "Invalid time command format");
|
||||
}
|
||||
@@ -353,11 +374,15 @@ impl ScpHandler {
|
||||
/// 路径解析(安全性检查)
|
||||
fn resolve_path(&self, path: &str) -> Result<PathBuf> {
|
||||
let full_path = self.root_dir.join(path);
|
||||
|
||||
let canonical_path = self.vfs.real_path(&full_path)
|
||||
|
||||
let canonical_path = self
|
||||
.vfs
|
||||
.real_path(&full_path)
|
||||
.map_err(|e| anyhow!("Path resolution error: {}", e))?;
|
||||
|
||||
let root_canonical = self.vfs.real_path(&self.root_dir)
|
||||
let root_canonical = self
|
||||
.vfs
|
||||
.real_path(&self.root_dir)
|
||||
.map_err(|e| anyhow!("Root path resolution error: {}", e))?;
|
||||
|
||||
if !canonical_path.starts_with(&root_canonical) {
|
||||
@@ -383,20 +408,23 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_scp_command_parse() {
|
||||
let handler = ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
let handler =
|
||||
ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
assert_eq!(handler.mode, ScpMode::Destination);
|
||||
assert_eq!(handler.root_dir, PathBuf::from("/tmp"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scp_recursive_parse() {
|
||||
let handler = ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
let handler =
|
||||
ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
assert!(handler.recursive);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scp_source_parse() {
|
||||
let handler = ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
let handler =
|
||||
ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap();
|
||||
assert_eq!(handler.mode, ScpMode::Source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,32 +1,32 @@
|
||||
// SSH服务器完整实现(Phase 1-7集成版 + Phase 13端口转发)
|
||||
// 参考OpenSSH sshd.c: complete SSH/SFTP flow + port forwarding
|
||||
|
||||
use crate::ssh_server::version::VersionExchange;
|
||||
use crate::ssh_server::packet::{SshPacket, PacketType};
|
||||
use crate::ssh_server::kex::{KexResult, KexProposal};
|
||||
use crate::ssh_server::kex_complete::{KexState};
|
||||
use crate::ssh_server::auth::{AuthHandler, AuthResult};
|
||||
use crate::provider::sqlite::SqliteProvider;
|
||||
use crate::provider::pg::PgProvider;
|
||||
use crate::provider::sqlite::SqliteProvider;
|
||||
use crate::provider::DataProvider;
|
||||
use crate::ssh_server::channel::{ChannelManager};
|
||||
use crate::ssh_server::cipher::{EncryptionContext, EncryptedPacket};
|
||||
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1
|
||||
use crate::ssh_server::port_forward::PortForwardManager; // Phase 13
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, warn, error, debug};
|
||||
use crate::ssh_server::auth::{AuthHandler, AuthResult};
|
||||
use crate::ssh_server::channel::ChannelManager;
|
||||
use crate::ssh_server::cipher::{EncryptedPacket, EncryptionContext};
|
||||
use crate::ssh_server::kex::{KexProposal, KexResult};
|
||||
use crate::ssh_server::kex_complete::KexState;
|
||||
use crate::ssh_server::packet::{PacketType, SshPacket};
|
||||
use crate::ssh_server::port_forward::PortForwardManager; // Phase 13
|
||||
use crate::ssh_server::ssh_security_config::SshSecurityConfig; // Phase 13.1
|
||||
use crate::ssh_server::version::VersionExchange;
|
||||
use anyhow::{anyhow, Result};
|
||||
use log::{error, info, warn};
|
||||
use std::io::{Read, Write};
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::path::PathBuf;
|
||||
use std::thread;
|
||||
use std::io::{Read, Write};
|
||||
use std::sync::{Arc, Mutex}; // Phase 13: 端口转发线程同步
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread; // Phase 13: 端口转发线程同步
|
||||
|
||||
/// SSH服务器配置(Phase 13.1企业级安全配置)
|
||||
pub struct SshServerConfig {
|
||||
pub port: u16,
|
||||
pub bind_address: String,
|
||||
pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置
|
||||
pub pg_conn: Option<String>, // PostgreSQL连接字符串(SFTPGo兼容认证)
|
||||
pub security_config: SshSecurityConfig, // Phase 13.1: 企业级安全配置
|
||||
pub pg_conn: Option<String>, // PostgreSQL连接字符串(SFTPGo兼容认证)
|
||||
}
|
||||
|
||||
impl Default for SshServerConfig {
|
||||
@@ -34,7 +34,7 @@ impl Default for SshServerConfig {
|
||||
Self {
|
||||
port: 2024,
|
||||
bind_address: "127.0.0.1".to_string(),
|
||||
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1
|
||||
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1
|
||||
pg_conn: None,
|
||||
}
|
||||
}
|
||||
@@ -56,43 +56,48 @@ impl SshServerConfig {
|
||||
/// SSH服务器主结构(Phase 1-13完整版)
|
||||
pub struct SshServer {
|
||||
config: SshServerConfig,
|
||||
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1: 共享安全配置
|
||||
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1: 共享安全配置
|
||||
}
|
||||
|
||||
impl SshServer {
|
||||
pub fn new(config: SshServerConfig) -> Self {
|
||||
let security_config = Arc::new(Mutex::new(config.security_config.clone())); // Phase 13.1: 先clone
|
||||
let security_config = Arc::new(Mutex::new(config.security_config.clone())); // Phase 13.1: 先clone
|
||||
Self {
|
||||
config,
|
||||
security_config, // Phase 13.1
|
||||
security_config, // Phase 13.1
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn run(&self) -> Result<()> {
|
||||
let bind_addr = format!("{}:{}", self.config.bind_address, self.config.port);
|
||||
let listener = TcpListener::bind(&bind_addr)?;
|
||||
|
||||
|
||||
info!("MarkBaseSSH server listening on {}", bind_addr);
|
||||
info!("Implementation: Complete SSH/SFTP + Port Forwarding (Phase 1-13)");
|
||||
info!("Security config: GatewayPorts={}, PermitOpen={:?}, MaxSessions={}",
|
||||
info!(
|
||||
"Security config: GatewayPorts={}, PermitOpen={:?}, MaxSessions={}",
|
||||
self.config.security_config.gateway_ports,
|
||||
self.config.security_config.permit_open,
|
||||
self.config.security_config.max_sessions);
|
||||
|
||||
let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置
|
||||
self.config.security_config.max_sessions
|
||||
);
|
||||
|
||||
let security_config = self.security_config.clone(); // Phase 13.1: 共享安全配置
|
||||
let pg_conn = self.config.pg_conn.clone();
|
||||
|
||||
|
||||
for stream in listener.incoming() {
|
||||
match stream {
|
||||
Ok(stream) => {
|
||||
let client_addr = stream.peer_addr()?;
|
||||
info!("New SSH connection from {}", client_addr);
|
||||
|
||||
let security_config_clone = security_config.clone(); // Phase 13.1
|
||||
|
||||
let security_config_clone = security_config.clone(); // Phase 13.1
|
||||
let pg_conn_clone = pg_conn.clone();
|
||||
|
||||
|
||||
thread::spawn(move || {
|
||||
if let Err(e) = handle_connection_complete(stream, security_config_clone, pg_conn_clone) { // Phase 13.1
|
||||
if let Err(e) =
|
||||
handle_connection_complete(stream, security_config_clone, pg_conn_clone)
|
||||
{
|
||||
// Phase 13.1
|
||||
error!("Connection error: {}", e);
|
||||
}
|
||||
});
|
||||
@@ -102,90 +107,127 @@ impl SshServer {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// 处理完整SSH连接(Phase 1-13完整流程)
|
||||
fn handle_connection_complete(stream: TcpStream, security_config: Arc<Mutex<SshSecurityConfig>>, pg_conn: Option<String>) -> Result<()> {
|
||||
fn handle_connection_complete(
|
||||
stream: TcpStream,
|
||||
security_config: Arc<Mutex<SshSecurityConfig>>,
|
||||
pg_conn: Option<String>,
|
||||
) -> Result<()> {
|
||||
info!("Handling client connection (Phase 1-13 complete flow with port forwarding)");
|
||||
|
||||
|
||||
// Phase 13.1: 增加活动会话数
|
||||
{
|
||||
let mut security = security_config.lock().unwrap();
|
||||
security.increment_sessions()?;
|
||||
}
|
||||
|
||||
|
||||
let mut stream = stream;
|
||||
|
||||
|
||||
// Phase 1: 版本交换
|
||||
let client_version = VersionExchange::exchange(&mut stream)?;
|
||||
info!("Version exchange: client={}, server=SSH-2.0-MarkBaseSSH_1.0", client_version);
|
||||
|
||||
info!(
|
||||
"Version exchange: client={}, server=SSH-2.0-MarkBaseSSH_1.0",
|
||||
client_version
|
||||
);
|
||||
|
||||
// Phase 2: 箋法协商
|
||||
let (kex_result, server_kexinit, client_kexinit) = perform_kex_negotiation_complete(&mut stream)?;
|
||||
info!("KEX negotiation: KEX={}, Cipher={}", kex_result.kex_algorithm, kex_result.encryption_ctos);
|
||||
|
||||
let (kex_result, server_kexinit, client_kexinit) =
|
||||
perform_kex_negotiation_complete(&mut stream)?;
|
||||
info!(
|
||||
"KEX negotiation: KEX={}, Cipher={}",
|
||||
kex_result.kex_algorithm, kex_result.encryption_ctos
|
||||
);
|
||||
|
||||
// Phase 3: 密钥交换完整流程
|
||||
let mut encryption_ctx = perform_complete_kex_exchange(&mut stream, client_version.clone(), kex_result, server_kexinit, client_kexinit)?;
|
||||
let mut encryption_ctx = perform_complete_kex_exchange(
|
||||
&mut stream,
|
||||
client_version.clone(),
|
||||
kex_result,
|
||||
server_kexinit,
|
||||
client_kexinit,
|
||||
)?;
|
||||
info!("Key exchange completed, encryption channel ready");
|
||||
|
||||
|
||||
// Phase 5: SSH认证(SFTPGo兼容 — PostgreSQL或SQLite)
|
||||
let provider: Box<dyn DataProvider> = if let Some(ref conn_str) = pg_conn {
|
||||
info!("Using PostgreSQL auth provider (SFTPGo-compatible): {}", conn_str);
|
||||
Box::new(PgProvider::new(conn_str)
|
||||
.map_err(|e| anyhow!("Failed to init PgProvider: {}", e))?)
|
||||
info!(
|
||||
"Using PostgreSQL auth provider (SFTPGo-compatible): {}",
|
||||
conn_str
|
||||
);
|
||||
Box::new(
|
||||
PgProvider::new(conn_str).map_err(|e| anyhow!("Failed to init PgProvider: {}", e))?,
|
||||
)
|
||||
} else {
|
||||
info!("Using SQLite auth provider");
|
||||
Box::new(SqliteProvider::new("data/auth.sqlite")
|
||||
.map_err(|e| anyhow!("Failed to init SqliteProvider: {}", e))?)
|
||||
Box::new(
|
||||
SqliteProvider::new("data/auth.sqlite")
|
||||
.map_err(|e| anyhow!("Failed to init SqliteProvider: {}", e))?,
|
||||
)
|
||||
};
|
||||
let mut auth_handler = AuthHandler::new(provider);
|
||||
let auth_user = perform_ssh_auth(&mut stream, &mut auth_handler, &mut encryption_ctx)?;
|
||||
info!("SSH authentication succeeded: user={}", auth_user.username);
|
||||
|
||||
|
||||
// Phase 6: SSH Channel管理(参考OpenSSH channel.c)
|
||||
let mut channel_manager = ChannelManager::new(auth_user.home_dir.clone());
|
||||
|
||||
|
||||
// Phase 13: PortForwardManager初始化
|
||||
let mut port_forward_manager = PortForwardManager::new();
|
||||
|
||||
|
||||
// Phase 6-13: SSH服务循环(处理channel请求 + 端口转发)
|
||||
let security_config_clone = security_config.clone(); // Phase 13.1: clone for service loop
|
||||
handle_ssh_service_loop(&mut stream, &mut channel_manager, &mut encryption_ctx, &mut port_forward_manager, security_config_clone)?;
|
||||
|
||||
let security_config_clone = security_config.clone(); // Phase 13.1: clone for service loop
|
||||
handle_ssh_service_loop(
|
||||
&mut stream,
|
||||
&mut channel_manager,
|
||||
&mut encryption_ctx,
|
||||
&mut port_forward_manager,
|
||||
security_config_clone,
|
||||
)?;
|
||||
|
||||
info!("SSH session completed successfully");
|
||||
|
||||
|
||||
// Phase 13.1: 减少活动会话数
|
||||
{
|
||||
let mut security = security_config.lock().unwrap();
|
||||
security.decrement_sessions();
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 完整算法协商(返回KEXINIT payloads)
|
||||
fn perform_kex_negotiation_complete(stream: &mut TcpStream) -> Result<(KexResult, SshPacket, SshPacket)> {
|
||||
fn perform_kex_negotiation_complete(
|
||||
stream: &mut TcpStream,
|
||||
) -> Result<(KexResult, SshPacket, SshPacket)> {
|
||||
info!("Starting complete KEX negotiation");
|
||||
|
||||
|
||||
// 1. 发送服务器KEXINIT
|
||||
let server_proposal = KexProposal::server_default();
|
||||
let server_kexinit = server_proposal.to_kexinit_packet()?;
|
||||
server_kexinit.write(stream)?;
|
||||
|
||||
info!("Sent server KEXINIT (payload size: {} bytes)", server_kexinit.payload.len());
|
||||
|
||||
|
||||
info!(
|
||||
"Sent server KEXINIT (payload size: {} bytes)",
|
||||
server_kexinit.payload.len()
|
||||
);
|
||||
|
||||
// 2. 接收客户端KEXINIT
|
||||
let client_kexinit = SshPacket::read(stream)?;
|
||||
let client_proposal = KexProposal::from_kexinit_packet(&client_kexinit)?;
|
||||
|
||||
info!("Received client KEXINIT (payload size: {} bytes)", client_kexinit.payload.len());
|
||||
|
||||
|
||||
info!(
|
||||
"Received client KEXINIT (payload size: {} bytes)",
|
||||
client_kexinit.payload.len()
|
||||
);
|
||||
|
||||
// 3. 算法匹配
|
||||
let kex_result = KexResult::choose_algorithms(&server_proposal, &client_proposal)?;
|
||||
|
||||
|
||||
Ok((kex_result, server_kexinit, client_kexinit))
|
||||
}
|
||||
|
||||
@@ -198,18 +240,18 @@ fn perform_complete_kex_exchange(
|
||||
client_kexinit: SshPacket,
|
||||
) -> Result<EncryptionContext> {
|
||||
info!("Starting complete key exchange flow");
|
||||
|
||||
|
||||
let mut kex_state = KexState::new(
|
||||
client_version,
|
||||
"SSH-2.0-MarkBaseSSH_1.0".to_string(),
|
||||
kex_result,
|
||||
)?;
|
||||
|
||||
|
||||
kex_state.save_kexinit_payloads(&client_kexinit, &server_kexinit);
|
||||
|
||||
|
||||
let kexdh_init = SshPacket::read(stream)?;
|
||||
info!("Received SSH_MSG_KEX_ECDH_INIT");
|
||||
|
||||
|
||||
let kexdh_reply = kex_state.exchange_handler.handle_kexdh_init(
|
||||
&kexdh_init,
|
||||
&kex_state.client_version,
|
||||
@@ -219,27 +261,27 @@ fn perform_complete_kex_exchange(
|
||||
)?;
|
||||
kexdh_reply.write(stream)?;
|
||||
info!("Sent SSH_MSG_KEX_ECDH_REPLY");
|
||||
|
||||
|
||||
// Strict KEX: Wait for client NEWKEYS first (OpenSSH 10.2 requirement)
|
||||
let client_newkeys = SshPacket::read(stream)?;
|
||||
kex_state.handle_newkeys(&client_newkeys)?;
|
||||
info!("Received SSH_MSG_NEWKEYS from client");
|
||||
|
||||
|
||||
// Now send server NEWKEYS
|
||||
let newkeys_packet = KexState::send_newkeys()?;
|
||||
newkeys_packet.write(stream)?;
|
||||
kex_state.newkeys_sent = true;
|
||||
info!("Sent SSH_MSG_NEWKEYS from server");
|
||||
|
||||
|
||||
if kex_state.is_encryption_ready() {
|
||||
info!("Encryption channel established successfully");
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Encryption channel not ready"));
|
||||
}
|
||||
|
||||
|
||||
let session_keys = kex_state.exchange_handler.compute_session_keys()?;
|
||||
let encryption_ctx = EncryptionContext::from_session_keys(&session_keys);
|
||||
|
||||
|
||||
Ok(encryption_ctx)
|
||||
}
|
||||
|
||||
@@ -250,102 +292,100 @@ pub struct AuthUser {
|
||||
}
|
||||
|
||||
fn perform_ssh_auth(
|
||||
stream: &mut TcpStream,
|
||||
stream: &mut TcpStream,
|
||||
auth_handler: &mut AuthHandler,
|
||||
encryption_ctx: &mut EncryptionContext,
|
||||
) -> Result<AuthUser> {
|
||||
info!("Starting SSH authentication");
|
||||
info!("Encryption context: key_ctos_len={}, key_stoc_len={}, iv_ctos_len={}, iv_stoc_len={}",
|
||||
info!(
|
||||
"Encryption context: key_ctos_len={}, key_stoc_len={}, iv_ctos_len={}, iv_stoc_len={}",
|
||||
encryption_ctx.encryption_key_ctos.len(),
|
||||
encryption_ctx.encryption_key_stoc.len(),
|
||||
encryption_ctx.iv_ctos.len(),
|
||||
encryption_ctx.iv_stoc.len()
|
||||
);
|
||||
|
||||
|
||||
// OpenSSH strict KEX: SSH_MSG_EXT_INFO may be sent before SSH_MSG_SERVICE_REQUEST
|
||||
let mut encrypted_request = EncryptedPacket::read(stream, encryption_ctx, true)?;
|
||||
let payload = encrypted_request.payload();
|
||||
|
||||
|
||||
if payload[0] == PacketType::SSH_MSG_EXT_INFO as u8 {
|
||||
info!("Received SSH_MSG_EXT_INFO, reading next packet");
|
||||
encrypted_request = EncryptedPacket::read(stream, encryption_ctx, true)?;
|
||||
}
|
||||
|
||||
|
||||
let payload = encrypted_request.payload();
|
||||
info!("Received packet type: {}", payload[0]);
|
||||
|
||||
|
||||
if payload[0] != PacketType::SSH_MSG_SERVICE_REQUEST as u8 {
|
||||
return Err(anyhow!("Expected SSH_MSG_SERVICE_REQUEST, got type {}", payload[0]));
|
||||
return Err(anyhow!(
|
||||
"Expected SSH_MSG_SERVICE_REQUEST, got type {}",
|
||||
payload[0]
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
let mut cursor = std::io::Cursor::new(&payload[1..]);
|
||||
let service_name_len = cursor.read_u32::<BigEndian>()?;
|
||||
let mut service_name = vec![0u8; service_name_len as usize];
|
||||
cursor.read_exact(&mut service_name)?;
|
||||
let service_name_str = String::from_utf8_lossy(&service_name);
|
||||
|
||||
|
||||
if service_name_str != "ssh-userauth" {
|
||||
return Err(anyhow!("Unsupported service: {}", service_name_str));
|
||||
}
|
||||
|
||||
|
||||
let mut service_accept_payload = Vec::new();
|
||||
service_accept_payload.write_u8(PacketType::SSH_MSG_SERVICE_ACCEPT as u8)?;
|
||||
service_accept_payload.write_u32::<BigEndian>(12)?; // "ssh-userauth" length is 12, not 14!
|
||||
service_accept_payload.write_u32::<BigEndian>(12)?; // "ssh-userauth" length is 12, not 14!
|
||||
service_accept_payload.write_all("ssh-userauth".as_bytes())?;
|
||||
|
||||
let encrypted_accept = EncryptedPacket::new(
|
||||
&service_accept_payload,
|
||||
encryption_ctx,
|
||||
true,
|
||||
)?;
|
||||
|
||||
let encrypted_accept = EncryptedPacket::new(&service_accept_payload, encryption_ctx, true)?;
|
||||
encrypted_accept.write(stream)?;
|
||||
info!("Sent encrypted SSH_MSG_SERVICE_ACCEPT");
|
||||
|
||||
|
||||
let session_id = encryption_ctx.session_id.clone();
|
||||
|
||||
|
||||
loop {
|
||||
let auth_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; // Reading from client, use cipher_ctos
|
||||
let auth_packet = EncryptedPacket::read(stream, encryption_ctx, true)?; // Reading from client, use cipher_ctos
|
||||
let auth_payload = auth_packet.payload();
|
||||
info!("Received encrypted SSH_MSG_USERAUTH_REQUEST");
|
||||
|
||||
|
||||
let auth_request = SshPacket::new(auth_payload.to_vec());
|
||||
|
||||
|
||||
match auth_handler.handle_userauth_request(&auth_request, &session_id)? {
|
||||
AuthResult::Success => {
|
||||
let success_payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8];
|
||||
let encrypted_success = EncryptedPacket::new(
|
||||
&success_payload,
|
||||
encryption_ctx,
|
||||
true,
|
||||
)?;
|
||||
let encrypted_success =
|
||||
EncryptedPacket::new(&success_payload, encryption_ctx, true)?;
|
||||
encrypted_success.write(stream)?;
|
||||
info!("Sent encrypted SSH_MSG_USERAUTH_SUCCESS");
|
||||
|
||||
|
||||
// Extract username from auth request
|
||||
let user = extract_username_from_auth_request(&auth_request)
|
||||
.unwrap_or_else(|_| "unknown".to_string());
|
||||
let home_dir = auth_handler.get_home_dir(&user)
|
||||
let home_dir = auth_handler
|
||||
.get_home_dir(&user)
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from("/Users/accusys/markbase"));
|
||||
info!("Auth success: user={}, home_dir={:?}", user, home_dir);
|
||||
return Ok(AuthUser { username: user, home_dir });
|
||||
return Ok(AuthUser {
|
||||
username: user,
|
||||
home_dir,
|
||||
});
|
||||
}
|
||||
AuthResult::Failure(message) => {
|
||||
AuthResult::Failure(message) => {
|
||||
// message包含可用的认证方法列表(如"password,publickey")
|
||||
let mut failure_payload = Vec::new();
|
||||
failure_payload.write_u8(PacketType::SSH_MSG_USERAUTH_FAILURE as u8)?;
|
||||
failure_payload.write_u32::<BigEndian>(message.len() as u32)?;
|
||||
failure_payload.write_all(message.as_bytes())?;
|
||||
failure_payload.write_u8(0)?; // partial_success = false
|
||||
|
||||
let encrypted_failure = EncryptedPacket::new(
|
||||
&failure_payload,
|
||||
encryption_ctx,
|
||||
true,
|
||||
)?;
|
||||
failure_payload.write_u8(0)?; // partial_success = false
|
||||
|
||||
let encrypted_failure =
|
||||
EncryptedPacket::new(&failure_payload, encryption_ctx, true)?;
|
||||
encrypted_failure.write(stream)?;
|
||||
warn!("Sent encrypted SSH_MSG_USERAUTH_FAILURE: {}", message);
|
||||
}
|
||||
@@ -356,27 +396,23 @@ AuthResult::Failure(message) => {
|
||||
AuthResult::PublicKeyOk(algorithm, public_key_blob) => {
|
||||
// SSH_MSG_USERAUTH_PK_OK:public key acceptable
|
||||
info!("Public key acceptable, sending USERAUTH_PK_OK");
|
||||
|
||||
|
||||
let mut pk_ok_payload = Vec::new();
|
||||
pk_ok_payload.write_u8(PacketType::SSH_MSG_USERAUTH_PK_OK as u8)?;
|
||||
|
||||
|
||||
// algorithm (SSH string)
|
||||
pk_ok_payload.write_u32::<BigEndian>(algorithm.len() as u32)?;
|
||||
pk_ok_payload.write_all(algorithm.as_bytes())?;
|
||||
|
||||
|
||||
// public key blob (SSH string)
|
||||
pk_ok_payload.write_u32::<BigEndian>(public_key_blob.len() as u32)?;
|
||||
pk_ok_payload.write_all(&public_key_blob)?;
|
||||
|
||||
let encrypted_pk_ok = EncryptedPacket::new(
|
||||
&pk_ok_payload,
|
||||
encryption_ctx,
|
||||
true,
|
||||
)?;
|
||||
|
||||
let encrypted_pk_ok = EncryptedPacket::new(&pk_ok_payload, encryption_ctx, true)?;
|
||||
encrypted_pk_ok.write(stream)?;
|
||||
info!("Sent SSH_MSG_USERAUTH_PK_OK");
|
||||
|
||||
continue; // Wait for signed request
|
||||
|
||||
continue; // Wait for signed request
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -389,16 +425,17 @@ fn handle_ssh_service_loop(
|
||||
stream: &mut TcpStream,
|
||||
channel_manager: &mut ChannelManager,
|
||||
encryption_ctx: &mut EncryptionContext,
|
||||
port_forward_manager: &mut PortForwardManager, // Phase 13
|
||||
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1
|
||||
port_forward_manager: &mut PortForwardManager, // Phase 13
|
||||
security_config: Arc<Mutex<SshSecurityConfig>>, // Phase 13.1
|
||||
) -> Result<()> {
|
||||
info!("Starting SSH service loop (Phase 14.2: unified poll + child status)");
|
||||
|
||||
|
||||
loop {
|
||||
// ⭐⭐⭐⭐⭐ Phase 14.2: 统一poll + child状态检测
|
||||
// 返回三元组:(stdout_packets, client_has_data, child_exited)
|
||||
let (stdout_packets, client_has_data, child_exited) = channel_manager.poll_exec_stdout_and_client(stream)?;
|
||||
|
||||
let (stdout_packets, client_has_data, child_exited) =
|
||||
channel_manager.poll_exec_stdout_and_client(stream)?;
|
||||
|
||||
// 1. 发送stdout/stderr数据(如果有)
|
||||
if let Some(packets) = stdout_packets {
|
||||
for packet in packets {
|
||||
@@ -407,93 +444,100 @@ fn handle_ssh_service_loop(
|
||||
info!("Sent stdout/stderr data (Phase 14.2)");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 2. 处理child exited(发送EOF + CLOSE)
|
||||
if child_exited {
|
||||
info!("Child process exited, sending SSH_MSG_CHANNEL_EOF + CLOSE");
|
||||
|
||||
|
||||
// ⭐⭐⭐⭐⭐ Phase 14.2: 使用ChannelManager.handle_child_exited()
|
||||
let exit_packets = channel_manager.handle_child_exited()?;
|
||||
for packet in exit_packets {
|
||||
let encrypted_packet = EncryptedPacket::new(&packet.payload, encryption_ctx, true)?;
|
||||
encrypted_packet.write(stream)?;
|
||||
}
|
||||
|
||||
|
||||
// 继续处理client数据(可能还有其他请求)
|
||||
}
|
||||
|
||||
|
||||
// 3. 处理client数据(如果有)
|
||||
if !client_has_data {
|
||||
// client没有数据,继续下一轮循环
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
// client有数据,读取并处理
|
||||
let encrypted_packet = EncryptedPacket::read(stream, encryption_ctx, true)?;
|
||||
let packet = SshPacket::new(encrypted_packet.payload().to_vec());
|
||||
|
||||
|
||||
match packet.payload.first() {
|
||||
// Phase 13: SSH_MSG_GLOBAL_REQUEST处理(端口转发)
|
||||
Some(&pt) if pt == PacketType::SSH_MSG_GLOBAL_REQUEST as u8 => {
|
||||
info!("Received SSH_MSG_GLOBAL_REQUEST (port forwarding)");
|
||||
|
||||
|
||||
// Phase 13.1: 安全配置验证
|
||||
let security = security_config.lock().unwrap();
|
||||
if !security.allow_tcp_forwarding {
|
||||
warn!("TCP forwarding disabled by security config");
|
||||
let failure_packet = vec![PacketType::SSH_MSG_REQUEST_FAILURE as u8];
|
||||
let encrypted_failure = EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
|
||||
let encrypted_failure =
|
||||
EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
|
||||
encrypted_failure.write(stream)?;
|
||||
info!("Sent SSH_MSG_REQUEST_FAILURE (TCP forwarding disabled)");
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
// Phase 13.2: 调用PortForwardManager处理(传递security_config)
|
||||
let (success, response) = port_forward_manager.handle_global_request(&packet.payload, &security)?;
|
||||
drop(security); // 释放锁
|
||||
|
||||
let (success, response) =
|
||||
port_forward_manager.handle_global_request(&packet.payload, &security)?;
|
||||
drop(security); // 释放锁
|
||||
|
||||
if success {
|
||||
if let Some(response_data) = response {
|
||||
let encrypted_response = EncryptedPacket::new(&response_data, encryption_ctx, true)?;
|
||||
let encrypted_response =
|
||||
EncryptedPacket::new(&response_data, encryption_ctx, true)?;
|
||||
encrypted_response.write(stream)?;
|
||||
info!("Sent SSH_MSG_REQUEST_SUCCESS (tcpip-forward accepted)");
|
||||
} else {
|
||||
// 无响应数据时,发送简单的SUCCESS
|
||||
let success_packet = vec![PacketType::SSH_MSG_REQUEST_SUCCESS as u8];
|
||||
let encrypted_success = EncryptedPacket::new(&success_packet, encryption_ctx, true)?;
|
||||
let encrypted_success =
|
||||
EncryptedPacket::new(&success_packet, encryption_ctx, true)?;
|
||||
encrypted_success.write(stream)?;
|
||||
info!("Sent SSH_MSG_REQUEST_SUCCESS");
|
||||
}
|
||||
} else {
|
||||
let failure_packet = vec![PacketType::SSH_MSG_REQUEST_FAILURE as u8];
|
||||
let encrypted_failure = EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
|
||||
let encrypted_failure =
|
||||
EncryptedPacket::new(&failure_packet, encryption_ctx, true)?;
|
||||
encrypted_failure.write(stream)?;
|
||||
info!("Sent SSH_MSG_REQUEST_FAILURE (tcpip-forward rejected)");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_OPEN as u8 => {
|
||||
info!("Received SSH_MSG_CHANNEL_OPEN");
|
||||
|
||||
|
||||
// Phase 13.3: 获取security_config并传递给handle_channel_open
|
||||
let security = security_config.lock().unwrap();
|
||||
let response = channel_manager.handle_channel_open(&packet, Some(&security))?;
|
||||
drop(security); // 释放锁
|
||||
|
||||
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
drop(security); // 释放锁
|
||||
|
||||
let encrypted_response =
|
||||
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
encrypted_response.write(stream)?;
|
||||
info!("Sent SSH_MSG_CHANNEL_OPEN_CONFIRMATION");
|
||||
}
|
||||
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_REQUEST as u8 => {
|
||||
info!("Received SSH_MSG_CHANNEL_REQUEST");
|
||||
if let Some(response) = channel_manager.handle_channel_request(&packet)? {
|
||||
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
let encrypted_response =
|
||||
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
encrypted_response.write(stream)?;
|
||||
|
||||
|
||||
// ⭐⭐⭐⭐⭐ Phase 14.5修复:区分普通命令和交互式进程
|
||||
// 检查是否有 exec_process(交互式进程如 rsync)
|
||||
let has_exec_process = channel_manager.has_exec_process();
|
||||
|
||||
|
||||
if has_exec_process {
|
||||
info!("⭐⭐⭐⭐⭐ [INTERACTIVE_PROCESS] Detected exec_process (rsync/SCP), skipping immediate EOF");
|
||||
// 对于交互式进程,只发送 SUCCESS,等待 poll 循环处理数据流
|
||||
@@ -503,23 +547,37 @@ fn handle_ssh_service_loop(
|
||||
if let Some(channel_id) = channel_manager.get_channel_with_output() {
|
||||
if let Some(output) = channel_manager.get_channel_output(channel_id) {
|
||||
// 发送命令输出(SSH_MSG_CHANNEL_DATA)
|
||||
let data_packet = channel_manager.build_channel_data(channel_id, &output)?;
|
||||
let encrypted_data = EncryptedPacket::new(&data_packet.payload, encryption_ctx, true)?;
|
||||
let data_packet =
|
||||
channel_manager.build_channel_data(channel_id, &output)?;
|
||||
let encrypted_data = EncryptedPacket::new(
|
||||
&data_packet.payload,
|
||||
encryption_ctx,
|
||||
true,
|
||||
)?;
|
||||
encrypted_data.write(stream)?;
|
||||
info!("Sent command output ({} bytes)", output.len());
|
||||
|
||||
|
||||
// 发送SSH_MSG_CHANNEL_EOF
|
||||
let eof_packet = channel_manager.build_channel_eof(channel_id)?;
|
||||
let encrypted_eof = EncryptedPacket::new(&eof_packet.payload, encryption_ctx, true)?;
|
||||
let encrypted_eof = EncryptedPacket::new(
|
||||
&eof_packet.payload,
|
||||
encryption_ctx,
|
||||
true,
|
||||
)?;
|
||||
encrypted_eof.write(stream)?;
|
||||
info!("Sent SSH_MSG_CHANNEL_EOF");
|
||||
|
||||
|
||||
// 发送SSH_MSG_CHANNEL_CLOSE
|
||||
let close_packet = channel_manager.build_channel_close(channel_id)?;
|
||||
let encrypted_close = EncryptedPacket::new(&close_packet.payload, encryption_ctx, true)?;
|
||||
let close_packet =
|
||||
channel_manager.build_channel_close(channel_id)?;
|
||||
let encrypted_close = EncryptedPacket::new(
|
||||
&close_packet.payload,
|
||||
encryption_ctx,
|
||||
true,
|
||||
)?;
|
||||
encrypted_close.write(stream)?;
|
||||
info!("Sent SSH_MSG_CHANNEL_CLOSE");
|
||||
|
||||
|
||||
// 移除channel
|
||||
channel_manager.remove_channel(channel_id);
|
||||
}
|
||||
@@ -531,22 +589,28 @@ fn handle_ssh_service_loop(
|
||||
info!("Received SSH_MSG_CHANNEL_DATA");
|
||||
if let Some(response) = channel_manager.handle_channel_data(&packet)? {
|
||||
// Phase 7: SFTP响应通过CHANNEL_DATA返回
|
||||
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
let encrypted_response =
|
||||
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
encrypted_response.write(stream)?;
|
||||
info!("Sent SSH_MSG_CHANNEL_DATA (SFTP response)");
|
||||
}
|
||||
|
||||
|
||||
// ⭐⭐⭐⭐⭐ Phase 15.1: Drain pending packets (e.g. WINDOW_ADJUST + delayed SFTP response)
|
||||
while let Some(pending) = channel_manager.pending_packets.pop_front() {
|
||||
let encrypted_pending = EncryptedPacket::new(&pending.payload, encryption_ctx, true)?;
|
||||
let encrypted_pending =
|
||||
EncryptedPacket::new(&pending.payload, encryption_ctx, true)?;
|
||||
encrypted_pending.write(stream)?;
|
||||
info!("Sent pending packet (type {})", pending.payload.first().unwrap_or(&0));
|
||||
info!(
|
||||
"Sent pending packet (type {})",
|
||||
pending.payload.first().unwrap_or(&0)
|
||||
);
|
||||
}
|
||||
}
|
||||
Some(&pt) if pt == PacketType::SSH_MSG_CHANNEL_CLOSE as u8 => {
|
||||
info!("Received SSH_MSG_CHANNEL_CLOSE");
|
||||
if let Some(response) = channel_manager.handle_channel_close(&packet)? {
|
||||
let encrypted_response = EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
let encrypted_response =
|
||||
EncryptedPacket::new(&response.payload, encryption_ctx, true)?;
|
||||
encrypted_response.write(stream)?;
|
||||
}
|
||||
break;
|
||||
@@ -565,8 +629,10 @@ fn handle_ssh_service_loop(
|
||||
let payload = &packet.payload;
|
||||
if payload.len() >= 9 {
|
||||
// Format: uint32 recipient_channel || uint32 bytes_to_add
|
||||
let recipient_channel = u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]);
|
||||
let bytes_to_add = u32::from_be_bytes([payload[5], payload[6], payload[7], payload[8]]);
|
||||
let recipient_channel =
|
||||
u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]);
|
||||
let bytes_to_add =
|
||||
u32::from_be_bytes([payload[5], payload[6], payload[7], payload[8]]);
|
||||
channel_manager.adjust_remote_window(recipient_channel, bytes_to_add);
|
||||
}
|
||||
}
|
||||
@@ -575,12 +641,14 @@ fn handle_ssh_service_loop(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 从SSH_MSG_USERAUTH_REQUEST payload中提取用户名
|
||||
fn extract_username_from_auth_request(packet: &crate::ssh_server::packet::SshPacket) -> Result<String> {
|
||||
fn extract_username_from_auth_request(
|
||||
packet: &crate::ssh_server::packet::SshPacket,
|
||||
) -> Result<String> {
|
||||
let payload = &packet.payload;
|
||||
if payload.len() < 5 {
|
||||
return Err(anyhow!("Auth request too short"));
|
||||
@@ -598,10 +666,10 @@ pub fn run_ssh_server(port: Option<u16>, pg_conn: Option<&str>) -> Result<()> {
|
||||
let config = SshServerConfig {
|
||||
port: port.unwrap_or(2024),
|
||||
bind_address: "127.0.0.1".to_string(),
|
||||
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置
|
||||
security_config: SshSecurityConfig::enterprise_default(), // Phase 13.1: 添加安全配置
|
||||
pg_conn: pg_conn.map(|s| s.to_string()),
|
||||
};
|
||||
|
||||
|
||||
let server = SshServer::new(config);
|
||||
server.run()
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
||||
// SSH企业级安全配置(Phase 13.1)
|
||||
// 参考OpenSSH sshd_config安全配置
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{anyhow, Result};
|
||||
use log::{info, warn};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
@@ -14,25 +14,25 @@ pub struct SshSecurityConfig {
|
||||
/// false: 只绑定127.0.0.1(安全)
|
||||
/// true: 允许绑定0.0.0.0(危险)
|
||||
pub gateway_ports: bool,
|
||||
|
||||
|
||||
/// PermitOpen白名单
|
||||
/// ["localhost:3000", "localhost:4000", "localhost:*"]
|
||||
/// 空数组表示允许所有目标(不安全)
|
||||
pub permit_open: Vec<String>,
|
||||
|
||||
|
||||
/// AllowTcpForwarding配置
|
||||
/// true: 允许TCP转发
|
||||
/// false: 禁止所有TCP转发
|
||||
pub allow_tcp_forwarding: bool,
|
||||
|
||||
|
||||
/// MaxSessions限制
|
||||
/// 最大会话数,防止资源耗尽
|
||||
pub max_sessions: u32,
|
||||
|
||||
|
||||
/// ConnectTimeout超时(秒)
|
||||
/// 连接超时设置,防止悬挂连接
|
||||
pub connect_timeout: u64,
|
||||
|
||||
|
||||
/// 活动会话数(运行时状态)
|
||||
pub active_sessions: u32,
|
||||
}
|
||||
@@ -42,110 +42,125 @@ impl SshSecurityConfig {
|
||||
/// 参考:OpenSSH企业级生产环境配置
|
||||
pub fn enterprise_default() -> Self {
|
||||
Self {
|
||||
gateway_ports: false, // 安全:只绑定127.0.0.1
|
||||
permit_open: vec!["localhost:*".to_string()], // 限制转发目标(白名单)
|
||||
allow_tcp_forwarding: true, // 允许TCP转发
|
||||
max_sessions: 10, // 最多10个会话
|
||||
connect_timeout: 30, // 30秒超时
|
||||
active_sessions: 0, // 运行时状态
|
||||
gateway_ports: false, // 安全:只绑定127.0.0.1
|
||||
permit_open: vec!["localhost:*".to_string()], // 限制转发目标(白名单)
|
||||
allow_tcp_forwarding: true, // 允许TCP转发
|
||||
max_sessions: 10, // 最多10个会话
|
||||
connect_timeout: 30, // 30秒超时
|
||||
active_sessions: 0, // 运行时状态
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 开发环境默认配置(宽松)
|
||||
pub fn development_default() -> Self {
|
||||
Self {
|
||||
gateway_ports: true, // 开发:允许0.0.0.0
|
||||
permit_open: vec![], // 开发:允许所有目标
|
||||
gateway_ports: true, // 开发:允许0.0.0.0
|
||||
permit_open: vec![], // 开发:允许所有目标
|
||||
allow_tcp_forwarding: true,
|
||||
max_sessions: 20, // 开发:更多会话
|
||||
connect_timeout: 60, // 开发:更长超时
|
||||
max_sessions: 20, // 开发:更多会话
|
||||
connect_timeout: 60, // 开发:更长超时
|
||||
active_sessions: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 从JSON配置文件加载
|
||||
pub fn load_from_file(path: &str) -> Result<Self> {
|
||||
if !Path::new(path).exists() {
|
||||
info!("SSH security config file not found, using enterprise default");
|
||||
return Ok(Self::enterprise_default());
|
||||
}
|
||||
|
||||
|
||||
let config_str = fs::read_to_string(path)?;
|
||||
let config: serde_json::Value = serde_json::from_str(&config_str)?;
|
||||
|
||||
let security = config.get("ssh_server")
|
||||
|
||||
let security = config
|
||||
.get("ssh_server")
|
||||
.and_then(|s| s.get("security"))
|
||||
.ok_or_else(|| anyhow!("Invalid config structure"))?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
gateway_ports: security.get("gateway_ports")
|
||||
gateway_ports: security
|
||||
.get("gateway_ports")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false),
|
||||
permit_open: security.get("permit_open")
|
||||
permit_open: security
|
||||
.get("permit_open")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_else(|| vec!["localhost:*".to_string()]),
|
||||
allow_tcp_forwarding: security.get("allow_tcp_forwarding")
|
||||
allow_tcp_forwarding: security
|
||||
.get("allow_tcp_forwarding")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(true),
|
||||
max_sessions: security.get("max_sessions")
|
||||
max_sessions: security
|
||||
.get("max_sessions")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|v| v as u32)
|
||||
.unwrap_or(10),
|
||||
connect_timeout: security.get("connect_timeout")
|
||||
connect_timeout: security
|
||||
.get("connect_timeout")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(30),
|
||||
active_sessions: 0,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// 验证tcpip-forward请求(安全检查)
|
||||
/// 参考OpenSSH auth2.c: ssh_forwarding_check()
|
||||
pub fn validate_tcpip_forward_request(
|
||||
&self,
|
||||
bind_address: &str,
|
||||
bind_port: u32,
|
||||
) -> Result<()> {
|
||||
info!("Validating tcpip-forward request: bind_address={}, bind_port={}", bind_address, bind_port);
|
||||
|
||||
pub fn validate_tcpip_forward_request(&self, bind_address: &str, bind_port: u32) -> Result<()> {
|
||||
info!(
|
||||
"Validating tcpip-forward request: bind_address={}, bind_port={}",
|
||||
bind_address, bind_port
|
||||
);
|
||||
|
||||
// 1. AllowTcpForwarding检查
|
||||
if !self.allow_tcp_forwarding {
|
||||
warn!("TCP forwarding disabled by security config");
|
||||
return Err(anyhow!("TCP forwarding disabled by AllowTcpForwarding=no"));
|
||||
}
|
||||
|
||||
|
||||
// 2. GatewayPorts检查
|
||||
if !self.gateway_ports {
|
||||
// 只允许绑定到127.0.0.1或localhost
|
||||
if bind_address != "127.0.0.1" && bind_address != "localhost" && bind_address != "" {
|
||||
warn!("GatewayPorts disabled, bind_address {} not allowed", bind_address);
|
||||
if bind_address != "127.0.0.1" && bind_address != "localhost" && !bind_address.is_empty() {
|
||||
warn!(
|
||||
"GatewayPorts disabled, bind_address {} not allowed",
|
||||
bind_address
|
||||
);
|
||||
return Err(anyhow!("GatewayPorts=no, only 127.0.0.1 allowed"));
|
||||
}
|
||||
info!("GatewayPorts check passed: bind_address={}", bind_address);
|
||||
}
|
||||
|
||||
|
||||
// 3. MaxSessions检查
|
||||
if self.active_sessions >= self.max_sessions {
|
||||
warn!("Max sessions limit reached: {} >= {}", self.active_sessions, self.max_sessions);
|
||||
warn!(
|
||||
"Max sessions limit reached: {} >= {}",
|
||||
self.active_sessions, self.max_sessions
|
||||
);
|
||||
return Err(anyhow!("Max sessions limit reached: {}", self.max_sessions));
|
||||
}
|
||||
|
||||
|
||||
// 4. 特权端口检查(防止<1024)
|
||||
if bind_port < 1024 {
|
||||
warn!("Cannot bind to privileged port: {}", bind_port);
|
||||
return Err(anyhow!("Cannot bind to privileged port < 1024"));
|
||||
}
|
||||
|
||||
|
||||
// 5. 端口范围检查(防止过大端口)
|
||||
if bind_port > 65535 {
|
||||
warn!("Invalid port number: {}", bind_port);
|
||||
return Err(anyhow!("Invalid port number > 65535"));
|
||||
}
|
||||
|
||||
|
||||
info!("tcpip-forward request validated successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 验证direct-tcpip channel请求(安全检查)
|
||||
/// 参考OpenSSH channels.c: channel_connect_direct_tcpip()
|
||||
pub fn validate_direct_tcpip_channel(
|
||||
@@ -153,14 +168,17 @@ impl SshSecurityConfig {
|
||||
host_to_connect: &str,
|
||||
port_to_connect: u32,
|
||||
) -> Result<()> {
|
||||
info!("Validating direct-tcpip channel: host={}, port={}", host_to_connect, port_to_connect);
|
||||
|
||||
info!(
|
||||
"Validating direct-tcpip channel: host={}, port={}",
|
||||
host_to_connect, port_to_connect
|
||||
);
|
||||
|
||||
// 1. AllowTcpForwarding检查
|
||||
if !self.allow_tcp_forwarding {
|
||||
warn!("TCP forwarding disabled by security config");
|
||||
return Err(anyhow!("TCP forwarding disabled by AllowTcpForwarding=no"));
|
||||
}
|
||||
|
||||
|
||||
// 2. PermitOpen白名单检查
|
||||
if !self.permit_open.is_empty() {
|
||||
let target = format!("{}:{}", host_to_connect, port_to_connect);
|
||||
@@ -173,28 +191,34 @@ impl SshSecurityConfig {
|
||||
target == *pattern
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
if !allowed {
|
||||
warn!("Target {}:{} not in PermitOpen whitelist", host_to_connect, port_to_connect);
|
||||
return Err(anyhow!("Target {}:{} not in PermitOpen whitelist",
|
||||
host_to_connect, port_to_connect));
|
||||
warn!(
|
||||
"Target {}:{} not in PermitOpen whitelist",
|
||||
host_to_connect, port_to_connect
|
||||
);
|
||||
return Err(anyhow!(
|
||||
"Target {}:{} not in PermitOpen whitelist",
|
||||
host_to_connect,
|
||||
port_to_connect
|
||||
));
|
||||
}
|
||||
info!("PermitOpen check passed: target={}", target);
|
||||
} else {
|
||||
// permit_open为空,允许所有目标(不安全,仅用于开发)
|
||||
info!("PermitOpen whitelist empty, allowing all targets (development mode)");
|
||||
}
|
||||
|
||||
|
||||
// 3. 端口范围检查
|
||||
if port_to_connect < 1 || port_to_connect > 65535 {
|
||||
if !(1..=65535).contains(&port_to_connect) {
|
||||
warn!("Invalid port number: {}", port_to_connect);
|
||||
return Err(anyhow!("Invalid port number: {}", port_to_connect));
|
||||
}
|
||||
|
||||
|
||||
info!("direct-tcpip channel validated successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 增加活动会话数
|
||||
pub fn increment_sessions(&mut self) -> Result<()> {
|
||||
if self.active_sessions >= self.max_sessions {
|
||||
@@ -204,7 +228,7 @@ impl SshSecurityConfig {
|
||||
info!("Active sessions: {}", self.active_sessions);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 减少活动会话数
|
||||
pub fn decrement_sessions(&mut self) {
|
||||
if self.active_sessions > 0 {
|
||||
@@ -217,56 +241,76 @@ impl SshSecurityConfig {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_enterprise_default_config() {
|
||||
let config = SshSecurityConfig::enterprise_default();
|
||||
|
||||
|
||||
assert_eq!(config.gateway_ports, false);
|
||||
assert_eq!(config.permit_open, vec!["localhost:*".to_string()]);
|
||||
assert_eq!(config.allow_tcp_forwarding, true);
|
||||
assert_eq!(config.max_sessions, 10);
|
||||
assert_eq!(config.connect_timeout, 30);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_validate_tcpip_forward_request() {
|
||||
let config = SshSecurityConfig::enterprise_default();
|
||||
|
||||
|
||||
// 正常请求应该通过
|
||||
assert!(config.validate_tcpip_forward_request("127.0.0.1", 8080).is_ok());
|
||||
assert!(config.validate_tcpip_forward_request("localhost", 8080).is_ok());
|
||||
|
||||
assert!(config
|
||||
.validate_tcpip_forward_request("127.0.0.1", 8080)
|
||||
.is_ok());
|
||||
assert!(config
|
||||
.validate_tcpip_forward_request("localhost", 8080)
|
||||
.is_ok());
|
||||
|
||||
// GatewayPorts=false时,0.0.0.0应该被拒绝
|
||||
assert!(config.validate_tcpip_forward_request("0.0.0.0", 8080).is_err());
|
||||
|
||||
assert!(config
|
||||
.validate_tcpip_forward_request("0.0.0.0", 8080)
|
||||
.is_err());
|
||||
|
||||
// 特权端口应该被拒绝
|
||||
assert!(config.validate_tcpip_forward_request("127.0.0.1", 80).is_err());
|
||||
assert!(config
|
||||
.validate_tcpip_forward_request("127.0.0.1", 80)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_validate_direct_tcpip_channel() {
|
||||
let config = SshSecurityConfig::enterprise_default();
|
||||
|
||||
|
||||
// localhost:*应该通过(通配符匹配)
|
||||
assert!(config.validate_direct_tcpip_channel("localhost", 3000).is_ok());
|
||||
assert!(config.validate_direct_tcpip_channel("localhost", 4000).is_ok());
|
||||
|
||||
assert!(config
|
||||
.validate_direct_tcpip_channel("localhost", 3000)
|
||||
.is_ok());
|
||||
assert!(config
|
||||
.validate_direct_tcpip_channel("localhost", 4000)
|
||||
.is_ok());
|
||||
|
||||
// 其他host应该被拒绝
|
||||
assert!(config.validate_direct_tcpip_channel("192.168.1.100", 3000).is_err());
|
||||
assert!(config.validate_direct_tcpip_channel("example.com", 80).is_err());
|
||||
assert!(config
|
||||
.validate_direct_tcpip_channel("192.168.1.100", 3000)
|
||||
.is_err());
|
||||
assert!(config
|
||||
.validate_direct_tcpip_channel("example.com", 80)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_development_default_config() {
|
||||
let config = SshSecurityConfig::development_default();
|
||||
|
||||
|
||||
assert_eq!(config.gateway_ports, true);
|
||||
assert_eq!(config.permit_open.len(), 0); // 空数组表示允许所有
|
||||
assert_eq!(config.max_sessions, 20);
|
||||
|
||||
|
||||
// 开发配置应该允许所有请求
|
||||
assert!(config.validate_tcpip_forward_request("0.0.0.0", 8080).is_ok());
|
||||
assert!(config.validate_direct_tcpip_channel("example.com", 80).is_ok());
|
||||
assert!(config
|
||||
.validate_tcpip_forward_request("0.0.0.0", 8080)
|
||||
.is_ok());
|
||||
assert!(config
|
||||
.validate_direct_tcpip_channel("example.com", 80)
|
||||
.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
// SSH Buffer 零拷贝实现(参考 OpenSSH sshbuf.c)
|
||||
// 提供高效的 buffer 管理,消除临时 buffer
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{anyhow, Result};
|
||||
use std::io::{Read, Write};
|
||||
|
||||
/// SSH Buffer(参考 OpenSSH struct sshbuf)
|
||||
///
|
||||
///
|
||||
/// OpenSSH 实现:
|
||||
/// ```c
|
||||
/// struct sshbuf {
|
||||
@@ -16,10 +16,10 @@ use std::io::{Read, Write};
|
||||
/// };
|
||||
/// ```
|
||||
pub struct SshBuf {
|
||||
data: Vec<u8>, // Data buffer (对应 OpenSSH buf->d)
|
||||
off: usize, // Offset (对应 OpenSSH buf->off)
|
||||
size: usize, // Size (对应 OpenSSH buf->size)
|
||||
max_size: usize, // Maximum size (对应 OpenSSH buf->max_size)
|
||||
data: Vec<u8>, // Data buffer (对应 OpenSSH buf->d)
|
||||
off: usize, // Offset (对应 OpenSSH buf->off)
|
||||
size: usize, // Size (对应 OpenSSH buf->size)
|
||||
max_size: usize, // Maximum size (对应 OpenSSH buf->max_size)
|
||||
}
|
||||
|
||||
impl SshBuf {
|
||||
@@ -32,7 +32,7 @@ impl SshBuf {
|
||||
max_size: 128 * 1024 * 1024, // 128MB (OpenSSH SSHBUF_SIZE_MAX)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 创建指定大小的 SSH Buffer
|
||||
pub fn with_capacity(capacity: usize) -> Self {
|
||||
Self {
|
||||
@@ -42,7 +42,7 @@ impl SshBuf {
|
||||
max_size: 128 * 1024 * 1024,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 设置最大大小
|
||||
pub fn set_max_size(&mut self, max_size: usize) -> Result<()> {
|
||||
if max_size > 128 * 1024 * 1024 {
|
||||
@@ -51,47 +51,47 @@ impl SshBuf {
|
||||
self.max_size = max_size;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 获取 buffer 长度(对应 OpenSSH sshbuf_len)
|
||||
///
|
||||
///
|
||||
/// OpenSSH: `sshbuf_len = buf->size - buf->off`
|
||||
pub fn len(&self) -> usize {
|
||||
self.size - self.off
|
||||
}
|
||||
|
||||
|
||||
/// 检查 buffer 是否为空
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
|
||||
/// 获取可用空间(对应 OpenSSH sshbuf_avail)
|
||||
///
|
||||
///
|
||||
/// OpenSSH: `sshbuf_avail = buf->max_size - buf->size`
|
||||
pub fn avail(&self) -> usize {
|
||||
self.max_size - self.size
|
||||
}
|
||||
|
||||
|
||||
/// 获取可变指针(对应 OpenSSH sshbuf_mutable_ptr)
|
||||
///
|
||||
///
|
||||
/// OpenSSH 实现:
|
||||
/// ```c
|
||||
/// u_char *sshbuf_mutable_ptr(const struct sshbuf *buf) {
|
||||
/// return buf->d + buf->off;
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
///
|
||||
/// Rust 实现:返回 `&mut [u8]` slice(零拷贝)
|
||||
pub fn mutable_ptr(&mut self) -> &mut [u8] {
|
||||
&mut self.data[self.off..self.size]
|
||||
}
|
||||
|
||||
|
||||
/// 获取不可变指针(对应 OpenSSH sshbuf_ptr)
|
||||
pub fn ptr(&self) -> &[u8] {
|
||||
&self.data[self.off..self.size]
|
||||
}
|
||||
|
||||
|
||||
/// 预分配空间(对应 OpenSSH sshbuf_reserve)
|
||||
///
|
||||
///
|
||||
/// OpenSSH 实现:
|
||||
/// ```c
|
||||
/// int sshbuf_reserve(struct sshbuf *buf, size_t len, u_char **dpp) {
|
||||
@@ -104,31 +104,31 @@ impl SshBuf {
|
||||
/// return 0;
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
///
|
||||
/// Rust 实现:返回 `&mut [u8]` slice(零拷贝,可直接 write)
|
||||
pub fn reserve(&mut self, len: usize) -> Result<&mut [u8]> {
|
||||
if len > self.avail() {
|
||||
return Err(anyhow!("no buffer space (avail={})", self.avail()));
|
||||
}
|
||||
|
||||
|
||||
// 预分配空间
|
||||
let current_size = self.size;
|
||||
let new_size = current_size + len;
|
||||
|
||||
|
||||
// 确保 Vec 有足够容量
|
||||
if new_size > self.data.len() {
|
||||
self.data.resize(new_size, 0);
|
||||
}
|
||||
|
||||
|
||||
// 更新 size
|
||||
self.size = new_size;
|
||||
|
||||
|
||||
// 返回新空间的 slice(零拷贝)
|
||||
Ok(&mut self.data[current_size..new_size])
|
||||
}
|
||||
|
||||
|
||||
/// 消费数据(对应 OpenSSH sshbuf_consume)
|
||||
///
|
||||
///
|
||||
/// OpenSSH 实现:
|
||||
/// ```c
|
||||
/// int sshbuf_consume(struct sshbuf *buf, size_t len) {
|
||||
@@ -140,29 +140,33 @@ impl SshBuf {
|
||||
/// return 0;
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
///
|
||||
/// Rust 实现:移动偏移量(零拷贝,不实际删除数据)
|
||||
pub fn consume(&mut self, len: usize) -> Result<()> {
|
||||
if len > self.len() {
|
||||
return Err(anyhow!("message incomplete (len={}, consume={})", self.len(), len));
|
||||
return Err(anyhow!(
|
||||
"message incomplete (len={}, consume={})",
|
||||
self.len(),
|
||||
len
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
self.off += len;
|
||||
|
||||
|
||||
// 如果 buffer 空,重置
|
||||
if self.off == self.size {
|
||||
self.off = 0;
|
||||
self.size = 0;
|
||||
|
||||
|
||||
// OpenSSH: pack buffer(移除已消费的数据)
|
||||
// Rust: 我们保留 Vec,但重置指针
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 从末尾消费数据(对应 OpenSSH sshbuf_consume_end)
|
||||
///
|
||||
///
|
||||
/// OpenSSH 实现:
|
||||
/// ```c
|
||||
/// int sshbuf_consume_end(struct sshbuf *buf, size_t len) {
|
||||
@@ -174,13 +178,13 @@ impl SshBuf {
|
||||
if len > self.len() {
|
||||
return Err(anyhow!("message incomplete"));
|
||||
}
|
||||
|
||||
|
||||
self.size -= len;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 直接从 fd read 到 buffer(对应 OpenSSH sshbuf_read)
|
||||
///
|
||||
///
|
||||
/// OpenSSH 实现:
|
||||
/// ```c
|
||||
/// int sshbuf_read(int fd, struct sshbuf *buf, size_t maxlen, size_t *rlen) {
|
||||
@@ -195,71 +199,75 @@ impl SshBuf {
|
||||
/// return 0;
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
///
|
||||
/// Rust 实现:零拷贝,直接 read 到 buffer
|
||||
pub fn read_from<R: Read>(&mut self, reader: &mut R, maxlen: usize) -> Result<usize> {
|
||||
// 1. reserve 空间
|
||||
let space = self.reserve(maxlen)?;
|
||||
|
||||
|
||||
// 2. 直接 read 到 buffer(零拷贝)
|
||||
let n = reader.read(space)?;
|
||||
|
||||
|
||||
// 3. 调整大小(移除未使用的空间)
|
||||
if maxlen > n {
|
||||
self.consume_end(maxlen - n)?;
|
||||
}
|
||||
|
||||
|
||||
Ok(n)
|
||||
}
|
||||
|
||||
|
||||
/// 直接从 buffer write 到 fd(对应 OpenSSH channel_handle_wfd)
|
||||
///
|
||||
///
|
||||
/// OpenSSH 实现:
|
||||
/// ```c
|
||||
/// buf = sshbuf_mutable_ptr(c->output); // 获取指针
|
||||
/// len = write(c->wfd, buf, dlen); // 直接 write
|
||||
/// sshbuf_consume(c->output, len); // 消费已写入的数据
|
||||
/// ```
|
||||
///
|
||||
///
|
||||
/// Rust 实现:零拷贝,直接 write 从 buffer
|
||||
pub fn write_to<W: Write>(&mut self, writer: &mut W) -> Result<usize> {
|
||||
if self.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
|
||||
// 1. 获取数据指针(零拷贝)
|
||||
let data = self.ptr();
|
||||
|
||||
|
||||
// 2. 直接 write(零拷贝)
|
||||
let n = writer.write(data)?;
|
||||
|
||||
|
||||
// 3. 消费已写入的数据(零拷贝,只移动偏移)
|
||||
self.consume(n)?;
|
||||
|
||||
|
||||
Ok(n)
|
||||
}
|
||||
|
||||
|
||||
/// 添加数据(对应 OpenSSH sshbuf_put)
|
||||
///
|
||||
///
|
||||
/// 用于不需要零拷贝的场景
|
||||
pub fn put(&mut self, data: &[u8]) -> Result<()> {
|
||||
let space = self.reserve(data.len())?;
|
||||
space.copy_from_slice(data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 清空 buffer
|
||||
pub fn reset(&mut self) {
|
||||
self.off = 0;
|
||||
self.size = 0;
|
||||
// OpenSSH: 保留 Vec,只重置指针
|
||||
}
|
||||
|
||||
|
||||
/// Debug: 打印 buffer 状态
|
||||
pub fn debug_info(&self) -> String {
|
||||
format!(
|
||||
"SshBuf: off={}, size={}, len={}, alloc={}, max_size={}",
|
||||
self.off, self.size, self.len(), self.data.len(), self.max_size
|
||||
self.off,
|
||||
self.size,
|
||||
self.len(),
|
||||
self.data.len(),
|
||||
self.max_size
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -274,11 +282,11 @@ impl Default for SshBuf {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Cursor;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sshbuf_basic() {
|
||||
let mut buf = SshBuf::new();
|
||||
|
||||
|
||||
// Test reserve - write into reserved space
|
||||
{
|
||||
let space = buf.reserve(10).unwrap();
|
||||
@@ -286,57 +294,57 @@ mod tests {
|
||||
space[0] = 1;
|
||||
space[1] = 2;
|
||||
} // space dropped, buf accessible
|
||||
|
||||
|
||||
// Verify buffer length after reserve
|
||||
assert_eq!(buf.len(), 10);
|
||||
let ptr = buf.mutable_ptr();
|
||||
assert_eq!(ptr[0], 1);
|
||||
assert_eq!(ptr[1], 2);
|
||||
|
||||
|
||||
// Test consume
|
||||
buf.consume(2).unwrap();
|
||||
assert_eq!(buf.len(), 8);
|
||||
assert_eq!(buf.off, 2);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sshbuf_zero_copy_read() {
|
||||
let mut buf = SshBuf::with_capacity(100);
|
||||
let mut reader = Cursor::new("hello world");
|
||||
|
||||
|
||||
// 零拷贝 read
|
||||
let n = buf.read_from(&mut reader, 20).unwrap();
|
||||
assert_eq!(n, 11); // "hello world" length
|
||||
assert_eq!(buf.len(), 11);
|
||||
|
||||
|
||||
// 检查数据
|
||||
let data = buf.ptr();
|
||||
assert_eq!(data, "hello world".as_bytes());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sshbuf_zero_copy_write() {
|
||||
let mut buf = SshBuf::new();
|
||||
buf.put("hello world".as_bytes()).unwrap();
|
||||
|
||||
|
||||
let mut writer = Vec::new();
|
||||
|
||||
|
||||
// 零拷贝 write
|
||||
let n = buf.write_to(&mut writer).unwrap();
|
||||
assert_eq!(n, 11);
|
||||
assert_eq!(buf.len(), 0); // 已消费
|
||||
|
||||
|
||||
// 检查数据
|
||||
assert_eq!(writer, "hello world".as_bytes());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sshbuf_max_size() {
|
||||
let mut buf = SshBuf::new();
|
||||
buf.set_max_size(1000).unwrap();
|
||||
|
||||
|
||||
// 尝试 reserve 超过 max_size
|
||||
let result = buf.reserve(2000);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
// 参考OpenSSH sshd.c: ssh_exchange_identification()
|
||||
|
||||
use anyhow::Result;
|
||||
use log::{debug, info};
|
||||
use std::io::{Read, Write};
|
||||
use log::{info, debug};
|
||||
|
||||
/// SSH版本字符串
|
||||
pub const SSH_VERSION: &str = "SSH-2.0-MarkBaseSSH_1.0";
|
||||
@@ -15,93 +15,96 @@ impl VersionExchange {
|
||||
/// 执行版本交换(服务器端)
|
||||
pub fn exchange<T: Read + Write>(stream: &mut T) -> Result<String> {
|
||||
info!("Starting SSH version exchange");
|
||||
|
||||
|
||||
// 1. 发送服务器版本
|
||||
Self::send_version(stream)?;
|
||||
|
||||
|
||||
// 2. 接收客户端版本
|
||||
let client_version = Self::receive_version(stream)?;
|
||||
|
||||
info!("Version exchange completed: server={}, client={}", SSH_VERSION, client_version);
|
||||
|
||||
info!(
|
||||
"Version exchange completed: server={}, client={}",
|
||||
SSH_VERSION, client_version
|
||||
);
|
||||
Ok(client_version)
|
||||
}
|
||||
|
||||
|
||||
/// 发送服务器版本(参考OpenSSH ssh_exchange_identification)
|
||||
fn send_version<T: Write>(stream: &mut T) -> Result<()> {
|
||||
let version_line = format!("{}\r\n", SSH_VERSION);
|
||||
stream.write_all(version_line.as_bytes())?;
|
||||
stream.flush()?;
|
||||
|
||||
|
||||
debug!("Sent version: {}", SSH_VERSION);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 接收客户端版本(参考OpenSSH ssh_exchange_identification)
|
||||
fn receive_version<T: Read>(stream: &mut T) -> Result<String> {
|
||||
let mut buffer = Vec::new();
|
||||
let mut byte = [0u8; 1];
|
||||
|
||||
|
||||
// 读取直到遇到'\n'(参考OpenSSH实现)
|
||||
loop {
|
||||
stream.read_exact(&mut byte)?;
|
||||
|
||||
|
||||
// OpenSSH兼容性处理:跳过前导空行和调试信息
|
||||
if buffer.is_empty() && byte[0] == '\n' as u8 {
|
||||
continue; // 跳过空行
|
||||
if buffer.is_empty() && byte[0] == b'\n' {
|
||||
continue; // 跳过空行
|
||||
}
|
||||
|
||||
|
||||
// 调试信息行(以'#'开头),跳过
|
||||
if buffer.is_empty() && byte[0] == '#' as u8 {
|
||||
if buffer.is_empty() && byte[0] == b'#' {
|
||||
// 读取整行调试信息
|
||||
while byte[0] != '\n' as u8 {
|
||||
while byte[0] != b'\n' {
|
||||
stream.read_exact(&mut byte)?;
|
||||
}
|
||||
buffer.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
buffer.push(byte[0]);
|
||||
|
||||
|
||||
// 遇到'\n'结束
|
||||
if byte[0] == '\n' as u8 {
|
||||
if byte[0] == b'\n' {
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// 缓冲区溢出保护(OpenSSH限制:255字节)
|
||||
if buffer.len() > 255 {
|
||||
return Err(anyhow::anyhow!("Version string too long"));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 解析版本字符串
|
||||
let version_line = String::from_utf8(buffer)?;
|
||||
let version = version_line.trim().trim_matches('\r');
|
||||
|
||||
|
||||
// 验证版本格式(SSH-2.0-*)
|
||||
if !version.starts_with("SSH-2.0-") {
|
||||
return Err(anyhow::anyhow!("Invalid SSH version: {}", version));
|
||||
}
|
||||
|
||||
|
||||
debug!("Received version: {}", version);
|
||||
Ok(version.to_string())
|
||||
}
|
||||
|
||||
|
||||
/// 解析客户端版本信息(兼容性检查)
|
||||
pub fn parse_client_version(version: &str) -> Result<ClientVersionInfo> {
|
||||
// 格式:SSH-protoversion-softwareversion SP comments
|
||||
let parts: Vec<&str> = version.split_whitespace().collect();
|
||||
|
||||
|
||||
let main_part = parts.first().map_or(version, |v| v);
|
||||
let dash_parts: Vec<&str> = main_part.split('-').collect();
|
||||
|
||||
|
||||
if dash_parts.len() < 3 {
|
||||
return Err(anyhow::anyhow!("Invalid version format: {}", version));
|
||||
}
|
||||
|
||||
|
||||
let proto_version = dash_parts.get(1).map_or("2.0", |v| v);
|
||||
let software_version = dash_parts.get(2).map_or("unknown", |v| v);
|
||||
let comments = parts.get(1).map(|s| s.to_string());
|
||||
|
||||
|
||||
Ok(ClientVersionInfo {
|
||||
proto_version: proto_version.to_string(),
|
||||
software_version: software_version.to_string(),
|
||||
@@ -120,12 +123,12 @@ pub struct ClientVersionInfo {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_version_format() {
|
||||
assert!(SSH_VERSION.starts_with("SSH-2.0-"));
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_parse_client_version() {
|
||||
let version = "SSH-2.0-OpenSSH_10.2";
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
// SSH Window Size管理(Phase 13.6)
|
||||
// 参考RFC 4254 Section 5.2: Window Size Adjustment
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use log::{info, warn, debug};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use byteorder::{BigEndian, WriteBytesExt};
|
||||
use crate::ssh_server::packet::PacketType;
|
||||
use anyhow::{anyhow, Result};
|
||||
use byteorder::{BigEndian, WriteBytesExt};
|
||||
use log::{info, warn};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
/// Window Size管理器(Phase 13.6)
|
||||
pub struct WindowManager {
|
||||
initial_window_size: u32, // RFC 4254: 2MB默认
|
||||
initial_window_size: u32, // RFC 4254: 2MB默认
|
||||
current_window_size: Arc<Mutex<u32>>,
|
||||
max_packet_size: u32, // RFC 4254: 32KB默认
|
||||
consumed_bytes: Arc<Mutex<u32>>, // 已消耗bytes统计
|
||||
max_packet_size: u32, // RFC 4254: 32KB默认
|
||||
consumed_bytes: Arc<Mutex<u32>>, // 已消耗bytes统计
|
||||
}
|
||||
|
||||
impl WindowManager {
|
||||
@@ -25,89 +25,103 @@ impl WindowManager {
|
||||
consumed_bytes: Arc::new(Mutex::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// RFC 4254默认window size(2MB)
|
||||
pub fn rfc_default() -> Self {
|
||||
Self::new(2097152, 32768) // 2MB window, 32KB packet
|
||||
Self::new(2097152, 32768) // 2MB window, 32KB packet
|
||||
}
|
||||
|
||||
|
||||
/// 检查window size是否足够(Phase 13.6)
|
||||
pub fn check_window_available(&self, data_size: u32) -> bool {
|
||||
let window = self.current_window_size.lock().unwrap();
|
||||
let available = *window >= data_size;
|
||||
|
||||
|
||||
if !available {
|
||||
warn!("Window size insufficient: need {}, have {}", data_size, *window);
|
||||
warn!(
|
||||
"Window size insufficient: need {}, have {}",
|
||||
data_size, *window
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
available
|
||||
}
|
||||
|
||||
|
||||
/// 消耗window size(Phase 13.6:发送数据后)
|
||||
pub fn consume_window(&self, data_size: u32) -> Result<()> {
|
||||
let mut window = self.current_window_size.lock().unwrap();
|
||||
|
||||
|
||||
if *window < data_size {
|
||||
return Err(anyhow!("Window size insufficient: need {}, have {}", data_size, *window));
|
||||
return Err(anyhow!(
|
||||
"Window size insufficient: need {}, have {}",
|
||||
data_size,
|
||||
*window
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
*window -= data_size;
|
||||
|
||||
|
||||
// 统计已消耗bytes
|
||||
let mut consumed = self.consumed_bytes.lock().unwrap();
|
||||
*consumed += data_size;
|
||||
|
||||
info!("Window size consumed: {} bytes, remaining {}, total consumed {}",
|
||||
data_size, *window, *consumed);
|
||||
|
||||
|
||||
info!(
|
||||
"Window size consumed: {} bytes, remaining {}, total consumed {}",
|
||||
data_size, *window, *consumed
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// 调整window size(Phase 13.6:收到SSH_MSG_CHANNEL_WINDOW_ADJUST)
|
||||
pub fn adjust_window(&self, bytes_to_add: u32) {
|
||||
let mut window = self.current_window_size.lock().unwrap();
|
||||
*window += bytes_to_add;
|
||||
|
||||
info!("Window size adjusted: added {} bytes, total {}", bytes_to_add, *window);
|
||||
|
||||
info!(
|
||||
"Window size adjusted: added {} bytes, total {}",
|
||||
bytes_to_add, *window
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
/// 构建SSH_MSG_CHANNEL_WINDOW_ADJUST packet(Phase 13.6)
|
||||
pub fn build_window_adjust_packet(channel_id: u32, bytes_to_add: u32) -> Result<Vec<u8>> {
|
||||
let mut packet = Vec::new();
|
||||
|
||||
|
||||
// Packet type: SSH_MSG_CHANNEL_WINDOW_ADJUST (type 93)
|
||||
packet.write_u8(PacketType::SSH_MSG_CHANNEL_WINDOW_ADJUST as u8)?;
|
||||
|
||||
|
||||
// Recipient channel ID
|
||||
packet.write_u32::<BigEndian>(channel_id)?;
|
||||
|
||||
|
||||
// Bytes to add
|
||||
packet.write_u32::<BigEndian>(bytes_to_add)?;
|
||||
|
||||
info!("Built SSH_MSG_CHANNEL_WINDOW_ADJUST for channel {}: +{} bytes",
|
||||
channel_id, bytes_to_add);
|
||||
|
||||
|
||||
info!(
|
||||
"Built SSH_MSG_CHANNEL_WINDOW_ADJUST for channel {}: +{} bytes",
|
||||
channel_id, bytes_to_add
|
||||
);
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
|
||||
|
||||
/// 获取当前window size(Phase 13.6)
|
||||
pub fn get_current_window(&self) -> u32 {
|
||||
*self.current_window_size.lock().unwrap()
|
||||
}
|
||||
|
||||
|
||||
/// 获取已消耗bytes(Phase 13.6)
|
||||
pub fn get_consumed_bytes(&self) -> u32 {
|
||||
*self.consumed_bytes.lock().unwrap()
|
||||
}
|
||||
|
||||
|
||||
/// 重置window size(Phase 13.6:channel重置)
|
||||
pub fn reset_window(&self) {
|
||||
let mut window = self.current_window_size.lock().unwrap();
|
||||
*window = self.initial_window_size;
|
||||
|
||||
|
||||
let mut consumed = self.consumed_bytes.lock().unwrap();
|
||||
*consumed = 0;
|
||||
|
||||
|
||||
info!("Window size reset to initial: {}", self.initial_window_size);
|
||||
}
|
||||
}
|
||||
@@ -128,63 +142,63 @@ impl ChannelLifecycle {
|
||||
close_received: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// 构建SSH_MSG_CHANNEL_EOF packet(Phase 13.7)
|
||||
pub fn build_eof_packet(channel_id: u32) -> Result<Vec<u8>> {
|
||||
let mut packet = Vec::new();
|
||||
|
||||
|
||||
// Packet type: SSH_MSG_CHANNEL_EOF (type 96)
|
||||
packet.write_u8(PacketType::SSH_MSG_CHANNEL_EOF as u8)?;
|
||||
|
||||
|
||||
// Recipient channel ID
|
||||
packet.write_u32::<BigEndian>(channel_id)?;
|
||||
|
||||
|
||||
info!("Built SSH_MSG_CHANNEL_EOF for channel {}", channel_id);
|
||||
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
|
||||
|
||||
/// 构建SSH_MSG_CHANNEL_CLOSE packet(Phase 13.7)
|
||||
pub fn build_close_packet(channel_id: u32) -> Result<Vec<u8>> {
|
||||
let mut packet = Vec::new();
|
||||
|
||||
|
||||
// Packet type: SSH_MSG_CHANNEL_CLOSE (type 97)
|
||||
packet.write_u8(PacketType::SSH_MSG_CHANNEL_CLOSE as u8)?;
|
||||
|
||||
|
||||
// Recipient channel ID
|
||||
packet.write_u32::<BigEndian>(channel_id)?;
|
||||
|
||||
|
||||
info!("Built SSH_MSG_CHANNEL_CLOSE for channel {}", channel_id);
|
||||
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
|
||||
|
||||
/// 标记EOF已发送(Phase 13.7)
|
||||
pub fn mark_eof_sent(&mut self) {
|
||||
self.eof_sent = true;
|
||||
info!("Channel {} EOF marked as sent", self.channel_id);
|
||||
}
|
||||
|
||||
|
||||
/// 标记CLOSE已接收(Phase 13.7)
|
||||
pub fn mark_close_received(&mut self) {
|
||||
self.close_received = true;
|
||||
info!("Channel {} CLOSE marked as received", self.channel_id);
|
||||
}
|
||||
|
||||
|
||||
/// 检查是否可以清理channel(Phase 13.7)
|
||||
pub fn can_cleanup(&self) -> bool {
|
||||
self.eof_sent && self.close_received
|
||||
}
|
||||
|
||||
|
||||
/// 清理channel资源(Phase 13.7)
|
||||
pub fn cleanup_channel(&self) -> Result<()> {
|
||||
info!("Cleaning up channel {} resources", self.channel_id);
|
||||
|
||||
|
||||
// Phase 13.7: 实际清理逻辑需要在ChannelManager中实现
|
||||
// - 移除channel记录
|
||||
// - 关闭TCP连接
|
||||
// - 清理监听器(如果是forwarded-tcpip)
|
||||
|
||||
|
||||
info!("Channel {} cleanup completed", self.channel_id);
|
||||
Ok(())
|
||||
}
|
||||
@@ -193,42 +207,42 @@ impl ChannelLifecycle {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_window_manager_creation() {
|
||||
let manager = WindowManager::rfc_default();
|
||||
assert_eq!(manager.get_current_window(), 2097152);
|
||||
assert_eq!(manager.max_packet_size, 32768);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_window_consumption() {
|
||||
let manager = WindowManager::rfc_default();
|
||||
|
||||
|
||||
// 消耗1000 bytes
|
||||
manager.consume_window(1000).unwrap();
|
||||
assert_eq!(manager.get_current_window(), 2097152 - 1000);
|
||||
assert_eq!(manager.get_consumed_bytes(), 1000);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_window_adjustment() {
|
||||
let manager = WindowManager::rfc_default();
|
||||
|
||||
|
||||
// 消耗1000 bytes
|
||||
manager.consume_window(1000).unwrap();
|
||||
|
||||
|
||||
// 调整500 bytes
|
||||
manager.adjust_window(500);
|
||||
assert_eq!(manager.get_current_window(), 2097152 - 1000 + 500);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_build_eof_packet() {
|
||||
let packet = ChannelLifecycle::build_eof_packet(1).unwrap();
|
||||
assert_eq!(packet[0], PacketType::SSH_MSG_CHANNEL_EOF as u8);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_build_close_packet() {
|
||||
let packet = ChannelLifecycle::build_close_packet(1).unwrap();
|
||||
|
||||
Reference in New Issue
Block a user