// SSH密钥交换完整流程(Phase 3剩余) // 参考OpenSSH kex.c: complete implementation use crate::ssh_server::packet::{SshPacket, PacketType}; use crate::ssh_server::kex::{KexProposal, KexResult}; use crate::ssh_server::crypto::{SessionKeys}; use crate::ssh_server::kex_exchange::KexExchangeHandler; use anyhow::{Result, anyhow}; use sha2::{Sha256, Digest}; use byteorder::{BigEndian, WriteBytesExt}; use log::{info, debug}; /// SSH密钥交换完整状态管理(参考OpenSSH struct kex) pub struct KexState { pub client_version: String, pub server_version: String, pub client_kexinit_payload: Vec, pub server_kexinit_payload: Vec, pub exchange_handler: KexExchangeHandler, pub session_keys: Option, pub newkeys_received: bool, pub newkeys_sent: bool, } impl KexState { /// 创建密钥交换状态 pub fn new( client_version: String, server_version: String, kex_result: KexResult, ) -> Result { let exchange_handler = KexExchangeHandler::new(kex_result)?; Ok(Self { client_version, server_version, client_kexinit_payload: Vec::new(), server_kexinit_payload: Vec::new(), exchange_handler, session_keys: None, newkeys_received: false, newkeys_sent: false, }) } /// 保存KEXINIT payloads(用于Exchange Hash计算) /// /// 分析OpenSSH源码后的结论: /// - kex->peer存储的是:incoming_packet剩余内容(payload fields + padding) /// - kex->my存储的是:prop2buf()结果(payload fields,不包括padding) /// /// **但exchange hash必须使用相同的I_C/I_S!** /// /// 疑问:OpenSSH如何确保client和server使用相同的padding? /// 可能答案:OpenSSH在计算exchange hash时,不包括padding? /// /// 暂时保持不包括padding(因为签名验证之前成功) pub fn save_kexinit_payloads( &mut self, client_kexinit: &SshPacket, server_kexinit: &SshPacket, ) { // Only save payload (without padding) for now self.client_kexinit_payload = client_kexinit.payload.clone(); self.server_kexinit_payload = server_kexinit.payload.clone(); info!("Saved KEXINIT payloads (payload only, no padding)"); info!(" client payload: {} bytes", self.client_kexinit_payload.len()); info!(" server payload: {} bytes", self.server_kexinit_payload.len()); } /// 计算Exchange Hash(参考OpenSSH kex.c: kex_hash()) /// H = SHA256(V_C || V_S || I_C || I_S || K_S || K_C || K_S || shared_secret) pub fn compute_exchange_hash( &self, shared_secret: &[u8], server_host_key_blob: &[u8], client_public_key: &[u8], server_public_key: &[u8], ) -> Result> { // 参考OpenSSH kex.c: kex_hash() let mut hasher = Sha256::new(); // V_C: 客户端版本字符串(SSH string格式) write_ssh_string_to_hash(&mut hasher, &self.client_version)?; // V_S: 服务器版本字符串(SSH string格式) write_ssh_string_to_hash(&mut hasher, &self.server_version)?; // OpenSSH kexgex.c: "kexinit messages: fake header: len+SSH2_MSG_KEXINIT" // Remove SSH_MSG_KEXINIT type byte from payloads and prepend it in exchange hash let client_kexinit_without_type = &self.client_kexinit_payload[1..]; let server_kexinit_without_type = &self.server_kexinit_payload[1..]; hasher.update(&((client_kexinit_without_type.len() + 1) as u32).to_be_bytes()); hasher.update(&[20]); // SSH_MSG_KEXINIT type byte hasher.update(client_kexinit_without_type); hasher.update(&((server_kexinit_without_type.len() + 1) as u32).to_be_bytes()); hasher.update(&[20]); // SSH_MSG_KEXINIT type byte hasher.update(server_kexinit_without_type); // K_S: 服务器主机密钥blob(SSH string格式) hasher.update(server_host_key_blob); // K_C: 客户端Curve25519公钥(SSH string格式) write_ssh_bytes_to_hash(&mut hasher, client_public_key)?; // K_S: 服务器Curve25519公钥(SSH string格式) write_ssh_bytes_to_hash(&mut hasher, server_public_key)?; // K: 共享密钥(SSH mpint格式) // OpenSSH要求:去掉前导零 write_ssh_mpint_to_hash(&mut hasher, shared_secret)?; Ok(hasher.finalize().to_vec()) } /// 处理SSH_MSG_NEWKEYS(参考OpenSSH kex.c: kex_input_newkeys()) pub fn handle_newkeys(&mut self, packet: &SshPacket) -> Result<()> { info!("Processing SSH_MSG_NEWKEYS"); // 验证packet类型 if packet.payload.len() < 1 { return Err(anyhow!("Invalid NEWKEYS packet")); } let packet_type = packet.payload[0]; if packet_type != PacketType::SSH_MSG_NEWKEYS as u8 { return Err(anyhow!("Invalid packet type for NEWKEYS")); } // 标记NEWKEYS接收完成(参考OpenSSH) self.newkeys_received = true; info!("SSH_MSG_NEWKEYS received, encryption channel ready"); Ok(()) } /// 发送SSH_MSG_NEWKEYS(参考OpenSSH kex.c: kex_send_newkeys()) pub fn send_newkeys() -> Result { info!("Sending SSH_MSG_NEWKEYS"); let payload = vec![PacketType::SSH_MSG_NEWKEYS as u8]; Ok(SshPacket::new(payload)) } /// 检查NEWKEYS完成状态(加密通道建立) pub fn is_encryption_ready(&self) -> bool { self.newkeys_received && self.newkeys_sent } } /// SSH string写入到hash(辅助函数) fn write_ssh_string_to_hash(hasher: &mut Sha256, s: &str) -> Result<()> { hasher.update(&(s.len() as u32).to_be_bytes()); hasher.update(s.as_bytes()); Ok(()) } /// SSH bytes写入到hash(辅助函数) fn write_ssh_bytes_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> { hasher.update(&(bytes.len() as u32).to_be_bytes()); hasher.update(bytes); Ok(()) } /// SSH mpint写入到hash(参考OpenSSH sshbuf_put_mpint()) fn write_ssh_mpint_to_hash(hasher: &mut Sha256, bytes: &[u8]) -> Result<()> { // OpenSSH要求:去掉前导零(如果最高位为1) let mpint_bytes = if bytes.len() > 0 && bytes[0] >= 0x80 { // 需要添加前导零(避免负数) let mut mpint = vec![0u8]; mpint.extend_from_slice(bytes); mpint } else { bytes.to_vec() }; hasher.update(&(mpint_bytes.len() as u32).to_be_bytes()); hasher.update(&mpint_bytes); Ok(()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_exchange_hash_computation() { let kex_result = KexResult::choose_algorithms( &KexProposal::server_default(), &KexProposal::client_default(), ).unwrap(); let mut state = KexState::new( "SSH-2.0-OpenSSH_10.2".to_string(), "SSH-2.0-MarkBaseSSH_1.0".to_string(), kex_result, ).unwrap(); // Set minimal KEXINIT payloads (need at least 1 byte for packet type) state.client_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte state.server_kexinit_payload = vec![20u8]; // SSH_MSG_KEXINIT type byte let shared_secret = vec![0u8; 32]; let host_key = vec![0u8; 32]; let client_pub = vec![0u8; 32]; let server_pub = vec![0u8; 32]; let hash = state.compute_exchange_hash(&shared_secret, &host_key, &client_pub, &server_pub).unwrap(); assert_eq!(hash.len(), 32); // SHA256输出32字节 } #[test] fn test_newkeys_handling() { let kex_result = KexResult::choose_algorithms( &KexProposal::server_default(), &KexProposal::client_default(), ).unwrap(); let mut state = KexState::new( "SSH-2.0-OpenSSH_10.2".to_string(), "SSH-2.0-MarkBaseSSH_1.0".to_string(), kex_result, ).unwrap(); let newkeys_packet = SshPacket::new(vec![PacketType::SSH_MSG_NEWKEYS as u8]); state.handle_newkeys(&newkeys_packet).unwrap(); assert!(state.newkeys_received); } }