Skip to content

Commit bb4e4b1

Browse files
committed
Neural-speculator with PARD approach
1 parent 3ec7bb7 commit bb4e4b1

19 files changed

Lines changed: 855 additions & 64 deletions

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ autocxx = "0.30"
7070
cxx = "1.0"
7171
cmake = "0.1"
7272
autocxx-build = "0.30"
73-
7473
# optimize the build script in debug builds
7574
[profile.dev.build-override]
7675
opt-level = 3

crates/cli/src/handlers/run.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use uzu::session::{
1212
types::{Input, Output},
1313
};
1414

15-
use crate::server::load_session;
15+
use crate::{server::load_session, speculator_args::SpeculatorArgs};
1616

1717
fn format_output(output: Output) -> String {
1818
let stats = &output.stats;
@@ -53,11 +53,11 @@ pub fn handle_run(
5353
tokens_limit: usize,
5454
prefill_step_size: Option<usize>,
5555
seed: Option<u64>,
56-
speculator: Option<String>,
5756
mut message: Option<String>,
5857
no_thinking: bool,
58+
speculator_args: SpeculatorArgs,
5959
) {
60-
let mut session = load_session(model_path, prefill_step_size, seed, speculator);
60+
let mut session = load_session(model_path, prefill_step_size, seed, speculator_args);
6161

6262
let is_model_running = Arc::new(AtomicBool::new(false));
6363
let is_model_running_for_ctrlc = is_model_running.clone();

crates/cli/src/handlers/serve.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use tokio::runtime::Runtime;
22

3-
use crate::server::main::run_server;
3+
use crate::{server::main::run_server, speculator_args::SpeculatorArgs};
44

55
pub fn handle_serve(
66
model_path: String,
77
prefill_step_size: Option<usize>,
8+
speculator_args: SpeculatorArgs,
89
) {
910
let runtime = Runtime::new().unwrap();
10-
runtime.block_on(run_server(model_path, prefill_step_size));
11+
runtime.block_on(run_server(model_path, prefill_step_size, speculator_args));
1112
}

crates/cli/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pub mod handlers;
22
pub mod server;
3+
pub mod speculator_args;

crates/cli/src/main.rs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use clap::{CommandFactory, Parser, Subcommand};
2-
use cli::handlers::{handle_bench, handle_run, handle_serve};
2+
use cli::{
3+
handlers::{handle_bench, handle_run, handle_serve},
4+
speculator_args::SpeculatorArgs,
5+
};
36

47
#[derive(Parser)]
58
struct Cli {
@@ -15,24 +18,26 @@ enum Commands {
1518
model_path: String,
1619
/// Prefill step size
1720
prefill_step_size: Option<usize>,
18-
// Seed
21+
/// Seed
1922
#[arg(long)]
2023
seed: Option<u64>,
21-
// Speculator
22-
#[arg(long)]
23-
speculator: Option<String>,
2424
/// Non-interactive mode: run a single message and exit
2525
#[arg(long, short)]
2626
message: Option<String>,
2727
#[arg(long, short)]
28+
/// Disable thinking mode
2829
no_thinking: bool,
30+
#[command(flatten)]
31+
speculator_args: SpeculatorArgs,
2932
},
3033
/// Start a server with the specified model path
3134
Serve {
3235
/// Folder with model's files
3336
model_path: String,
3437
/// Prefill step size
3538
prefill_step_size: Option<usize>,
39+
#[command(flatten)]
40+
speculator_args: SpeculatorArgs,
3641
},
3742
/// Run benchmarks for the specified model
3843
Bench {
@@ -53,17 +58,26 @@ fn main() {
5358
model_path,
5459
prefill_step_size,
5560
seed,
56-
speculator,
5761
message,
5862
no_thinking,
63+
speculator_args,
5964
}) => {
60-
handle_run(model_path, 2048, prefill_step_size, seed, speculator, message, no_thinking);
65+
handle_run(
66+
model_path,
67+
2048,
68+
prefill_step_size,
69+
seed,
70+
message,
71+
no_thinking,
72+
speculator_args,
73+
);
6174
},
6275
Some(Commands::Serve {
6376
model_path,
6477
prefill_step_size,
78+
speculator_args,
6579
}) => {
66-
handle_serve(model_path, prefill_step_size);
80+
handle_serve(model_path, prefill_step_size, speculator_args);
6781
},
6882
Some(Commands::Bench {
6983
model_path,

crates/cli/src/server/main.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ use std::path::PathBuf;
33
use log::LevelFilter;
44
use rocket::{Config, config::LogLevel, log::private as log, routes};
55

6-
use crate::server::{SessionState, SessionWrapper, handle_chat_completions, handle_models, load_session};
6+
use crate::{
7+
server::{SessionState, SessionWrapper, handle_chat_completions, handle_models, load_session},
8+
speculator_args::SpeculatorArgs,
9+
};
710

811
struct SilentLogger;
912
static SILENT_LOGGER: SilentLogger = SilentLogger;
@@ -26,6 +29,7 @@ impl log::Log for SilentLogger {
2629
pub async fn run_server(
2730
model_path: String,
2831
prefill_step_size: Option<usize>,
32+
speculator_args: SpeculatorArgs,
2933
) {
3034
// Install the silent logger **before** Rocket initializes its own logger.
3135
let _ = log::set_logger(&SILENT_LOGGER).map(|_| log::set_max_level(LevelFilter::Off));
@@ -43,7 +47,7 @@ pub async fn run_server(
4347
println!("🌐 Server will be available at: http://localhost:{}", config.port);
4448
println!("📝 Endpoints:\n POST /chat/completions - Chat completions API\n");
4549

46-
let session = load_session(model_path, prefill_step_size, None, None);
50+
let session = load_session(model_path, prefill_step_size, None, speculator_args);
4751
let state = SessionState {
4852
model_name,
4953
session_wrapper: std::sync::Arc::new(SessionWrapper::new(session)),

crates/cli/src/server/state.rs

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::{
66
use console::Style;
77
use indicatif::{ProgressBar, ProgressStyle};
88
use uzu::{
9-
prelude::{SamplingSeed, SpeculatorConfig},
9+
prelude::SamplingSeed,
1010
session::{
1111
Session,
1212
config::{DecodingConfig, RunConfig},
@@ -15,6 +15,8 @@ use uzu::{
1515
},
1616
};
1717

18+
use crate::speculator_args::SpeculatorArgs;
19+
1820
pub trait RunSession {
1921
fn run(
2022
&mut self,
@@ -59,7 +61,7 @@ pub fn load_session(
5961
model_path: String,
6062
prefill_step_size: Option<usize>,
6163
seed: Option<u64>,
62-
speculator: Option<String>,
64+
speculator_args: SpeculatorArgs,
6365
) -> Session {
6466
let style_bold = Style::new().bold();
6567

@@ -71,35 +73,18 @@ pub fn load_session(
7173
progress_bar.set_style(ProgressStyle::default_spinner().template("{spinner:.green} Loading: {msg}").unwrap());
7274
progress_bar.set_message(model_name.clone());
7375

74-
let prefill_step_size_config: PrefillStepSize;
75-
if let Some(value) = prefill_step_size {
76-
prefill_step_size_config = PrefillStepSize::Custom(value);
77-
} else {
78-
prefill_step_size_config = PrefillStepSize::Default;
79-
}
76+
let prefill_step_size_config = match prefill_step_size {
77+
Some(value) => PrefillStepSize::Custom(value),
78+
None => PrefillStepSize::Default,
79+
};
8080

8181
let decoding_config = DecodingConfig::default()
8282
.with_prefill_step_size(prefill_step_size_config)
8383
.with_sampling_seed(match seed {
8484
Some(seed) => SamplingSeed::Custom(seed),
8585
None => SamplingSeed::Default,
8686
})
87-
.with_speculator_config(match speculator {
88-
Some(speculator) => {
89-
let (speculator, number_of_speculated_tokens) =
90-
speculator.split_once(':').unwrap_or((&speculator, "1"));
91-
92-
let number_of_speculated_tokens = number_of_speculated_tokens.parse().unwrap();
93-
94-
let speculator = Arc::new(uzu::speculators::ngram_speculator::NGramSpeculator::load(speculator));
95-
96-
SpeculatorConfig {
97-
number_of_speculated_tokens,
98-
speculator,
99-
}
100-
},
101-
None => SpeculatorConfig::default(),
102-
});
87+
.with_speculator_config(speculator_args.build_speculator_config(&model_path_buf));
10388
let session = Session::new(model_path_buf, decoding_config).expect("Failed to create session");
10489

10590
progress_bar.set_style(ProgressStyle::default_spinner().template("Loaded: {msg}").unwrap());

crates/cli/src/speculator_args.rs

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
2+
use std::{path::Path, sync::Arc};
3+
use clap::{Args, ValueEnum};
4+
use uzu::{
5+
backends::metal::Metal,
6+
prelude::{NGramSpeculator, NeuralSpeculator, SpeculatorConfig},
7+
};
8+
9+
#[derive(ValueEnum, Debug, Clone)]
10+
pub enum SpeculatorType {
11+
Pard,
12+
Ngram,
13+
}
14+
15+
#[derive(Args, Debug, Clone)]
16+
pub struct SpeculatorArgs {
17+
#[arg(long, value_enum)]
18+
/// Type of the speculator to use. If not specified, speculation is disabled.
19+
pub speculator_type: Option<SpeculatorType>,
20+
#[arg(long)]
21+
/// Path to the speculator model.
22+
pub speculator_path: Option<String>,
23+
#[arg(long = "speculator-tokens", default_value_t = 1)]
24+
/// Number of tokens to speculate.
25+
pub speculated_tokens: usize,
26+
}
27+
28+
impl SpeculatorArgs {
29+
pub fn build_speculator_config(
30+
&self,
31+
model_path: &Path,
32+
) -> SpeculatorConfig {
33+
let Some(spec_type) = &self.speculator_type else {
34+
return SpeculatorConfig::default();
35+
};
36+
37+
let n = self.speculated_tokens;
38+
39+
match spec_type {
40+
SpeculatorType::Pard => {
41+
let path_str = self
42+
.speculator_path
43+
.as_deref()
44+
.expect("--speculator-path is required when --speculator-type is pard");
45+
let speculator = NeuralSpeculator::<Metal>::new(Path::new(path_str), n, 8)
46+
.expect("Failed to load PARD draft model");
47+
SpeculatorConfig::new(n + 1, Arc::new(speculator))
48+
},
49+
SpeculatorType::Ngram => {
50+
let ngram_path = match self.speculator_path.as_deref() {
51+
Some(path) => path,
52+
None => &self.resolve_ngram_path(model_path).to_string_lossy().into_owned()
53+
};
54+
let speculator = NGramSpeculator::load(ngram_path).expect("Failed to load NGram speculator");
55+
SpeculatorConfig::new(n, Arc::new(speculator))
56+
},
57+
}
58+
}
59+
60+
fn resolve_ngram_path(
61+
&self,
62+
model_path: &Path,
63+
) -> std::path::PathBuf {
64+
if let Some(explicit) = self.speculator_path.as_deref() {
65+
return std::path::PathBuf::from(explicit);
66+
}
67+
68+
let speculators_dir = model_path.join("speculators");
69+
let mut found: Vec<std::path::PathBuf> = Vec::new();
70+
71+
if let Ok(entries) = std::fs::read_dir(&speculators_dir) {
72+
for entry in entries.flatten() {
73+
let candidate = entry.path().join("model.bin");
74+
if candidate.exists() {
75+
found.push(candidate);
76+
}
77+
}
78+
}
79+
80+
if found.is_empty() {
81+
eprintln!(
82+
"error: no ngram speculator found in {}\n\
83+
Looked for: {}/*/model.bin\n\
84+
Specify a path explicitly with --speculator-path <path>",
85+
speculators_dir.display(),
86+
speculators_dir.display(),
87+
);
88+
std::process::exit(1);
89+
}
90+
91+
if let Some(chat) =
92+
found.iter().find(|p| p.parent().and_then(|d| d.file_name()).map(|n| n == "chat").unwrap_or(false))
93+
{
94+
return chat.clone();
95+
}
96+
97+
found.remove(0)
98+
}
99+
}

crates/uzu/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ is_close.workspace = true
4646
serde_json5.workspace = true
4747
schemars.workspace = true
4848
mpsgraph.workspace = true
49+
tempfile.workspace = true
4950

5051
[build-dependencies]
5152
anyhow.workspace = true

crates/uzu/src/config/language_model.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ pub struct InnerModelConfig {
1919
pub embedding_config: EmbeddingConfig,
2020
pub transformer_config: TransformerConfig,
2121
pub vocab_size: usize,
22+
#[serde(default, skip_serializing_if = "Option::is_none")]
23+
pub pard_token: Option<u64>,
2224
}
2325

2426
impl InnerModelConfig {

0 commit comments

Comments
 (0)