Skip to content

Commit d7e6a0d

Browse files
Perf: parallel extract workers + download largest wheels first
Two optimizations to reduce pipeline stalls: 1. Sort queue by size descending — start downloading the biggest wheels (torch 873MB, nvidia-cudnn 674MB) first, so extraction can begin sooner and pipeline better with remaining downloads. 2. Multiple extract workers (4) — instead of one extract worker processing wheels serially, spawn N workers pulling from the same channel. Each gets extract_threads/N rayon threads. Prevents small wheels from queuing behind large ones. Cold start: 34.9s → 32.6s (health), 43.7s → 41.8s (inference) Warm start: 5.3s → 4.6s Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a88c2f2 commit d7e6a0d

2 files changed

Lines changed: 123 additions & 86 deletions

File tree

crates/zs-fast-wheel/src/daemon.rs

Lines changed: 104 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@ pub struct DaemonConfig {
2929
pub extract_threads: usize,
3030
}
3131

32+
impl DaemonConfig {
33+
/// Number of parallel extract workers. Each worker pulls wheels from the
34+
/// channel and extracts independently, preventing small wheels from queuing
35+
/// behind large ones. Each worker gets `extract_threads / workers` rayon threads.
36+
pub fn extract_workers(&self) -> usize {
37+
// 4 workers is a good default: allows 4 wheels to extract simultaneously,
38+
// each with ~6-7 threads on a 26-core machine.
39+
// Minimum 1, cap at extract_threads (no point having more workers than threads).
40+
4.min(self.extract_threads)
41+
}
42+
}
43+
3244
impl Default for DaemonConfig {
3345
fn default() -> Self {
3446
Self {
@@ -121,10 +133,11 @@ impl DaemonEngine {
121133
.pool_max_idle_per_host(config.parallel_downloads)
122134
.build()?;
123135

124-
// Channel: downloaded wheels flow from download workers → extract worker.
125-
// Small capacity (4) provides backpressure — if extraction is slow,
126-
// downloads pause rather than filling disk with temp files.
127-
let (tx, rx) = tokio::sync::mpsc::channel::<DownloadedWheel>(4);
136+
// Channel: downloaded wheels flow from download workers → extract workers.
137+
// Capacity = 2 * extract_workers provides enough buffering for workers to
138+
// stay busy while providing backpressure to avoid filling disk with temp files.
139+
let num_workers = config.extract_workers();
140+
let (tx, rx) = tokio::sync::mpsc::channel::<DownloadedWheel>(num_workers * 2);
128141

129142
let tmp_dir = tempfile::tempdir().context("failed to create temp dir")?;
130143
let tmp_path = tmp_dir.path().to_path_buf();
@@ -191,76 +204,95 @@ impl DaemonEngine {
191204
drop(tx);
192205

193206
// === Extract stage ===
194-
// Single blocking loop: receives downloaded wheels, extracts each immediately.
195-
// Extraction uses all extract_threads for parallelism within a single wheel.
196-
let site_packages = config.site_packages.clone();
197-
let ext_threads = config.extract_threads;
198-
let stats = self.stats.clone();
199-
let completion = self.completion.clone();
200-
let queue = self.queue.clone();
201-
let total_wheels = self.total_wheels;
202-
203-
let extract_handle = tokio::task::spawn_blocking(move || {
204-
let rx = rx;
205-
// blocking_recv in a loop — channel closes when all downloads finish
206-
let mut rx = rx;
207-
while let Some(downloaded) = rx.blocking_recv() {
208-
let dist = downloaded.spec.distribution.clone();
209-
let extract_start = Instant::now();
210-
211-
let result = extract::extract_wheel_atomic(
212-
&downloaded.path,
213-
&site_packages,
214-
&dist,
215-
ext_threads,
216-
true,
217-
&stats,
218-
);
219-
220-
let (lock, cvar) = &*completion;
221-
222-
match result {
223-
Ok(()) => {
224-
let elapsed = extract_start.elapsed();
225-
tracing::info!(
226-
"[{dist}] extracted in {:.1}s",
227-
elapsed.as_secs_f64()
228-
);
229-
230-
{
231-
let mut q = queue.lock().unwrap();
232-
q.mark_done(&dist);
207+
// Multiple extract workers pull from the same channel, extracting different
208+
// wheels in parallel. Each worker gets a share of the total extract threads.
209+
// This prevents small wheels from queuing behind large ones.
210+
let num_extract_workers = config.extract_workers();
211+
let threads_per_worker = (config.extract_threads / num_extract_workers).max(1);
212+
let rx = Arc::new(tokio::sync::Mutex::new(rx));
213+
214+
let mut extract_handles = Vec::new();
215+
for worker_id in 0..num_extract_workers {
216+
let site_packages = config.site_packages.clone();
217+
let stats = self.stats.clone();
218+
let completion = self.completion.clone();
219+
let queue = self.queue.clone();
220+
let total_wheels = self.total_wheels;
221+
let rx = rx.clone();
222+
223+
let handle = tokio::task::spawn_blocking(move || {
224+
loop {
225+
// Lock channel briefly to receive next wheel
226+
let downloaded = {
227+
let mut rx = rx.blocking_lock();
228+
rx.blocking_recv()
229+
};
230+
let downloaded = match downloaded {
231+
Some(d) => d,
232+
None => break, // channel closed
233+
};
234+
235+
let dist = downloaded.spec.distribution.clone();
236+
let extract_start = Instant::now();
237+
238+
tracing::debug!("[{dist}] extract worker {worker_id} starting");
239+
240+
let result = extract::extract_wheel_atomic(
241+
&downloaded.path,
242+
&site_packages,
243+
&dist,
244+
threads_per_worker,
245+
true,
246+
&stats,
247+
);
248+
249+
let (lock, cvar) = &*completion;
250+
251+
match result {
252+
Ok(()) => {
253+
let elapsed = extract_start.elapsed();
254+
tracing::info!(
255+
"[{dist}] extracted in {:.1}s (worker {worker_id})",
256+
elapsed.as_secs_f64()
257+
);
258+
259+
{
260+
let mut q = queue.lock().unwrap();
261+
q.mark_done(&dist);
262+
}
263+
264+
let mut state = lock.lock().unwrap();
265+
state.done.insert(dist);
266+
if state.done.len() + state.failed.len() >= total_wheels {
267+
state.all_finished = true;
268+
}
269+
cvar.notify_all();
233270
}
234-
235-
let mut state = lock.lock().unwrap();
236-
state.done.insert(dist);
237-
if state.done.len() + state.failed.len() >= total_wheels {
238-
state.all_finished = true;
271+
Err(e) => {
272+
let err_msg = format!("{e:#}");
273+
tracing::error!("[{dist}] extraction failed: {err_msg}");
274+
275+
{
276+
let mut q = queue.lock().unwrap();
277+
q.mark_failed(&dist);
278+
}
279+
280+
let mut state = lock.lock().unwrap();
281+
state.failed.insert(dist, err_msg);
282+
if state.done.len() + state.failed.len() >= total_wheels {
283+
state.all_finished = true;
284+
}
285+
cvar.notify_all();
239286
}
240-
cvar.notify_all();
241287
}
242-
Err(e) => {
243-
let err_msg = format!("{e:#}");
244-
tracing::error!("[{dist}] extraction failed: {err_msg}");
245-
246-
{
247-
let mut q = queue.lock().unwrap();
248-
q.mark_failed(&dist);
249-
}
250288

251-
let mut state = lock.lock().unwrap();
252-
state.failed.insert(dist, err_msg);
253-
if state.done.len() + state.failed.len() >= total_wheels {
254-
state.all_finished = true;
255-
}
256-
cvar.notify_all();
257-
}
289+
// Clean up temp file
290+
let _ = std::fs::remove_file(&downloaded.path);
258291
}
292+
});
259293

260-
// Clean up temp file
261-
let _ = std::fs::remove_file(&downloaded.path);
262-
}
263-
});
294+
extract_handles.push(handle);
295+
}
264296

265297
// === Wait for download failures ===
266298
// Collect download errors and mark them as failed
@@ -277,8 +309,10 @@ impl DaemonEngine {
277309
}
278310
}
279311

280-
// Wait for extract worker to finish
281-
extract_handle.await?;
312+
// Wait for all extract workers to finish
313+
for handle in extract_handles {
314+
handle.await?;
315+
}
282316

283317
// Mark all finished
284318
{

crates/zs-fast-wheel/src/queue.rs

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ use crate::manifest::WheelSpec;
44

55
/// Priority queue for wheel installation scheduling.
66
///
7-
/// Default order: small wheels first, large last.
7+
/// Default order: large wheels first (start downloading big wheels early so
8+
/// extraction can pipeline with remaining downloads).
89
/// Supports demand-driven reprioritization via `prioritize()`.
910
pub struct InstallQueue {
1011
/// Wheels not yet started, ordered by priority
@@ -17,9 +18,10 @@ pub struct InstallQueue {
1718

1819
impl InstallQueue {
1920
/// Create a new queue from a list of wheel specs.
20-
/// Sorts by size ascending (small wheels first).
21+
/// Sorts by size descending (large wheels first — download big wheels early
22+
/// so extraction starts sooner and pipelines better with remaining downloads).
2123
pub fn new(mut wheels: Vec<WheelSpec>) -> Self {
22-
wheels.sort_by_key(|w| w.size);
24+
wheels.sort_by_key(|w| std::cmp::Reverse(w.size));
2325
Self {
2426
pending: wheels.into(),
2527
in_progress: HashSet::new(),
@@ -99,7 +101,7 @@ mod tests {
99101
}
100102

101103
#[test]
102-
fn test_sorts_by_size_ascending() {
104+
fn test_sorts_by_size_descending() {
103105
let wheels = vec![
104106
make_wheel("torch", 900_000_000),
105107
make_wheel("six", 12_000),
@@ -108,11 +110,11 @@ mod tests {
108110
let mut queue = InstallQueue::new(wheels);
109111

110112
let first = queue.next().unwrap();
111-
assert_eq!(first.distribution, "six");
113+
assert_eq!(first.distribution, "torch");
112114
let second = queue.next().unwrap();
113115
assert_eq!(second.distribution, "numpy");
114116
let third = queue.next().unwrap();
115-
assert_eq!(third.distribution, "torch");
117+
assert_eq!(third.distribution, "six");
116118
assert!(queue.next().is_none());
117119
}
118120

@@ -125,9 +127,10 @@ mod tests {
125127
];
126128
let mut queue = InstallQueue::new(wheels);
127129

128-
queue.prioritize("torch");
130+
// six is last (smallest) — prioritize moves it to front
131+
queue.prioritize("six");
129132
let first = queue.next().unwrap();
130-
assert_eq!(first.distribution, "torch");
133+
assert_eq!(first.distribution, "six");
131134
}
132135

133136
#[test]
@@ -139,13 +142,13 @@ mod tests {
139142
let mut queue = InstallQueue::new(wheels);
140143

141144
let first = queue.next().unwrap();
142-
assert_eq!(first.distribution, "six");
143-
queue.mark_done("six");
145+
assert_eq!(first.distribution, "torch"); // largest first
146+
queue.mark_done("torch");
144147

145148
// Prioritizing a done package should be a no-op
146-
queue.prioritize("six");
149+
queue.prioritize("torch");
147150
let second = queue.next().unwrap();
148-
assert_eq!(second.distribution, "torch");
151+
assert_eq!(second.distribution, "six");
149152
}
150153

151154
#[test]
@@ -157,13 +160,13 @@ mod tests {
157160
];
158161
let mut queue = InstallQueue::new(wheels);
159162

160-
let first = queue.next().unwrap(); // six is now in_progress
161-
assert_eq!(first.distribution, "six");
163+
let first = queue.next().unwrap(); // torch is now in_progress (largest first)
164+
assert_eq!(first.distribution, "torch");
162165

163166
// Prioritizing an in-progress package should be a no-op
164-
queue.prioritize("six");
167+
queue.prioritize("torch");
165168
let second = queue.next().unwrap();
166-
assert_eq!(second.distribution, "numpy"); // not six again
169+
assert_eq!(second.distribution, "numpy"); // not torch again
167170
}
168171

169172
#[test]

0 commit comments

Comments
 (0)