- Fix trailing whitespace in kex.rs and s3.rs - Add missing KexProposal import in kex_complete.rs - Auto-fix clippy warnings across all crates - All 153 tests pass
210 lines
5.6 KiB
Rust
210 lines
5.6 KiB
Rust
use axum::http::HeaderMap;
|
||
use hmac::{Hmac, Mac};
|
||
use sha2::Sha256;
|
||
use std::fs;
|
||
|
||
type HmacSha256 = Hmac<Sha256>;
|
||
|
||
pub fn verify_signature(headers: HeaderMap, method: &str, path: &str) -> bool {
|
||
// Load S3 config and check require_auth flag
|
||
let config = crate::s3_config::S3Config::load_default().unwrap_or_default();
|
||
|
||
// Merge environment variables (allows override via MB_S3_REQUIRE_AUTH)
|
||
let mut config = config;
|
||
config.merge_env();
|
||
|
||
if !config.s3.require_auth {
|
||
// Development mode: allow access without authentication
|
||
return true;
|
||
}
|
||
|
||
// 生产模式:必须提供Authorization header
|
||
let auth_header = headers
|
||
.get("Authorization")
|
||
.and_then(|v| v.to_str().ok())
|
||
.unwrap_or("");
|
||
|
||
if !auth_header.starts_with("AWS4-HMAC-SHA256") {
|
||
return false;
|
||
}
|
||
|
||
// 2. Parse Credential
|
||
let credential = extract_credential(auth_header);
|
||
if credential.is_none() {
|
||
return false;
|
||
}
|
||
|
||
let credential = credential.unwrap();
|
||
|
||
// 3. Get secret_key from S3AccessKey database
|
||
let secret_key = get_secret_key(&credential.access_key);
|
||
if secret_key.is_none() {
|
||
return false;
|
||
}
|
||
|
||
let secret_key = secret_key.unwrap();
|
||
|
||
// 4. Calculate Signature
|
||
let calculated_signature = calculate_signature(
|
||
headers.clone(),
|
||
method,
|
||
path,
|
||
&credential.access_key,
|
||
&secret_key,
|
||
&credential.region,
|
||
&credential.service,
|
||
&credential.date,
|
||
);
|
||
|
||
// 5. Extract Signature from header
|
||
let provided_signature = extract_signature(auth_header);
|
||
if provided_signature.is_none() {
|
||
return false;
|
||
}
|
||
|
||
// 6. Compare signatures
|
||
calculated_signature == provided_signature.unwrap()
|
||
}
|
||
|
||
struct Credential {
|
||
access_key: String,
|
||
date: String,
|
||
region: String,
|
||
service: String,
|
||
}
|
||
|
||
fn extract_credential(auth_header: &str) -> Option<Credential> {
|
||
let parts: Vec<&str> = auth_header.split_whitespace().collect();
|
||
if parts.len() < 2 {
|
||
return None;
|
||
}
|
||
|
||
let credential_part = parts.iter().find(|p| p.starts_with("Credential="))?;
|
||
|
||
let credential_str = credential_part.strip_prefix("Credential=")?;
|
||
let credential_parts: Vec<&str> = credential_str.split('/').collect();
|
||
|
||
if credential_parts.len() < 5 {
|
||
return None;
|
||
}
|
||
|
||
Some(Credential {
|
||
access_key: credential_parts[0].to_string(),
|
||
date: credential_parts[1].to_string(),
|
||
region: credential_parts[2].to_string(),
|
||
service: credential_parts[3].to_string(),
|
||
})
|
||
}
|
||
|
||
fn extract_signature(auth_header: &str) -> Option<String> {
|
||
let parts: Vec<&str> = auth_header.split_whitespace().collect();
|
||
|
||
let signature_part = parts.iter().find(|p| p.starts_with("Signature="))?;
|
||
|
||
Some(signature_part.strip_prefix("Signature=")?.to_string())
|
||
}
|
||
|
||
fn get_secret_key(access_key: &str) -> Option<String> {
|
||
// Load S3AccessKey database from data/s3_keys.json
|
||
let s3_keys_path = "data/s3_keys.json";
|
||
let s3_keys_json = fs::read_to_string(s3_keys_path).ok()?;
|
||
|
||
#[derive(serde::Deserialize)]
|
||
struct S3Key {
|
||
access_key: String,
|
||
secret_key: String,
|
||
}
|
||
|
||
let s3_keys: Vec<S3Key> = serde_json::from_str(&s3_keys_json).ok()?;
|
||
|
||
s3_keys
|
||
.iter()
|
||
.find(|k| k.access_key == access_key)
|
||
.map(|k| k.secret_key.clone())
|
||
}
|
||
|
||
fn calculate_signature(
|
||
headers: HeaderMap,
|
||
method: &str,
|
||
path: &str,
|
||
_access_key: &str,
|
||
secret_key: &str,
|
||
region: &str,
|
||
service: &str,
|
||
date: &str,
|
||
) -> String {
|
||
// 1. Create Canonical Request
|
||
let canonical_request = create_canonical_request(headers, method, path);
|
||
|
||
// 2. Create String to Sign
|
||
let string_to_sign = create_string_to_sign(date, region, service, &canonical_request);
|
||
|
||
// 3. Calculate Signing Key
|
||
let signing_key = calculate_signing_key(secret_key, date, region, service);
|
||
|
||
// 4. Calculate Signature
|
||
|
||
|
||
hmac_sha256_hex(&signing_key, &string_to_sign)
|
||
}
|
||
|
||
fn create_canonical_request(headers: HeaderMap, method: &str, path: &str) -> String {
|
||
// Simplified implementation for POC
|
||
let host = headers
|
||
.get("Host")
|
||
.and_then(|v| v.to_str().ok())
|
||
.unwrap_or("localhost:11438");
|
||
|
||
format!(
|
||
"{}\n{}\n\nhost:{}\n\nhost\nUNSIGNED-PAYLOAD",
|
||
method, path, host
|
||
)
|
||
}
|
||
|
||
fn create_string_to_sign(
|
||
date: &str,
|
||
region: &str,
|
||
service: &str,
|
||
canonical_request: &str,
|
||
) -> String {
|
||
let canonical_request_hash = sha256_hex(canonical_request);
|
||
|
||
format!(
|
||
"AWS4-HMAC-SHA256\n{}T000000Z\n{}/{}/{}/aws4_request\n{}",
|
||
date, date, region, service, canonical_request_hash
|
||
)
|
||
}
|
||
|
||
fn calculate_signing_key(secret_key: &str, date: &str, region: &str, service: &str) -> Vec<u8> {
|
||
let k_secret = format!("AWS4{}", secret_key);
|
||
let k_date = hmac_sha256(k_secret.as_bytes(), date);
|
||
let k_region = hmac_sha256(&k_date, region);
|
||
let k_service = hmac_sha256(&k_region, service);
|
||
hmac_sha256(&k_service, "aws4_request")
|
||
}
|
||
|
||
fn hmac_sha256(key: &[u8], data: &str) -> Vec<u8> {
|
||
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC initialization failed");
|
||
mac.update(data.as_bytes());
|
||
mac.finalize().into_bytes().to_vec()
|
||
}
|
||
|
||
fn hmac_sha256_hex(key: &[u8], data: &str) -> String {
|
||
let result = hmac_sha256(key, data);
|
||
hex_encode(&result)
|
||
}
|
||
|
||
fn sha256_hex(data: &str) -> String {
|
||
use sha2::Digest;
|
||
let mut hasher = Sha256::new();
|
||
hasher.update(data.as_bytes());
|
||
let hash = hasher.finalize();
|
||
hex_encode(&hash)
|
||
}
|
||
|
||
fn hex_encode(data: &[u8]) -> String {
|
||
data.iter()
|
||
.map(|b| format!("{:02x}", b))
|
||
.collect::<String>()
|
||
}
|