@@ -111,39 +111,41 @@ def generate_command(args):
111111 history_ids = set (input_ids [0 ].tolist ())
112112
113113 print ("-" * 50 )
114- print (f"Prompt: { prompt } " )
115- print ("Generated Continuation:" )
116-
117- for step in range (max_new_tokens ):
118- # Check if we should use autocast (skip if model uses float32)
119- use_autocast = True
120- if config .torch_dtype == torch .float32 :
121- use_autocast = False
122-
123- if use_autocast :
124- with torch .amp .autocast ('cuda' if device .type == 'cuda' else 'cpu' , dtype = model .config .torch_dtype ):
125- outputs = model (generated_ids )
126- logits = outputs ['logits' ]
127- next_token_logits = logits [:, - 1 , :]
128- else :
114+ print (f"Prompt: { prompt } " )
115+ print ("Generated Continuation:" )
116+
117+ # Start generation loop
118+ for step in range (max_new_tokens ):
119+ # Check if we should use autocast (skip if model uses float32)
120+ use_autocast = True
121+ if config .torch_dtype == torch .float32 :
122+ use_autocast = False
123+
124+ if use_autocast :
125+ with torch .amp .autocast ('cuda' if device .type == 'cuda' else 'cpu' , dtype = model .config .torch_dtype ):
129126 outputs = model (generated_ids )
130127 logits = outputs ['logits' ]
131128 next_token_logits = logits [:, - 1 , :]
132-
133- # --- DEBUG: Print Top Predictions for First Step ---
134- if step == 0 :
135- probs = F .softmax (next_token_logits , dim = - 1 )
136- top_probs , top_indices = torch .topk (probs , 5 )
137- print ("\n [DEBUG] Step 0 Top-5 Predictions:" )
138- for i in range (5 ):
139- token_idx = top_indices [0 , i ].item ()
140- prob = top_probs [0 , i ].item ()
141- token_str = tokenizer .decode ([token_idx ])
142- print (f" { i + 1 } . '{ token_str } ' ({ prob :.4f} )" )
143- print ("-----------------------------------" )
144- # ---------------------------------------------------
145-
146- # Repetition penalty for token_id in history_ids:
129+ else :
130+ outputs = model (generated_ids )
131+ logits = outputs ['logits' ]
132+ next_token_logits = logits [:, - 1 , :]
133+
134+ # --- DEBUG: Print Top Predictions for First Step ---
135+ if step == 0 :
136+ probs = F .softmax (next_token_logits , dim = - 1 )
137+ top_probs , top_indices = torch .topk (probs , 5 )
138+ print ("\n [DEBUG] Step 0 Top-5 Predictions:" )
139+ for i in range (5 ):
140+ token_idx = top_indices [0 , i ].item ()
141+ prob = top_probs [0 , i ].item ()
142+ token_str = tokenizer .decode ([token_idx ])
143+ print (f" { i + 1 } . '{ token_str } ' ({ prob :.4f} )" )
144+ print ("-----------------------------------" )
145+ # ---------------------------------------------------
146+
147+ # Repetition penalty
148+ for token_id in history_ids :
147149 if token_id < next_token_logits .size (- 1 ):
148150 logit = next_token_logits [0 , token_id ].item ()
149151 if logit > 0 :
0 commit comments