Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 121 additions & 7 deletions src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use base64::engine::general_purpose;
use base64::Engine;
use mime2ext::mime2ext;
use nostr_sdk::{Event, PublicKey, SingleLetterTag, TagKind};
use reqwest::Client;
use reqwest::{header, Client};
use serde::{Deserialize, Serialize};
use sqlx::{query_as, Error};
use tower_http::cors::Any;
Expand All @@ -42,6 +42,10 @@ pub struct AuthEvent {
pub sig: String,
}

fn build_range_not_satisfiable_error_response(message: &str) -> Response {
build_error_response(StatusCode::RANGE_NOT_SATISFIABLE, message)
}

fn build_unauthorized_error_response(message: &str) -> Response {
build_error_response(StatusCode::UNAUTHORIZED, message)
}
Expand Down Expand Up @@ -121,6 +125,20 @@ where
}
}

fn parse_range_header(header: &str) -> Option<(u64, u64)> {
let parts: Vec<&str> = header.split("=").collect();
if parts.len() != 2 || parts[0] != "bytes" {
return None;
}
let range_parts: Vec<&str> = parts[1].split("-").collect();
if range_parts.len() != 2 {
return None;
}
let start = range_parts[0].parse::<u64>().ok()?;
let stop = range_parts[1].parse::<u64>().ok()?;
Some((start, stop))
}

pub async fn create_router(app_state: AppState) -> Router {
// Configure CORS policy
let cors = tower_http::cors::CorsLayer::new()
Expand Down Expand Up @@ -155,6 +173,7 @@ pub async fn get_blob_handler(
Path(file_hash): Path<String>,
State(app_state): State<AppState>,
AuthHeader(auth_event): AuthHeader,
headers: HeaderMap,
) -> impl IntoResponse {
// Get the file hash and file type
let (file_hash, _filetype) = split_filehash_and_filetype(file_hash);
Expand Down Expand Up @@ -249,12 +268,45 @@ pub async fn get_blob_handler(
.r#type
.unwrap_or_else(|| "application/octet-stream".to_string());

Response::builder()
.status(StatusCode::OK)
.header("Content-Type", content_type)
.header("Cache-Control", "max-age=31536000, immutable")
.body(file_contents.into())
.unwrap()
// Check for Range header (RFC 7233)
if let Some(range_header) = headers.get(header::RANGE) {
let range = parse_range_header(range_header.to_str().unwrap());
match range {
Some((start, stop)) => {
if start >= stop || start >= file_contents.len() as u64 {
return build_range_not_satisfiable_error_response("Invalid range");
}
let partial_response = file_contents[start as usize..(stop as usize + 1)].to_vec();
let content_range = format!("bytes {}-{}/{}", start, stop, file_contents.len());
let content_length = stop - start + 1;
let headers = [
(
header::CONTENT_TYPE,
HeaderValue::from_str(&content_type).unwrap(),
),
(
header::CONTENT_RANGE,
HeaderValue::from_str(&content_range).unwrap(),
),
(
header::CONTENT_LENGTH,
HeaderValue::from_str(&format!("{}", content_length)).unwrap(),
),
(header::ACCEPT_RANGES, HeaderValue::from_static("bytes")),
];
(StatusCode::PARTIAL_CONTENT, headers, partial_response).into_response()
}
None => build_bad_request_error_response("Invalid Range header"),
}
} else {
// Return the full blob
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", content_type)
.header("Cache-Control", "max-age=31536000, immutable")
.body(file_contents.into())
.unwrap()
}
}

pub async fn has_blob_handler(
Expand Down Expand Up @@ -938,6 +990,68 @@ mod tests {
);
}

#[tokio::test]
async fn get_blob_handler_test_range_request() {
// Set up app config, keypair and axum router
let keypair = Keys::generate();
let (app_state, _temp_dir) = set_up_app_state(ConfigBuilder::new()).await;
let app = create_router(app_state.clone()).await;

// Create a test blob descriptor
let file_hash =
"b1674191a88ec5cdd733e4240a81803105dc412d6c6708d53ab94fc248f4f553".to_string();
let blob_descriptor = BlobDescriptor {
url: format!("{}/{}", app_state.config.server_url, file_hash),
sha256: file_hash.clone(),
size: 1024,
r#type: Some("text/plain".to_string()),
uploaded: 1643723400,
};

// Insert the blob descriptor into the database
sqlx::query(
"INSERT INTO blob_descriptors (url, sha256, size, type, uploaded, pubkey) VALUES (?, ?, ?, ?, ?, ?)",
)
.bind(&blob_descriptor.url)
.bind(&blob_descriptor.sha256)
.bind(blob_descriptor.size)
.bind(&blob_descriptor.r#type)
.bind(blob_descriptor.uploaded)
.bind(keypair.public_key().to_hex())
.execute(&app_state.pool)
.await
.unwrap();

// Create a test file to store in the file directory.
let file_contents = b"Hello, World!";
write_blob_to_file(
&Path::new(&app_state.config.files_directory),
&file_hash,
Bytes::from(file_contents.to_vec()),
)
.unwrap();

// Send a GET request with a Range header to retrieve the blob.
let response = app
.oneshot(
Request::builder()
.method(http::Method::GET)
.uri(&format!("/{}", file_hash))
.header("Range", "bytes=0-5")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();

// Verify that the response status code is Partial Content (206).
assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);

// Verify that the response body matches the expected range of bytes.
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], &file_contents[0..6]);
}

#[tokio::test]
async fn has_blob_handler_test() {
// Set up app config and axum router
Expand Down
Loading