-
Notifications
You must be signed in to change notification settings - Fork 234
Stage 5: Add TinkerTrainingAdapter for forward_backward/optim_step #938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Stage 5: Add TinkerTrainingAdapter for forward_backward/optim_step #938
Conversation
Implements Tinker-compatible training operations through SkyRL's WorkerDispatch.
**Architecture:**
```
Tinker API (/api/v1/forward_backward, /api/v1/optim_step)
↓
SkyRLTrainingClient (skyrl-tx) - Type conversion + DB storage
↓
TinkerTrainingAdapter (skyrl-train) - Core training logic
↓
WorkerDispatch.forward_backward() / optim_step()
```
**Components:**
skyrl-train (core training logic):
- TinkerTrainingAdapter: Converts Tinker format → WorkerDispatch calls
- Supports loss functions: cross_entropy, importance_sampling, ppo
- Maps Tinker Datum to TrainingInputBatch with left-padding
- Async wrappers around WorkerDispatch methods
- 16 unit tests covering all loss functions and edge cases
skyrl-tx (API integration):
- SkyRLTrainingClient: Thin wrapper for pydantic conversion + database
- call_forward_backward_and_store(): Background task for async API
- call_optim_step_and_store(): Background task for optimizer step
- attach_skyrl_training(): Easy integration with FastAPI app
**Loss Function Support:**
- cross_entropy: Supervised learning (requires: target_tokens, weights)
- importance_sampling: REINFORCE with IS (requires: target_tokens, logprobs, advantages)
- ppo: PPO with clipping (same as IS, mapped to SkyRL's "regular" loss)
**Tests:**
✅ 16/16 CPU unit tests passing
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a TinkerTrainingAdapter to bridge Tinker-style training operations with SkyRL's WorkerDispatch, along with the SkyRLTrainingClient for API integration and a comprehensive set of unit tests. A critical Denial of Service vulnerability was identified where the TinkerTrainingAdapter defines async methods that perform blocking synchronous calls to Ray, which will freeze the FastAPI event loop and make the server unresponsive during training operations. This should be remediated by offloading the blocking calls to a thread pool. Additionally, the review identified a critical issue regarding potential crashes in background tasks due to missing None checks when retrieving database records, a high-severity bug where important metadata is overwritten, and a medium-severity issue related to inconsistent data handling within the adapter.
| status = RequestStatus.FAILED | ||
|
|
||
| async with AsyncSession(self.db_engine) as session: | ||
| future = await session.get(FutureDB, request_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| status = RequestStatus.FAILED | ||
|
|
||
| async with AsyncSession(self.db_engine) as session: | ||
| future = await session.get(FutureDB, request_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # Call WorkerDispatch forward_backward | ||
| # Note: WorkerDispatch.forward_backward is synchronous, but we make this | ||
| # method async for consistency with Tinker's API | ||
| metrics = self.worker_dispatch.forward_backward(self.model_name, training_batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward_backward method is defined as asynchronous, but it performs a synchronous, blocking operation by calling self.worker_dispatch.forward_backward. In the provided context, WorkerDispatch.forward_backward internally calls ray.get(), which blocks the execution thread until the remote task completes. Since this is called within the FastAPI event loop (via the background task in SkyRLTrainingClient), it will block the entire event loop, making the API server unresponsive to other concurrent requests during the training operation. This is a significant Denial of Service (DoS) vulnerability.
| metrics = self.worker_dispatch.forward_backward(self.model_name, training_batch) | |
| metrics = await asyncio.to_thread(self.worker_dispatch.forward_backward, self.model_name, training_batch) |
| # Note: SkyRL's optim_step doesn't take learning_rate as an arg; | ||
| # LR is controlled by the scheduler. Tinker's API accepts it for | ||
| # compatibility, but we ignore it here. | ||
| grad_norm = self.worker_dispatch.optim_step(self.model_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the forward_backward method, optim_step is defined as async but calls the synchronous, blocking self.worker_dispatch.optim_step method. This will block the FastAPI event loop, leading to a Denial of Service. The blocking call should be offloaded to a separate thread.
| grad_norm = self.worker_dispatch.optim_step(self.model_name) | |
| grad_norm = await asyncio.to_thread(self.worker_dispatch.optim_step, self.model_name) |
| training_batch.metadata = { | ||
| "loss_fn": self.LOSS_FN_MAP[loss_fn], | ||
| "loss_fn_config": loss_fn_config or {}, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The metadata from _convert_data_to_batch (which contains num_actions) is being overwritten here. You should update the existing metadata dictionary instead of replacing it to preserve all metadata.
| training_batch.metadata = { | |
| "loss_fn": self.LOSS_FN_MAP[loss_fn], | |
| "loss_fn_config": loss_fn_config or {}, | |
| } | |
| training_batch.metadata.update({ | |
| "loss_fn": self.LOSS_FN_MAP[loss_fn], | |
| "loss_fn_config": loss_fn_config or {}, | |
| }) |
| max_seq_len = max( | ||
| len(d["model_input"].get("tokens", [])) | ||
| for d in data | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation for calculating max_seq_len is not fully robust and duplicates logic from extract_tokens_from_model_input.
- It uses direct key access
d["model_input"]which can raise aKeyError. It's safer to use.get(). - It only considers
"tokens"and ignores the"chunks"format, whichextract_tokens_from_model_inputis designed to handle.
Using the extract_tokens_from_model_input static method here and on line 224 will improve consistency and robustness.
| max_seq_len = max( | |
| len(d["model_input"].get("tokens", [])) | |
| for d in data | |
| ) | |
| max_seq_len = max( | |
| len(TinkerTrainingAdapter.extract_tokens_from_model_input(d.get("model_input", {}))) | |
| for d in data | |
| ) |
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
…ions This commit addresses all critical bugs found in PR NovaSky-AI#938: 1. **Add cross_entropy and importance_sampling loss functions** - Added PolicyLossType.CROSS_ENTROPY and PolicyLossType.IMPORTANCE_SAMPLING - Implemented cross_entropy_loss() for supervised learning (-log_probs) - Implemented importance_sampling_loss() for RL without clipping (ratio * advantages) - Both functions registered in PolicyLossRegistry 2. **Fix missing batch keys in TinkerTrainingAdapter._convert_data_to_batch** - Added required keys: base_action_log_probs, values, returns, response_mask - These are populated with zeros for supervised learning (not used when use_kl_loss=False) - Prevents KeyError crashes in BatchIterator.batch_to_experience 3. **Fix metadata handling** - Added response_length to batch.metadata (set to max_seq_len) - Changed metadata assignment from overwrite to update() to preserve num_actions - Prevents KeyError when batch_to_experience reads metadata["response_length"] 4. **Update LOSS_FN_MAP** - Already correct: cross_entropy→"cross_entropy", importance_sampling→"importance_sampling", ppo→"regular" - Now maps to actual loss functions that exist in PolicyLossRegistry All 16 unit tests passing. Addresses user feedback on PR NovaSky-AI#938 regarding: - Missing required batch keys causing immediate crashes - Missing metadata["response_length"] - Metadata overwrite bug losing num_actions - Wrong loss function names (cross_entropy/importance_sampling didn't exist) References: - tinker-backend loss functions: skyrl-tx/tx/tinker/loss_fns.py - SkyRL loss semantics: ~/claude-docs/skyrl/loss-fn.md - Tinker loss docs: ~/tinker-cookbook/docs/losses.mdx
This commit adds extensive test coverage for Stage 5: **Unit Tests (tests/cpu/algorithms/test_losses.py):** - test_cross_entropy_loss: Verifies cross-entropy ignores old_log_probs/advantages - test_cross_entropy_loss_with_mask: Tests masking for variable-length sequences - test_importance_sampling_loss: Verifies -(ratio * advantages) computation - test_importance_sampling_vs_ppo: Confirms IS differs from PPO when clipping occurs - test_importance_sampling_with_tis: Tests truncated importance sampling support **GPU Integration Tests (tests/gpu/gpu_ci/test_tinker_training_adapter_integration.py):** - test_tinker_adapter_cross_entropy_forward_backward: End-to-end cross-entropy through real workers - test_tinker_adapter_importance_sampling_forward_backward: End-to-end importance sampling - test_tinker_adapter_ppo_forward_backward: End-to-end PPO with clipping - test_tinker_adapter_forward_backward_then_optim_step: Full training cycle test **Test Coverage Summary:** - ✅ New loss functions (cross_entropy, importance_sampling) tested in isolation - ✅ Loss functions tested with masking and different reduction modes - ✅ TinkerTrainingAdapter tested through real workers (not just mocks) - ✅ All three Tinker loss types tested end-to-end - ✅ Full training cycle (forward_backward + optim_step) verified **Previous Test Status:** - 16/16 unit tests for TinkerTrainingAdapter (with mocks) - 5/5 new loss function unit tests - 4/4 GPU integration tests (to be run on GPU CI) Total: 25 tests covering Stage 5 functionality
This commit fixes two critical issues discovered in PR NovaSky-AI#938 review: **Issue #1: loss_fn parameter was completely ignored** - Problem: Workers always used `cfg.trainer.algorithm.policy_loss_type` (set at initialization) instead of checking `batch.metadata["loss_fn"]` (per-request from Tinker API) - Impact: Every API request used the same loss function regardless of the loss_fn parameter! A client requesting cross_entropy would get PPO if that's what the config said. - Fix: Modified `_forward_backward_micro()` to check `experience.metadata["loss_fn"]` first, fall back to config-based policy_loss_fn **Issue #3: KL loss computed with all-zero inputs** - Problem: `base_action_log_probs` are always zeros from Tinker adapter (not provided by API), but KL loss was computed anyway if `use_kl_loss=True` in config (meaningless KL) - Impact: Could destabilize training with random KL values from zero inputs - Fix: Check if `base_action_log_probs` is all zeros, skip KL computation if so **Code Changes:** - `worker.py` lines 714-734: Check metadata["loss_fn"] before using policy_loss_fn - `worker.py` lines 760-770: Verify KL inputs are non-zero before computing KL loss **Testing:** - All 16 unit tests passing - Fixes validated against expected behavior **Documentation:** - Known limitations documented in tinker-sampling-api-proposal.md - target_tokens validation (Issue #2): Low priority, deferred - Per-datum outputs (Issue #4): Requires worker changes, deferred to Stage 6+ References: - PR NovaSky-AI#938 review feedback - Issue #1: loss_fn effectively ignored (CRITICAL) - Issue #3: KL/TIS plumbing missing
22203c9 to
df1fb33
Compare
Summary
Implements Tinker-compatible training operations (
forward_backwardandoptim_step) through SkyRL's WorkerDispatch.Components:
skyrl-train (core training logic):
TinkerTrainingAdapter: Converts Tinker format → WorkerDispatch callscross_entropy,importance_sampling,pposkyrl-tx (API integration):
SkyRLTrainingClient: Thin wrapper for pydantic conversion + database storagecall_forward_backward_and_store(): Background task for async APIcall_optim_step_and_store(): Background task for optimizer stepattach_skyrl_training(): Easy integration with FastAPI appArchitecture:
Loss Function Support:
cross_entropy: Supervised learning (requires: target_tokens, weights)importance_sampling: REINFORCE with IS (requires: target_tokens, logprobs, advantages)ppo: PPO with clipping (mapped to SkyRL's "regular" loss)Tests:
✅ 16/16 CPU unit tests passing
Next Steps (Stage 7-9):