diff --git a/ltengine/build.rs b/ltengine/build.rs index 8a25a4b..cd501f9 100644 --- a/ltengine/build.rs +++ b/ltengine/build.rs @@ -2,4 +2,4 @@ use static_files::resource_dir; fn main() -> std::io::Result<()> { resource_dir("./resources").build() -} \ No newline at end of file +} diff --git a/ltengine/src/banner.rs b/ltengine/src/banner.rs index 962b6d3..dd9c1fc 100644 --- a/ltengine/src/banner.rs +++ b/ltengine/src/banner.rs @@ -1,5 +1,6 @@ -pub fn print_banner(){ - println!(r#" +pub fn print_banner() { + println!( + r#" __________________________________________ __ ____________ _ / / /_ __/ ____/___ ____ _(_)___ ___ @@ -10,5 +11,6 @@ pub fn print_banner(){ Local AI Machine Translation ___________________________________________ -"#); -} \ No newline at end of file +"# + ); +} diff --git a/ltengine/src/error_response.rs b/ltengine/src/error_response.rs index 8685d63..be53906 100644 --- a/ltengine/src/error_response.rs +++ b/ltengine/src/error_response.rs @@ -1,18 +1,18 @@ +use actix_web::http::StatusCode; +use actix_web::{HttpResponse, ResponseError, body::BoxBody}; use serde::Serialize; -use std::fmt::{Display, Formatter, Result as FmtResult}; use serde_json::to_string_pretty; -use actix_web::{ResponseError, HttpResponse, body::BoxBody}; -use actix_web::http::StatusCode; +use std::fmt::{Display, Formatter, Result as FmtResult}; #[derive(Debug, Serialize)] pub struct ErrorResponse { - pub error: String, - pub status: u16 + pub error: String, + pub status: u16, } impl Display for ErrorResponse { fn fmt(&self, f: &mut Formatter) -> FmtResult { - write!(f, "{}", to_string_pretty(self).unwrap()) + write!(f, "{}", to_string_pretty(self).unwrap()) } } @@ -22,7 +22,8 @@ impl ResponseError for ErrorResponse { } fn error_response(&self) -> HttpResponse { - HttpResponse::build(self.status_code()).json(serde_json::json!({"error": self.error.clone()})) + HttpResponse::build(self.status_code()) + .json(serde_json::json!({"error": self.error.clone()})) } } @@ -33,4 +34,4 @@ impl From for ErrorResponse { status: err.as_response_error().status_code().as_u16(), } } -} \ No newline at end of file +} diff --git a/ltengine/src/languages.rs b/ltengine/src/languages.rs index febfc82..17531cd 100644 --- a/ltengine/src/languages.rs +++ b/ltengine/src/languages.rs @@ -1,7 +1,7 @@ -use serde::Serialize; use once_cell::sync::Lazy; -use whatlang::{Lang, Detector}; +use serde::Serialize; use std::collections::HashMap; +use whatlang::{Detector, Lang}; const LANGS: &[(&str, &str, &str)] = &[ ("en", "", "English"), @@ -65,24 +65,22 @@ pub struct Language { pub lang_detect: Option<&'static Lang>, #[serde(skip)] - pub internal_code: &'static str + pub internal_code: &'static str, } - pub static LANGUAGES: Lazy> = Lazy::new(|| { - // From whatlang names to our names - let eng_name_map: HashMap<&'static str, &'static str> = [ - ("Mandarin", "Chinese") - ].iter().cloned().collect(); - + let eng_name_map: HashMap<&'static str, &'static str> = + [("Mandarin", "Chinese")].iter().cloned().collect(); + let mut lang_detect_map: HashMap<&'static str, &'static Lang> = HashMap::new(); for lang in Lang::all() { let eng_name = lang.eng_name(); lang_detect_map.insert(eng_name_map.get(eng_name).unwrap_or(&eng_name), lang); } - LANGS.iter() + LANGS + .iter() .map(|&(code, alias, name)| { let targets: Vec<&str> = LANGS .iter() @@ -94,29 +92,35 @@ pub static LANGUAGES: Lazy> = Lazy::new(|| { name, targets: Box::leak(targets.into_boxed_slice()), lang_detect: lang_detect_map.get(name).map(|v| &**v), - internal_code: code + internal_code: code, } }) .collect() }); static LANGUAGES_MAP: Lazy> = Lazy::new(|| { - LANGUAGES.iter().map(|lang| (lang.internal_code, lang)).collect() + LANGUAGES + .iter() + .map(|lang| (lang.internal_code, lang)) + .collect() }); static CODE_TO_INTERNAL_CODE_MAP: Lazy> = Lazy::new(|| { - LANGUAGES.iter().map(|lang| (lang.code, lang.internal_code)).collect() + LANGUAGES + .iter() + .map(|lang| (lang.code, lang.internal_code)) + .collect() }); -pub fn get_language_from_code(code: &String) -> Option<&'static Language>{ +pub fn get_language_from_code(code: &String) -> Option<&'static Language> { let code_str = code.as_str(); let internal_code = CODE_TO_INTERNAL_CODE_MAP.get(code_str).unwrap_or(&code_str); LANGUAGES_MAP.get(internal_code).map(|v| &**v) } -pub struct LangDetect{ +pub struct LangDetect { pub language: &'static Language, - pub confidence: i32 + pub confidence: i32, } pub fn detect_lang(q: &String) -> LangDetect { @@ -130,9 +134,13 @@ pub fn detect_lang(q: &String) -> LangDetect { let lang = info.lang(); let confidence = info.confidence(); - LANGUAGES.iter() + LANGUAGES + .iter() .find(|l| l.lang_detect == Some(&lang)) - .map(|l| LangDetect { language: l, confidence: (confidence * 100.0) as i32 }) + .map(|l| LangDetect { + language: l, + confidence: (confidence * 100.0) as i32, + }) .unwrap_or(LangDetect { language: &LANGUAGES[0], confidence: 0, @@ -143,4 +151,4 @@ pub fn detect_lang(q: &String) -> LangDetect { confidence: 0, } } -} \ No newline at end of file +} diff --git a/ltengine/src/llm.rs b/ltengine/src/llm.rs index 381aa2a..b3e17f2 100644 --- a/ltengine/src/llm.rs +++ b/ltengine/src/llm.rs @@ -1,17 +1,17 @@ +use anyhow::{Context, Result}; +use llama_cpp_2::context::LlamaContext; use llama_cpp_2::context::params::LlamaContextParams; use llama_cpp_2::llama_backend::LlamaBackend; -use llama_cpp_2::model::params::LlamaModelParams; -use llama_cpp_2::model::{AddBos, LlamaModel, LlamaChatMessage}; -use llama_cpp_2::token::LlamaToken; -use llama_cpp_2::context::LlamaContext; use llama_cpp_2::llama_batch::LlamaBatch; +use llama_cpp_2::model::params::LlamaModelParams; +use llama_cpp_2::model::{AddBos, LlamaChatMessage, LlamaModel}; use llama_cpp_2::sampling::LlamaSampler; -use llama_cpp_2::{send_logs_to_tracing, LogOptions}; +use llama_cpp_2::token::LlamaToken; +use llama_cpp_2::{LogOptions, send_logs_to_tracing}; +use parking_lot::Mutex; use std::num::NonZeroU32; use std::path::PathBuf; -use parking_lot::Mutex; use std::time::Duration; -use anyhow::{Result, Context}; #[derive(Debug, thiserror::Error)] pub enum LLMError { @@ -28,7 +28,11 @@ fn vram_mib() -> Option<(u64, u64)> { #[cfg(feature = "cuda")] { unsafe extern "C" { - fn ggml_backend_cuda_get_device_memory(device: i32, free: *mut usize, total: *mut usize); + fn ggml_backend_cuda_get_device_memory( + device: i32, + free: *mut usize, + total: *mut usize, + ); } unsafe { ggml_backend_cuda_get_device_memory(0, &mut free, &mut total) }; if total > 0 { @@ -75,10 +79,10 @@ pub struct LLM { n_ubatch: u32, } -pub struct LLMContext<'a>{ +pub struct LLMContext<'a> { llm: &'a LLM, ctx: LlamaContext<'a>, - ctx_size: i32 + ctx_size: i32, } impl LLM { @@ -94,7 +98,8 @@ impl LLM { let mut n_gpu = 9999u32; let model = loop { let model = match LlamaModel::load_from_file( - &backend, &model_path, + &backend, + &model_path, &LlamaModelParams::default().with_n_gpu_layers(n_gpu), ) { Ok(m) => m, @@ -102,10 +107,15 @@ impl LLM { // Load failed (likely GPU OOM before probe). On the first failure // jump to 64 (covers most models); after that halve to converge fast. let next = if n_gpu >= 9999 { 64 } else { n_gpu / 2 }; - eprintln!("ltengine: model load failed at {} GPU layers, retrying with {}", n_gpu, next); + eprintln!( + "ltengine: model load failed at {} GPU layers, retrying with {}", + n_gpu, next + ); n_gpu = next; if n_gpu == 0 { - return Err(anyhow::anyhow!("Unable to load model even with 0 GPU layers")); + return Err(anyhow::anyhow!( + "Unable to load model even with 0 GPU layers" + )); } continue; } @@ -113,16 +123,20 @@ impl LLM { // Probe: create a minimal context and decode one token to confirm // the GPU has enough VRAM for compute scratch buffers. - let probe_ok = model.new_context( - &backend, - LlamaContextParams::default() - .with_n_ctx(Some(NonZeroU32::new(8).unwrap())) - .with_n_ubatch(1), - ).ok().and_then(|mut ctx| { - let mut batch = LlamaBatch::new(8, 1); - batch.add(LlamaToken(0), 0, &[0], true).ok()?; - ctx.decode(&mut batch).ok() - }).is_some(); + let probe_ok = model + .new_context( + &backend, + LlamaContextParams::default() + .with_n_ctx(Some(NonZeroU32::new(8).unwrap())) + .with_n_ubatch(1), + ) + .ok() + .and_then(|mut ctx| { + let mut batch = LlamaBatch::new(8, 1); + batch.add(LlamaToken(0), 0, &[0], true).ok()?; + ctx.decode(&mut batch).ok() + }) + .is_some(); if probe_ok { break model; @@ -131,7 +145,10 @@ impl LLM { let actual = model.n_layer() as u32; let current = n_gpu.min(actual); let next = current.saturating_sub((current / 10).max(1)); - eprintln!("ltengine: GPU probe failed at {} layers, retrying with {}", current, next); + eprintln!( + "ltengine: GPU probe failed at {} layers, retrying with {}", + current, next + ); n_gpu = next; drop(model); @@ -146,9 +163,11 @@ impl LLM { (model, gpu_layers) } else { let model = LlamaModel::load_from_file( - &backend, model_path, + &backend, + model_path, &LlamaModelParams::default().with_n_gpu_layers(0), - ).with_context(|| "Unable to load model")?; + ) + .with_context(|| "Unable to load model")?; (model, None) }; @@ -156,39 +175,56 @@ impl LLM { match (use_gpu, gpu_layers) { (false, _) => eprintln!("ltengine: {} model layers, CPU only", model.n_layer()), - (true, None) => eprintln!("ltengine: {} model layers, all offloaded to GPU", model.n_layer()), - (true, Some(n)) => eprintln!("ltengine: {}/{} model layers on GPU, rest on CPU", n, model.n_layer()), + (true, None) => eprintln!( + "ltengine: {} model layers, all offloaded to GPU", + model.n_layer() + ), + (true, Some(n)) => eprintln!( + "ltengine: {}/{} model layers on GPU, rest on CPU", + n, + model.n_layer() + ), } - Ok(LLM { backend, model, prompt_lock: Mutex::new(()), n_ubatch }) + Ok(LLM { + backend, + model, + prompt_lock: Mutex::new(()), + n_ubatch, + }) } - pub fn create_context(&self, ctx_size: i32) -> Result>{ - let ctx_params = - LlamaContextParams::default() - .with_n_ctx(Some(NonZeroU32::new(ctx_size as u32).unwrap())) - .with_n_ubatch(self.n_ubatch); + pub fn create_context(&self, ctx_size: i32) -> Result> { + let ctx_params = LlamaContextParams::default() + .with_n_ctx(Some(NonZeroU32::new(ctx_size as u32).unwrap())) + .with_n_ubatch(self.n_ubatch); // Use all threads // ctx_params = ctx_params.with_n_threads(threads); // ctx_params = ctx_params.with_n_threads_batch(threads_batch); - let ctx = self.model + let ctx = self + .model .new_context(&self.backend, ctx_params) .with_context(|| "Unable to create the llama context")?; - Ok(LLMContext{ llm: self, ctx, ctx_size }) + Ok(LLMContext { + llm: self, + ctx, + ctx_size, + }) } - pub fn run_prompt(&self, system: String, user: String) -> Result{ + pub fn run_prompt(&self, system: String, user: String) -> Result { let messages = [ LlamaChatMessage::new("user".to_string(), format!("{system}\n\n{user}")) - .context("Failed to build chat message")? + .context("Failed to build chat message")?, ]; // Use the model's embedded chat template when llama.cpp can detect it. // Falls back to hardcoded Gemma format when detection fails (e.g. Gemma 4 // until llama-cpp-sys picks up the upstream Gemma 4 template detection fix). - let llm_input = match self.model + let llm_input = match self + .model .chat_template(None) .ok() .and_then(|tmpl| self.model.apply_chat_template(&tmpl, &messages, true).ok()) @@ -196,12 +232,15 @@ impl LLM { Some(s) => s, None => { eprintln!("ltengine: apply_chat_template failed: using hardcoded Gemma format"); - format!("user\n{system}\n\n{user}\nmodel\n") + format!( + "user\n{system}\n\n{user}\nmodel\n" + ) } }; // BOS is not added by apply_chat_template — str_to_token handles it. - let tokens_list = self.model + let tokens_list = self + .model .str_to_token(&llm_input, AddBos::Always) .with_context(|| "Failed to tokenize prompt")?; // for token in &tokens_list { @@ -214,17 +253,19 @@ impl LLM { // as garbage starts to come out when we run inference in parallel // this might need to be investigated and fixed. For now we lock and process requests // one at a time. - let _lock = self.prompt_lock.try_lock_for(Duration::from_secs(120)) + let _lock = self + .prompt_lock + .try_lock_for(Duration::from_secs(120)) .ok_or(LLMError::Busy)?; let mut ctx = self.create_context(ctx_size)?; ctx.process(tokens_list) } } -impl LLMContext<'_>{ - pub fn process(&mut self, tokens_list: Vec) -> Result{ +impl LLMContext<'_> { + pub fn process(&mut self, tokens_list: Vec) -> Result { // let ctx_size: i32 = tokens_list.len() as i32 * 3; - + // We use this object to submit token data for decoding let mut batch = LlamaBatch::new(self.ctx_size.try_into()?, 1); @@ -235,7 +276,8 @@ impl LLMContext<'_>{ batch.add(token, i, &[0], is_last)?; } - self.ctx.decode(&mut batch) + self.ctx + .decode(&mut batch) .with_context(|| "llama_decode() failed")?; let mut n_cur = batch.n_tokens(); @@ -252,13 +294,12 @@ impl LLMContext<'_>{ LlamaSampler::min_p(0.05, 0), LlamaSampler::xtc(0.0, 0.1, 0, 42), LlamaSampler::temp_ext(0.0, 0.0, 1.0), - LlamaSampler::dist(42) + LlamaSampler::dist(42), ]); let mut output = String::new(); while n_cur <= self.ctx_size { - // sample the next token { let token = sampler.sample(&self.ctx, batch.n_tokens() - 1); @@ -269,8 +310,11 @@ impl LLMContext<'_>{ if self.llm.model.is_eog_token(token) { break; } - - let output_string = self.llm.model.token_to_piece(token, &mut decoder, true, None)?; + + let output_string = + self.llm + .model + .token_to_piece(token, &mut decoder, true, None)?; output.push_str(&output_string); batch.clear(); @@ -279,7 +323,9 @@ impl LLMContext<'_>{ n_cur += 1; - self.ctx.decode(&mut batch).with_context(|| "Failed to eval")?; + self.ctx + .decode(&mut batch) + .with_context(|| "Failed to eval")?; } // Gemma 4 thinking mode emits thinking content before the actual response in two forms: diff --git a/ltengine/src/main.rs b/ltengine/src/main.rs index 933d95d..3ade716 100644 --- a/ltengine/src/main.rs +++ b/ltengine/src/main.rs @@ -1,24 +1,24 @@ +use actix_multipart::form::{MultipartForm, text::Text as MPText}; use actix_web::{ - get, post, web, App, HttpRequest, HttpResponse, - HttpServer, Responder, http::header, FromRequest + App, FromRequest, HttpRequest, HttpResponse, HttpServer, Responder, get, http::header, post, + web, }; -use actix_multipart::form::{MultipartForm, text::Text as MPText}; use actix_web_static_files::ResourceFiles; -use std::sync::Arc; use clap::Parser; use serde::{Deserialize, Serialize}; +use std::sync::Arc; +mod banner; mod error_response; mod languages; -mod models; mod llm; -mod banner; +mod models; mod prompt; -use languages::{detect_lang, get_language_from_code, LANGUAGES}; +use banner::print_banner; use error_response::ErrorResponse; +use languages::{LANGUAGES, detect_lang, get_language_from_code}; use models::{MODELS, load_model}; -use banner::print_banner; use prompt::PromptBuilder; include!(concat!(env!("OUT_DIR"), "/generated.rs")); @@ -48,7 +48,7 @@ struct Args { /// Set an API key #[arg(long, default_value = "")] - api_key: String, + api_key: String, /// Use CPU only #[arg(long)] @@ -56,7 +56,7 @@ struct Args { /// Enable verbose logging #[arg(short = 'v', long)] - verbose: bool + verbose: bool, } #[derive(Debug, Deserialize, Serialize)] @@ -66,7 +66,7 @@ struct TranslateRequest { target: Option, format: Option, api_key: Option, - alternatives: Option + alternatives: Option, } #[derive(MultipartForm)] @@ -76,7 +76,7 @@ struct MPTranslateRequest { target: Option>, format: Option>, api_key: Option>, - alternatives: Option> + alternatives: Option>, } impl MPTranslateRequest { fn into_translate_request(self) -> TranslateRequest { @@ -91,27 +91,47 @@ impl MPTranslateRequest { } } -async fn parse_payload(req: HttpRequest, payload: web::Payload) -> Result{ - let content_type = req.headers().get(header::CONTENT_TYPE).map(|h| h.to_str().unwrap_or("")).unwrap_or(""); +async fn parse_payload( + req: HttpRequest, + payload: web::Payload, +) -> Result { + let content_type = req + .headers() + .get(header::CONTENT_TYPE) + .map(|h| h.to_str().unwrap_or("")) + .unwrap_or(""); let body: TranslateRequest; if content_type.starts_with("application/json") { - let json = actix_web::web::Json::::from_request(&req, &mut payload.into_inner()).await?; + let json = + actix_web::web::Json::::from_request(&req, &mut payload.into_inner()) + .await?; body = json.into_inner() } else if content_type.starts_with("application/x-www-form-urlencoded") { - let form = actix_web::web::Form::::from_request(&req, &mut payload.into_inner()).await?; + let form = + actix_web::web::Form::::from_request(&req, &mut payload.into_inner()) + .await?; body = form.into_inner() } else if content_type.starts_with("multipart/form-data") { - let form = MultipartForm::::from_request(&req, &mut payload.into_inner()).await?; + let form = + MultipartForm::::from_request(&req, &mut payload.into_inner()) + .await?; body = form.into_inner().into_translate_request(); } else { - return Err(ErrorResponse{ error: "Unsupported content-type".to_string(), status: 400 }); + return Err(ErrorResponse { + error: "Unsupported content-type".to_string(), + status: 400, + }); } return Ok(body); } -fn check_params(body: &TranslateRequest, args: &Args, required_params: &[(&str, &Option)]) -> Result { +fn check_params( + body: &TranslateRequest, + args: &Args, + required_params: &[(&str, &Option)], +) -> Result { // Validate required params for (key, value) in required_params { if value.as_ref().is_none_or(|v| v.trim().is_empty()) { @@ -121,7 +141,7 @@ fn check_params(body: &TranslateRequest, args: &Args, required_params: &[(&str, }); } } - + // Check key if !args.api_key.is_empty() && body.api_key.as_ref().is_none_or(|key| *key != args.api_key) { return Err(ErrorResponse { @@ -133,7 +153,11 @@ fn check_params(body: &TranslateRequest, args: &Args, required_params: &[(&str, let q = body.q.as_ref().unwrap(); if q.len() > args.char_limit { return Err(ErrorResponse { - error: format!("Invalid request: request ({}) exceeds text limit ({})", q.len(), args.char_limit), + error: format!( + "Invalid request: request ({}) exceeds text limit ({})", + q.len(), + args.char_limit + ), status: 400, }); } @@ -157,16 +181,16 @@ fn improve_formatting(q: &String, translation: &String) -> String { let mut result = t.clone(); const PUNCTUATION_CHARS: [char; 6] = ['!', '?', '.', ',', ';', '。']; - if PUNCTUATION_CHARS.contains(&q_last_char){ - if q_last_char != translation_last_char{ - if PUNCTUATION_CHARS.contains(&translation_last_char){ + if PUNCTUATION_CHARS.contains(&q_last_char) { + if q_last_char != translation_last_char { + if PUNCTUATION_CHARS.contains(&translation_last_char) { result.pop(); } result.push(q_last_char); } - }else if PUNCTUATION_CHARS.contains(&translation_last_char) { - result.pop(); + } else if PUNCTUATION_CHARS.contains(&translation_last_char) { + result.pop(); } if q.chars().all(|c| c.is_lowercase()) { @@ -180,7 +204,7 @@ fn improve_formatting(q: &String, translation: &String) -> String { if let (Some(q0), Some(r0)) = (q.chars().next(), result.chars().next()) { if q0.is_lowercase() && r0.is_uppercase() { result.replace_range(0..r0.len_utf8(), &r0.to_lowercase().to_string()); - }else if q0.is_uppercase() && r0.is_lowercase() { + } else if q0.is_uppercase() && r0.is_lowercase() { result.replace_range(0..r0.len_utf8(), &r0.to_uppercase().to_string()); } } @@ -189,11 +213,13 @@ fn improve_formatting(q: &String, translation: &String) -> String { } #[post("/detect")] -async fn detect(req: HttpRequest, payload: web::Payload, args: web::Data>) -> Result { +async fn detect( + req: HttpRequest, + payload: web::Payload, + args: web::Data>, +) -> Result { let body = parse_payload(req, payload).await?; - check_params(&body, &args, &[ - ("q", &body.q) - ])?; + check_params(&body, &args, &[("q", &body.q)])?; let q = body.q.unwrap(); let d = detect_lang(&q); @@ -210,33 +236,42 @@ fn check_format(format: &str) -> Result { _ => Err(ErrorResponse { error: "Invalid format. Supported formats: text, html".to_string(), status: 400, - }) + }), } } #[post("/translate")] -async fn translate(req: HttpRequest, payload: web::Payload, args: web::Data>, llm: actix_web::web::Data>) -> Result { +async fn translate( + req: HttpRequest, + payload: web::Payload, + args: web::Data>, + llm: actix_web::web::Data>, +) -> Result { let body = parse_payload(req, payload).await?; - check_params(&body, &args, &[ - ("q", &body.q), - ("source", &body.source), - ("target", &body.target), - ])?; + check_params( + &body, + &args, + &[ + ("q", &body.q), + ("source", &body.source), + ("target", &body.target), + ], + )?; let q = body.q.unwrap(); let source = body.source.unwrap(); let target = body.target.unwrap(); let format = body.format.unwrap_or("text".to_string()); check_format(&format)?; - + let mut pb = PromptBuilder::new(); pb.set_format(&format); // TODO: add HTML support - - if source == "auto"{ + + if source == "auto" { pb.set_source_language("auto"); - }else{ + } else { let src_lang = get_language_from_code(&source).ok_or_else(|| ErrorResponse { error: format!("{} is not supported", source), status: 400, @@ -252,7 +287,7 @@ async fn translate(req: HttpRequest, payload: web::Payload, args: web::Data() { 503 } else { 500 }; @@ -263,8 +298,9 @@ async fn translate(req: HttpRequest, payload: web::Payload, args: web::Data Result { - Err(ErrorResponse{ + Err(ErrorResponse { error: "Not implemented".to_string(), - status: 501 + status: 501, }) } #[post("/suggest")] async fn suggest() -> Result { - Err(ErrorResponse{ + Err(ErrorResponse { error: "Not implemented".to_string(), - status: 501 + status: 501, }) } @@ -338,13 +374,15 @@ async fn main() -> std::io::Result<()> { eprintln!("Failed to load model: {}", err); std::process::exit(1); }); - + println!("Loading model: {}", model_path.display()); - let llm = Arc::new(llm::LLM::new(model_path, args.cpu, args.verbose).unwrap_or_else(|err| { - eprintln!("Failed to initialize LLM: {}", err); - std::process::exit(1); - })); + let llm = Arc::new( + llm::LLM::new(model_path, args.cpu, args.verbose).unwrap_or_else(|err| { + eprintln!("Failed to initialize LLM: {}", err); + std::process::exit(1); + }), + ); print_banner(); @@ -369,4 +407,4 @@ async fn main() -> std::io::Result<()> { println!("Running on: http://{}:{}", host, port); return server.await; -} \ No newline at end of file +} diff --git a/ltengine/src/models.rs b/ltengine/src/models.rs index 58f5bf3..7baa845 100644 --- a/ltengine/src/models.rs +++ b/ltengine/src/models.rs @@ -1,7 +1,7 @@ +use anyhow::{Context, Result, anyhow}; +use hf_hub::api::sync::ApiBuilder; use std::collections::HashMap; use std::path::PathBuf; -use hf_hub::api::sync::ApiBuilder; -use anyhow::{anyhow, Context, Result}; #[derive(Clone, Debug)] pub struct HuggingFace { @@ -11,24 +11,50 @@ pub struct HuggingFace { #[derive(Debug)] pub enum Model { - Local { - path: PathBuf, - }, - Remote { - hf: HuggingFace, - }, + Local { path: PathBuf }, + Remote { hf: HuggingFace }, } -pub static MODELS: once_cell::sync::Lazy> = once_cell::sync::Lazy::new(|| { - let mut m = HashMap::new(); - m.insert("gemma3-1b", HuggingFace { repo: "libretranslate/gemma3", model: "gemma-3-1b-it-q4_0.gguf" }); - m.insert("gemma3-4b", HuggingFace { repo: "libretranslate/gemma3", model: "gemma-3-4b-it-q4_0.gguf" }); - m.insert("gemma3-12b", HuggingFace { repo: "libretranslate/gemma3", model: "gemma-3-12b-it-q4_0.gguf" }); - m.insert("gemma3-27b", HuggingFace { repo: "libretranslate/gemma3", model: "gemma-3-27b-it-q4_0.gguf" }); - m.insert("gemma4-e4b", HuggingFace { repo: "bartowski/google_gemma-4-E4B-it-GGUF", model: "google_gemma-4-E4B-it-Q4_0.gguf" }); - m -}); - +pub static MODELS: once_cell::sync::Lazy> = + once_cell::sync::Lazy::new(|| { + let mut m = HashMap::new(); + m.insert( + "gemma3-1b", + HuggingFace { + repo: "libretranslate/gemma3", + model: "gemma-3-1b-it-q4_0.gguf", + }, + ); + m.insert( + "gemma3-4b", + HuggingFace { + repo: "libretranslate/gemma3", + model: "gemma-3-4b-it-q4_0.gguf", + }, + ); + m.insert( + "gemma3-12b", + HuggingFace { + repo: "libretranslate/gemma3", + model: "gemma-3-12b-it-q4_0.gguf", + }, + ); + m.insert( + "gemma3-27b", + HuggingFace { + repo: "libretranslate/gemma3", + model: "gemma-3-27b-it-q4_0.gguf", + }, + ); + m.insert( + "gemma4-e4b", + HuggingFace { + repo: "bartowski/google_gemma-4-E4B-it-GGUF", + model: "google_gemma-4-E4B-it-Q4_0.gguf", + }, + ); + m + }); impl Model { fn load(&self) -> Result { @@ -37,9 +63,12 @@ impl Model { if path.exists() && path.extension().and_then(|ext| ext.to_str()) == Some("gguf") { Ok(path.clone()) } else { - Err(anyhow!(format!("Invalid path or not a .gguf file: {}", path.display()))) + Err(anyhow!(format!( + "Invalid path or not a .gguf file: {}", + path.display() + ))) } - }, + } Model::Remote { hf } => ApiBuilder::new() .with_progress(true) .build() @@ -66,4 +95,4 @@ pub fn load_model(model_id: &String, model_file: &String) -> Result { Ok(path) => Ok(path), Err(e) => Err(e), } -} \ No newline at end of file +} diff --git a/ltengine/src/prompt.rs b/ltengine/src/prompt.rs index 9cc3047..26a8483 100644 --- a/ltengine/src/prompt.rs +++ b/ltengine/src/prompt.rs @@ -1,12 +1,12 @@ -pub struct PromptBuilder{ +pub struct PromptBuilder { source_language: &'static str, target_language: &'static str, format: String, } -pub struct Prompt{ +pub struct Prompt { pub system: String, - pub user: String + pub user: String, } impl PromptBuilder { @@ -23,7 +23,6 @@ impl PromptBuilder { self } - pub fn set_source_language(&mut self, s: &'static str) -> &mut PromptBuilder { self.source_language = s; self @@ -41,20 +40,23 @@ impl PromptBuilder { "You are an expert linguist, specializing in translation. You are able to capture the nuances of the languages you translate. You pay attention to masculine/feminine/plural and proper use of articles and grammar. You always provide natural sounding translations that fully preserve the meaning of the original text. You never provide explanations for your work. You always answer with the translated text and nothing else." }.to_string(); - - let user = (if self.source_language == "auto"{ + let user = (if self.source_language == "auto" { format!( "Translate the text below to {}.\n\nText: {}\n\n{}:\n", self.target_language, q, self.target_language ) - }else{ + } else { format!( "Translate the text below from {} to {}.\n\n{}: {}\n\n{}:\n", - self.source_language, self.target_language, - self.source_language, q, self.target_language + self.source_language, + self.target_language, + self.source_language, + q, + self.target_language ) - }).to_string(); + }) + .to_string(); Prompt { system, user } } -} \ No newline at end of file +}