@@ -1017,11 +1017,21 @@ def __init__(self, config):
10171017 self .post_init ()
10181018
10191019 # Moreh Config
1020- self .moreh_pipeline_layers = []
10211020 moreh_config = getattr (config , "moreh_config" , None )
1021+
1022+ # Moreh Pipeline Layers
1023+ self .moreh_pipeline_layers = []
10221024 if moreh_config is not None and "pipeline_layers" in moreh_config :
10231025 self .moreh_pipeline_layers = moreh_config ["pipeline_layers" ]
10241026
1027+ # Moreh Gradient Checkpoint Layers Step
1028+ # If moreh_gradient_checkpoint_layers_step is N,
1029+ # then 1st, (1+N)th, (1+2N)th, ... layer's input activations will be checkpointed
1030+ self .moreh_gradient_checkpoint_layers_step = None
1031+ if self .moreh_gradient_checkpoint_layers_step is not None and (
1032+ layer_idx %
1033+ self .moreh_gradient_checkpoint_layers_step ) == 0 :
1034+ hidden_states = torch .moreh .checkpoint_assign (hidden_states )
10251035
10261036 @add_start_docstrings (PARALLELIZE_DOCSTRING )
10271037 def parallelize (self , device_map = None ):
@@ -1212,6 +1222,12 @@ def forward(
12121222 all_cross_attentions = () if output_attentions and self .config .add_cross_attention else None
12131223 all_hidden_states = () if output_hidden_states else None
12141224 for i , (block , layer_past ) in enumerate (zip (self .h , past_key_values )):
1225+ # Gradient checkpoint assign
1226+ if self .moreh_gradient_checkpoint_layers_step is not None and (
1227+ layer_idx %
1228+ self .moreh_gradient_checkpoint_layers_step ) == 0 :
1229+ hidden_states = torch .moreh .checkpoint_assign (hidden_states )
1230+
12151231 # Model parallel
12161232 if self .model_parallel :
12171233 torch .cuda .set_device (hidden_states .device )
@@ -2075,4 +2091,4 @@ def _reorder_cache(
20752091# hidden_states=outputs.hidden_states,
20762092# attentions=outputs.attentions,
20772093# )
2078- #
2094+ #
0 commit comments