From 43e44ca73912ea0f9cd0467908baa90c6653f5a5 Mon Sep 17 00:00:00 2001 From: thuong Date: Fri, 21 Nov 2025 08:50:28 +0700 Subject: [PATCH 1/9] feat: add ETag to track the changed status of an object on the server --- infera/Cargo.toml | 9 + infera/src/http.rs | 423 +++++++++++++++++++++++++++++++++++++++------ infera/src/lib.rs | 11 +- 3 files changed, 391 insertions(+), 52 deletions(-) diff --git a/infera/Cargo.toml b/infera/Cargo.toml index d7d89f1..89b12f4 100644 --- a/infera/Cargo.toml +++ b/infera/Cargo.toml @@ -25,6 +25,15 @@ sha2 = "0.10" hex = "0.4" filetime = "0.2" +axum = "0.6" +tokio = { version = "1", features = ["full"] } +tokio-util = "0.7" # for ReaderStream +hyper = { version = "0.14", features = ["full"] } +http_cache_tags_axum="0.1.0-alpha.5" +actix-middleware-etag = "0.4.6" +actix-web = "4.12.0" +bytes = "1.11.0" + [dev-dependencies] tempfile = "3.10" mockito = "1.7.0" diff --git a/infera/src/http.rs b/infera/src/http.rs index d508551..f275884 100644 --- a/infera/src/http.rs +++ b/infera/src/http.rs @@ -1,4 +1,3 @@ -// Handles downloading and caching of remote models. use crate::config::{LogLevel, CONFIG}; use crate::error::InferaError; @@ -10,6 +9,33 @@ use std::path::{Path, PathBuf}; use std::thread; use std::time::{Duration, SystemTime}; +use reqwest::blocking::Client; +use reqwest::header::{IF_NONE_MATCH, ETAG}; +use std::io::{Write}; + + + +use axum::{ + Router, + routing::get, + middleware, + body::StreamBody, + response::IntoResponse, +}; +use std::net::SocketAddr; +use tokio_util::io::ReaderStream; + + + + +use bytes::Bytes; +use std::sync::Arc; +use tokio::sync::RwLock; +use tokio::runtime::Runtime; +use tokio::time::sleep; + + + /// A guard that guarantees a temporary file is deleted when it goes out of scope. /// This is used to implement a panic-safe cleanup of partial downloads. struct TempFileGuard<'a> { @@ -17,6 +43,7 @@ struct TempFileGuard<'a> { committed: bool, } + impl<'a> TempFileGuard<'a> { /// Creates a new guard for the given path. fn new(path: &'a Path) -> Self { @@ -26,6 +53,7 @@ impl<'a> TempFileGuard<'a> { } } + /// Marks the file as "committed," preventing its deletion on drop. /// This should be called only after the file has been successfully and /// atomically moved to its final destination. @@ -34,6 +62,7 @@ impl<'a> TempFileGuard<'a> { } } + impl<'a> Drop for TempFileGuard<'a> { fn drop(&mut self) { if !self.committed { @@ -42,16 +71,19 @@ impl<'a> Drop for TempFileGuard<'a> { } } + /// Return the cache directory path used by Infera for remote models. pub(crate) fn cache_dir() -> PathBuf { CONFIG.cache_dir.clone() } + /// Gets the cache size limit in bytes from environment variable or default. fn get_cache_size_limit() -> u64 { CONFIG.cache_size_limit } + /// Updates the access time of a cached file by touching it. fn touch_cache_file(path: &Path) -> Result<(), InferaError> { if path.exists() { @@ -61,6 +93,7 @@ fn touch_cache_file(path: &Path) -> Result<(), InferaError> { Ok(()) } + /// Gets metadata about cached files sorted by access time (oldest first). fn get_cached_files_by_access_time() -> Result, InferaError> { let dir = cache_dir(); @@ -68,6 +101,7 @@ fn get_cached_files_by_access_time() -> Result, return Ok(Vec::new()); } + let mut files = Vec::new(); for entry in fs::read_dir(&dir) .map_err(|e| InferaError::IoError(e.to_string()))? @@ -83,42 +117,52 @@ fn get_cached_files_by_access_time() -> Result, } } + // Sort by access time, oldest first files.sort_by_key(|(_, time, _)| *time); Ok(files) } + /// Calculates total cache size in bytes. fn get_cache_size() -> Result { let files = get_cached_files_by_access_time()?; Ok(files.iter().map(|(_, _, size)| size).sum()) } + /// Evicts least recently used cache files until cache size is below limit. fn evict_cache_if_needed(required_space: u64) -> Result<(), InferaError> { + println!("evict_cache_if_needed"); let limit = get_cache_size_limit(); let current_size = get_cache_size()?; + if current_size + required_space <= limit { return Ok(()); } + let target_size = limit.saturating_sub(required_space); let mut freed_size = 0u64; let files = get_cached_files_by_access_time()?; + for (path, _, size) in files { if current_size - freed_size <= target_size { break; } + println!("consider to remove file: {:?}", path); fs::remove_file(&path).map_err(|e| InferaError::IoError(e.to_string()))?; freed_size += size; } + Ok(()) } + /// Clears the entire cache directory by deleting its contents. /// If the directory does not exist, this is a no-op. pub(crate) fn clear_cache() -> Result<(), InferaError> { @@ -140,28 +184,31 @@ pub(crate) fn clear_cache() -> Result<(), InferaError> { Ok(()) } -/// Handles the download and caching of a remote model from a URL. -/// -/// If the model for the given URL is already present in the local cache, this -/// function updates its access time and returns the path. Otherwise, it downloads -/// the file, evicts old cache entries if needed, stores it in the cache directory, -/// and then returns the path. -/// -/// The cache uses an LRU (Least Recently Used) eviction policy with a configurable -/// size limit (default 1GB, configurable via INFERA_CACHE_SIZE_LIMIT env var). -/// -/// Downloads support automatic retries with exponential backoff. -/// -/// # Arguments -/// -/// * `url` - The HTTP/HTTPS URL of the ONNX model to be downloaded. -/// -/// # Returns -/// -/// A `Result` which is: -/// * `Ok(PathBuf)`: The local file path of the cached model. -/// * `Err(InferaError)`: An error indicating failure in creating the cache directory, -/// making the HTTP request, or writing the file to disk. + + + +// / Handles the download and caching of a remote model from a URL. +// / +// / If the model for the given URL is already present in the local cache, this +// / function updates its access time and returns the path. Otherwise, it downloads +// / the file, evicts old cache entries if needed, stores it in the cache directory, +// / and then returns the path. +// / +// / The cache uses an LRU (Least Recently Used) eviction policy with a configurable +// / size limit (default 1GB, configurable via INFERA_CACHE_SIZE_LIMIT env var). +// / +// / Downloads support automatic retries with exponential backoff. +// / +// / # Arguments +// / +// / * `url` - The HTTP/HTTPS URL of the ONNX model to be downloaded. +// / +// / # Returns +// / +// / A `Result` which is: +// / * `Ok(PathBuf)`: The local file path of the cached model. +// / * `Err(InferaError)`: An error indicating failure in creating the cache directory, +// / making the HTTP request, or writing the file to disk. pub(crate) fn handle_remote_model(url: &str) -> Result { let cache_dir = cache_dir(); if !cache_dir.exists() { @@ -180,22 +227,27 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { return Ok(cached_path); } + log!( LogLevel::Info, "Cache miss for URL: {}, downloading...", url ); + let temp_path = cached_path.with_extension("onnx.part"); let mut guard = TempFileGuard::new(&temp_path); + // Download with retry logic let max_attempts = CONFIG.http_retry_attempts; let retry_delay_ms = CONFIG.http_retry_delay_ms; let timeout_secs = CONFIG.http_timeout_secs; + let mut last_error = None; + for attempt in 1..=max_attempts { log!( LogLevel::Debug, @@ -205,21 +257,26 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { url ); + match download_file(url, &temp_path, timeout_secs) { Ok(_) => { log!(LogLevel::Info, "Successfully downloaded: {}", url); + // Check file size and evict cache if needed let file_size = fs::metadata(&temp_path) .map_err(|e| InferaError::IoError(e.to_string()))? .len(); + log!(LogLevel::Debug, "Downloaded file size: {} bytes", file_size); evict_cache_if_needed(file_size)?; + fs::rename(&temp_path, &cached_path) .map_err(|e| InferaError::IoError(e.to_string()))?; + guard.commit(); return Ok(cached_path); } @@ -233,6 +290,7 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { ); last_error = Some(e); + // Don't sleep after the last attempt if attempt < max_attempts { let delay = Duration::from_millis(retry_delay_ms * attempt as u64); @@ -243,6 +301,7 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { } } + log!( LogLevel::Error, "Failed to download after {} attempts: {}", @@ -252,6 +311,10 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { Err(last_error.unwrap_or_else(|| InferaError::HttpRequestError("Unknown error".to_string()))) } + + + + /// Download a file from a URL to a local path with timeout fn download_file(url: &str, dest: &Path, timeout_secs: u64) -> Result<(), InferaError> { let client = reqwest::blocking::Client::builder() @@ -259,6 +322,7 @@ fn download_file(url: &str, dest: &Path, timeout_secs: u64) -> Result<(), Infera .build() .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; + let mut response = client .get(url) .send() @@ -266,26 +330,155 @@ fn download_file(url: &str, dest: &Path, timeout_secs: u64) -> Result<(), Infera .error_for_status() .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; + let mut file = File::create(dest).map_err(|e| InferaError::IoError(e.to_string()))?; io::copy(&mut response, &mut file).map_err(|e| InferaError::IoError(e.to_string()))?; + Ok(()) } + +pub(crate) fn handle_remote_model_3(url: &str) -> Result { + let max_attempts = CONFIG.http_retry_attempts; + let retry_delay_ms = CONFIG.http_retry_delay_ms; + let timeout_secs = CONFIG.http_timeout_secs; + + let cache_dir = cache_dir(); + println!("Cache dir: {:?}", cache_dir); + println!("Handling remote model for URL: {}", url); + if !cache_dir.exists() { + fs::create_dir_all(&cache_dir).map_err(|e| InferaError::CacheDirError(e.to_string()))?; + } + + // Compute cache key based on URL hash + let mut hasher = Sha256::new(); + hasher.update(url.as_bytes()); + let hash_hex = hex::encode(hasher.finalize()); + let cached_path = cache_dir.join(format!("{}.onnx", hash_hex)); + let etag_path = cache_dir.join(format!("{}.etag", hash_hex)); + + // Load cached ETag if available + let etag_trimmed = match fs::read_to_string(&etag_path) { + Ok(etag_value) => etag_value.trim().to_string(), + Err(_) => "abc".to_string(), + }; + println!("Using ETag for request: {}", etag_trimmed); + + let temp_path = cached_path.with_extension("onnx.part"); + let mut guard = TempFileGuard::new(&temp_path); + + let mut last_error = None; + + for attempt in 1..=max_attempts { + println!("Download attempt {}/{} for URL: {}", attempt, max_attempts, url); + + // Perform download using helper function + match download_file_with_etag(url, &temp_path, timeout_secs, &etag_trimmed) { + Ok(response_etag) => { + println!("Download succeeded for URL: {}", url); + println!("response_etag: {:?}", response_etag); + + if let Some(etag_str) = &response_etag { + // Only if new file was downloaded: + println!("temp_path: {:?}, cached_path: {:?}", temp_path, cached_path); + let file_size = fs::metadata(&temp_path).map_err(|e| InferaError::IoError(e.to_string()))?.len(); + evict_cache_if_needed(file_size)?; + fs::rename(&temp_path, &cached_path).map_err(|e| InferaError::IoError(e.to_string()))?; + + // Update ETag file + fs::write(&etag_path, etag_str).map_err(|e| InferaError::IoError(e.to_string()))?; + } else { + // 304 Not Modified response: update access time and use cached file + touch_cache_file(&cached_path)?; + } + + guard.commit(); + return Ok(cached_path); + } + Err(e) => { + last_error = Some(e); + println!( + "Download attempt {}/{} failed: {}", + attempt, + max_attempts, + last_error.as_ref().unwrap() + ); + } + } + + if attempt < max_attempts { + let delay = Duration::from_millis(retry_delay_ms * 2_u64.pow((attempt - 1) as u32)); + println!("Waiting {:?} before retry", delay); + std::thread::sleep(delay); + } + } + + println!("Failed to download after {} attempts: {}", max_attempts, url); + + Err(last_error.unwrap_or_else(|| InferaError::HttpRequestError("Unknown download error".to_string()))) +} + +fn download_file_with_etag( + url: &str, + dest: &Path, + timeout_secs: u64, + etag: &str, +) -> Result, InferaError> { + let client = reqwest::blocking::Client::builder() + .timeout(Duration::from_secs(timeout_secs)) + .build() + .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; + + let mut request = client.get(url).header(IF_NONE_MATCH, etag); + + let mut response = request + .send() + .map_err(|e| InferaError::HttpRequestError(e.to_string()))? + .error_for_status() + .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; + + println!("response.status(): {:?}", response.status()); + if response.status() == reqwest::StatusCode::NOT_MODIFIED { + // Not modified, no file write needed, return None etag + return Ok(None); + } + println!("prepare to create dest file: {:?}", dest); + let mut file = File::create(dest).map_err(|e| InferaError::IoError(e.to_string()))?; + io::copy(&mut response, &mut file).map_err(|e| InferaError::IoError(e.to_string()))?; + + // Extract ETag header if present for updating cache metadata + let etag_header = response.headers().get(ETAG).and_then(|v| v.to_str().ok()).map(|s| s.to_owned()); + + Ok(etag_header) +} + + +fn get_file_modification_time(path: &std::path::Path) -> std::io::Result { + let metadata = fs::metadata(path)?; + let modified_time = metadata.modified()?; + Ok(modified_time) + } #[cfg(test)] mod tests { use super::*; - use mockito::Server; + use axum::extract::path; + use hyper; + use mockito::{Server, ServerOpts}; use std::env; // moved here: used in tests only use std::thread; use tiny_http::{Header, Response, Server as TinyServer}; - + use tokio::fs::File as TokioFile; + use actix_web::{web, App, HttpResponse, Error,HttpServer}; + use actix_middleware_etag::Etag; + #[test] fn test_handle_remote_model_cleanup_on_incomplete_download() { let server = TinyServer::http("127.0.0.1:0").unwrap(); let port = server.server_addr().to_ip().unwrap().port(); let model_url = format!("http://127.0.0.1:{}/incomplete_model.onnx", port); + let server_handle = thread::spawn(move || { if let Ok(request) = server.recv() { let mut response = Response::from_string("incomplete data"); @@ -295,24 +488,30 @@ mod tests { } }); + // The download should fail because the response body is shorter than the content-length. - let result = handle_remote_model(&model_url); + let result = handle_remote_model_3(&model_url); assert!(result.is_err()); + // Ensure no partial or final file is left in the cache. let cache_dir = env::temp_dir().join("infera_cache"); + println!("Cache dir: {:?}", cache_dir); let mut hasher = Sha256::new(); hasher.update(model_url.as_bytes()); let hash_hex = hex::encode(hasher.finalize()); let cached_path = cache_dir.join(format!("{}.onnx", hash_hex)); let temp_path = cached_path.with_extension("onnx.part"); + assert!(!cached_path.exists(), "Final cache file should not exist"); assert!(!temp_path.exists(), "Partial file should be cleaned up"); + server_handle.join().unwrap(); } + #[test] fn test_handle_remote_model_download_error() { // Simulate a server error instead of an interrupted download, @@ -323,10 +522,13 @@ mod tests { .with_status(500) .create(); + let url = server.url(); let model_url = format!("{}/server_error_model.onnx", url); - let result = handle_remote_model(&model_url); + + let result = handle_remote_model_3(&model_url); + // The download should fail because of the server error. assert!( @@ -334,6 +536,7 @@ mod tests { "handle_remote_model should return an error on 500 status" ); + // Ensure no partial or final file is left in the cache. let cache_dir = env::temp_dir().join("infera_cache"); let mut hasher = Sha256::new(); @@ -342,16 +545,19 @@ mod tests { let cached_path = cache_dir.join(format!("{}.onnx", hash_hex)); let temp_path = cached_path.with_extension("onnx.part"); + assert!(!cached_path.exists(), "Final cache file should not exist"); assert!(!temp_path.exists(), "Partial file should be cleaned up"); } + #[test] fn test_handle_remote_model_cleanup_on_connection_drop() { let server = TinyServer::http("127.0.0.1:0").unwrap(); let port = server.server_addr().to_ip().unwrap().port(); let model_url = format!("http://127.0.0.1:{}/dropped_connection.onnx", port); + let server_handle = thread::spawn(move || { if let Ok(request) = server.recv() { // By responding with a Content-Length header but then dropping the @@ -366,13 +572,15 @@ mod tests { } }); + // The download should fail because the server closes the connection prematurely. - let result = handle_remote_model(&model_url); + let result = handle_remote_model_3(&model_url); assert!( result.is_err(), "The download should fail on a connection drop" ); + // After the failure, the temporary file should be cleaned up. let cache_dir = env::temp_dir().join("infera_cache"); let mut hasher = Sha256::new(); @@ -381,38 +589,155 @@ mod tests { let cached_path = cache_dir.join(format!("{}.onnx", hash_hex)); let temp_path = cached_path.with_extension("onnx.part"); + assert!(!cached_path.exists(), "Final cache file should not exist"); assert!( !temp_path.exists(), "Partial file should be cleaned up after a connection drop" ); - server_handle.join().unwrap(); } - #[test] - fn test_handle_remote_model_success_and_cache() { - // Serve a small body with an accurate Content-Length - let mut server = Server::new(); - let body = b"onnxdata".to_vec(); - let _m = server - .mock("GET", "/ok_model.onnx") - .with_status(200) - .with_header("Content-Length", &body.len().to_string()) - .with_body(body.clone()) - .create(); - let url = format!("{}/ok_model.onnx", server.url()); - let path1 = handle_remote_model(&url).expect("download should succeed"); - assert!(path1.exists(), "cached file must exist"); +// #[test] +// async fn test_handle_remote_model_success_and_cache() { +// // Serve a small body with an accurate Content-Length +// let mut server = Server::new(); +// let body = b"onnxdata".to_vec(); +// let _m = server +// .mock("GET", "/ok_model.onnx") +// .with_status(200) +// .with_header("Content-Length", &body.len().to_string()) +// .with_body(body.clone()) +// .create(); +// let url = format!("{}/ok_model.onnx", server.url()); + + +// let path1 = handle_remote_model(&url).expect("download should succeed"); +// assert!(path1.exists(), "cached file must exist"); +// let content1 = fs::read(&path1).expect("read cached file"); +// assert_eq!(content1, body); + + +// // Second call should hit cache and return same path without network +// let path2 = handle_remote_model(&url).expect("cache should hit"); +// assert_eq!(path1, path2); + + + +// // Remove the old mock to simulate file removal +// server.reset(); // or server.delete_mock(m) if your mock library supports it + + +// // Create a new mock representing the updated file +// let updated_body = b"new_onnx_data".to_vec(); +// let _m2 = server +// .mock("GET", "/ok_model.onnx") +// .with_status(200) +// .with_header("Content-Length", &updated_body.len().to_string()) +// .with_body(updated_body.clone()) +// .create(); + + +// let path3 = handle_remote_model(&url).expect("download should succeed"); +// assert!(path3.exists(), "cached file must exist"); +// let content3 = fs::read(&path3).expect("read cached file"); +// assert_ne!(content1, content3, "content should changed due to caching invalidation"); + + +// // let temp_path = path1.with_extension("onnx.part"); +// // assert!(!temp_path.exists(), "no partial file should remain"); + + + + +// } + + #[actix_web::test] + async fn test_handle_remote_model_3_with_actix_server() { + use std::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind random port"); + let addr = listener.local_addr().unwrap(); + + let file_content = Arc::new(RwLock::new(b"initial content".to_vec())); + let file_content_server = file_content.clone(); + + let server = HttpServer::new(move || { + let content = file_content_server.clone(); + App::new() + .wrap(Etag::default()) + .route("/model.onnx", web::get().to(move || { + let content = content.clone(); + async move { + let data = content.read().await; + let bytes = Bytes::copy_from_slice(&*data); + Ok::<_, Error>( + HttpResponse::Ok() + .content_type("application/octet-stream") + .body(bytes), + ) + } + })) + }) + .listen(listener) + .expect("Failed to bind server") + .run(); + + // Spawn server in background + let srv_handle = actix_web::rt::spawn(server); + + let url = format!("http://{}:{}/model.onnx", addr.ip(), addr.port()); + let second_call_url = url.clone(); + let third_call_url = url.clone(); + // Call your blocking cache-and-download function in blocking task + let path1 = tokio::task::spawn_blocking(move || handle_remote_model_3(&url)) + .await + .expect("Task panicked") + .expect("handle_remote_model_3 failed"); + + assert!(path1.exists()); let content1 = fs::read(&path1).expect("read cached file"); - assert_eq!(content1, body); + let path1_modification_time = get_file_modification_time(&path1).unwrap(); + + // Call again, should refresh cache + let path2 = tokio::task::spawn_blocking(move || handle_remote_model_3(&second_call_url)) + .await + .expect("Task panicked") + .expect("handle_remote_model_3 failed"); + + tokio::time::sleep(Duration::from_secs(1)).await; + + let content2 = fs::read(&path2).expect("read cached file"); + let path2_modification_time = get_file_modification_time(&path2).unwrap(); - // Second call should hit cache and return same path without network - let path2 = handle_remote_model(&url).expect("cache should hit"); assert_eq!(path1, path2); - let temp_path = path1.with_extension("onnx.part"); - assert!(!temp_path.exists(), "no partial file should remain"); + assert_eq!(content1, content2); + assert_eq!(path1_modification_time, path2_modification_time); + + + // Modify content to simulate update + { + let mut content_write = file_content.write().await; + *content_write = b"updated content".to_vec(); + } + + tokio::time::sleep(Duration::from_secs(1)).await; + + // Call again, should refresh cache + let path3 = tokio::task::spawn_blocking(move || handle_remote_model_3(&third_call_url)) + .await + .expect("Task panicked") + .expect("handle_remote_model_3 failed"); + + let content3 = fs::read(&path3).expect("read cached file"); + let path3_modification_time = get_file_modification_time(&path3).unwrap(); + assert_eq!(path1, path3); + assert_ne!(content1, content3); + assert_ne!(path1_modification_time, path3_modification_time); + + // Cleanup server + srv_handle.abort(); } #[test] @@ -425,4 +750,4 @@ mod tests { clear_cache().unwrap(); assert!(!dummy.exists()); } -} +} \ No newline at end of file diff --git a/infera/src/lib.rs b/infera/src/lib.rs index aca7de8..1c11470 100644 --- a/infera/src/lib.rs +++ b/infera/src/lib.rs @@ -338,6 +338,7 @@ pub extern "C" fn infera_get_cache_info() -> *mut c_char { .flatten() { let path = entry.path(); + println!("path: {:?}", path); if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("onnx") { if let Ok(metadata) = fs::metadata(&path) { total_size += metadata.len(); @@ -428,7 +429,7 @@ pub unsafe extern "C" fn infera_set_autoload_dir(path: *const c_char) -> *mut c_ mod tests { use super::*; use std::ffi::CString; - use std::fs; + use std::{fs, result}; use tempfile::tempdir; #[test] @@ -437,6 +438,7 @@ mod tests { let version_json = unsafe { CStr::from_ptr(version_ptr).to_str().unwrap() }; let version_data: serde_json::Value = serde_json::from_str(version_json).unwrap(); + println!("version_data {:?}", version_data); assert!(version_data["version"].is_string()); assert!(version_data["onnx_backend"].is_string()); assert!(version_data["model_cache_dir"].is_string()); @@ -454,7 +456,7 @@ mod tests { let result_ptr = unsafe { infera_set_autoload_dir(path_cstr.as_ptr()) }; let result_json = unsafe { CStr::from_ptr(result_ptr).to_str().unwrap() }; let result_data: serde_json::Value = serde_json::from_str(result_json).unwrap(); - + println!("result_data {:?}", result_data); assert_eq!(result_data["loaded"].as_array().unwrap().len(), 1); assert_eq!(result_data["loaded"][0], "linear"); assert_eq!(result_data["errors"].as_array().unwrap().len(), 0); @@ -470,7 +472,7 @@ mod tests { let result_ptr = unsafe { infera_set_autoload_dir(path_cstr.as_ptr()) }; let result_json = unsafe { CStr::from_ptr(result_ptr).to_str().unwrap() }; let result_data: serde_json::Value = serde_json::from_str(result_json).unwrap(); - + println!("result_data {:?}", result_data); assert!(result_data["error"].is_string()); unsafe { infera_free(result_ptr) }; @@ -487,6 +489,7 @@ mod tests { let result_json = unsafe { CStr::from_ptr(result_ptr).to_str().unwrap() }; let result_data: serde_json::Value = serde_json::from_str(result_json).unwrap(); + println!("result_data {:?}", result_data); assert_eq!(result_data["loaded"].as_array().unwrap().len(), 0); assert_eq!(result_data["errors"].as_array().unwrap().len(), 1); assert_eq!( @@ -647,7 +650,9 @@ mod tests { fn test_infera_get_cache_info_includes_configured_limit() { let cache_info_ptr = infera_get_cache_info(); let cache_info_json = unsafe { CStr::from_ptr(cache_info_ptr).to_str().unwrap() }; + println!("cache_info_json: {:?}", cache_info_json); let value: serde_json::Value = serde_json::from_str(cache_info_json).unwrap(); + println!("value: {:?}", value); let size_limit = value["size_limit_bytes"] .as_u64() .expect("size_limit_bytes should be u64"); From 499117231466278dfc11f64bf718a01f123383c2 Mon Sep 17 00:00:00 2001 From: thuong Date: Fri, 21 Nov 2025 22:10:31 +0700 Subject: [PATCH 2/9] feat: refactor testcase and import --- infera/src/http.rs | 380 ++++++++++----------------------------------- infera/src/lib.rs | 2 +- 2 files changed, 82 insertions(+), 300 deletions(-) diff --git a/infera/src/http.rs b/infera/src/http.rs index f275884..865ba13 100644 --- a/infera/src/http.rs +++ b/infera/src/http.rs @@ -1,4 +1,3 @@ - use crate::config::{LogLevel, CONFIG}; use crate::error::InferaError; use crate::log; @@ -9,32 +8,7 @@ use std::path::{Path, PathBuf}; use std::thread; use std::time::{Duration, SystemTime}; -use reqwest::blocking::Client; -use reqwest::header::{IF_NONE_MATCH, ETAG}; -use std::io::{Write}; - - - -use axum::{ - Router, - routing::get, - middleware, - body::StreamBody, - response::IntoResponse, -}; -use std::net::SocketAddr; -use tokio_util::io::ReaderStream; - - - - -use bytes::Bytes; -use std::sync::Arc; -use tokio::sync::RwLock; -use tokio::runtime::Runtime; -use tokio::time::sleep; - - +use reqwest::header::{ETAG, IF_NONE_MATCH}; /// A guard that guarantees a temporary file is deleted when it goes out of scope. /// This is used to implement a panic-safe cleanup of partial downloads. @@ -43,7 +17,6 @@ struct TempFileGuard<'a> { committed: bool, } - impl<'a> TempFileGuard<'a> { /// Creates a new guard for the given path. fn new(path: &'a Path) -> Self { @@ -53,7 +26,6 @@ impl<'a> TempFileGuard<'a> { } } - /// Marks the file as "committed," preventing its deletion on drop. /// This should be called only after the file has been successfully and /// atomically moved to its final destination. @@ -62,7 +34,6 @@ impl<'a> TempFileGuard<'a> { } } - impl<'a> Drop for TempFileGuard<'a> { fn drop(&mut self) { if !self.committed { @@ -71,19 +42,16 @@ impl<'a> Drop for TempFileGuard<'a> { } } - /// Return the cache directory path used by Infera for remote models. pub(crate) fn cache_dir() -> PathBuf { CONFIG.cache_dir.clone() } - /// Gets the cache size limit in bytes from environment variable or default. fn get_cache_size_limit() -> u64 { CONFIG.cache_size_limit } - /// Updates the access time of a cached file by touching it. fn touch_cache_file(path: &Path) -> Result<(), InferaError> { if path.exists() { @@ -93,7 +61,6 @@ fn touch_cache_file(path: &Path) -> Result<(), InferaError> { Ok(()) } - /// Gets metadata about cached files sorted by access time (oldest first). fn get_cached_files_by_access_time() -> Result, InferaError> { let dir = cache_dir(); @@ -101,7 +68,6 @@ fn get_cached_files_by_access_time() -> Result, return Ok(Vec::new()); } - let mut files = Vec::new(); for entry in fs::read_dir(&dir) .map_err(|e| InferaError::IoError(e.to_string()))? @@ -117,37 +83,31 @@ fn get_cached_files_by_access_time() -> Result, } } - // Sort by access time, oldest first files.sort_by_key(|(_, time, _)| *time); Ok(files) } - /// Calculates total cache size in bytes. fn get_cache_size() -> Result { let files = get_cached_files_by_access_time()?; Ok(files.iter().map(|(_, _, size)| size).sum()) } - /// Evicts least recently used cache files until cache size is below limit. fn evict_cache_if_needed(required_space: u64) -> Result<(), InferaError> { println!("evict_cache_if_needed"); let limit = get_cache_size_limit(); let current_size = get_cache_size()?; - if current_size + required_space <= limit { return Ok(()); } - let target_size = limit.saturating_sub(required_space); let mut freed_size = 0u64; let files = get_cached_files_by_access_time()?; - for (path, _, size) in files { if current_size - freed_size <= target_size { break; @@ -158,11 +118,9 @@ fn evict_cache_if_needed(required_space: u64) -> Result<(), InferaError> { freed_size += size; } - Ok(()) } - /// Clears the entire cache directory by deleting its contents. /// If the directory does not exist, this is a no-op. pub(crate) fn clear_cache() -> Result<(), InferaError> { @@ -184,9 +142,6 @@ pub(crate) fn clear_cache() -> Result<(), InferaError> { Ok(()) } - - - // / Handles the download and caching of a remote model from a URL. // / // / If the model for the given URL is already present in the local cache, this @@ -209,45 +164,37 @@ pub(crate) fn clear_cache() -> Result<(), InferaError> { // / * `Ok(PathBuf)`: The local file path of the cached model. // / * `Err(InferaError)`: An error indicating failure in creating the cache directory, // / making the HTTP request, or writing the file to disk. + pub(crate) fn handle_remote_model(url: &str) -> Result { + let max_attempts = CONFIG.http_retry_attempts; + let retry_delay_ms = CONFIG.http_retry_delay_ms; + let timeout_secs = CONFIG.http_timeout_secs; + let cache_dir = cache_dir(); if !cache_dir.exists() { log!(LogLevel::Info, "Creating cache directory: {:?}", cache_dir); fs::create_dir_all(&cache_dir).map_err(|e| InferaError::CacheDirError(e.to_string()))?; } + + // Compute cache key based on URL hash let mut hasher = Sha256::new(); hasher.update(url.as_bytes()); let hash_hex = hex::encode(hasher.finalize()); let cached_path = cache_dir.join(format!("{}.onnx", hash_hex)); + let etag_path = cache_dir.join(format!("{}.etag", hash_hex)); - if cached_path.exists() { - log!(LogLevel::Info, "Cache hit for URL: {}", url); - // Update access time for LRU tracking - touch_cache_file(&cached_path)?; - return Ok(cached_path); - } - - - log!( - LogLevel::Info, - "Cache miss for URL: {}, downloading...", - url - ); - + // Load cached ETag if available + let etag_trimmed = match fs::read_to_string(&etag_path) { + Ok(etag_value) => etag_value.trim().to_string(), + Err(_) => "abc".to_string(), + }; + println!("Using ETag for request: {}", etag_trimmed); let temp_path = cached_path.with_extension("onnx.part"); let mut guard = TempFileGuard::new(&temp_path); - - // Download with retry logic - let max_attempts = CONFIG.http_retry_attempts; - let retry_delay_ms = CONFIG.http_retry_delay_ms; - let timeout_secs = CONFIG.http_timeout_secs; - - let mut last_error = None; - for attempt in 1..=max_attempts { log!( LogLevel::Debug, @@ -257,25 +204,32 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { url ); + // Perform download using helper function + match download_file_with_etag(url, &temp_path, timeout_secs, &etag_trimmed) { + Ok(Some((false, etag_str))) => { + // New file was downloaded (first value is false) + log!(LogLevel::Info, "Download succeeded for URL: {}", url); - match download_file(url, &temp_path, timeout_secs) { - Ok(_) => { - log!(LogLevel::Info, "Successfully downloaded: {}", url); - - - // Check file size and evict cache if needed let file_size = fs::metadata(&temp_path) .map_err(|e| InferaError::IoError(e.to_string()))? .len(); - - - log!(LogLevel::Debug, "Downloaded file size: {} bytes", file_size); evict_cache_if_needed(file_size)?; - - fs::rename(&temp_path, &cached_path) .map_err(|e| InferaError::IoError(e.to_string()))?; + // Update ETag file + fs::write(&etag_path, &etag_str) + .map_err(|e| InferaError::IoError(e.to_string()))?; + + guard.commit(); + return Ok(cached_path); + } + Ok(Some((true, _))) => { + // status is 304 Not Modified, use cached file (first value is true) + log!(LogLevel::Info, "Cache hit for URL: {}", url); + + // Update access time for LRU tracking + touch_cache_file(&cached_path)?; guard.commit(); return Ok(cached_path); @@ -289,134 +243,29 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { e ); last_error = Some(e); - - - // Don't sleep after the last attempt if attempt < max_attempts { let delay = Duration::from_millis(retry_delay_ms * attempt as u64); log!(LogLevel::Debug, "Waiting {:?} before retry", delay); thread::sleep(delay); } } + Ok(None) => { + // theoretically unreachable, but necessary to satisfy exhaustiveness + // Handle as error or log warning, or panic + log!(LogLevel::Error, "Can't exist None for this matching"); + } } } - log!( LogLevel::Error, "Failed to download after {} attempts: {}", max_attempts, url ); - Err(last_error.unwrap_or_else(|| InferaError::HttpRequestError("Unknown error".to_string()))) -} - - - - -/// Download a file from a URL to a local path with timeout -fn download_file(url: &str, dest: &Path, timeout_secs: u64) -> Result<(), InferaError> { - let client = reqwest::blocking::Client::builder() - .timeout(Duration::from_secs(timeout_secs)) - .build() - .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; - - - let mut response = client - .get(url) - .send() - .map_err(|e| InferaError::HttpRequestError(e.to_string()))? - .error_for_status() - .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; - - - let mut file = File::create(dest).map_err(|e| InferaError::IoError(e.to_string()))?; - io::copy(&mut response, &mut file).map_err(|e| InferaError::IoError(e.to_string()))?; - - - Ok(()) -} - - -pub(crate) fn handle_remote_model_3(url: &str) -> Result { - let max_attempts = CONFIG.http_retry_attempts; - let retry_delay_ms = CONFIG.http_retry_delay_ms; - let timeout_secs = CONFIG.http_timeout_secs; - - let cache_dir = cache_dir(); - println!("Cache dir: {:?}", cache_dir); - println!("Handling remote model for URL: {}", url); - if !cache_dir.exists() { - fs::create_dir_all(&cache_dir).map_err(|e| InferaError::CacheDirError(e.to_string()))?; - } - - // Compute cache key based on URL hash - let mut hasher = Sha256::new(); - hasher.update(url.as_bytes()); - let hash_hex = hex::encode(hasher.finalize()); - let cached_path = cache_dir.join(format!("{}.onnx", hash_hex)); - let etag_path = cache_dir.join(format!("{}.etag", hash_hex)); - - // Load cached ETag if available - let etag_trimmed = match fs::read_to_string(&etag_path) { - Ok(etag_value) => etag_value.trim().to_string(), - Err(_) => "abc".to_string(), - }; - println!("Using ETag for request: {}", etag_trimmed); - - let temp_path = cached_path.with_extension("onnx.part"); - let mut guard = TempFileGuard::new(&temp_path); - - let mut last_error = None; - - for attempt in 1..=max_attempts { - println!("Download attempt {}/{} for URL: {}", attempt, max_attempts, url); - - // Perform download using helper function - match download_file_with_etag(url, &temp_path, timeout_secs, &etag_trimmed) { - Ok(response_etag) => { - println!("Download succeeded for URL: {}", url); - println!("response_etag: {:?}", response_etag); - - if let Some(etag_str) = &response_etag { - // Only if new file was downloaded: - println!("temp_path: {:?}, cached_path: {:?}", temp_path, cached_path); - let file_size = fs::metadata(&temp_path).map_err(|e| InferaError::IoError(e.to_string()))?.len(); - evict_cache_if_needed(file_size)?; - fs::rename(&temp_path, &cached_path).map_err(|e| InferaError::IoError(e.to_string()))?; - - // Update ETag file - fs::write(&etag_path, etag_str).map_err(|e| InferaError::IoError(e.to_string()))?; - } else { - // 304 Not Modified response: update access time and use cached file - touch_cache_file(&cached_path)?; - } - - guard.commit(); - return Ok(cached_path); - } - Err(e) => { - last_error = Some(e); - println!( - "Download attempt {}/{} failed: {}", - attempt, - max_attempts, - last_error.as_ref().unwrap() - ); - } - } - - if attempt < max_attempts { - let delay = Duration::from_millis(retry_delay_ms * 2_u64.pow((attempt - 1) as u32)); - println!("Waiting {:?} before retry", delay); - std::thread::sleep(delay); - } - } - - println!("Failed to download after {} attempts: {}", max_attempts, url); - - Err(last_error.unwrap_or_else(|| InferaError::HttpRequestError("Unknown download error".to_string()))) + Err(last_error + .unwrap_or_else(|| InferaError::HttpRequestError("Unknown download error".to_string()))) } fn download_file_with_etag( @@ -424,13 +273,13 @@ fn download_file_with_etag( dest: &Path, timeout_secs: u64, etag: &str, -) -> Result, InferaError> { +) -> Result, InferaError> { let client = reqwest::blocking::Client::builder() .timeout(Duration::from_secs(timeout_secs)) .build() .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; - let mut request = client.get(url).header(IF_NONE_MATCH, etag); + let request = client.get(url).header(IF_NONE_MATCH, etag); let mut response = request .send() @@ -439,46 +288,49 @@ fn download_file_with_etag( .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; println!("response.status(): {:?}", response.status()); - if response.status() == reqwest::StatusCode::NOT_MODIFIED { + if etag != "" && response.status() == reqwest::StatusCode::NOT_MODIFIED { // Not modified, no file write needed, return None etag - return Ok(None); + return Ok(Some((true, etag.to_string()))); } println!("prepare to create dest file: {:?}", dest); let mut file = File::create(dest).map_err(|e| InferaError::IoError(e.to_string()))?; io::copy(&mut response, &mut file).map_err(|e| InferaError::IoError(e.to_string()))?; // Extract ETag header if present for updating cache metadata - let etag_header = response.headers().get(ETAG).and_then(|v| v.to_str().ok()).map(|s| s.to_owned()); - - Ok(etag_header) + let etag_header = response + .headers() + .get(ETAG) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_owned()); + let etag_str = etag_header.unwrap_or_else(|| "".to_string()); + + return Ok(Some((false, etag_str))); } - -fn get_file_modification_time(path: &std::path::Path) -> std::io::Result { - let metadata = fs::metadata(path)?; - let modified_time = metadata.modified()?; - Ok(modified_time) - } #[cfg(test)] mod tests { use super::*; - use axum::extract::path; - use hyper; - use mockito::{Server, ServerOpts}; + use actix_middleware_etag::Etag; + use actix_web::{web, App, Error, HttpResponse, HttpServer}; + use bytes::Bytes; + use mockito::Server; use std::env; // moved here: used in tests only + use std::sync::Arc; use std::thread; use tiny_http::{Header, Response, Server as TinyServer}; - use tokio::fs::File as TokioFile; - use actix_web::{web, App, HttpResponse, Error,HttpServer}; - use actix_middleware_etag::Etag; - + use tokio::sync::RwLock; + + fn get_file_modification_time(path: &std::path::Path) -> std::io::Result { + let metadata = fs::metadata(path)?; + let modified_time = metadata.modified()?; + Ok(modified_time) + } #[test] fn test_handle_remote_model_cleanup_on_incomplete_download() { let server = TinyServer::http("127.0.0.1:0").unwrap(); let port = server.server_addr().to_ip().unwrap().port(); let model_url = format!("http://127.0.0.1:{}/incomplete_model.onnx", port); - let server_handle = thread::spawn(move || { if let Ok(request) = server.recv() { let mut response = Response::from_string("incomplete data"); @@ -488,12 +340,10 @@ mod tests { } }); - // The download should fail because the response body is shorter than the content-length. - let result = handle_remote_model_3(&model_url); + let result = handle_remote_model(&model_url); assert!(result.is_err()); - // Ensure no partial or final file is left in the cache. let cache_dir = env::temp_dir().join("infera_cache"); println!("Cache dir: {:?}", cache_dir); @@ -503,15 +353,12 @@ mod tests { let cached_path = cache_dir.join(format!("{}.onnx", hash_hex)); let temp_path = cached_path.with_extension("onnx.part"); - assert!(!cached_path.exists(), "Final cache file should not exist"); assert!(!temp_path.exists(), "Partial file should be cleaned up"); - server_handle.join().unwrap(); } - #[test] fn test_handle_remote_model_download_error() { // Simulate a server error instead of an interrupted download, @@ -522,13 +369,10 @@ mod tests { .with_status(500) .create(); - let url = server.url(); let model_url = format!("{}/server_error_model.onnx", url); - - let result = handle_remote_model_3(&model_url); - + let result = handle_remote_model(&model_url); // The download should fail because of the server error. assert!( @@ -536,7 +380,6 @@ mod tests { "handle_remote_model should return an error on 500 status" ); - // Ensure no partial or final file is left in the cache. let cache_dir = env::temp_dir().join("infera_cache"); let mut hasher = Sha256::new(); @@ -545,19 +388,16 @@ mod tests { let cached_path = cache_dir.join(format!("{}.onnx", hash_hex)); let temp_path = cached_path.with_extension("onnx.part"); - assert!(!cached_path.exists(), "Final cache file should not exist"); assert!(!temp_path.exists(), "Partial file should be cleaned up"); } - #[test] fn test_handle_remote_model_cleanup_on_connection_drop() { let server = TinyServer::http("127.0.0.1:0").unwrap(); let port = server.server_addr().to_ip().unwrap().port(); let model_url = format!("http://127.0.0.1:{}/dropped_connection.onnx", port); - let server_handle = thread::spawn(move || { if let Ok(request) = server.recv() { // By responding with a Content-Length header but then dropping the @@ -572,15 +412,13 @@ mod tests { } }); - // The download should fail because the server closes the connection prematurely. - let result = handle_remote_model_3(&model_url); + let result = handle_remote_model(&model_url); assert!( result.is_err(), "The download should fail on a connection drop" ); - // After the failure, the temporary file should be cleaned up. let cache_dir = env::temp_dir().join("infera_cache"); let mut hasher = Sha256::new(); @@ -589,7 +427,6 @@ mod tests { let cached_path = cache_dir.join(format!("{}.onnx", hash_hex)); let temp_path = cached_path.with_extension("onnx.part"); - assert!(!cached_path.exists(), "Final cache file should not exist"); assert!( !temp_path.exists(), @@ -598,63 +435,8 @@ mod tests { server_handle.join().unwrap(); } - -// #[test] -// async fn test_handle_remote_model_success_and_cache() { -// // Serve a small body with an accurate Content-Length -// let mut server = Server::new(); -// let body = b"onnxdata".to_vec(); -// let _m = server -// .mock("GET", "/ok_model.onnx") -// .with_status(200) -// .with_header("Content-Length", &body.len().to_string()) -// .with_body(body.clone()) -// .create(); -// let url = format!("{}/ok_model.onnx", server.url()); - - -// let path1 = handle_remote_model(&url).expect("download should succeed"); -// assert!(path1.exists(), "cached file must exist"); -// let content1 = fs::read(&path1).expect("read cached file"); -// assert_eq!(content1, body); - - -// // Second call should hit cache and return same path without network -// let path2 = handle_remote_model(&url).expect("cache should hit"); -// assert_eq!(path1, path2); - - - -// // Remove the old mock to simulate file removal -// server.reset(); // or server.delete_mock(m) if your mock library supports it - - -// // Create a new mock representing the updated file -// let updated_body = b"new_onnx_data".to_vec(); -// let _m2 = server -// .mock("GET", "/ok_model.onnx") -// .with_status(200) -// .with_header("Content-Length", &updated_body.len().to_string()) -// .with_body(updated_body.clone()) -// .create(); - - -// let path3 = handle_remote_model(&url).expect("download should succeed"); -// assert!(path3.exists(), "cached file must exist"); -// let content3 = fs::read(&path3).expect("read cached file"); -// assert_ne!(content1, content3, "content should changed due to caching invalidation"); - - -// // let temp_path = path1.with_extension("onnx.part"); -// // assert!(!temp_path.exists(), "no partial file should remain"); - - - - -// } - - #[actix_web::test] - async fn test_handle_remote_model_3_with_actix_server() { + #[actix_web::test] + async fn test_handle_remote_model_success_and_cache() { use std::net::TcpListener; let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind random port"); @@ -665,9 +447,9 @@ mod tests { let server = HttpServer::new(move || { let content = file_content_server.clone(); - App::new() - .wrap(Etag::default()) - .route("/model.onnx", web::get().to(move || { + App::new().wrap(Etag::default()).route( + "/model.onnx", + web::get().to(move || { let content = content.clone(); async move { let data = content.read().await; @@ -678,7 +460,8 @@ mod tests { .body(bytes), ) } - })) + }), + ) }) .listen(listener) .expect("Failed to bind server") @@ -691,20 +474,20 @@ mod tests { let second_call_url = url.clone(); let third_call_url = url.clone(); // Call your blocking cache-and-download function in blocking task - let path1 = tokio::task::spawn_blocking(move || handle_remote_model_3(&url)) + let path1 = tokio::task::spawn_blocking(move || handle_remote_model(&url)) .await .expect("Task panicked") - .expect("handle_remote_model_3 failed"); + .expect("handle_remote_model failed"); assert!(path1.exists()); let content1 = fs::read(&path1).expect("read cached file"); let path1_modification_time = get_file_modification_time(&path1).unwrap(); // Call again, should refresh cache - let path2 = tokio::task::spawn_blocking(move || handle_remote_model_3(&second_call_url)) + let path2 = tokio::task::spawn_blocking(move || handle_remote_model(&second_call_url)) .await .expect("Task panicked") - .expect("handle_remote_model_3 failed"); + .expect("handle_remote_model failed"); tokio::time::sleep(Duration::from_secs(1)).await; @@ -715,7 +498,6 @@ mod tests { assert_eq!(content1, content2); assert_eq!(path1_modification_time, path2_modification_time); - // Modify content to simulate update { let mut content_write = file_content.write().await; @@ -725,10 +507,10 @@ mod tests { tokio::time::sleep(Duration::from_secs(1)).await; // Call again, should refresh cache - let path3 = tokio::task::spawn_blocking(move || handle_remote_model_3(&third_call_url)) + let path3 = tokio::task::spawn_blocking(move || handle_remote_model(&third_call_url)) .await .expect("Task panicked") - .expect("handle_remote_model_3 failed"); + .expect("handle_remote_model failed"); let content3 = fs::read(&path3).expect("read cached file"); let path3_modification_time = get_file_modification_time(&path3).unwrap(); @@ -750,4 +532,4 @@ mod tests { clear_cache().unwrap(); assert!(!dummy.exists()); } -} \ No newline at end of file +} diff --git a/infera/src/lib.rs b/infera/src/lib.rs index 1c11470..9645501 100644 --- a/infera/src/lib.rs +++ b/infera/src/lib.rs @@ -429,7 +429,7 @@ pub unsafe extern "C" fn infera_set_autoload_dir(path: *const c_char) -> *mut c_ mod tests { use super::*; use std::ffi::CString; - use std::{fs, result}; + use std::fs; use tempfile::tempdir; #[test] From 62f92ba929c534095bd2a6d8a220239622d5a917 Mon Sep 17 00:00:00 2001 From: thuong Date: Fri, 21 Nov 2025 22:40:02 +0700 Subject: [PATCH 3/9] feat: add test cases for etag enabling and disabling --- infera/src/http.rs | 114 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 103 insertions(+), 11 deletions(-) diff --git a/infera/src/http.rs b/infera/src/http.rs index 865ba13..dae691d 100644 --- a/infera/src/http.rs +++ b/infera/src/http.rs @@ -96,7 +96,6 @@ fn get_cache_size() -> Result { /// Evicts least recently used cache files until cache size is below limit. fn evict_cache_if_needed(required_space: u64) -> Result<(), InferaError> { - println!("evict_cache_if_needed"); let limit = get_cache_size_limit(); let current_size = get_cache_size()?; @@ -113,7 +112,6 @@ fn evict_cache_if_needed(required_space: u64) -> Result<(), InferaError> { break; } - println!("consider to remove file: {:?}", path); fs::remove_file(&path).map_err(|e| InferaError::IoError(e.to_string()))?; freed_size += size; } @@ -186,9 +184,8 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { // Load cached ETag if available let etag_trimmed = match fs::read_to_string(&etag_path) { Ok(etag_value) => etag_value.trim().to_string(), - Err(_) => "abc".to_string(), + Err(_) => "".to_string(), }; - println!("Using ETag for request: {}", etag_trimmed); let temp_path = cached_path.with_extension("onnx.part"); let mut guard = TempFileGuard::new(&temp_path); @@ -207,12 +204,20 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { // Perform download using helper function match download_file_with_etag(url, &temp_path, timeout_secs, &etag_trimmed) { Ok(Some((false, etag_str))) => { - // New file was downloaded (first value is false) - log!(LogLevel::Info, "Download succeeded for URL: {}", url); + // first value is false, it mean the object in server is new or changed, take the downloading + log!( + LogLevel::Info, + "Cache miss for URL: {}, downloading...", + url + ); + + log!(LogLevel::Info, "Successfully downloaded: {}", url); let file_size = fs::metadata(&temp_path) .map_err(|e| InferaError::IoError(e.to_string()))? .len(); + + log!(LogLevel::Debug, "Downloaded file size: {} bytes", file_size); evict_cache_if_needed(file_size)?; fs::rename(&temp_path, &cached_path) .map_err(|e| InferaError::IoError(e.to_string()))?; @@ -225,7 +230,7 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { return Ok(cached_path); } Ok(Some((true, _))) => { - // status is 304 Not Modified, use cached file (first value is true) + // first value is true, it mean the object in server is Not Modified, use cached file log!(LogLevel::Info, "Cache hit for URL: {}", url); // Update access time for LRU tracking @@ -287,12 +292,10 @@ fn download_file_with_etag( .error_for_status() .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; - println!("response.status(): {:?}", response.status()); if etag != "" && response.status() == reqwest::StatusCode::NOT_MODIFIED { // Not modified, no file write needed, return None etag return Ok(Some((true, etag.to_string()))); } - println!("prepare to create dest file: {:?}", dest); let mut file = File::create(dest).map_err(|e| InferaError::IoError(e.to_string()))?; io::copy(&mut response, &mut file).map_err(|e| InferaError::IoError(e.to_string()))?; @@ -346,7 +349,6 @@ mod tests { // Ensure no partial or final file is left in the cache. let cache_dir = env::temp_dir().join("infera_cache"); - println!("Cache dir: {:?}", cache_dir); let mut hasher = Sha256::new(); hasher.update(model_url.as_bytes()); let hash_hex = hex::encode(hasher.finalize()); @@ -436,7 +438,7 @@ mod tests { } #[actix_web::test] - async fn test_handle_remote_model_success_and_cache() { + async fn test_remote_model_download_and_cache_with_etag_enabled() { use std::net::TcpListener; let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind random port"); @@ -522,6 +524,96 @@ mod tests { srv_handle.abort(); } + + #[actix_web::test] + async fn test_remote_model_download_and_cache_with_etag_disabled() { + use std::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind random port"); + let addr = listener.local_addr().unwrap(); + + let file_content = Arc::new(RwLock::new(b"initial content".to_vec())); + let file_content_server = file_content.clone(); + + let server = HttpServer::new(move || { + let content = file_content_server.clone(); + App::new().route( + "/model.onnx", + web::get().to(move || { + let content = content.clone(); + async move { + let data = content.read().await; + let bytes = Bytes::copy_from_slice(&*data); + Ok::<_, Error>( + HttpResponse::Ok() + .content_type("application/octet-stream") + .body(bytes), + ) + } + }), + ) + }) + .listen(listener) + .expect("Failed to bind server") + .run(); + + // Spawn server in background + let srv_handle = actix_web::rt::spawn(server); + + let url = format!("http://{}:{}/model.onnx", addr.ip(), addr.port()); + let second_call_url = url.clone(); + let third_call_url = url.clone(); + // Call your blocking cache-and-download function in blocking task + let path1 = tokio::task::spawn_blocking(move || handle_remote_model(&url)) + .await + .expect("Task panicked") + .expect("handle_remote_model failed"); + + assert!(path1.exists()); + let content1 = fs::read(&path1).expect("read cached file"); + let path1_modification_time = get_file_modification_time(&path1).unwrap(); + + // Call again, should refresh cache + let path2 = tokio::task::spawn_blocking(move || handle_remote_model(&second_call_url)) + .await + .expect("Task panicked") + .expect("handle_remote_model failed"); + + tokio::time::sleep(Duration::from_secs(1)).await; + + let content2 = fs::read(&path2).expect("read cached file"); + let path2_modification_time = get_file_modification_time(&path2).unwrap(); + + assert_eq!(path1, path2); + assert_eq!(content1, content2); + assert_ne!(path1_modification_time, path2_modification_time); + + // Modify content to simulate update + { + let mut content_write = file_content.write().await; + *content_write = b"updated content".to_vec(); + } + + tokio::time::sleep(Duration::from_secs(1)).await; + + // Call again, should refresh cache + let path3 = tokio::task::spawn_blocking(move || handle_remote_model(&third_call_url)) + .await + .expect("Task panicked") + .expect("handle_remote_model failed"); + + let content3 = fs::read(&path3).expect("read cached file"); + let path3_modification_time = get_file_modification_time(&path3).unwrap(); + assert_eq!(path1, path3); + assert_ne!(content1, content3); + assert_ne!(path1_modification_time, path3_modification_time); + + // Cleanup server + srv_handle.abort(); + } + + + #[test] fn test_clear_cache_removes_files() { let dir = cache_dir(); From eaa2874e750b68ba7c3cdb5b22be7c5948898ed6 Mon Sep 17 00:00:00 2001 From: thuong Date: Fri, 21 Nov 2025 22:47:03 +0700 Subject: [PATCH 4/9] feat: add allow(clippy::needless_return) --- infera/src/http.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/infera/src/http.rs b/infera/src/http.rs index dae691d..dc392a4 100644 --- a/infera/src/http.rs +++ b/infera/src/http.rs @@ -216,7 +216,7 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { let file_size = fs::metadata(&temp_path) .map_err(|e| InferaError::IoError(e.to_string()))? .len(); - + log!(LogLevel::Debug, "Downloaded file size: {} bytes", file_size); evict_cache_if_needed(file_size)?; fs::rename(&temp_path, &cached_path) @@ -230,7 +230,7 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { return Ok(cached_path); } Ok(Some((true, _))) => { - // first value is true, it mean the object in server is Not Modified, use cached file + // first value is true, it mean the object in server is Not Modified, use cached file log!(LogLevel::Info, "Cache hit for URL: {}", url); // Update access time for LRU tracking @@ -273,6 +273,7 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { .unwrap_or_else(|| InferaError::HttpRequestError("Unknown download error".to_string()))) } +#[allow(clippy::needless_return)] fn download_file_with_etag( url: &str, dest: &Path, @@ -292,7 +293,7 @@ fn download_file_with_etag( .error_for_status() .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; - if etag != "" && response.status() == reqwest::StatusCode::NOT_MODIFIED { + if !etag.is_empty() && response.status() == reqwest::StatusCode::NOT_MODIFIED { // Not modified, no file write needed, return None etag return Ok(Some((true, etag.to_string()))); } @@ -524,7 +525,6 @@ mod tests { srv_handle.abort(); } - #[actix_web::test] async fn test_remote_model_download_and_cache_with_etag_disabled() { use std::net::TcpListener; @@ -612,8 +612,6 @@ mod tests { srv_handle.abort(); } - - #[test] fn test_clear_cache_removes_files() { let dir = cache_dir(); From f87ff6557354a17c18ff5d1ec923752175ad4644 Mon Sep 17 00:00:00 2001 From: thuong Date: Fri, 21 Nov 2025 23:00:19 +0700 Subject: [PATCH 5/9] feat: update comments for http code --- infera/src/http.rs | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/infera/src/http.rs b/infera/src/http.rs index dc392a4..b4ce544 100644 --- a/infera/src/http.rs +++ b/infera/src/http.rs @@ -140,12 +140,12 @@ pub(crate) fn clear_cache() -> Result<(), InferaError> { Ok(()) } -// / Handles the download and caching of a remote model from a URL. -// / -// / If the model for the given URL is already present in the local cache, this -// / function updates its access time and returns the path. Otherwise, it downloads -// / the file, evicts old cache entries if needed, stores it in the cache directory, -// / and then returns the path. +/// Handles downloading and caching a remote model from a URL. +/// +/// If the model for the given URL is already present in the local cache and the ETag of this object +/// has not changed since the last call, this function updates its access time and returns the path. +/// Otherwise, it downloads the file, evicts old cache entries if needed, stores it in the cache directory, +/// and then returns the path. // / // / The cache uses an LRU (Least Recently Used) eviction policy with a configurable // / size limit (default 1GB, configurable via INFERA_CACHE_SIZE_LIMIT env var). @@ -256,7 +256,7 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { } Ok(None) => { // theoretically unreachable, but necessary to satisfy exhaustiveness - // Handle as error or log warning, or panic + // Handle as error log!(LogLevel::Error, "Can't exist None for this matching"); } } @@ -273,6 +273,27 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { .unwrap_or_else(|| InferaError::HttpRequestError("Unknown download error".to_string()))) } + +/// Downloads a file from a URL with ETag support for caching. +/// +/// If the ETag is non-empty and the server responds with `304 Not Modified`, +/// this function returns early with `(true, etag)` indicating no download. +/// +/// Otherwise, it downloads the file to `dest`, updates the ETag if present, +/// and returns `(false, new_etag)`. +/// +/// # Arguments +/// +/// * `url` - The URL to download from. +/// * `dest` - The destination path to save the file. +/// * `timeout_secs` - HTTP request timeout in seconds. +/// * `etag` - Cached ETag string for conditional requests. +/// +/// # Returns +/// +/// `Result, InferaError>` tuple where the boolean indicates +/// if the file was not modified (true) and the string is the ETag. + #[allow(clippy::needless_return)] fn download_file_with_etag( url: &str, From 4fe962b73094c05f7ae3b6d8de01aa35b9d44b93 Mon Sep 17 00:00:00 2001 From: thuong Date: Fri, 21 Nov 2025 23:07:42 +0700 Subject: [PATCH 6/9] feat: update comments format --- infera/src/http.rs | 54 ++++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/infera/src/http.rs b/infera/src/http.rs index b4ce544..bb01cdd 100644 --- a/infera/src/http.rs +++ b/infera/src/http.rs @@ -140,12 +140,12 @@ pub(crate) fn clear_cache() -> Result<(), InferaError> { Ok(()) } -/// Handles downloading and caching a remote model from a URL. -/// -/// If the model for the given URL is already present in the local cache and the ETag of this object -/// has not changed since the last call, this function updates its access time and returns the path. -/// Otherwise, it downloads the file, evicts old cache entries if needed, stores it in the cache directory, -/// and then returns the path. +// / Handles downloading and caching a remote model from a URL. +// / +// / If the model for the given URL is already present in the local cache and the ETag of this object +// / has not changed since the last call, this function updates its access time and returns the path. +// / Otherwise, it downloads the file, evicts old cache entries if needed, stores it in the cache directory, +// / and then returns the path. // / // / The cache uses an LRU (Least Recently Used) eviction policy with a configurable // / size limit (default 1GB, configurable via INFERA_CACHE_SIZE_LIMIT env var). @@ -256,7 +256,7 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { } Ok(None) => { // theoretically unreachable, but necessary to satisfy exhaustiveness - // Handle as error + // Handle as error log!(LogLevel::Error, "Can't exist None for this matching"); } } @@ -273,27 +273,25 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { .unwrap_or_else(|| InferaError::HttpRequestError("Unknown download error".to_string()))) } - -/// Downloads a file from a URL with ETag support for caching. -/// -/// If the ETag is non-empty and the server responds with `304 Not Modified`, -/// this function returns early with `(true, etag)` indicating no download. -/// -/// Otherwise, it downloads the file to `dest`, updates the ETag if present, -/// and returns `(false, new_etag)`. -/// -/// # Arguments -/// -/// * `url` - The URL to download from. -/// * `dest` - The destination path to save the file. -/// * `timeout_secs` - HTTP request timeout in seconds. -/// * `etag` - Cached ETag string for conditional requests. -/// -/// # Returns -/// -/// `Result, InferaError>` tuple where the boolean indicates -/// if the file was not modified (true) and the string is the ETag. - +// / Downloads a file from a URL with ETag support for caching. +// / +// / If the ETag is non-empty and the server responds with `304 Not Modified`, +// / this function returns early with `(true, etag)` indicating no download. +// / +// / Otherwise, it downloads the file to `dest`, updates the ETag if present, +// / and returns `(false, new_etag)`. +// / +// / # Arguments +// / +// / * `url` - The URL to download from. +// / * `dest` - The destination path to save the file. +// / * `timeout_secs` - HTTP request timeout in seconds. +// / * `etag` - Cached ETag string for conditional requests. +// / +// / # Returns +// / +// / `Result, InferaError>` tuple where the boolean indicates +// / if the file was not modified (true) and the string is the ETag. #[allow(clippy::needless_return)] fn download_file_with_etag( url: &str, From 87469e44831de666214665059b18345711a416c8 Mon Sep 17 00:00:00 2001 From: thuong Date: Sat, 22 Nov 2025 07:15:05 +0700 Subject: [PATCH 7/9] feat: remove redundant libs --- infera/Cargo.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/infera/Cargo.toml b/infera/Cargo.toml index 89b12f4..21026be 100644 --- a/infera/Cargo.toml +++ b/infera/Cargo.toml @@ -25,11 +25,8 @@ sha2 = "0.10" hex = "0.4" filetime = "0.2" -axum = "0.6" tokio = { version = "1", features = ["full"] } tokio-util = "0.7" # for ReaderStream -hyper = { version = "0.14", features = ["full"] } -http_cache_tags_axum="0.1.0-alpha.5" actix-middleware-etag = "0.4.6" actix-web = "4.12.0" bytes = "1.11.0" From b5d35bf790b968c7e04ad638432ce97cb65d67aa Mon Sep 17 00:00:00 2001 From: thuong Date: Sat, 22 Nov 2025 09:17:55 +0700 Subject: [PATCH 8/9] feat: erasure redundant commments --- infera/src/http.rs | 5 +++-- infera/src/lib.rs | 7 ------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/infera/src/http.rs b/infera/src/http.rs index bb01cdd..cb847d5 100644 --- a/infera/src/http.rs +++ b/infera/src/http.rs @@ -212,7 +212,7 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { ); log!(LogLevel::Info, "Successfully downloaded: {}", url); - + // Check file size and evict cache if needed let file_size = fs::metadata(&temp_path) .map_err(|e| InferaError::IoError(e.to_string()))? .len(); @@ -248,6 +248,7 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { e ); last_error = Some(e); + // Don't sleep after the last attempt if attempt < max_attempts { let delay = Duration::from_millis(retry_delay_ms * attempt as u64); log!(LogLevel::Debug, "Waiting {:?} before retry", delay); @@ -313,7 +314,7 @@ fn download_file_with_etag( .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; if !etag.is_empty() && response.status() == reqwest::StatusCode::NOT_MODIFIED { - // Not modified, no file write needed, return None etag + // Not modified, no file write needed, return true and etag return Ok(Some((true, etag.to_string()))); } let mut file = File::create(dest).map_err(|e| InferaError::IoError(e.to_string()))?; diff --git a/infera/src/lib.rs b/infera/src/lib.rs index 9645501..8364198 100644 --- a/infera/src/lib.rs +++ b/infera/src/lib.rs @@ -338,7 +338,6 @@ pub extern "C" fn infera_get_cache_info() -> *mut c_char { .flatten() { let path = entry.path(); - println!("path: {:?}", path); if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("onnx") { if let Ok(metadata) = fs::metadata(&path) { total_size += metadata.len(); @@ -438,7 +437,6 @@ mod tests { let version_json = unsafe { CStr::from_ptr(version_ptr).to_str().unwrap() }; let version_data: serde_json::Value = serde_json::from_str(version_json).unwrap(); - println!("version_data {:?}", version_data); assert!(version_data["version"].is_string()); assert!(version_data["onnx_backend"].is_string()); assert!(version_data["model_cache_dir"].is_string()); @@ -456,7 +454,6 @@ mod tests { let result_ptr = unsafe { infera_set_autoload_dir(path_cstr.as_ptr()) }; let result_json = unsafe { CStr::from_ptr(result_ptr).to_str().unwrap() }; let result_data: serde_json::Value = serde_json::from_str(result_json).unwrap(); - println!("result_data {:?}", result_data); assert_eq!(result_data["loaded"].as_array().unwrap().len(), 1); assert_eq!(result_data["loaded"][0], "linear"); assert_eq!(result_data["errors"].as_array().unwrap().len(), 0); @@ -472,7 +469,6 @@ mod tests { let result_ptr = unsafe { infera_set_autoload_dir(path_cstr.as_ptr()) }; let result_json = unsafe { CStr::from_ptr(result_ptr).to_str().unwrap() }; let result_data: serde_json::Value = serde_json::from_str(result_json).unwrap(); - println!("result_data {:?}", result_data); assert!(result_data["error"].is_string()); unsafe { infera_free(result_ptr) }; @@ -489,7 +485,6 @@ mod tests { let result_json = unsafe { CStr::from_ptr(result_ptr).to_str().unwrap() }; let result_data: serde_json::Value = serde_json::from_str(result_json).unwrap(); - println!("result_data {:?}", result_data); assert_eq!(result_data["loaded"].as_array().unwrap().len(), 0); assert_eq!(result_data["errors"].as_array().unwrap().len(), 1); assert_eq!( @@ -650,9 +645,7 @@ mod tests { fn test_infera_get_cache_info_includes_configured_limit() { let cache_info_ptr = infera_get_cache_info(); let cache_info_json = unsafe { CStr::from_ptr(cache_info_ptr).to_str().unwrap() }; - println!("cache_info_json: {:?}", cache_info_json); let value: serde_json::Value = serde_json::from_str(cache_info_json).unwrap(); - println!("value: {:?}", value); let size_limit = value["size_limit_bytes"] .as_u64() .expect("size_limit_bytes should be u64"); From e6949b4c552df325fc9a4e8f0c0f82b6c7ea6bd2 Mon Sep 17 00:00:00 2001 From: thuong Date: Sat, 22 Nov 2025 09:24:19 +0700 Subject: [PATCH 9/9] feat: erasure redundant commments --- infera/src/lib.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/infera/src/lib.rs b/infera/src/lib.rs index 8364198..2c4b21b 100644 --- a/infera/src/lib.rs +++ b/infera/src/lib.rs @@ -454,6 +454,7 @@ mod tests { let result_ptr = unsafe { infera_set_autoload_dir(path_cstr.as_ptr()) }; let result_json = unsafe { CStr::from_ptr(result_ptr).to_str().unwrap() }; let result_data: serde_json::Value = serde_json::from_str(result_json).unwrap(); + assert_eq!(result_data["loaded"].as_array().unwrap().len(), 1); assert_eq!(result_data["loaded"][0], "linear"); assert_eq!(result_data["errors"].as_array().unwrap().len(), 0); @@ -469,6 +470,7 @@ mod tests { let result_ptr = unsafe { infera_set_autoload_dir(path_cstr.as_ptr()) }; let result_json = unsafe { CStr::from_ptr(result_ptr).to_str().unwrap() }; let result_data: serde_json::Value = serde_json::from_str(result_json).unwrap(); + assert!(result_data["error"].is_string()); unsafe { infera_free(result_ptr) }; @@ -652,4 +654,4 @@ mod tests { assert_eq!(size_limit, crate::config::CONFIG.cache_size_limit); unsafe { infera_free(cache_info_ptr) }; } -} +} \ No newline at end of file