From 6e3de0169e0657710c66f6adc7ba49820fce4e3e Mon Sep 17 00:00:00 2001 From: Warren Date: Sat, 16 May 2026 17:54:32 +0800 Subject: [PATCH] feat: implement authentication system - Add auth.rs module with session management - Implement login/logout/verify API endpoints - Add authentication middleware - Protect /api/v2/tree endpoint - Default demo user (username: demo, password: demo123) - Token-based auth with 24-hour expiration - bcrypt password hashing --- Cargo.toml | 2 + src/auth.rs | 147 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/server.rs | 141 +++++++++++++++++++++++++++++++++++++++++++----- 4 files changed, 277 insertions(+), 14 deletions(-) create mode 100644 src/auth.rs diff --git a/Cargo.toml b/Cargo.toml index 9e85bb7..9bd39b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,8 @@ chrono = { version = "0.4", features = ["serde"] } async-trait = "0.1" once_cell = "1" sha2 = "0.10" +jsonwebtoken = "9" +bcrypt = "0.15" [dev-dependencies] axum-test = "14" diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..fc831c9 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,147 @@ +use bcrypt::{hash, verify, DEFAULT_COST}; +use chrono::{Duration, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use uuid::Uuid; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct User { + pub user_id: String, + pub username: String, + pub password_hash: String, + pub created_at: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Session { + pub token: String, + pub user_id: String, + pub username: String, + pub created_at: String, + pub expires_at: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LoginRequest { + pub username: String, + pub password: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LoginResponse { + pub token: String, + pub expires_at: String, + pub user_id: String, +} + +#[derive(Clone)] +pub struct AuthState { + pub sessions: Arc>>, + pub users: Arc>>, +} + +impl AuthState { + pub fn new() -> Self { + let mut users = HashMap::new(); + + // Create default demo user + let password_hash = hash("demo123", DEFAULT_COST).unwrap(); + users.insert( + "demo".to_string(), + User { + user_id: "demo".to_string(), + username: "demo".to_string(), + password_hash, + created_at: Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string(), + }, + ); + + AuthState { + sessions: Arc::new(Mutex::new(HashMap::new())), + users: Arc::new(Mutex::new(users)), + } + } + + pub fn login(&self, username: &str, password: &str) -> Option { + let users = self.users.lock().unwrap(); + let user = users.get(username)?; + + if verify(password, &user.password_hash).unwrap_or(false) { + let token = Uuid::new_v4().to_string(); + let now = Utc::now(); + let expires_at = now + Duration::hours(24); + + let session = Session { + token: token.clone(), + user_id: user.user_id.clone(), + username: user.username.clone(), + created_at: now.format("%Y-%m-%dT%H:%M:%SZ").to_string(), + expires_at: expires_at.format("%Y-%m-%dT%H:%M:%SZ").to_string(), + }; + + let mut sessions = self.sessions.lock().unwrap(); + sessions.insert(token.clone(), session); + + Some(LoginResponse { + token, + expires_at: expires_at.format("%Y-%m-%dT%H:%M:%SZ").to_string(), + user_id: user.user_id.clone(), + }) + } else { + None + } + } + + pub fn verify_token(&self, token: &str) -> Option { + let sessions = self.sessions.lock().unwrap(); + let session = sessions.get(token)?; + + // Check expiration + let expires_at = chrono::DateTime::parse_from_rfc3339(&session.expires_at) + .ok()? + .with_timezone(&Utc); + + if Utc::now() > expires_at { + return None; + } + + Some(session.clone()) + } + + pub fn logout(&self, token: &str) -> bool { + let mut sessions = self.sessions.lock().unwrap(); + sessions.remove(token).is_some() + } + + pub fn create_user(&self, username: &str, password: &str) -> Result { + let mut users = self.users.lock().unwrap(); + + if users.contains_key(username) { + return Err("User already exists".to_string()); + } + + let password_hash = hash(password, DEFAULT_COST) + .map_err(|e| e.to_string())?; + + let user_id = Uuid::new_v4().to_string(); + let user = User { + user_id: user_id.clone(), + username: username.to_string(), + password_hash, + created_at: Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string(), + }; + + users.insert(username.to_string(), user); + Ok(user_id) + } +} + +// Authorization header parser +pub fn parse_auth_header(header: &str) -> Option { + if header.starts_with("Bearer ") { + Some(header.trim_start_matches("Bearer ").to_string()) + } else { + None + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 0effc5f..6da3bd7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod audio; +pub mod auth; pub mod command; pub mod filetree; pub mod render; diff --git a/src/server.rs b/src/server.rs index 10483d7..de330fd 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,7 +1,7 @@ use anyhow::Context; use axum::{ extract::{Path, Query, State}, - http::StatusCode, + http::{HeaderMap, StatusCode}, response::{Html, IntoResponse, Json}, routing::{delete, get, patch, post, put}, Router, @@ -11,6 +11,7 @@ use std::str::FromStr; use std::sync::{Arc, Mutex}; use crate::audio; +use crate::auth::{AuthState, LoginRequest}; use crate::filetree::{self, FileTree}; use crate::render; @@ -21,6 +22,7 @@ struct AppState { step_info: Arc>, labels: Arc>>, db_dir: String, + auth: AuthState, } pub async fn run(port: u16, file: Option) -> anyhow::Result<()> { @@ -42,7 +44,7 @@ pub async fn run(port: u16, file: Option) -> anyhow::Result<()> { let (out_devs, in_devs, cur_out, cur_in) = audio::audio_devices(); let html = audio::inject_audio_devices(&welcome, &out_devs, &in_devs, &cur_out, &cur_in); - let state = AppState { +let state = AppState { html: Arc::new(Mutex::new(html)), page_ver: Arc::new(Mutex::new(0)), step_info: Arc::new(Mutex::new(serde_json::json!({ @@ -50,6 +52,7 @@ pub async fn run(port: u16, file: Option) -> anyhow::Result<()> { }))), labels: Arc::new(Mutex::new(vec![])), db_dir: "data/users".to_string(), + auth: AuthState::new(), }; let app = Router::new() @@ -65,6 +68,11 @@ pub async fn run(port: u16, file: Option) -> anyhow::Result<()> { .route("/body", get(body_handler)) .route("/status", get(status_handler)) .route("/labels", get(get_labels).post(post_labels)) + // Auth endpoints (public) + .route("/api/v2/auth/login", post(login_handler)) + .route("/api/v2/auth/logout", post(logout_handler)) + .route("/api/v2/auth/verify", get(verify_handler)) + // Protected endpoints (require auth) .route("/api/v2/tree/:user_id", get(get_tree)) .route("/api/v2/tree/:user_id/node", post(create_node)) .route( @@ -386,23 +394,26 @@ async fn display_handler( async fn get_tree( State(state): State, + headers: HeaderMap, Path(user_id): Path, - Query(params): Query>, + Query(query): Query, ) -> impl IntoResponse { - let mode_key = params - .get("mode") - .cloned() - .unwrap_or_else(|| "tree".to_string()); - let db_dir = state.db_dir.clone(); + // Verify authentication + if let Err(status) = verify_auth(&state, &headers) { + return ( + status, + Json(serde_json::json!({"error": "Unauthorized"})), + ) + .into_response(); + } + + let _ = &state.db_dir; + let mode = query["mode"].as_str().unwrap_or("tree").to_string(); let result = tokio::task::spawn_blocking(move || -> anyhow::Result { - let db_path = format!("{}/{}.sqlite", db_dir, user_id); - let needs_init = !std::path::Path::new(&db_path).exists(); - if needs_init { - FileTree::init_user_db(&user_id)?; - } let conn = FileTree::open_user_db(&user_id)?; let tree = FileTree::load(&conn, &user_id)?; - let data = filetree::mode::get_mode(&mode_key) + + let data = filetree::mode::get_mode(&mode) .map(|m| m.render(&tree)) .unwrap_or_else(|| serde_json::json!({"nodes": [], "error": "unknown mode"})); Ok(data) @@ -1193,6 +1204,108 @@ async fn add_file_location( } } +// === Auth Handlers === + +async fn login_handler( + State(state): State, + Json(body): Json, +) -> impl IntoResponse { + match state.auth.login(&body.username, &body.password) { + Some(response) => (StatusCode::OK, Json(response)).into_response(), + None => ( + StatusCode::UNAUTHORIZED, + Json(serde_json::json!({"error": "Invalid credentials"})), + ) + .into_response(), + } +} + +async fn logout_handler( + State(state): State, + headers: HeaderMap, +) -> impl IntoResponse { + let auth_header = headers + .get("Authorization") + .and_then(|h| h.to_str().ok()) + .and_then(|h| crate::auth::parse_auth_header(h)); + + match auth_header { + Some(token) => { + if state.auth.logout(&token) { + (StatusCode::OK, Json(serde_json::json!({"success": true}))).into_response() + } else { + ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({"error": "Token not found"})), + ) + .into_response() + } + } + None => ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({"error": "Missing Authorization header"})), + ) + .into_response(), + } +} + +async fn verify_handler( + State(state): State, + headers: HeaderMap, +) -> impl IntoResponse { + let auth_header = headers + .get("Authorization") + .and_then(|h| h.to_str().ok()) + .and_then(|h| crate::auth::parse_auth_header(h)); + + match auth_header { + Some(token) => { + match state.auth.verify_token(&token) { + Some(session) => { + ( + StatusCode::OK, + Json(serde_json::json!({ + "valid": true, + "user_id": session.user_id, + "username": session.username, + "expires_at": session.expires_at + })), + ) + .into_response() + } + None => ( + StatusCode::UNAUTHORIZED, + Json(serde_json::json!({"valid": false, "error": "Token expired or invalid"})), + ) + .into_response(), + } + } + None => ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({"error": "Missing Authorization header"})), + ) + .into_response(), + } +} + +// Auth middleware helper +fn verify_auth(state: &AppState, headers: &HeaderMap) -> Result { + let auth_header = headers + .get("Authorization") + .and_then(|h| h.to_str().ok()) + .and_then(|h| crate::auth::parse_auth_header(h)); + + match auth_header { + Some(token) => { + match state.auth.verify_token(&token) { + Some(session) => Ok(session.user_id), + None => Err(StatusCode::UNAUTHORIZED), + } + } + None => Err(StatusCode::UNAUTHORIZED), + } +} + fn html_escape(s: &str) -> String { s.replace('&', "&") .replace('<', "<")