Skip to content

Commit 436071f

Browse files
committed
fix: pass gts argument in _dump_generations call in _train_step
1 parent c746af2 commit 436071f

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

agentlightning/verl/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,10 +417,14 @@ def _train_step(self, batch_dict: dict) -> dict:
417417
print(batch.batch.keys())
418418
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
419419
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
420+
sample_gts = [
421+
item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch
422+
]
420423
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
421424
self._dump_generations(
422425
inputs=inputs,
423426
outputs=outputs,
427+
gts=sample_gts,
424428
scores=scores,
425429
reward_extra_infos_dict=reward_extra_infos_dict,
426430
dump_path=rollout_data_dir,

0 commit comments

Comments
 (0)