Files
markbase/markbase-core/src/s3.rs
Warren 097521b35d
Some checks failed
Test / build (push) Has been cancelled
Test / test (push) Has been cancelled
P2: Fix S3 multipart route - use query param for action
- Change route from /s3/multipart/:bucket/*key/init to /s3/multipart/:bucket/*key?action=init
- Add multipart_handler to unify all multipart operations
- Use Response type instead of impl IntoResponse for type compatibility
2026-06-22 01:22:16 +08:00

1106 lines
37 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::{
body::Body,
extract::{Path, State},
http::{HeaderMap, StatusCode},
response::{IntoResponse, Json},
};
use filetree::{
node::FileNode,
FileTree,
};
use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::io::Write;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_util::io::ReaderStream;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct S3AccessKey {
pub access_key: String,
pub secret_key: String,
pub user_id: String,
pub permissions: Vec<String>,
pub created_at: String,
}
pub async fn list_buckets(State(state): State<crate::server::AppState>) -> impl IntoResponse {
let mut buckets = vec![];
if let Ok(dir) = std::fs::read_dir(&state.db_dir) {
for entry in dir.flatten() {
if let Some(name) = entry.file_name().to_str() {
if name.ends_with(".sqlite") {
let bucket_name = name.replace(".sqlite", "");
buckets.push(bucket_name);
}
}
}
}
let (headers, xml_body) = crate::s3_xml::list_buckets_xml(&buckets);
(StatusCode::OK, headers, xml_body).into_response()
}
pub async fn list_objects(
Path(bucket): Path<String>,
State(_state): State<crate::server::AppState>,
) -> impl IntoResponse {
println!("S3 List Objects: bucket={}", bucket);
let conn = match FileTree::open_user_db(&bucket) {
Ok(c) => c,
Err(e) => {
println!("Error opening DB: {}", e);
return (StatusCode::NOT_FOUND, "Bucket not found").into_response();
}
};
let tree = match FileTree::load(&conn, &bucket, "untitled folder") {
Ok(t) => t,
Err(e) => {
println!("Error loading tree: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to load tree").into_response();
}
};
let objects: Vec<Value> = tree
.nodes
.iter()
.filter(|n| n.node_type == filetree::node::NodeType::File)
.map(|n| {
serde_json::json!({
"Key": build_s3_key(&tree, n),
"LastModified": n.registered_at.clone().unwrap_or_default(),
"ETag": n.sha256.clone().unwrap_or_default(),
"Size": n.file_size.unwrap_or(0),
})
})
.collect();
println!("Listed {} objects for bucket {}", objects.len(), bucket);
let (headers, xml_body) = crate::s3_xml::list_objects_xml(&bucket, &objects);
(StatusCode::OK, headers, xml_body).into_response()
}
pub async fn get_object(
Path((bucket, key)): Path<(String, String)>,
State(state): State<crate::server::AppState>,
headers: HeaderMap,
) -> impl IntoResponse {
println!("S3 GET Object: bucket={}, key={}", bucket, key);
// Policy check - user needs GetObject permission
let user_id = extract_user_from_auth(&headers).unwrap_or_else(|| "anonymous".to_string());
if !check_bucket_policy(&bucket, "s3:GetObject", &format!("arn:aws:s3:::{}", bucket), &user_id) {
return (StatusCode::FORBIDDEN, "Policy denied").into_response();
}
let conn = match FileTree::open_user_db(&bucket) {
Ok(c) => c,
Err(e) => {
println!("Error opening DB: {}", e);
return (StatusCode::NOT_FOUND, "Bucket not found").into_response();
}
};
let tree = match FileTree::load(&conn, &bucket, "untitled folder") {
Ok(t) => t,
Err(e) => {
println!("Error loading tree: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to load tree").into_response();
}
};
println!("Tree loaded, {} nodes", tree.nodes.len());
let node = find_node_by_s3_key(&tree, &key);
if node.is_none() {
println!("Node not found for key: {}", key);
return (StatusCode::NOT_FOUND, "Object not found").into_response();
}
let node = node.unwrap();
println!(
"Node found: file_uuid={}",
node.file_uuid.clone().unwrap_or_default()
);
let file_uuid = node.file_uuid.clone().unwrap_or_default();
let file_size = node.file_size.unwrap_or(0);
let sha256 = node.sha256.clone().unwrap_or_default();
let real_path = get_real_file_path(&conn, &file_uuid);
if real_path.is_none() {
println!("File location not found for uuid: {}", file_uuid);
return (StatusCode::NOT_FOUND, "File location not found").into_response();
}
let real_path = real_path.unwrap();
println!("Real path: {}", real_path);
// 检查Range header
let range_header = headers.get("Range").and_then(|v| v.to_str().ok());
if let Some(range) = range_header {
println!("Range request: {}", range);
return handle_range_request(real_path, range, file_size, sha256).await;
}
// 完整文件下载
let file = match tokio::fs::File::open(&real_path).await {
Ok(f) => f,
Err(e) => {
println!("Error opening file: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to open file").into_response();
}
};
println!("File opened successfully for streaming");
let stream = ReaderStream::new(file);
let body = Body::from_stream(stream);
let mut response_headers = HeaderMap::new();
response_headers.insert("Content-Type", "application/octet-stream".parse().unwrap());
response_headers.insert("ETag", format!("\"{}\"", sha256).parse().unwrap());
response_headers.insert("Content-Length", file_size.into());
response_headers.insert("Accept-Ranges", "bytes".parse().unwrap());
(StatusCode::OK, response_headers, body).into_response()
}
pub async fn put_object(
Path((bucket, key)): Path<(String, String)>,
State(state): State<crate::server::AppState>,
headers: HeaderMap,
body: Body,
) -> impl IntoResponse {
println!("S3 PUT Object: bucket={}, key={}", bucket, key);
// Policy check - user needs PutObject permission
let user_id = extract_user_from_auth(&headers).unwrap_or_else(|| "anonymous".to_string());
if !check_bucket_policy(&bucket, "s3:PutObject", &format!("arn:aws:s3:::{}", bucket), &user_id) {
return (StatusCode::FORBIDDEN, "Policy denied").into_response();
}
let base_dir = "/Users/accusys/momentry/var/sftpgo/data";
let file_path = format!("{}/{}/{}", base_dir, bucket, key);
if let Err(e) = tokio::fs::create_dir_all(&format!("{}/{}", base_dir, bucket)).await {
println!("Error creating directory: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to create directory",
)
.into_response();
}
let file = match tokio::fs::File::create(&file_path).await {
Ok(f) => f,
Err(e) => {
println!("Error creating file: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to create file").into_response();
}
};
let mut writer = tokio::io::BufWriter::with_capacity(64 * 1024, file);
let mut hasher = Sha256::new();
let mut total_size: u64 = 0;
let mut stream = body.into_data_stream();
while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
println!("Error reading chunk: {}", e);
let _ = tokio::fs::remove_file(&file_path).await;
return (StatusCode::BAD_REQUEST, "Failed to read body").into_response();
}
};
total_size += chunk.len() as u64;
if total_size > 100_000_000_000 {
println!("File too large: {} bytes", total_size);
let _ = tokio::fs::remove_file(&file_path).await;
return (StatusCode::BAD_REQUEST, "File too large (>100GB)").into_response();
}
hasher.update(&chunk);
if let Err(e) = writer.write_all(&chunk).await {
println!("Error writing chunk: {}", e);
let _ = tokio::fs::remove_file(&file_path).await;
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to write file").into_response();
}
}
if let Err(e) = writer.flush().await {
println!("Error flushing writer: {}", e);
let _ = tokio::fs::remove_file(&file_path).await;
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to flush file").into_response();
}
let sha256_hash = format!("{:x}", hasher.finalize());
println!("File written: {} bytes, SHA256={}", total_size, sha256_hash);
let sha256_hash_clone = sha256_hash.clone();
let file_path_clone = file_path.clone();
let label = key.split('/').next_back().unwrap_or(&key).to_string();
let result = tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
let conn = match FileTree::open_user_db(&bucket) {
Ok(c) => {
// Check if database has tables
let has_tables: bool = c
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='file_nodes'",
[],
|row| row.get::<_, i32>(0),
)
.unwrap_or(0) > 0;
if !has_tables {
// Initialize tables if not exist
c.execute_batch(filetree::CREATE_TABLES)?;
}
c
}
Err(_) => FileTree::init_user_db(&bucket)?,
};
let file_uuid = sha256_hash_clone.clone();
let (node, _) = FileTree::new_file_node(
&label,
&file_uuid,
Some(&sha256_hash_clone),
&label,
Some(total_size as i64),
None,
None,
None,
);
let mut tree = FileTree::load(&conn, &bucket, "untitled folder")?;
tree.insert_node(&conn, &node)?;
FileTree::add_location(&conn, &file_uuid, &file_path_clone, Some(&label))?;
Ok(())
})
.await;
match result {
Ok(Ok(_)) => {
println!("PutObject success: {}", key);
let mut headers = HeaderMap::new();
headers.insert("ETag", format!("\"{}\"", sha256_hash).parse().unwrap());
(StatusCode::OK, headers).into_response()
}
Ok(Err(e)) => {
println!("DB error: {}", e);
let _ = tokio::fs::remove_file(&file_path).await;
(StatusCode::INTERNAL_SERVER_ERROR, "Database error").into_response()
}
Err(e) => {
println!("Task error: {}", e);
let _ = tokio::fs::remove_file(&file_path).await;
(StatusCode::INTERNAL_SERVER_ERROR, "Task error").into_response()
}
}
}
pub async fn head_object(
Path((bucket, key)): Path<(String, String)>,
State(_state): State<crate::server::AppState>,
) -> impl IntoResponse {
let conn = match FileTree::open_user_db(&bucket) {
Ok(c) => c,
Err(_) => return (StatusCode::NOT_FOUND, HeaderMap::new()),
};
let tree = match FileTree::load(&conn, &bucket, "untitled folder") {
Ok(t) => t,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new()),
};
let node = find_node_by_s3_key(&tree, &key);
if node.is_none() {
return (StatusCode::NOT_FOUND, HeaderMap::new());
}
let node = node.unwrap();
let mut headers = HeaderMap::new();
headers.insert("Content-Type", "application/octet-stream".parse().unwrap());
headers.insert(
"ETag",
node.sha256.clone().unwrap_or_default().parse().unwrap(),
);
headers.insert("Content-Length", node.file_size.unwrap_or(0).into());
(StatusCode::OK, headers)
}
pub async fn s3_status(State(state): State<crate::server::AppState>) -> impl IntoResponse {
let buckets = count_buckets(&state.db_dir);
let keys_count = state.s3_keys.lock().unwrap().len();
Json(serde_json::json!({
"enabled": true,
"endpoint": "http://localhost:11438/s3",
"region": "us-east-1",
"buckets_count": buckets,
"keys_count": keys_count
}))
}
pub async fn generate_s3_key(State(state): State<crate::server::AppState>) -> impl IntoResponse {
let new_key = S3AccessKey {
access_key: format!("markbase_access_key_{}", uuid::Uuid::new_v4()),
secret_key: format!("markbase_secret_key_{}", uuid::Uuid::new_v4()),
user_id: "warren".to_string(),
permissions: vec!["GetObject".to_string(), "ListBucket".to_string()],
created_at: chrono::Utc::now().to_rfc3339(),
};
state.s3_keys.lock().unwrap().push(new_key.clone());
Json(serde_json::json!({
"access_key": new_key.access_key,
"secret_key": new_key.secret_key,
"user_id": new_key.user_id
}))
}
pub async fn delete_object(
Path((bucket, key)): Path<(String, String)>,
State(state): State<crate::server::AppState>,
headers: HeaderMap,
) -> impl IntoResponse {
println!("S3 DELETE Object: bucket={}, key={}", bucket, key);
// Policy check - user needs DeleteObject permission
let user_id = extract_user_from_auth(&headers).unwrap_or_else(|| "anonymous".to_string());
if !check_bucket_policy(&bucket, "s3:DeleteObject", &format!("arn:aws:s3:::{}", bucket), &user_id) {
return (StatusCode::FORBIDDEN, "Policy denied").into_response();
}
let result = tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
let conn = FileTree::open_user_db(&bucket)?;
let mut tree = FileTree::load(&conn, &bucket, "untitled folder")?;
let node = find_node_by_s3_key(&tree, &key);
if node.is_none() {
return Err(anyhow::anyhow!("Object not found"));
}
let node = node.unwrap();
let file_uuid = node.file_uuid.clone().unwrap_or_default();
let file_path = get_real_file_path(&conn, &file_uuid);
if let Some(path) = file_path {
std::fs::remove_file(&path)?;
}
tree.delete_node(&conn, &node.node_id)?;
Ok(())
})
.await;
match result {
Ok(Ok(_)) => (StatusCode::NO_CONTENT, HeaderMap::new()).into_response(),
Ok(Err(e)) => {
println!("Delete error: {}", e);
if e.to_string().contains("Object not found") {
(StatusCode::NOT_FOUND, "Object not found").into_response()
} else {
(StatusCode::INTERNAL_SERVER_ERROR, "Delete error").into_response()
}
}
Err(e) => {
println!("Task error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Task error").into_response()
}
}
}
fn build_s3_key(tree: &FileTree, node: &FileNode) -> String {
let mut path_parts = vec![];
let mut current_parent = node.parent_id.clone();
while let Some(parent_id) = current_parent {
let parent = tree.nodes.iter().find(|n| n.node_id == parent_id);
if let Some(p) = parent {
path_parts.push(p.label.clone());
current_parent = p.parent_id.clone();
} else {
break;
}
}
path_parts.reverse();
path_parts.push(node.label.clone());
path_parts.join("/")
}
fn find_node_by_s3_key(tree: &FileTree, key: &str) -> Option<FileNode> {
// 方法1通过完整路径匹配
let node_by_path = tree
.nodes
.iter()
.filter(|n| n.node_type == filetree::node::NodeType::File)
.find(|n| build_s3_key(tree, n) == key)
.cloned();
if node_by_path.is_some() {
return node_by_path;
}
// 方法2通过filename直接匹配fallback
let filename = key.split('/').next_back().unwrap_or(key);
tree.nodes
.iter()
.filter(|n| n.node_type == filetree::node::NodeType::File)
.find(|n| n.label == filename)
.cloned()
}
fn get_real_file_path(conn: &rusqlite::Connection, file_uuid: &str) -> Option<String> {
let mut stmt = conn
.prepare("SELECT location FROM file_locations WHERE file_uuid = ?1 LIMIT 1")
.ok()?;
stmt.query_row([file_uuid], |row| row.get(0)).ok()
}
fn count_buckets(db_dir: &str) -> usize {
if let Ok(dir) = std::fs::read_dir(db_dir) {
dir.flatten()
.filter(|e| e.file_name().to_str().unwrap_or("").ends_with(".sqlite"))
.count()
} else {
0
}
}
async fn handle_range_request(
real_path: String,
range: &str,
file_size: i64,
sha256: String,
) -> axum::response::Response<Body> {
let range_spec = parse_range_header(range, file_size);
if range_spec.is_none() {
println!("Invalid Range header: {}", range);
return (StatusCode::BAD_REQUEST, "Invalid Range header").into_response();
}
let (start, end) = range_spec.unwrap();
let content_length = end - start + 1;
println!(
"Range request: bytes {}-{}, content_length={}",
start, end, content_length
);
let mut file = match tokio::fs::File::open(&real_path).await {
Ok(f) => f,
Err(e) => {
println!("Error opening file for range: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to open file").into_response();
}
};
// Seek到start位置
use tokio::io::AsyncSeekExt;
if let Err(e) = file.seek(tokio::io::SeekFrom::Start(start)).await {
println!("Error seeking file: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to seek file").into_response();
}
// 使用take限制读取长度
let limited_file = file.take(content_length);
let stream = ReaderStream::new(limited_file);
let body = Body::from_stream(stream);
let mut headers = HeaderMap::new();
headers.insert("Content-Type", "application/octet-stream".parse().unwrap());
headers.insert(
"Content-Range",
format!("bytes {}-{}/{}", start, end, file_size)
.parse()
.unwrap(),
);
headers.insert("Content-Length", content_length.into());
headers.insert("ETag", format!("\"{}\"", sha256).parse().unwrap());
headers.insert("Accept-Ranges", "bytes".parse().unwrap());
(StatusCode::PARTIAL_CONTENT, headers, body).into_response()
}
fn parse_range_header(range: &str, file_size: i64) -> Option<(u64, u64)> {
let range_str = range.strip_prefix("bytes=")?;
if range_str.contains(',') {
return None;
}
let parts: Vec<&str> = range_str.split('-').collect();
if parts.len() != 2 {
return None;
}
let (start, end) = if parts[0].is_empty() {
// "bytes=-N"格式最后N字节
let suffix_length = parts[1].parse::<u64>().ok()?;
let start = (file_size as u64).saturating_sub(suffix_length);
(start, file_size as u64 - 1)
} else if parts[1].is_empty() {
// "bytes=N-"格式从N到结尾
let start = parts[0].parse::<u64>().ok()?;
(start, file_size as u64 - 1)
} else {
// "bytes=N-M"格式从N到M
let start = parts[0].parse::<u64>().ok()?;
let end = parts[1].parse::<u64>().ok()?;
(start, end)
};
if start > end || end >= file_size as u64 {
return None;
}
Some((start, end))
}
// ===== Multipart Upload Support =====
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultipartUpload {
pub upload_id: String,
pub bucket: String,
pub key: String,
pub parts: Vec<UploadedPart>,
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UploadedPart {
pub part_number: u32,
pub etag: String,
pub size: u64,
}
static MULTIPART_UPLOADS: once_cell::sync::Lazy<Arc<RwLock<HashMap<String, MultipartUpload>>>> =
once_cell::sync::Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
pub async fn initiate_multipart_upload(
Path((bucket, key)): Path<(String, String)>,
State(state): State<crate::server::AppState>,
headers: HeaderMap,
) -> impl IntoResponse {
// Authentication check
if !crate::s3_auth::verify_signature(headers.clone(), "POST", &format!("/s3/multipart/{}/{}?uploads", bucket, key)) {
return (StatusCode::FORBIDDEN, "Access denied").into_response();
}
// Policy check - user needs PutObject permission
let user_id = extract_user_from_auth(&headers).unwrap_or_else(|| "anonymous".to_string());
if !check_bucket_policy(&bucket, "s3:PutObject", &format!("arn:aws:s3:::{}/*", bucket), &user_id) {
return (StatusCode::FORBIDDEN, "Policy denied").into_response();
}
let upload_id = Uuid::new_v4().to_string();
let upload = MultipartUpload {
upload_id: upload_id.clone(),
bucket: bucket.clone(),
key: key.clone(),
parts: Vec::new(),
created_at: chrono::Utc::now(),
};
{
let mut uploads = MULTIPART_UPLOADS.write().await;
uploads.insert(upload_id.clone(), upload);
}
let (headers, xml_body) = crate::s3_xml::initiate_multipart_upload_xml(&bucket, &key, &upload_id);
(StatusCode::OK, headers, xml_body).into_response()
}
pub async fn upload_part(
Path((bucket, key)): Path<(String, String)>,
State(state): State<crate::server::AppState>,
query: axum::extract::Query<UploadPartQuery>,
headers: HeaderMap,
body: Body,
) -> impl IntoResponse {
// Authentication check
if !crate::s3_auth::verify_signature(headers.clone(), "PUT", &format!("/s3/multipart/{}/{}?uploadId={}&partNumber={}", bucket, key, query.upload_id, query.part_number)) {
return (StatusCode::FORBIDDEN, "Access denied").into_response();
}
// Policy check
let user_id = extract_user_from_auth(&headers).unwrap_or_else(|| "anonymous".to_string());
if !check_bucket_policy(&bucket, "s3:PutObject", &format!("arn:aws:s3:::{}/*", bucket), &user_id) {
return (StatusCode::FORBIDDEN, "Policy denied").into_response();
}
let upload_id = query.upload_id.clone();
let part_number = query.part_number;
let uploads = MULTIPART_UPLOADS.read().await;
let upload = uploads.get(&upload_id);
if upload.is_none() {
return (StatusCode::NOT_FOUND, "Upload not found").into_response();
}
let upload = upload.unwrap();
if upload.bucket != bucket || upload.key != key {
return (StatusCode::BAD_REQUEST, "Bucket/key mismatch").into_response();
}
// Collect body data
let mut total_size: u64 = 0;
let mut hasher = Sha256::new();
let mut stream = body.into_data_stream();
// Create temp file for part data
let temp_dir = std::env::temp_dir();
let part_file_path = temp_dir.join(format!("s3_multipart_{}_{}_{}.tmp", upload_id, part_number, Uuid::new_v4()));
let part_file = match tokio::fs::File::create(&part_file_path).await {
Ok(f) => f,
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to create temp file: {}", e)).into_response(),
};
let mut writer = tokio::io::BufWriter::new(part_file);
while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => return (StatusCode::BAD_REQUEST, format!("Failed to read chunk: {}", e)).into_response(),
};
total_size += chunk.len() as u64;
hasher.update(&chunk);
if let Err(e) = writer.write_all(&chunk).await {
return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to write chunk: {}", e)).into_response();
}
}
if let Err(e) = writer.flush().await {
return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to flush: {}", e)).into_response();
}
let etag = format!("{:x}", hasher.finalize());
// Update multipart upload with new part
{
let mut uploads = MULTIPART_UPLOADS.write().await;
if let Some(upload) = uploads.get_mut(&upload_id) {
upload.parts.push(UploadedPart {
part_number,
etag: etag.clone(),
size: total_size,
});
upload.parts.sort_by_key(|p| p.part_number);
}
}
let mut headers = HeaderMap::new();
headers.insert("ETag", format!("\"{}\"", etag).parse().unwrap());
(StatusCode::OK, headers).into_response()
}
#[derive(Debug, serde::Deserialize)]
pub struct UploadPartQuery {
pub upload_id: String,
pub part_number: u32,
}
pub async fn complete_multipart_upload(
Path((bucket, key)): Path<(String, String)>,
State(state): State<crate::server::AppState>,
query: axum::extract::Query<CompleteMultipartQuery>,
headers: HeaderMap,
body: Body,
) -> impl IntoResponse {
// Authentication check
if !crate::s3_auth::verify_signature(headers.clone(), "POST", &format!("/s3/multipart/{}/{}?uploadId={}", bucket, key, query.upload_id)) {
return (StatusCode::FORBIDDEN, "Access denied").into_response();
}
// Policy check
let user_id = extract_user_from_auth(&headers).unwrap_or_else(|| "anonymous".to_string());
if !check_bucket_policy(&bucket, "s3:PutObject", &format!("arn:aws:s3:::{}/*", bucket), &user_id) {
return (StatusCode::FORBIDDEN, "Policy denied").into_response();
}
let upload_id = query.upload_id.clone();
let uploads = MULTIPART_UPLOADS.read().await;
let upload = uploads.get(&upload_id);
if upload.is_none() {
return (StatusCode::NOT_FOUND, "Upload not found").into_response();
}
let upload = upload.unwrap();
if upload.bucket != bucket || upload.key != key {
return (StatusCode::BAD_REQUEST, "Bucket/key mismatch").into_response();
}
// Parse CompleteMultipartUpload XML from body
let body_bytes = axum::body::to_bytes(body, 10000).await.ok();
let part_list = body_bytes.as_ref().and_then(|b| parse_complete_multipart_xml(b));
if part_list.is_none() {
return (StatusCode::BAD_REQUEST, "Invalid CompleteMultipartUpload XML").into_response();
}
// Combine parts into final file
let base_dir = "/Users/accusys/momentry/var/sftpgo/data";
let file_path = format!("{}/{}/{}", base_dir, bucket, key);
if let Err(e) = tokio::fs::create_dir_all(&format!("{}/{}", base_dir, bucket)).await {
return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to create dir: {}", e)).into_response();
}
let final_file = match tokio::fs::File::create(&file_path).await {
Ok(f) => f,
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to create file: {}", e)).into_response(),
};
let mut final_writer = tokio::io::BufWriter::new(final_file);
let temp_dir = std::env::temp_dir();
let mut final_hasher = Sha256::new();
let mut final_size: u64 = 0;
for part in &upload.parts {
let _part_file_path = temp_dir.join(format!("s3_multipart_{}_{}_*.tmp", upload_id, part.part_number));
// Find the actual part file (with UUID suffix)
let part_files: Option<Vec<_>> = std::fs::read_dir(&temp_dir).ok().map(|dir| dir.filter_map(|e| e.ok())
.filter(|e| e.file_name().to_str().unwrap_or("").starts_with(&format!("s3_multipart_{}_{}_", upload_id, part.part_number)))
.collect::<Vec<_>>());
if let Some(files) = part_files {
if let Some(part_file_entry) = files.first() {
let part_file = part_file_entry.path();
if let Ok(data) = tokio::fs::read(&part_file).await {
final_hasher.update(&data);
final_size += data.len() as u64;
if let Err(e) = final_writer.write_all(&data).await {
return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to write part: {}", e)).into_response();
}
// Clean up temp file
let _ = tokio::fs::remove_file(&part_file).await;
}
}
}
}
if let Err(e) = final_writer.flush().await {
return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to flush final: {}", e)).into_response();
}
let final_etag = format!("{:x}", final_hasher.finalize());
// Remove upload from tracking
{
let mut uploads = MULTIPART_UPLOADS.write().await;
uploads.remove(&upload_id);
}
let (headers, xml_body) = crate::s3_xml::complete_multipart_upload_xml(&bucket, &key, &final_etag);
(StatusCode::OK, headers, xml_body).into_response()
}
#[derive(Debug, serde::Deserialize)]
pub struct CompleteMultipartQuery {
pub upload_id: String,
}
pub async fn abort_multipart_upload(
Path((bucket, key)): Path<(String, String)>,
State(state): State<crate::server::AppState>,
query: axum::extract::Query<AbortMultipartQuery>,
headers: HeaderMap,
) -> impl IntoResponse {
// Authentication check
if !crate::s3_auth::verify_signature(headers.clone(), "DELETE", &format!("/s3/multipart/{}/{}?uploadId={}", bucket, key, query.upload_id)) {
return (StatusCode::FORBIDDEN, "Access denied").into_response();
}
let upload_id = query.upload_id.clone();
let uploads = MULTIPART_UPLOADS.read().await;
let upload = uploads.get(&upload_id);
if upload.is_none() {
return (StatusCode::NOT_FOUND, "Upload not found").into_response();
}
let upload = upload.unwrap();
if upload.bucket != bucket || upload.key != key {
return (StatusCode::BAD_REQUEST, "Bucket/key mismatch").into_response();
}
// Clean up temp files
let temp_dir = std::env::temp_dir();
if let Ok(dir) = std::fs::read_dir(&temp_dir) {
for entry in dir.filter_map(|e| e.ok()) {
if entry.file_name().to_str().unwrap_or("").starts_with(&format!("s3_multipart_{}_", upload_id)) {
let _ = tokio::fs::remove_file(entry.path()).await;
}
}
}
// Remove upload from tracking
{
let mut uploads = MULTIPART_UPLOADS.write().await;
uploads.remove(&upload_id);
}
(StatusCode::NO_CONTENT, HeaderMap::new()).into_response()
}
#[derive(Debug, serde::Deserialize)]
pub struct AbortMultipartQuery {
pub upload_id: String,
}
fn parse_complete_multipart_xml(xml: &[u8]) -> Option<Vec<(u32, String)>> {
let xml_str = std::str::from_utf8(xml).ok()?;
let mut parts = Vec::new();
for part_elem in xml_str.split("<Part>") {
if part_elem.contains("</Part>") {
let part_number = part_elem.split("<PartNumber>")
.nth(1)
.and_then(|s| s.split("</PartNumber>").next())
.and_then(|s| s.parse().ok());
let etag = part_elem.split("<ETag>")
.nth(1)
.and_then(|s| s.split("</ETag>").next())
.map(|s| s.replace("\"", ""));
if let (Some(num), Some(tag)) = (part_number, etag) {
parts.push((num, tag));
}
}
}
Some(parts)
}
// ===== Bucket Policy Support =====
use crate::s3_policy::BucketPolicy;
static BUCKET_POLICIES: once_cell::sync::Lazy<Arc<RwLock<HashMap<String, BucketPolicy>>>> =
once_cell::sync::Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
pub async fn get_bucket_policy(
Path(bucket): Path<String>,
State(_state): State<crate::server::AppState>,
) -> impl IntoResponse {
let policies = BUCKET_POLICIES.read().await;
let policy = policies.get(&bucket);
if policy.is_none() {
return (StatusCode::NOT_FOUND, "Bucket policy not found").into_response();
}
let policy = policy.unwrap();
let json = serde_json::to_string_pretty(policy)
.unwrap_or_else(|_| "{}".to_string());
let mut headers = HeaderMap::new();
headers.insert("Content-Type", "application/json".parse().unwrap());
(StatusCode::OK, headers, json).into_response()
}
pub async fn put_bucket_policy(
Path(bucket): Path<String>,
State(_state): State<crate::server::AppState>,
body: Body,
) -> impl IntoResponse {
let body_bytes = axum::body::to_bytes(body, 100000).await.ok();
if body_bytes.is_none() {
return (StatusCode::BAD_REQUEST, "Empty body").into_response();
}
let policy: BucketPolicy = match serde_json::from_slice(&body_bytes.unwrap()) {
Ok(p) => p,
Err(e) => return (StatusCode::BAD_REQUEST, format!("Invalid policy JSON: {}", e)).into_response(),
};
// Persist to file first (before moving policy)
let policy_path = format!("data/s3_policies/{}/policy.json", bucket);
if let Err(e) = std::fs::create_dir_all(format!("data/s3_policies/{}", bucket)) {
return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to create policy dir: {}", e)).into_response();
}
let policy_json = serde_json::to_string_pretty(&policy).unwrap_or_default();
if let Err(e) = std::fs::write(&policy_path, &policy_json) {
return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to write policy: {}", e)).into_response();
}
// Now move policy to in-memory storage
{
let mut policies = BUCKET_POLICIES.write().await;
policies.insert(bucket.clone(), policy);
}
(StatusCode::NO_CONTENT, HeaderMap::new()).into_response()
}
pub async fn delete_bucket_policy(
Path(bucket): Path<String>,
State(_state): State<crate::server::AppState>,
) -> impl IntoResponse {
{
let mut policies = BUCKET_POLICIES.write().await;
policies.remove(&bucket);
}
let policy_path = format!("data/s3_policies/{}/policy.json", bucket);
let _ = std::fs::remove_file(&policy_path);
(StatusCode::NO_CONTENT, HeaderMap::new()).into_response()
}
pub fn check_bucket_policy(bucket: &str, action: &str, resource: &str, user_id: &str) -> bool {
let policies = BUCKET_POLICIES.blocking_read();
if let Some(policy) = policies.get(bucket) {
return policy.is_allowed(action, resource, user_id);
}
true
}
fn extract_user_from_auth(headers: &HeaderMap) -> Option<String> {
let auth_header = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if auth_header.starts_with("AWS4-HMAC-SHA256") {
// Extract from Credential=access_key/date/region/service
let credential_part = auth_header.split(',')
.find(|p| p.trim().starts_with("Credential="))?;
let credential_str = credential_part.trim().strip_prefix("Credential=")?;
let access_key = credential_str.split('/').next()?;
// Look up user_id from s3_keys.json
let s3_keys_path = "data/s3_keys.json";
let s3_keys_json = std::fs::read_to_string(s3_keys_path).ok()?;
#[derive(serde::Deserialize)]
struct S3Key {
access_key: String,
user_id: String,
}
let s3_keys: Vec<S3Key> = serde_json::from_str(&s3_keys_json).ok()?;
s3_keys.iter()
.find(|k| k.access_key == access_key)
.map(|k| k.user_id.clone())
} else {
None
}
}
/// Unified multipart handler using query param for action
#[derive(serde::Deserialize)]
pub struct MultipartActionQuery {
action: Option<String>, // init, part, complete, abort
upload_id: Option<String>,
part_number: Option<u32>,
}
pub async fn multipart_handler(
method: axum::http::Method,
Path((bucket, key)): Path<(String, String)>,
State(state): State<crate::server::AppState>,
query: axum::extract::Query<MultipartActionQuery>,
headers: HeaderMap,
body: Body,
) -> axum::response::Response {
let action = query.action.as_deref().unwrap_or("");
match action {
"init" => {
initiate_multipart_upload(
Path((bucket, key)),
State(state),
headers,
).await.into_response()
}
"part" => {
let upload_query = axum::extract::Query(UploadPartQuery {
upload_id: query.upload_id.clone().unwrap_or_default(),
part_number: query.part_number.unwrap_or(1),
});
upload_part(
Path((bucket, key)),
State(state),
upload_query,
headers,
body,
).await.into_response()
}
"complete" => {
let complete_query = axum::extract::Query(CompleteMultipartQuery {
upload_id: query.upload_id.clone().unwrap_or_default(),
});
complete_multipart_upload(
Path((bucket, key)),
State(state),
complete_query,
headers,
body,
).await.into_response()
}
"abort" => {
let abort_query = axum::extract::Query(AbortMultipartQuery {
upload_id: query.upload_id.clone().unwrap_or_default(),
});
abort_multipart_upload(
Path((bucket, key)),
State(state),
abort_query,
headers,
).await.into_response()
}
_ => {
if method == axum::http::Method::POST {
initiate_multipart_upload(
Path((bucket, key)),
State(state),
headers,
).await.into_response()
} else {
(StatusCode::BAD_REQUEST, "Missing action parameter").into_response()
}
}
}
}