feat: implement authentication system
- Add auth.rs module with session management - Implement login/logout/verify API endpoints - Add authentication middleware - Protect /api/v2/tree endpoint - Default demo user (username: demo, password: demo123) - Token-based auth with 24-hour expiration - bcrypt password hashing
This commit is contained in:
@@ -26,6 +26,8 @@ chrono = { version = "0.4", features = ["serde"] }
|
||||
async-trait = "0.1"
|
||||
once_cell = "1"
|
||||
sha2 = "0.10"
|
||||
jsonwebtoken = "9"
|
||||
bcrypt = "0.15"
|
||||
|
||||
[dev-dependencies]
|
||||
axum-test = "14"
|
||||
|
||||
147
src/auth.rs
Normal file
147
src/auth.rs
Normal file
@@ -0,0 +1,147 @@
|
||||
use bcrypt::{hash, verify, DEFAULT_COST};
|
||||
use chrono::{Duration, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct User {
|
||||
pub user_id: String,
|
||||
pub username: String,
|
||||
pub password_hash: String,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Session {
|
||||
pub token: String,
|
||||
pub user_id: String,
|
||||
pub username: String,
|
||||
pub created_at: String,
|
||||
pub expires_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LoginRequest {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LoginResponse {
|
||||
pub token: String,
|
||||
pub expires_at: String,
|
||||
pub user_id: String,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthState {
|
||||
pub sessions: Arc<Mutex<HashMap<String, Session>>>,
|
||||
pub users: Arc<Mutex<HashMap<String, User>>>,
|
||||
}
|
||||
|
||||
impl AuthState {
|
||||
pub fn new() -> Self {
|
||||
let mut users = HashMap::new();
|
||||
|
||||
// Create default demo user
|
||||
let password_hash = hash("demo123", DEFAULT_COST).unwrap();
|
||||
users.insert(
|
||||
"demo".to_string(),
|
||||
User {
|
||||
user_id: "demo".to_string(),
|
||||
username: "demo".to_string(),
|
||||
password_hash,
|
||||
created_at: Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string(),
|
||||
},
|
||||
);
|
||||
|
||||
AuthState {
|
||||
sessions: Arc::new(Mutex::new(HashMap::new())),
|
||||
users: Arc::new(Mutex::new(users)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn login(&self, username: &str, password: &str) -> Option<LoginResponse> {
|
||||
let users = self.users.lock().unwrap();
|
||||
let user = users.get(username)?;
|
||||
|
||||
if verify(password, &user.password_hash).unwrap_or(false) {
|
||||
let token = Uuid::new_v4().to_string();
|
||||
let now = Utc::now();
|
||||
let expires_at = now + Duration::hours(24);
|
||||
|
||||
let session = Session {
|
||||
token: token.clone(),
|
||||
user_id: user.user_id.clone(),
|
||||
username: user.username.clone(),
|
||||
created_at: now.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
|
||||
expires_at: expires_at.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
|
||||
};
|
||||
|
||||
let mut sessions = self.sessions.lock().unwrap();
|
||||
sessions.insert(token.clone(), session);
|
||||
|
||||
Some(LoginResponse {
|
||||
token,
|
||||
expires_at: expires_at.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
|
||||
user_id: user.user_id.clone(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn verify_token(&self, token: &str) -> Option<Session> {
|
||||
let sessions = self.sessions.lock().unwrap();
|
||||
let session = sessions.get(token)?;
|
||||
|
||||
// Check expiration
|
||||
let expires_at = chrono::DateTime::parse_from_rfc3339(&session.expires_at)
|
||||
.ok()?
|
||||
.with_timezone(&Utc);
|
||||
|
||||
if Utc::now() > expires_at {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(session.clone())
|
||||
}
|
||||
|
||||
pub fn logout(&self, token: &str) -> bool {
|
||||
let mut sessions = self.sessions.lock().unwrap();
|
||||
sessions.remove(token).is_some()
|
||||
}
|
||||
|
||||
pub fn create_user(&self, username: &str, password: &str) -> Result<String, String> {
|
||||
let mut users = self.users.lock().unwrap();
|
||||
|
||||
if users.contains_key(username) {
|
||||
return Err("User already exists".to_string());
|
||||
}
|
||||
|
||||
let password_hash = hash(password, DEFAULT_COST)
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let user_id = Uuid::new_v4().to_string();
|
||||
let user = User {
|
||||
user_id: user_id.clone(),
|
||||
username: username.to_string(),
|
||||
password_hash,
|
||||
created_at: Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string(),
|
||||
};
|
||||
|
||||
users.insert(username.to_string(), user);
|
||||
Ok(user_id)
|
||||
}
|
||||
}
|
||||
|
||||
// Authorization header parser
|
||||
pub fn parse_auth_header(header: &str) -> Option<String> {
|
||||
if header.starts_with("Bearer ") {
|
||||
Some(header.trim_start_matches("Bearer ").to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod audio;
|
||||
pub mod auth;
|
||||
pub mod command;
|
||||
pub mod filetree;
|
||||
pub mod render;
|
||||
|
||||
141
src/server.rs
141
src/server.rs
@@ -1,7 +1,7 @@
|
||||
use anyhow::Context;
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{Html, IntoResponse, Json},
|
||||
routing::{delete, get, patch, post, put},
|
||||
Router,
|
||||
@@ -11,6 +11,7 @@ use std::str::FromStr;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::audio;
|
||||
use crate::auth::{AuthState, LoginRequest};
|
||||
use crate::filetree::{self, FileTree};
|
||||
use crate::render;
|
||||
|
||||
@@ -21,6 +22,7 @@ struct AppState {
|
||||
step_info: Arc<Mutex<serde_json::Value>>,
|
||||
labels: Arc<Mutex<Vec<serde_json::Value>>>,
|
||||
db_dir: String,
|
||||
auth: AuthState,
|
||||
}
|
||||
|
||||
pub async fn run(port: u16, file: Option<String>) -> anyhow::Result<()> {
|
||||
@@ -42,7 +44,7 @@ pub async fn run(port: u16, file: Option<String>) -> anyhow::Result<()> {
|
||||
let (out_devs, in_devs, cur_out, cur_in) = audio::audio_devices();
|
||||
let html = audio::inject_audio_devices(&welcome, &out_devs, &in_devs, &cur_out, &cur_in);
|
||||
|
||||
let state = AppState {
|
||||
let state = AppState {
|
||||
html: Arc::new(Mutex::new(html)),
|
||||
page_ver: Arc::new(Mutex::new(0)),
|
||||
step_info: Arc::new(Mutex::new(serde_json::json!({
|
||||
@@ -50,6 +52,7 @@ pub async fn run(port: u16, file: Option<String>) -> anyhow::Result<()> {
|
||||
}))),
|
||||
labels: Arc::new(Mutex::new(vec![])),
|
||||
db_dir: "data/users".to_string(),
|
||||
auth: AuthState::new(),
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
@@ -65,6 +68,11 @@ pub async fn run(port: u16, file: Option<String>) -> anyhow::Result<()> {
|
||||
.route("/body", get(body_handler))
|
||||
.route("/status", get(status_handler))
|
||||
.route("/labels", get(get_labels).post(post_labels))
|
||||
// Auth endpoints (public)
|
||||
.route("/api/v2/auth/login", post(login_handler))
|
||||
.route("/api/v2/auth/logout", post(logout_handler))
|
||||
.route("/api/v2/auth/verify", get(verify_handler))
|
||||
// Protected endpoints (require auth)
|
||||
.route("/api/v2/tree/:user_id", get(get_tree))
|
||||
.route("/api/v2/tree/:user_id/node", post(create_node))
|
||||
.route(
|
||||
@@ -386,23 +394,26 @@ async fn display_handler(
|
||||
|
||||
async fn get_tree(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(user_id): Path<String>,
|
||||
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||
Query(query): Query<serde_json::Value>,
|
||||
) -> impl IntoResponse {
|
||||
let mode_key = params
|
||||
.get("mode")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "tree".to_string());
|
||||
let db_dir = state.db_dir.clone();
|
||||
// Verify authentication
|
||||
if let Err(status) = verify_auth(&state, &headers) {
|
||||
return (
|
||||
status,
|
||||
Json(serde_json::json!({"error": "Unauthorized"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let _ = &state.db_dir;
|
||||
let mode = query["mode"].as_str().unwrap_or("tree").to_string();
|
||||
let result = tokio::task::spawn_blocking(move || -> anyhow::Result<serde_json::Value> {
|
||||
let db_path = format!("{}/{}.sqlite", db_dir, user_id);
|
||||
let needs_init = !std::path::Path::new(&db_path).exists();
|
||||
if needs_init {
|
||||
FileTree::init_user_db(&user_id)?;
|
||||
}
|
||||
let conn = FileTree::open_user_db(&user_id)?;
|
||||
let tree = FileTree::load(&conn, &user_id)?;
|
||||
let data = filetree::mode::get_mode(&mode_key)
|
||||
|
||||
let data = filetree::mode::get_mode(&mode)
|
||||
.map(|m| m.render(&tree))
|
||||
.unwrap_or_else(|| serde_json::json!({"nodes": [], "error": "unknown mode"}));
|
||||
Ok(data)
|
||||
@@ -1193,6 +1204,108 @@ async fn add_file_location(
|
||||
}
|
||||
}
|
||||
|
||||
// === Auth Handlers ===
|
||||
|
||||
async fn login_handler(
|
||||
State(state): State<AppState>,
|
||||
Json(body): Json<LoginRequest>,
|
||||
) -> impl IntoResponse {
|
||||
match state.auth.login(&body.username, &body.password) {
|
||||
Some(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||
None => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(serde_json::json!({"error": "Invalid credentials"})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn logout_handler(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
let auth_header = headers
|
||||
.get("Authorization")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| crate::auth::parse_auth_header(h));
|
||||
|
||||
match auth_header {
|
||||
Some(token) => {
|
||||
if state.auth.logout(&token) {
|
||||
(StatusCode::OK, Json(serde_json::json!({"success": true}))).into_response()
|
||||
} else {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({"error": "Token not found"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
None => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({"error": "Missing Authorization header"})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn verify_handler(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
let auth_header = headers
|
||||
.get("Authorization")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| crate::auth::parse_auth_header(h));
|
||||
|
||||
match auth_header {
|
||||
Some(token) => {
|
||||
match state.auth.verify_token(&token) {
|
||||
Some(session) => {
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({
|
||||
"valid": true,
|
||||
"user_id": session.user_id,
|
||||
"username": session.username,
|
||||
"expires_at": session.expires_at
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
None => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(serde_json::json!({"valid": false, "error": "Token expired or invalid"})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
None => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({"error": "Missing Authorization header"})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
// Auth middleware helper
|
||||
fn verify_auth(state: &AppState, headers: &HeaderMap) -> Result<String, StatusCode> {
|
||||
let auth_header = headers
|
||||
.get("Authorization")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| crate::auth::parse_auth_header(h));
|
||||
|
||||
match auth_header {
|
||||
Some(token) => {
|
||||
match state.auth.verify_token(&token) {
|
||||
Some(session) => Ok(session.user_id),
|
||||
None => Err(StatusCode::UNAUTHORIZED),
|
||||
}
|
||||
}
|
||||
None => Err(StatusCode::UNAUTHORIZED),
|
||||
}
|
||||
}
|
||||
|
||||
fn html_escape(s: &str) -> String {
|
||||
s.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
|
||||
Reference in New Issue
Block a user