Add Trackio rollout trace logging#1697
Conversation
There was a problem hiding this comment.
Code Review
This pull request integrates Trackio for logging GRPO rollout traces, adding documentation, configuration settings, and a new TrackioRolloutLogger utility. Feedback highlights the need to prevent potential IndexError crashes when accessing batch metadata and suggests using a try...finally block to ensure the logger is properly closed. Additionally, the reviewer recommended using standard imports for mandatory dependencies and replacing runtime assertions with explicit conditional checks.
| if batch.indices is not None: | ||
| metadata["dataset_index"] = batch.indices[i] | ||
| if batch.model_steps: | ||
| metadata["model_step"] = batch.model_steps[i] |
There was a problem hiding this comment.
There is a risk of IndexError if batch.indices or batch.model_steps have fewer elements than batch.queries. While these are typically aligned, adding a length check ensures the logging utility doesn't crash the training process if the data structure is unexpected.
| if batch.indices is not None: | |
| metadata["dataset_index"] = batch.indices[i] | |
| if batch.model_steps: | |
| metadata["model_step"] = batch.model_steps[i] | |
| if batch.indices is not None and i < len(batch.indices): | |
| metadata["dataset_index"] = batch.indices[i] | |
| if batch.model_steps and i < len(batch.model_steps): | |
| metadata["model_step"] = batch.model_steps[i] |
| prompt = ( | ||
| batch.raw_queries[i] | ||
| if batch.raw_queries is not None and batch.raw_queries[i] is not None | ||
| else self.tokenizer.decode(batch.queries[i], skip_special_tokens=False) | ||
| ) | ||
| response = ( | ||
| batch.decoded_responses[i] | ||
| if batch.decoded_responses is not None and batch.decoded_responses[i] is not None | ||
| else self.tokenizer.decode(result.responses[i], skip_special_tokens=False) | ||
| ) |
There was a problem hiding this comment.
Accessing batch.raw_queries[i] and batch.decoded_responses[i] without checking their length relative to i can lead to an IndexError. Defensive checks should be added to prevent crashing the training process due to a telemetry failure.
| prompt = ( | |
| batch.raw_queries[i] | |
| if batch.raw_queries is not None and batch.raw_queries[i] is not None | |
| else self.tokenizer.decode(batch.queries[i], skip_special_tokens=False) | |
| ) | |
| response = ( | |
| batch.decoded_responses[i] | |
| if batch.decoded_responses is not None and batch.decoded_responses[i] is not None | |
| else self.tokenizer.decode(result.responses[i], skip_special_tokens=False) | |
| ) | |
| prompt = ( | |
| batch.raw_queries[i] | |
| if batch.raw_queries is not None and i < len(batch.raw_queries) and batch.raw_queries[i] is not None | |
| else self.tokenizer.decode(batch.queries[i], skip_special_tokens=False) | |
| ) | |
| response = ( | |
| batch.decoded_responses[i] | |
| if batch.decoded_responses is not None and i < len(batch.decoded_responses) and batch.decoded_responses[i] is not None | |
| else self.tokenizer.decode(result.responses[i], skip_special_tokens=False) | |
| ) |
| if self.trackio_rollout_logger is not None: | ||
| self.trackio_rollout_logger.close() | ||
| return |
There was a problem hiding this comment.
To ensure that the Trackio session is properly finalized and resources are released, the close() call should ideally be handled in a try...finally block covering the data preparation loop. This ensures trackio.finish() is called even if the loop terminates due to an unhandled exception, preventing orphaned logging sessions.
| ): | ||
| self.tokenizer = tokenizer | ||
| self.max_traces_per_step = max_traces_per_step | ||
| self.trackio = importlib.import_module("trackio") |
There was a problem hiding this comment.
trackio is listed as a mandatory dependency in pyproject.toml. Using importlib.import_module inside __init__ is unnecessary and less idiomatic than a standard top-level import. If trackio is intended to be an optional dependency (only required when trace logging is enabled), it should be moved to optional-dependencies in pyproject.toml. Otherwise, a standard import trackio at the top of the file is preferred.
| if self.max_traces_per_step <= 0: | ||
| return | ||
|
|
||
| assert batch.scores is not None, "batch.scores must not be None when logging Trackio traces" |
There was a problem hiding this comment.
Hi folks! This PR adds trace logging via Trackio, the free, local-first experiment tracking library from Hugging Face 🤗
This PR follows Open-Instruct's existing rollout trace saving path, specifically I did this:
StreamingDataLoaderConfig.trackio_projectsupport to enable Trackio rollout tracestrackio.Tracerecords from the data preparation actortrackio_max_traces_per_stepto cap trace volume per training step, plus optionaltrackio_space_idfor remote loggingI tested it end-to-end and here's what it looks like:
AI assistance was used to prepare this PR.