diff --git a/modified_llama.py b/modified_llama.py index e1b395b..d647c6c 100644 --- a/modified_llama.py +++ b/modified_llama.py @@ -34,7 +34,7 @@ def forward(self, x): gate_proj = self.gate_proj.weight[:self.current_subset_hd] up_proj = self.up_proj.weight[:self.current_subset_hd] down_proj = self.down_proj.weight[:, :self.current_subset_hd] - down_proj = F.linear(self.act_fn(F.linear(x, gate_proj) * F.linear(x, up_proj)), down_proj) + down_proj = F.linear(self.act_fn(F.linear(x, gate_proj)) * F.linear(x, up_proj), down_proj) self.current_subset_hd = None return down_proj