diff --git a/src/main.rs b/src/main.rs index e376cfc..733812a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,9 +10,11 @@ use axum::{ use isolang::Language; use linguaspark::Translator; use std::{fs, io, net::SocketAddr, path::PathBuf, sync::Arc}; -use tokio::net::TcpListener; +use tokio::{net::TcpListener, signal}; use tower_http::{ - cors::{Any, CorsLayer}, + cors::{ + AllowCredentials, AllowHeaders, AllowMethods, AllowOrigin, AllowPrivateNetwork, CorsLayer, + }, trace::TraceLayer, }; use tracing::{debug, error, info}; @@ -130,6 +132,34 @@ fn load_models_manually( Ok(models) } +async fn shutdown_signal() { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => { + info!("Received Ctrl+C, shutting down gracefully..."); + }, + _ = terminate => { + info!("Received SIGTERM, shutting down gracefully..."); + }, + } +} + #[tokio::main] async fn main() -> anyhow::Result<()> { if std::env::var(ENV_LOG_LEVEL).is_err() { @@ -178,9 +208,11 @@ async fn main() -> anyhow::Result<()> { let app_state = Arc::new(AppState { translator, models }); let cors = CorsLayer::new() - .allow_origin(Any) - .allow_methods(Any) - .allow_headers(Any); + .allow_origin(AllowOrigin::mirror_request()) + .allow_credentials(AllowCredentials::yes()) + .allow_methods(AllowMethods::mirror_request()) + .allow_headers(AllowHeaders::mirror_request()) + .allow_private_network(AllowPrivateNetwork::yes()); let app = Router::new() .route("/translate", post(endpoint::translate)) @@ -214,7 +246,11 @@ async fn main() -> anyhow::Result<()> { .await .context(format!("Failed to bind to address: {}", addr))?; - axum::serve(listener, app).await.context("Server error")?; + axum::serve(listener, app) + .with_graceful_shutdown(shutdown_signal()) + .await + .context("Server error")?; + info!("Server has been shut down gracefully"); Ok(()) } diff --git a/src/translation.rs b/src/translation.rs index e4b6fc4..9dd0298 100644 --- a/src/translation.rs +++ b/src/translation.rs @@ -70,6 +70,11 @@ pub async fn perform_translation( let from_code = get_iso_code(&source_lang)?; let to_code = get_iso_code(&target_lang)?; + // If source and target languages are the same, return the original text + if from_code == to_code { + return Ok((text.to_string(), from_code.to_string(), to_code.to_string())); + } + if !state.translator.is_supported(from_code, to_code)? { return Err(AppError::TranslationError(format!( "Translation from '{}' to '{}' is not supported",