Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions spnl/src/generate/backend/mistralrs/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ impl ModelPool {
}
}

/// Unload all models, releasing GPU memory.
pub async fn unload_all(&self) {
let mut models = self.models.write().await;
models.clear();
}

/// Get or load a model
pub async fn get_or_load(&self, model_name: &str) -> anyhow::Result<Arc<Model>> {
// Check if model is already loaded
Expand Down
6 changes: 6 additions & 0 deletions spnl/src/generate/backend/mistralrs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,12 @@ pub async fn generate_completion(
Ok(Query::Par(final_results))
}

/// Unload all models from the global pool, releasing GPU memory.
/// Call between benchmark runs to avoid accumulating models in VRAM.
pub async fn unload_all_models() {
get_model_pool().unload_all().await
}

/// Generate multiple completions for the same input (Repeat operation)
pub async fn generate_chat(
spec: Repeat,
Expand Down
9 changes: 9 additions & 0 deletions spnl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,12 @@ pub mod gce;

#[cfg(feature = "vllm")]
pub mod vllm;

/// Model pool management. Only available with the `local` feature.
#[cfg(feature = "local")]
pub mod model_pool {
/// Unload all models from the global pool, releasing GPU memory.
pub async fn unload_all() {
crate::generate::backend::mistralrs::unload_all_models().await
}
}
Loading