|
| 1 | + |
| 2 | +use std::{path::Path, sync::Arc}; |
| 3 | +use clap::{Args, ValueEnum}; |
| 4 | +use uzu::{ |
| 5 | + backends::metal::Metal, |
| 6 | + prelude::{NGramSpeculator, NeuralSpeculator, SpeculatorConfig}, |
| 7 | +}; |
| 8 | + |
| 9 | +#[derive(ValueEnum, Debug, Clone)] |
| 10 | +pub enum SpeculatorType { |
| 11 | + Pard, |
| 12 | + Ngram, |
| 13 | +} |
| 14 | + |
| 15 | +#[derive(Args, Debug, Clone)] |
| 16 | +pub struct SpeculatorArgs { |
| 17 | + #[arg(long, value_enum)] |
| 18 | + /// Type of the speculator to use. If not specified, speculation is disabled. |
| 19 | + pub speculator_type: Option<SpeculatorType>, |
| 20 | + #[arg(long)] |
| 21 | + /// Path to the speculator model. |
| 22 | + pub speculator_path: Option<String>, |
| 23 | + #[arg(long = "speculator-tokens", default_value_t = 1)] |
| 24 | + /// Number of tokens to speculate. |
| 25 | + pub speculated_tokens: usize, |
| 26 | +} |
| 27 | + |
| 28 | +impl SpeculatorArgs { |
| 29 | + pub fn build_speculator_config( |
| 30 | + &self, |
| 31 | + model_path: &Path, |
| 32 | + ) -> SpeculatorConfig { |
| 33 | + let Some(spec_type) = &self.speculator_type else { |
| 34 | + return SpeculatorConfig::default(); |
| 35 | + }; |
| 36 | + |
| 37 | + let n = self.speculated_tokens; |
| 38 | + |
| 39 | + match spec_type { |
| 40 | + SpeculatorType::Pard => { |
| 41 | + let path_str = self |
| 42 | + .speculator_path |
| 43 | + .as_deref() |
| 44 | + .expect("--speculator-path is required when --speculator-type is pard"); |
| 45 | + let speculator = NeuralSpeculator::<Metal>::new(Path::new(path_str), n, 8) |
| 46 | + .expect("Failed to load PARD draft model"); |
| 47 | + SpeculatorConfig::new(n + 1, Arc::new(speculator)) |
| 48 | + }, |
| 49 | + SpeculatorType::Ngram => { |
| 50 | + let ngram_path = match self.speculator_path.as_deref() { |
| 51 | + Some(path) => path, |
| 52 | + None => &self.resolve_ngram_path(model_path).to_string_lossy().into_owned() |
| 53 | + }; |
| 54 | + let speculator = NGramSpeculator::load(ngram_path).expect("Failed to load NGram speculator"); |
| 55 | + SpeculatorConfig::new(n, Arc::new(speculator)) |
| 56 | + }, |
| 57 | + } |
| 58 | + } |
| 59 | + |
| 60 | + fn resolve_ngram_path( |
| 61 | + &self, |
| 62 | + model_path: &Path, |
| 63 | + ) -> std::path::PathBuf { |
| 64 | + if let Some(explicit) = self.speculator_path.as_deref() { |
| 65 | + return std::path::PathBuf::from(explicit); |
| 66 | + } |
| 67 | + |
| 68 | + let speculators_dir = model_path.join("speculators"); |
| 69 | + let mut found: Vec<std::path::PathBuf> = Vec::new(); |
| 70 | + |
| 71 | + if let Ok(entries) = std::fs::read_dir(&speculators_dir) { |
| 72 | + for entry in entries.flatten() { |
| 73 | + let candidate = entry.path().join("model.bin"); |
| 74 | + if candidate.exists() { |
| 75 | + found.push(candidate); |
| 76 | + } |
| 77 | + } |
| 78 | + } |
| 79 | + |
| 80 | + if found.is_empty() { |
| 81 | + eprintln!( |
| 82 | + "error: no ngram speculator found in {}\n\ |
| 83 | + Looked for: {}/*/model.bin\n\ |
| 84 | + Specify a path explicitly with --speculator-path <path>", |
| 85 | + speculators_dir.display(), |
| 86 | + speculators_dir.display(), |
| 87 | + ); |
| 88 | + std::process::exit(1); |
| 89 | + } |
| 90 | + |
| 91 | + if let Some(chat) = |
| 92 | + found.iter().find(|p| p.parent().and_then(|d| d.file_name()).map(|n| n == "chat").unwrap_or(false)) |
| 93 | + { |
| 94 | + return chat.clone(); |
| 95 | + } |
| 96 | + |
| 97 | + found.remove(0) |
| 98 | + } |
| 99 | +} |
0 commit comments