Skip to content

Commit f750d0b

Browse files
committed
fix: correct indentation in cli main loop
1 parent 4808395 commit f750d0b

1 file changed

Lines changed: 32 additions & 30 deletions

File tree

aetheris/cli/main.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -111,39 +111,41 @@ def generate_command(args):
111111
history_ids = set(input_ids[0].tolist())
112112

113113
print("-" * 50)
114-
print(f"Prompt: {prompt}")
115-
print("Generated Continuation:")
116-
117-
for step in range(max_new_tokens):
118-
# Check if we should use autocast (skip if model uses float32)
119-
use_autocast = True
120-
if config.torch_dtype == torch.float32:
121-
use_autocast = False
122-
123-
if use_autocast:
124-
with torch.amp.autocast('cuda' if device.type == 'cuda' else 'cpu', dtype=model.config.torch_dtype):
125-
outputs = model(generated_ids)
126-
logits = outputs['logits']
127-
next_token_logits = logits[:, -1, :]
128-
else:
114+
print(f"Prompt: {prompt}")
115+
print("Generated Continuation:")
116+
117+
# Start generation loop
118+
for step in range(max_new_tokens):
119+
# Check if we should use autocast (skip if model uses float32)
120+
use_autocast = True
121+
if config.torch_dtype == torch.float32:
122+
use_autocast = False
123+
124+
if use_autocast:
125+
with torch.amp.autocast('cuda' if device.type == 'cuda' else 'cpu', dtype=model.config.torch_dtype):
129126
outputs = model(generated_ids)
130127
logits = outputs['logits']
131128
next_token_logits = logits[:, -1, :]
132-
133-
# --- DEBUG: Print Top Predictions for First Step ---
134-
if step == 0:
135-
probs = F.softmax(next_token_logits, dim=-1)
136-
top_probs, top_indices = torch.topk(probs, 5)
137-
print("\n[DEBUG] Step 0 Top-5 Predictions:")
138-
for i in range(5):
139-
token_idx = top_indices[0, i].item()
140-
prob = top_probs[0, i].item()
141-
token_str = tokenizer.decode([token_idx])
142-
print(f" {i+1}. '{token_str}' ({prob:.4f})")
143-
print("-----------------------------------")
144-
# ---------------------------------------------------
145-
146-
# Repetition penalty for token_id in history_ids:
129+
else:
130+
outputs = model(generated_ids)
131+
logits = outputs['logits']
132+
next_token_logits = logits[:, -1, :]
133+
134+
# --- DEBUG: Print Top Predictions for First Step ---
135+
if step == 0:
136+
probs = F.softmax(next_token_logits, dim=-1)
137+
top_probs, top_indices = torch.topk(probs, 5)
138+
print("\n[DEBUG] Step 0 Top-5 Predictions:")
139+
for i in range(5):
140+
token_idx = top_indices[0, i].item()
141+
prob = top_probs[0, i].item()
142+
token_str = tokenizer.decode([token_idx])
143+
print(f" {i+1}. '{token_str}' ({prob:.4f})")
144+
print("-----------------------------------")
145+
# ---------------------------------------------------
146+
147+
# Repetition penalty
148+
for token_id in history_ids:
147149
if token_id < next_token_logits.size(-1):
148150
logit = next_token_logits[0, token_id].item()
149151
if logit > 0:

0 commit comments

Comments
 (0)