From a733985a94affdc9a54636d9d51080040ef2eae2 Mon Sep 17 00:00:00 2001 From: gandalf <183387594@qq.com> Date: Fri, 2 Aug 2024 18:37:32 +0800 Subject: [PATCH] fix 3x3kernel lora_merge --- facechain/merge_lora.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/facechain/merge_lora.py b/facechain/merge_lora.py index dbd19833..a205537f 100644 --- a/facechain/merge_lora.py +++ b/facechain/merge_lora.py @@ -77,7 +77,13 @@ def merge_lora(pipeline, alpha = 1.0 curr_layer.weight.data = curr_layer.weight.data.to(device) - if len(weight_up.shape) == 4: + + if weight_down.size()[2:4] == (3, 3): + if not hasattr(curr_layer.weight, 'data_restore'): + curr_layer.weight.data_restore = curr_layer.weight.data.clone() + curr_layer.weight.data += multiplier * alpha * conved = torch.nn.functional.conv2d( + weight_down.permute(1, 0, 2, 3), weight_up).permute(1, 0, 2, 3) + elif len(weight_up.shape) == 4: if not hasattr(curr_layer.weight, 'data_restore'): curr_layer.weight.data_restore = curr_layer.weight.data.clone() curr_layer.weight.data += multiplier * alpha * torch.mm( @@ -163,4 +169,4 @@ def restore_lora(pipeline, curr_layer.weight.data = curr_layer.weight.data.to(device) curr_layer.weight.data = curr_layer.weight.data_restore.clone() - return pipeline \ No newline at end of file + return pipeline