Implement SSH Known Hosts Verification: Parse ~/.ssh/known_hosts + verify host keys + hashed host support
Some checks failed
Test / build (push) Has been cancelled
Test / test (push) Has been cancelled

This commit is contained in:
Warren
2026-06-21 05:24:33 +08:00
parent 5238a84972
commit 30c1e5fff9
2 changed files with 529 additions and 0 deletions

View File

@@ -0,0 +1,528 @@
use anyhow::{anyhow, Result};
use log::{info, warn};
use std::collections::HashMap;
use std::fs;
use std::io::{BufRead, BufReader};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, PartialEq)]
pub enum KnownHostKey {
Ed25519(Vec<u8>),
Rsa(Vec<u8>),
Ecdsa(Vec<u8>),
Dsa(Vec<u8>),
}
#[derive(Debug, Clone)]
pub struct KnownHostEntry {
pub hosts: Vec<String>,
pub key_type: String,
pub key: KnownHostKey,
pub comment: Option<String>,
pub is_hashed: bool,
pub is_cert_authority: bool,
}
impl KnownHostEntry {
pub fn matches_host(&self, hostname: &str, ip: Option<IpAddr>) -> bool {
if self.is_hashed {
return self.matches_hashed_host(hostname, ip);
}
for host in &self.hosts {
if host == hostname {
return true;
}
if let Some(ip_addr) = ip {
if host == &ip_addr.to_string() {
return true;
}
}
if host.contains(',') {
let parts: Vec<&str> = host.split(',').collect();
for part in parts {
if part == hostname {
return true;
}
if let Some(ip_addr) = ip {
if part == &ip_addr.to_string() {
return true;
}
}
}
}
if host.starts_with('|') {
if self.matches_pattern_host(host, hostname) {
return true;
}
}
}
false
}
fn matches_hashed_host(&self, hostname: &str, _ip: Option<IpAddr>) -> bool {
for host in &self.hosts {
if host.starts_with('|') {
if let Ok(decoded) = decode_hashed_host(host) {
if decoded == hostname {
return true;
}
}
}
}
false
}
fn matches_pattern_host(&self, pattern: &str, hostname: &str) -> bool {
if pattern.contains('*') || pattern.contains('?') {
let regex_pattern = pattern.replace('*', ".*").replace('?', ".");
if let Ok(re) = regex::Regex::new(&regex_pattern) {
return re.is_match(hostname);
}
}
false
}
pub fn verify_key(&self, server_key: &[u8], key_type: &str) -> Result<bool> {
if self.key_type != key_type {
return Ok(false);
}
match &self.key {
KnownHostKey::Ed25519(key_bytes) => {
if key_type == "ssh-ed25519" {
Ok(key_bytes == server_key)
} else {
Ok(false)
}
}
KnownHostKey::Rsa(key_bytes) => {
if key_type == "ssh-rsa" || key_type == "rsa-sha2-256" || key_type == "rsa-sha2-512" {
Ok(key_bytes == server_key)
} else {
Ok(false)
}
}
KnownHostKey::Ecdsa(key_bytes) => {
if key_type.starts_with("ecdsa-sha2-") {
Ok(key_bytes == server_key)
} else {
Ok(false)
}
}
KnownHostKey::Dsa(key_bytes) => {
if key_type == "ssh-dss" {
Ok(key_bytes == server_key)
} else {
Ok(false)
}
}
}
}
}
fn decode_hashed_host(hashed: &str) -> Result<String> {
use base64::{engine::general_purpose::STANDARD, Engine as _};
let parts: Vec<&str> = hashed.split('|').collect();
if parts.len() < 4 || parts[0] != "1" {
return Err(anyhow!("Invalid hashed host format"));
}
let salt = STANDARD.decode(parts[1])?;
let hash = STANDARD.decode(parts[2])?;
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(&salt);
hasher.update(parts[3].as_bytes());
let computed_hash = hasher.finalize();
if hash == computed_hash.as_slice() {
Ok(parts[3].to_string())
} else {
Err(anyhow!("Hash mismatch"))
}
}
pub struct KnownHostsParser {
entries: Vec<KnownHostEntry>,
}
impl KnownHostsParser {
pub fn new() -> Self {
Self {
entries: Vec::new(),
}
}
pub fn load_default() -> Result<Self> {
let known_hosts_path = Self::default_known_hosts_path()?;
Self::load_from_file(&known_hosts_path)
}
pub fn load_from_file(path: &Path) -> Result<Self> {
if !path.exists() {
info!("Known hosts file not found: {}", path.display());
return Ok(Self::new());
}
let file = fs::File::open(path)?;
let reader = BufReader::new(file);
let mut parser = Self::new();
for line in reader.lines() {
let line = line?;
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some(entry) = parser.parse_line(&line) {
parser.entries.push(entry);
}
}
info!("Loaded {} known hosts entries from {}", parser.entries.len(), path.display());
Ok(parser)
}
fn default_known_hosts_path() -> Result<PathBuf> {
let home = std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE"))
.map_err(|_| anyhow!("Cannot determine home directory"))?;
Ok(PathBuf::from(home).join(".ssh").join("known_hosts"))
}
fn parse_line(&self, line: &str) -> Option<KnownHostEntry> {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() < 3 {
return None;
}
let is_cert_authority = parts[0].starts_with("@cert-authority");
let (hosts_part, key_type, key_base64, rest_parts) = if is_cert_authority {
(parts[1], parts[2], parts[3], &parts[4..])
} else {
(parts[0], parts[1], parts[2], &parts[3..])
};
let comment = if rest_parts.len() > 0 {
Some(rest_parts.join(" "))
} else {
None
};
let hosts: Vec<String> = hosts_part.split(',').map(|s| s.to_string()).collect();
let is_hashed = hosts.iter().any(|h| h.starts_with('|'));
let key = self.decode_key(key_type, key_base64)?;
Some(KnownHostEntry {
hosts,
key_type: key_type.to_string(),
key,
comment,
is_hashed,
is_cert_authority,
})
}
fn decode_key(&self, key_type: &str, key_base64: &str) -> Option<KnownHostKey> {
use base64::{engine::general_purpose::STANDARD, Engine as _};
let key_bytes = STANDARD.decode(key_base64).ok()?;
match key_type {
"ssh-ed25519" => Some(KnownHostKey::Ed25519(key_bytes)),
"ssh-rsa" | "rsa-sha2-256" | "rsa-sha2-512" => Some(KnownHostKey::Rsa(key_bytes)),
"ecdsa-sha2-nistp256" | "ecdsa-sha2-nistp384" | "ecdsa-sha2-nistp521" => {
Some(KnownHostKey::Ecdsa(key_bytes))
}
"ssh-dss" => Some(KnownHostKey::Dsa(key_bytes)),
_ => {
warn!("Unknown key type: {}", key_type);
None
}
}
}
pub fn verify_host_key(
&self,
hostname: &str,
ip: Option<IpAddr>,
server_key: &[u8],
key_type: &str,
) -> Result<VerifyResult> {
let matching_entries: Vec<&KnownHostEntry> = self
.entries
.iter()
.filter(|e| e.matches_host(hostname, ip))
.collect();
if matching_entries.is_empty() {
return Ok(VerifyResult::UnknownHost);
}
for entry in matching_entries {
if entry.verify_key(server_key, key_type)? {
return Ok(VerifyResult::Verified);
}
}
Ok(VerifyResult::KeyMismatch)
}
pub fn add_host_key(
&self,
hostname: &str,
key_type: &str,
key: &[u8],
comment: Option<&str>,
) -> Result<String> {
use base64::{engine::general_purpose::STANDARD, Engine as _};
let key_base64 = STANDARD.encode(key);
let line = if let Some(c) = comment {
format!("{} {} {} {}", hostname, key_type, key_base64, c)
} else {
format!("{} {} {}", hostname, key_type, key_base64)
};
Ok(line)
}
pub fn get_entries(&self) -> &[KnownHostEntry] {
&self.entries
}
pub fn get_entries_for_host(&self, hostname: &str) -> Vec<&KnownHostEntry> {
self.entries
.iter()
.filter(|e| e.matches_host(hostname, None))
.collect()
}
pub fn remove_host(&mut self, hostname: &str) -> usize {
let original_len = self.entries.len();
self.entries.retain(|e| !e.matches_host(hostname, None));
original_len - self.entries.len()
}
pub fn hash_host(&self, hostname: &str) -> Result<String> {
use base64::{engine::general_purpose::STANDARD, Engine as _};
use rand::Rng;
use sha2::{Digest, Sha256};
let salt: [u8; 20] = rand::rngs::OsRng.gen();
let mut hasher = Sha256::new();
hasher.update(&salt);
hasher.update(hostname.as_bytes());
let hash = hasher.finalize();
Ok(format!(
"|1|{}|{}|{}",
STANDARD.encode(&salt),
STANDARD.encode(&hash),
hostname
))
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum VerifyResult {
Verified,
KeyMismatch,
UnknownHost,
}
impl std::fmt::Display for VerifyResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VerifyResult::Verified => write!(f, "Host key verified"),
VerifyResult::KeyMismatch => write!(f, "Host key mismatch - possible MITM attack"),
VerifyResult::UnknownHost => write!(f, "Unknown host - key not found in known_hosts"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64::{engine::general_purpose::STANDARD, Engine as _};
use tempfile::TempDir;
#[test]
fn test_parse_simple_entry() {
let parser = KnownHostsParser::new();
let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==";
let line = format!("example.com ssh-ed25519 {}", valid_key);
let entry = parser.parse_line(&line).unwrap();
assert_eq!(entry.hosts, vec!["example.com"]);
assert_eq!(entry.key_type, "ssh-ed25519");
assert!(!entry.is_hashed);
assert!(!entry.is_cert_authority);
}
#[test]
fn test_parse_multiple_hosts() {
let parser = KnownHostsParser::new();
let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==";
let line = format!("host1,host2,192.168.1.1 ssh-ed25519 {}", valid_key);
let entry = parser.parse_line(&line).unwrap();
assert_eq!(entry.hosts.len(), 3);
assert!(entry.hosts.contains(&"host1".to_string()));
assert!(entry.hosts.contains(&"host2".to_string()));
}
#[test]
fn test_parse_cert_authority() {
let parser = KnownHostsParser::new();
let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==";
let line = format!("@cert-authority *.example.com ssh-ed25519 {}", valid_key);
let entry = parser.parse_line(&line).unwrap();
assert!(entry.is_cert_authority);
}
#[test]
fn test_matches_host() {
let parser = KnownHostsParser::new();
let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==";
let line = format!("example.com ssh-ed25519 {}", valid_key);
let entry = parser.parse_line(&line).unwrap();
assert!(entry.matches_host("example.com", None));
assert!(!entry.matches_host("other.com", None));
}
#[test]
fn test_matches_ip() {
let parser = KnownHostsParser::new();
let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==";
let line = format!("example.com,192.168.1.1 ssh-ed25519 {}", valid_key);
let entry = parser.parse_line(&line).unwrap();
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
assert!(entry.matches_host("example.com", Some(ip)));
}
#[test]
fn test_verify_host_key() {
let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==";
let parser = KnownHostsParser::new();
let line = format!("example.com ssh-ed25519 {}", valid_key);
let entry = parser.parse_line(&line).unwrap();
let mut parser = KnownHostsParser::new();
parser.entries.push(entry);
let key_bytes = STANDARD.decode(valid_key).unwrap();
let result = parser.verify_host_key("example.com", None, &key_bytes, "ssh-ed25519");
assert_eq!(result.unwrap(), VerifyResult::Verified);
let result = parser.verify_host_key("example.com", None, &[0u8; 32], "ssh-ed25519");
assert_eq!(result.unwrap(), VerifyResult::KeyMismatch);
let result = parser.verify_host_key("unknown.com", None, &key_bytes, "ssh-ed25519");
assert_eq!(result.unwrap(), VerifyResult::UnknownHost);
}
#[test]
fn test_load_from_file() {
let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==";
let temp_dir = TempDir::new().unwrap();
let known_hosts_path = temp_dir.path().join("known_hosts");
fs::write(
&known_hosts_path,
format!("example.com ssh-ed25519 {}\n", valid_key),
)
.unwrap();
let parser = KnownHostsParser::load_from_file(&known_hosts_path).unwrap();
assert_eq!(parser.entries.len(), 1);
}
#[test]
fn test_add_host_key() {
let parser = KnownHostsParser::new();
let key_bytes = vec![1, 2, 3, 4];
let line = parser.add_host_key("example.com", "ssh-ed25519", &key_bytes, None).unwrap();
assert!(line.contains("example.com"));
assert!(line.contains("ssh-ed25519"));
}
#[test]
fn test_remove_host() {
let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==";
let parser = KnownHostsParser::new();
let line = format!("example.com ssh-ed25519 {}", valid_key);
let entry = parser.parse_line(&line).unwrap();
let mut parser = KnownHostsParser::new();
parser.entries.push(entry);
let removed = parser.remove_host("example.com");
assert_eq!(removed, 1);
assert_eq!(parser.entries.len(), 0);
}
#[test]
fn test_hash_host() {
let parser = KnownHostsParser::new();
let hashed = parser.hash_host("example.com").unwrap();
assert!(hashed.starts_with("|1|"));
}
#[test]
fn test_comment_parsing() {
let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==";
let parser = KnownHostsParser::new();
let line = format!("example.com ssh-ed25519 {} this is a comment", valid_key);
let entry = parser.parse_line(&line).unwrap();
assert_eq!(entry.comment, Some("this is a comment".to_string()));
}
#[test]
fn test_empty_file() {
let temp_dir = TempDir::new().unwrap();
let known_hosts_path = temp_dir.path().join("known_hosts");
fs::write(&known_hosts_path, "").unwrap();
let parser = KnownHostsParser::load_from_file(&known_hosts_path).unwrap();
assert_eq!(parser.entries.len(), 0);
}
#[test]
fn test_skip_comments() {
let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==";
let temp_dir = TempDir::new().unwrap();
let known_hosts_path = temp_dir.path().join("known_hosts");
fs::write(
&known_hosts_path,
format!("# This is a comment\nexample.com ssh-ed25519 {}\n", valid_key),
)
.unwrap();
let parser = KnownHostsParser::load_from_file(&known_hosts_path).unwrap();
assert_eq!(parser.entries.len(), 1);
}
}

View File

@@ -11,6 +11,7 @@ pub mod host_key;
pub mod kex; pub mod kex;
pub mod kex_complete; pub mod kex_complete;
pub mod kex_exchange; pub mod kex_exchange;
pub mod known_hosts;
pub mod packet; pub mod packet;
pub mod port_forward; pub mod port_forward;
pub mod port_forward_listener; pub mod port_forward_listener;