1515
1616from datasets .utils .logging import disable_progress_bar
1717
18- from patch_phi3_moe import patch_phi3moe
19-
2018def get_args ():
2119 parser = argparse .ArgumentParser ()
2220 parser .add_argument ("--model_name" , type = str , default = "meta-llama/Llama-2-7b-hf" )
@@ -27,6 +25,7 @@ def get_args():
2725 parser .add_argument ("--max_grad_norm" , type = float , default = 1.0 )
2826 parser .add_argument ("--gradient_accumulation_steps" , type = int , default = 1 )
2927 parser .add_argument ("--activation_checkpointing" , action = "store_true" )
28+ parser .add_argument ("--eval" , action = "store_true" )
3029 parser .add_argument ("--dataset_name" , type = str , default = "timdettmers/openassistant-guanaco" )
3130 parser .add_argument ("--num_layers" , type = int , default = 0 )
3231 parser .add_argument ("--attn_impl" , type = str , default = "spda" )
@@ -74,7 +73,7 @@ def main():
7473 args = get_args ()
7574 print (args )
7675
77- if "offload_adam_states" in args .passes :
76+ if args . passes is not None and "offload_adam_states" in args .passes :
7877 os .environ ['PYTORCH_CUDA_ALLOC_CONF' ] = 'max_split_size_mb:128'
7978
8079 if args .deterministic :
@@ -98,16 +97,13 @@ def main():
9897 model = AutoModelForCausalLM .from_pretrained (model_weight_path , trust_remote_code = True )
9998 else :
10099 if args .num_layers > 0 :
101- model_config = AutoConfig .from_pretrained (model_name , trust_remote_code = True )
100+ model_config = AutoConfig .from_pretrained (model_name , attn_implementation = args . attn_impl , trust_remote_code = True )
102101 print (f"num_hidden_layers: { model_config .num_hidden_layers } -> { args .num_layers } " )
103102 model_config .num_hidden_layers = args .num_layers
104103 model = AutoModelForCausalLM .from_config (model_config , trust_remote_code = True )
105104 else :
106105 model = AutoModelForCausalLM .from_pretrained (model_name , trust_remote_code = True )
107106
108- if patch_phi3moe (model ) and accelerator .is_main_process :
109- print ("Patched Phi-3.5-MoE model" )
110-
111107 tokenizer = AutoTokenizer .from_pretrained (model_name , trust_remote_code = True )
112108
113109 if args .save_weights and accelerator .is_main_process :
@@ -149,7 +145,6 @@ def tokenize_function(examples):
149145 torch ._dynamo .config .capture_dynamic_output_shape_ops = True
150146 torch ._dynamo .config .capture_scalar_outputs = True
151147
152-
153148 if is_deepspeed :
154149 if args .compile :
155150 schedule = make_schedule (args .passes .split ("," ), warmup = 5 ) if args .passes else None
@@ -185,10 +180,13 @@ def tokenize_function(examples):
185180 on_trace_ready = torch .profiler .tensorboard_trace_handler (prof_dir ),
186181 ) if do_profile else nullcontext ()
187182
188- # Training loop
189- model .train ()
190- global_step = 0
183+ # Training
184+ if args .eval :
185+ model .eval ()
186+ else :
187+ model .train ()
191188
189+ global_step = 0
192190 iter_times = []
193191
194192 # See https://github.com/microsoft/DeepSpeed/issues/6793
0 commit comments