diff --git a/spnl/src/generate/backend/mistralrs/loader.rs b/spnl/src/generate/backend/mistralrs/loader.rs index f261595a..4eca3b65 100644 --- a/spnl/src/generate/backend/mistralrs/loader.rs +++ b/spnl/src/generate/backend/mistralrs/loader.rs @@ -185,6 +185,12 @@ impl ModelPool { } } + /// Unload all models, releasing GPU memory. + pub async fn unload_all(&self) { + let mut models = self.models.write().await; + models.clear(); + } + /// Get or load a model pub async fn get_or_load(&self, model_name: &str) -> anyhow::Result> { // Check if model is already loaded diff --git a/spnl/src/generate/backend/mistralrs/mod.rs b/spnl/src/generate/backend/mistralrs/mod.rs index 29175d6b..0d612fa6 100644 --- a/spnl/src/generate/backend/mistralrs/mod.rs +++ b/spnl/src/generate/backend/mistralrs/mod.rs @@ -271,6 +271,12 @@ pub async fn generate_completion( Ok(Query::Par(final_results)) } +/// Unload all models from the global pool, releasing GPU memory. +/// Call between benchmark runs to avoid accumulating models in VRAM. +pub async fn unload_all_models() { + get_model_pool().unload_all().await +} + /// Generate multiple completions for the same input (Repeat operation) pub async fn generate_chat( spec: Repeat, diff --git a/spnl/src/lib.rs b/spnl/src/lib.rs index 2b9e74c7..f2c931c2 100644 --- a/spnl/src/lib.rs +++ b/spnl/src/lib.rs @@ -30,3 +30,12 @@ pub mod gce; #[cfg(feature = "vllm")] pub mod vllm; + +/// Model pool management. Only available with the `local` feature. +#[cfg(feature = "local")] +pub mod model_pool { + /// Unload all models from the global pool, releasing GPU memory. + pub async fn unload_all() { + crate::generate::backend::mistralrs::unload_all_models().await + } +}