Skip to content

Commit f976c40

Browse files
fix
1 parent cb51166 commit f976c40

2 files changed

Lines changed: 13 additions & 9 deletions

File tree

src/transformers/models/gpt2/modeling_gpt2_moreh.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,10 +1028,9 @@ def __init__(self, config):
10281028
# If moreh_gradient_checkpoint_layers_step is N,
10291029
# then 1st, (1+N)th, (1+2N)th, ... layer's input activations will be checkpointed
10301030
self.moreh_gradient_checkpoint_layers_step = None
1031-
if self.moreh_gradient_checkpoint_layers_step is not None and (
1032-
layer_idx %
1033-
self.moreh_gradient_checkpoint_layers_step) == 0:
1034-
hidden_states = torch.moreh.checkpoint_assign(hidden_states)
1031+
if moreh_config is not None and "gradient_checkpoint_layers_step" in moreh_config:
1032+
self.moreh_gradient_checkpoint_layers_step = moreh_config[
1033+
"gradient_checkpoint_layers_step"]
10351034

10361035
@add_start_docstrings(PARALLELIZE_DOCSTRING)
10371036
def parallelize(self, device_map=None):
@@ -1224,7 +1223,7 @@ def forward(
12241223
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
12251224
# Gradient checkpoint assign
12261225
if self.moreh_gradient_checkpoint_layers_step is not None and (
1227-
layer_idx %
1226+
i %
12281227
self.moreh_gradient_checkpoint_layers_step) == 0:
12291228
hidden_states = torch.moreh.checkpoint_assign(hidden_states)
12301229

src/transformers/models/mistral/modeling_mistral_moreh.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -930,10 +930,9 @@ def __init__(self, config: MistralMorehConfig):
930930
# If moreh_gradient_checkpoint_layers_step is N,
931931
# then 1st, (1+N)th, (1+2N)th, ... layer's input activations will be checkpointed
932932
self.moreh_gradient_checkpoint_layers_step = None
933-
if self.moreh_gradient_checkpoint_layers_step is not None and (
934-
layer_idx %
935-
self.moreh_gradient_checkpoint_layers_step) == 0:
936-
hidden_states = torch.moreh.checkpoint_assign(hidden_states)
933+
if moreh_config is not None and "gradient_checkpoint_layers_step" in moreh_config:
934+
self.moreh_gradient_checkpoint_layers_step = moreh_config[
935+
"gradient_checkpoint_layers_step"]
937936

938937
def get_input_embeddings(self):
939938
return self.embed_tokens
@@ -1008,6 +1007,12 @@ def forward(
10081007
next_decoder_cache = None
10091008

10101009
for layer_idx, decoder_layer in enumerate(self.layers):
1010+
# Gradient checkpoint assign
1011+
if self.moreh_gradient_checkpoint_layers_step is not None and (
1012+
layer_idx %
1013+
self.moreh_gradient_checkpoint_layers_step) == 0:
1014+
hidden_states = torch.moreh.checkpoint_assign(hidden_states)
1015+
10111016
if output_hidden_states:
10121017
all_hidden_states += (hidden_states,)
10131018

0 commit comments

Comments
 (0)