diff --git a/modified_llama.py b/modified_llama.py index e1b395b..f8d363f 100644 --- a/modified_llama.py +++ b/modified_llama.py @@ -6,10 +6,7 @@ class ModifiedLlamaMLP(LlamaMLP): def __init__(self, config, scale_factors): - super().__init__( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act) + super().__init__(config) self.intermediate_size = config.intermediate_size self.scale_factors = scale_factors # List of scale factors for 's', 'm', 'l', 'xl' self.current_subset_hd = None