diff --git a/facechain/merge_lora.py b/facechain/merge_lora.py index dbd19833..a205537f 100644 --- a/facechain/merge_lora.py +++ b/facechain/merge_lora.py @@ -77,7 +77,13 @@ def merge_lora(pipeline, alpha = 1.0 curr_layer.weight.data = curr_layer.weight.data.to(device) - if len(weight_up.shape) == 4: + + if weight_down.size()[2:4] == (3, 3): + if not hasattr(curr_layer.weight, 'data_restore'): + curr_layer.weight.data_restore = curr_layer.weight.data.clone() + curr_layer.weight.data += multiplier * alpha * conved = torch.nn.functional.conv2d( + weight_down.permute(1, 0, 2, 3), weight_up).permute(1, 0, 2, 3) + elif len(weight_up.shape) == 4: if not hasattr(curr_layer.weight, 'data_restore'): curr_layer.weight.data_restore = curr_layer.weight.data.clone() curr_layer.weight.data += multiplier * alpha * torch.mm( @@ -163,4 +169,4 @@ def restore_lora(pipeline, curr_layer.weight.data = curr_layer.weight.data.to(device) curr_layer.weight.data = curr_layer.weight.data_restore.clone() - return pipeline \ No newline at end of file + return pipeline