Skip to content

Commit cb51166

Browse files
apply gradient checkpoint config
1 parent b6f47b1 commit cb51166

2 files changed

Lines changed: 31 additions & 4 deletions

File tree

src/transformers/models/gpt2/modeling_gpt2_moreh.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,11 +1017,21 @@ def __init__(self, config):
10171017
self.post_init()
10181018

10191019
# Moreh Config
1020-
self.moreh_pipeline_layers = []
10211020
moreh_config = getattr(config, "moreh_config", None)
1021+
1022+
# Moreh Pipeline Layers
1023+
self.moreh_pipeline_layers = []
10221024
if moreh_config is not None and "pipeline_layers" in moreh_config:
10231025
self.moreh_pipeline_layers = moreh_config["pipeline_layers"]
10241026

1027+
# Moreh Gradient Checkpoint Layers Step
1028+
# If moreh_gradient_checkpoint_layers_step is N,
1029+
# then 1st, (1+N)th, (1+2N)th, ... layer's input activations will be checkpointed
1030+
self.moreh_gradient_checkpoint_layers_step = None
1031+
if self.moreh_gradient_checkpoint_layers_step is not None and (
1032+
layer_idx %
1033+
self.moreh_gradient_checkpoint_layers_step) == 0:
1034+
hidden_states = torch.moreh.checkpoint_assign(hidden_states)
10251035

10261036
@add_start_docstrings(PARALLELIZE_DOCSTRING)
10271037
def parallelize(self, device_map=None):
@@ -1212,6 +1222,12 @@ def forward(
12121222
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
12131223
all_hidden_states = () if output_hidden_states else None
12141224
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
1225+
# Gradient checkpoint assign
1226+
if self.moreh_gradient_checkpoint_layers_step is not None and (
1227+
layer_idx %
1228+
self.moreh_gradient_checkpoint_layers_step) == 0:
1229+
hidden_states = torch.moreh.checkpoint_assign(hidden_states)
1230+
12151231
# Model parallel
12161232
if self.model_parallel:
12171233
torch.cuda.set_device(hidden_states.device)
@@ -2075,4 +2091,4 @@ def _reorder_cache(
20752091
# hidden_states=outputs.hidden_states,
20762092
# attentions=outputs.attentions,
20772093
# )
2078-
#
2094+
#

src/transformers/models/mistral/modeling_mistral_moreh.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -919,11 +919,22 @@ def __init__(self, config: MistralMorehConfig):
919919
self.post_init()
920920

921921
# Moreh Config
922-
self.moreh_pipeline_layers = []
923922
moreh_config = getattr(config, "moreh_config", None)
923+
924+
# Moreh Pipeline Layers
925+
self.moreh_pipeline_layers = []
924926
if moreh_config is not None and "pipeline_layers" in moreh_config:
925927
self.moreh_pipeline_layers = moreh_config["pipeline_layers"]
926928

929+
# Moreh Gradient Checkpoint Layers Step
930+
# If moreh_gradient_checkpoint_layers_step is N,
931+
# then 1st, (1+N)th, (1+2N)th, ... layer's input activations will be checkpointed
932+
self.moreh_gradient_checkpoint_layers_step = None
933+
if self.moreh_gradient_checkpoint_layers_step is not None and (
934+
layer_idx %
935+
self.moreh_gradient_checkpoint_layers_step) == 0:
936+
hidden_states = torch.moreh.checkpoint_assign(hidden_states)
937+
927938
def get_input_embeddings(self):
928939
return self.embed_tokens
929940

@@ -1579,4 +1590,4 @@ def _reorder_cache(past_key_values, beam_idx):
15791590
# hidden_states=outputs.hidden_states,
15801591
# attentions=outputs.attentions,
15811592
# )
1582-
#
1593+
#

0 commit comments

Comments
 (0)