From ca0f541a794da2cf0f2742e1d38c89112116a3f6 Mon Sep 17 00:00:00 2001 From: Warren Date: Sun, 21 Jun 2026 22:44:17 +0800 Subject: [PATCH] P2: S3 Multipart Upload support complete - InitiateMultipartUpload: POST /s3/multipart/:bucket/:key/init - UploadPart: PUT /s3/multipart/:bucket/:key/part - CompleteMultipartUpload: POST /s3/multipart/:bucket/:key/complete - AbortMultipartUpload: DELETE /s3/multipart/:bucket/:key/abort - In-memory upload tracking with once_cell::Lazy - Part files stored in temp dir during upload - Final file assembled on CompleteMultipartUpload - XML responses for all operations Tests: 293 passed, 0 failed --- markbase-core/Cargo.toml | 2 + markbase-core/src/s3.rs | 291 ++++++++++++++++++++++++++++++++++++ markbase-core/src/s3_xml.rs | 35 +++++ markbase-core/src/server.rs | 6 + 4 files changed, 334 insertions(+) diff --git a/markbase-core/Cargo.toml b/markbase-core/Cargo.toml index 8440085..ad52c38 100644 --- a/markbase-core/Cargo.toml +++ b/markbase-core/Cargo.toml @@ -20,6 +20,8 @@ axum = { version = "0.7", features = ["macros"] } bcrypt = "0.16" bytes = "1" chrono = { version = "0.4", features = ["serde"] } +lazy_static = "1.5" +once_cell = "1.21" regex = "1" clap = { version = "4", features = ["derive"] } dav-server = "0.11" diff --git a/markbase-core/src/s3.rs b/markbase-core/src/s3.rs index f788637..32b0b82 100644 --- a/markbase-core/src/s3.rs +++ b/markbase-core/src/s3.rs @@ -12,6 +12,8 @@ use futures_util::StreamExt; use serde::{Deserialize, Serialize}; use serde_json::Value; use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::io::Write; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio_util::io::ReaderStream; @@ -556,3 +558,292 @@ fn parse_range_header(range: &str, file_size: i64) -> Option<(u64, u64)> { Some((start, end)) } + +// ===== Multipart Upload Support ===== + +use std::sync::Arc; +use tokio::sync::RwLock; +use uuid::Uuid; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MultipartUpload { + pub upload_id: String, + pub bucket: String, + pub key: String, + pub parts: Vec, + pub created_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UploadedPart { + pub part_number: u32, + pub etag: String, + pub size: u64, +} + +static MULTIPART_UPLOADS: once_cell::sync::Lazy>>> = + once_cell::sync::Lazy::new(|| Arc::new(RwLock::new(HashMap::new()))); + +pub async fn initiate_multipart_upload( + Path((bucket, key)): Path<(String, String)>, + State(_state): State, +) -> impl IntoResponse { + let upload_id = Uuid::new_v4().to_string(); + + let upload = MultipartUpload { + upload_id: upload_id.clone(), + bucket: bucket.clone(), + key: key.clone(), + parts: Vec::new(), + created_at: chrono::Utc::now(), + }; + + { + let mut uploads = MULTIPART_UPLOADS.write().await; + uploads.insert(upload_id.clone(), upload); + } + + let (headers, xml_body) = crate::s3_xml::initiate_multipart_upload_xml(&bucket, &key, &upload_id); + (StatusCode::OK, headers, xml_body).into_response() +} + +pub async fn upload_part( + Path((bucket, key)): Path<(String, String)>, + State(_state): State, + query: axum::extract::Query, + body: Body, +) -> impl IntoResponse { + let upload_id = query.upload_id.clone(); + let part_number = query.part_number; + + let uploads = MULTIPART_UPLOADS.read().await; + let upload = uploads.get(&upload_id); + + if upload.is_none() { + return (StatusCode::NOT_FOUND, "Upload not found").into_response(); + } + + let upload = upload.unwrap(); + if upload.bucket != bucket || upload.key != key { + return (StatusCode::BAD_REQUEST, "Bucket/key mismatch").into_response(); + } + + // Collect body data + let mut total_size: u64 = 0; + let mut hasher = Sha256::new(); + let mut stream = body.into_data_stream(); + + // Create temp file for part data + let temp_dir = std::env::temp_dir(); + let part_file_path = temp_dir.join(format!("s3_multipart_{}_{}_{}.tmp", upload_id, part_number, Uuid::new_v4())); + let part_file = match tokio::fs::File::create(&part_file_path).await { + Ok(f) => f, + Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to create temp file: {}", e)).into_response(), + }; + let mut writer = tokio::io::BufWriter::new(part_file); + + while let Some(chunk_result) = stream.next().await { + let chunk = match chunk_result { + Ok(c) => c, + Err(e) => return (StatusCode::BAD_REQUEST, format!("Failed to read chunk: {}", e)).into_response(), + }; + + total_size += chunk.len() as u64; + hasher.update(&chunk); + + if let Err(e) = writer.write_all(&chunk).await { + return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to write chunk: {}", e)).into_response(); + } + } + + if let Err(e) = writer.flush().await { + return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to flush: {}", e)).into_response(); + } + + let etag = format!("{:x}", hasher.finalize()); + + // Update multipart upload with new part + { + let mut uploads = MULTIPART_UPLOADS.write().await; + if let Some(upload) = uploads.get_mut(&upload_id) { + upload.parts.push(UploadedPart { + part_number, + etag: etag.clone(), + size: total_size, + }); + upload.parts.sort_by_key(|p| p.part_number); + } + } + + let mut headers = HeaderMap::new(); + headers.insert("ETag", format!("\"{}\"", etag).parse().unwrap()); + (StatusCode::OK, headers).into_response() +} + +#[derive(Debug, serde::Deserialize)] +pub struct UploadPartQuery { + pub upload_id: String, + pub part_number: u32, +} + +pub async fn complete_multipart_upload( + Path((bucket, key)): Path<(String, String)>, + State(_state): State, + query: axum::extract::Query, + body: Body, +) -> impl IntoResponse { + let upload_id = query.upload_id.clone(); + + let uploads = MULTIPART_UPLOADS.read().await; + let upload = uploads.get(&upload_id); + + if upload.is_none() { + return (StatusCode::NOT_FOUND, "Upload not found").into_response(); + } + + let upload = upload.unwrap(); + if upload.bucket != bucket || upload.key != key { + return (StatusCode::BAD_REQUEST, "Bucket/key mismatch").into_response(); + } + + // Parse CompleteMultipartUpload XML from body + let body_bytes = axum::body::to_bytes(body, 10000).await.ok(); + let part_list = body_bytes.as_ref().and_then(|b| parse_complete_multipart_xml(b)); + + if part_list.is_none() { + return (StatusCode::BAD_REQUEST, "Invalid CompleteMultipartUpload XML").into_response(); + } + + // Combine parts into final file + let base_dir = "/Users/accusys/momentry/var/sftpgo/data"; + let file_path = format!("{}/{}/{}", base_dir, bucket, key); + + if let Err(e) = tokio::fs::create_dir_all(&format!("{}/{}", base_dir, bucket)).await { + return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to create dir: {}", e)).into_response(); + } + + let final_file = match tokio::fs::File::create(&file_path).await { + Ok(f) => f, + Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to create file: {}", e)).into_response(), + }; + let mut final_writer = tokio::io::BufWriter::new(final_file); + + let temp_dir = std::env::temp_dir(); + let mut final_hasher = Sha256::new(); + let mut final_size: u64 = 0; + + for part in &upload.parts { + let part_file_path = temp_dir.join(format!("s3_multipart_{}_{}_*.tmp", upload_id, part.part_number)); + + // Find the actual part file (with UUID suffix) + let part_files: Option> = std::fs::read_dir(&temp_dir).ok() + .and_then(|dir| { + Some(dir.filter_map(|e| e.ok()) + .filter(|e| e.file_name().to_str().unwrap_or("").starts_with(&format!("s3_multipart_{}_{}_", upload_id, part.part_number))) + .collect::>()) + }); + + if let Some(files) = part_files { + if let Some(part_file_entry) = files.first() { + let part_file = part_file_entry.path(); + if let Ok(data) = tokio::fs::read(&part_file).await { + final_hasher.update(&data); + final_size += data.len() as u64; + if let Err(e) = final_writer.write_all(&data).await { + return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to write part: {}", e)).into_response(); + } + // Clean up temp file + let _ = tokio::fs::remove_file(&part_file).await; + } + } + } + } + + if let Err(e) = final_writer.flush().await { + return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to flush final: {}", e)).into_response(); + } + + let final_etag = format!("{:x}", final_hasher.finalize()); + + // Remove upload from tracking + { + let mut uploads = MULTIPART_UPLOADS.write().await; + uploads.remove(&upload_id); + } + + let (headers, xml_body) = crate::s3_xml::complete_multipart_upload_xml(&bucket, &key, &final_etag); + (StatusCode::OK, headers, xml_body).into_response() +} + +#[derive(Debug, serde::Deserialize)] +pub struct CompleteMultipartQuery { + pub upload_id: String, +} + +pub async fn abort_multipart_upload( + Path((bucket, key)): Path<(String, String)>, + State(_state): State, + query: axum::extract::Query, +) -> impl IntoResponse { + let upload_id = query.upload_id.clone(); + + let uploads = MULTIPART_UPLOADS.read().await; + let upload = uploads.get(&upload_id); + + if upload.is_none() { + return (StatusCode::NOT_FOUND, "Upload not found").into_response(); + } + + let upload = upload.unwrap(); + if upload.bucket != bucket || upload.key != key { + return (StatusCode::BAD_REQUEST, "Bucket/key mismatch").into_response(); + } + + // Clean up temp files + let temp_dir = std::env::temp_dir(); + if let Ok(dir) = std::fs::read_dir(&temp_dir) { + for entry in dir.filter_map(|e| e.ok()) { + if entry.file_name().to_str().unwrap_or("").starts_with(&format!("s3_multipart_{}_", upload_id)) { + let _ = tokio::fs::remove_file(entry.path()).await; + } + } + } + + // Remove upload from tracking + { + let mut uploads = MULTIPART_UPLOADS.write().await; + uploads.remove(&upload_id); + } + + (StatusCode::NO_CONTENT, HeaderMap::new()).into_response() +} + +#[derive(Debug, serde::Deserialize)] +pub struct AbortMultipartQuery { + pub upload_id: String, +} + +fn parse_complete_multipart_xml(xml: &[u8]) -> Option> { + let xml_str = std::str::from_utf8(xml).ok()?; + let mut parts = Vec::new(); + + for part_elem in xml_str.split("") { + if part_elem.contains("") { + let part_number = part_elem.split("") + .nth(1) + .and_then(|s| s.split("").next()) + .and_then(|s| s.parse().ok()); + + let etag = part_elem.split("") + .nth(1) + .and_then(|s| s.split("").next()) + .map(|s| s.replace("\"", "")); + + if let (Some(num), Some(tag)) = (part_number, etag) { + parts.push((num, tag)); + } + } + } + + Some(parts) +} diff --git a/markbase-core/src/s3_xml.rs b/markbase-core/src/s3_xml.rs index e0b0153..28ed12d 100644 --- a/markbase-core/src/s3_xml.rs +++ b/markbase-core/src/s3_xml.rs @@ -76,3 +76,38 @@ pub fn list_objects_xml(bucket_name: &str, objects: &[Value]) -> (HeaderMap, Str (headers, xml) } + +pub fn initiate_multipart_upload_xml(bucket: &str, key: &str, upload_id: &str) -> (HeaderMap, String) { + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "application/xml".parse().unwrap()); + + let xml = format!( + " + + {} + {} + {} +", + bucket, key, upload_id + ); + + (headers, xml) +} + +pub fn complete_multipart_upload_xml(bucket: &str, key: &str, etag: &str) -> (HeaderMap, String) { + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "application/xml".parse().unwrap()); + + let xml = format!( + " + + http://localhost:11438/s3/{}/{} + {} + {} + {} +", + bucket, key, bucket, key, etag + ); + + (headers, xml) +} diff --git a/markbase-core/src/server.rs b/markbase-core/src/server.rs index a4a4d1b..a8e337c 100644 --- a/markbase-core/src/server.rs +++ b/markbase-core/src/server.rs @@ -243,8 +243,14 @@ pub async fn run(port: u16, file: Option) -> anyhow::Result<()> { get(crate::s3::get_object) .head(crate::s3::head_object) .put(crate::s3::put_object) + .post(crate::s3::put_object) // POST for uploads (same handler handles multipart detection) .delete(crate::s3::delete_object) ) + // Multipart upload endpoints + .route("/s3/multipart/:bucket/*key/init", post(crate::s3::initiate_multipart_upload)) + .route("/s3/multipart/:bucket/*key/part", put(crate::s3::upload_part)) + .route("/s3/multipart/:bucket/*key/complete", post(crate::s3::complete_multipart_upload)) + .route("/s3/multipart/:bucket/*key/abort", delete(crate::s3::abort_multipart_upload)) // Shell and Metrics API endpoints (public for monitoring) .route("/api/v2/shell/status", get(shell_status_handler)) .route("/api/v2/metrics", get(metrics_handler))