diff --git a/src/cogkit/api/services/image_generation.py b/src/cogkit/api/services/image_generation.py index b332f2a..672d8f0 100644 --- a/src/cogkit/api/services/image_generation.py +++ b/src/cogkit/api/services/image_generation.py @@ -57,15 +57,17 @@ def generate( if model not in self._models: raise ValueError(f"Model {model} not loaded") width, height = list(map(int, size.split("x"))) - # TODO: Refactor this to switch by LoRA endpoint API if lora_path != self._current_lora[model]: if lora_path is not None: adapter_name = os.path.basename(lora_path) - _logger.info(f"Loading LORA weights from {adapter_name}") + _logger.info( + f"Loading LORA weights from {adapter_name} and unload previous weights {self._current_lora[model]}" + ) + unload_lora_checkpoint(self._models[model]) load_lora_checkpoint(self._models[model], lora_path, lora_scale) else: - _logger.info("Unloading LORA weights") + _logger.info(f"Unloading LORA weights {self._current_lora[model]}") unload_lora_checkpoint(self._models[model]) self._current_lora[model] = lora_path diff --git a/src/cogkit/utils/lora.py b/src/cogkit/utils/lora.py index f0ccbe0..2c09d97 100644 --- a/src/cogkit/utils/lora.py +++ b/src/cogkit/utils/lora.py @@ -9,8 +9,8 @@ def load_lora_checkpoint( lora_model_id_or_path: str, lora_scale: float = 1.0, ) -> None: - pipeline.load_lora_weights(lora_model_id_or_path) - pipeline.fuse_lora(components=["transformer"], lora_scale=lora_scale) + pipeline.load_lora_weights(lora_model_id_or_path, lora_scale=lora_scale) + # pipeline.fuse_lora(components=["transformer"], lora_scale=lora_scale) def unload_lora_checkpoint(