Skip to content

Conversation

@tyler-griggs
Copy link
Member

Summary

Implements Tinker-compatible training operations (forward_backward and optim_step) through SkyRL's WorkerDispatch.

Components:

skyrl-train (core training logic):

  • TinkerTrainingAdapter: Converts Tinker format → WorkerDispatch calls
    • Supports loss functions: cross_entropy, importance_sampling, ppo
    • Maps Tinker Datum to TrainingInputBatch with left-padding
    • Async wrappers around WorkerDispatch methods
  • 16 unit tests covering all loss functions and edge cases

skyrl-tx (API integration):

  • SkyRLTrainingClient: Thin wrapper for pydantic conversion + database storage
    • call_forward_backward_and_store(): Background task for async API
    • call_optim_step_and_store(): Background task for optimizer step
    • attach_skyrl_training(): Easy integration with FastAPI app

Architecture:

Tinker API (/api/v1/forward_backward, /api/v1/optim_step)
    ↓
SkyRLTrainingClient (skyrl-tx)
    • Tinker pydantic ↔ plain Python dicts
    • Database storage (FutureDB)
    ↓
TinkerTrainingAdapter (skyrl-train)
    • Datum list → TrainingInputBatch
    • Loss function mapping
    ↓
WorkerDispatch.forward_backward() / optim_step()

Loss Function Support:

  • cross_entropy: Supervised learning (requires: target_tokens, weights)
  • importance_sampling: REINFORCE with IS (requires: target_tokens, logprobs, advantages)
  • ppo: PPO with clipping (mapped to SkyRL's "regular" loss)

Tests:
✅ 16/16 CPU unit tests passing

Next Steps (Stage 7-9):

  • Stage 7: ServiceClient factory for creating training/sampling clients
  • Stage 8: Checkpoint management
  • Stage 9: End-to-end integration test with tinker-cookbook scripts

Implements Tinker-compatible training operations through SkyRL's WorkerDispatch.

**Architecture:**
```
Tinker API (/api/v1/forward_backward, /api/v1/optim_step)
    ↓
SkyRLTrainingClient (skyrl-tx) - Type conversion + DB storage
    ↓
TinkerTrainingAdapter (skyrl-train) - Core training logic
    ↓
WorkerDispatch.forward_backward() / optim_step()
```

**Components:**

skyrl-train (core training logic):
- TinkerTrainingAdapter: Converts Tinker format → WorkerDispatch calls
  - Supports loss functions: cross_entropy, importance_sampling, ppo
  - Maps Tinker Datum to TrainingInputBatch with left-padding
  - Async wrappers around WorkerDispatch methods
- 16 unit tests covering all loss functions and edge cases

skyrl-tx (API integration):
- SkyRLTrainingClient: Thin wrapper for pydantic conversion + database
  - call_forward_backward_and_store(): Background task for async API
  - call_optim_step_and_store(): Background task for optimizer step
  - attach_skyrl_training(): Easy integration with FastAPI app

**Loss Function Support:**
- cross_entropy: Supervised learning (requires: target_tokens, weights)
- importance_sampling: REINFORCE with IS (requires: target_tokens, logprobs, advantages)
- ppo: PPO with clipping (same as IS, mapped to SkyRL's "regular" loss)

**Tests:**
✅ 16/16 CPU unit tests passing

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a TinkerTrainingAdapter to bridge Tinker-style training operations with SkyRL's WorkerDispatch, along with the SkyRLTrainingClient for API integration and a comprehensive set of unit tests. A critical Denial of Service vulnerability was identified where the TinkerTrainingAdapter defines async methods that perform blocking synchronous calls to Ray, which will freeze the FastAPI event loop and make the server unresponsive during training operations. This should be remediated by offloading the blocking calls to a thread pool. Additionally, the review identified a critical issue regarding potential crashes in background tasks due to missing None checks when retrieving database records, a high-severity bug where important metadata is overwritten, and a medium-severity issue related to inconsistent data handling within the adapter.

status = RequestStatus.FAILED

async with AsyncSession(self.db_engine) as session:
future = await session.get(FutureDB, request_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

session.get(FutureDB, request_id) can return None if no record is found for the given request_id. Accessing attributes on future (e.g., future.result_data on line 90) would then raise an AttributeError, crashing the background task. You should handle the case where future is None.

status = RequestStatus.FAILED

async with AsyncSession(self.db_engine) as session:
future = await session.get(FutureDB, request_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

session.get(FutureDB, request_id) can return None if no record is found for the given request_id. Accessing attributes on future (e.g., future.result_data on line 119) would then raise an AttributeError, crashing the background task. You should handle the case where future is None.

# Call WorkerDispatch forward_backward
# Note: WorkerDispatch.forward_backward is synchronous, but we make this
# method async for consistency with Tinker's API
metrics = self.worker_dispatch.forward_backward(self.model_name, training_batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The forward_backward method is defined as asynchronous, but it performs a synchronous, blocking operation by calling self.worker_dispatch.forward_backward. In the provided context, WorkerDispatch.forward_backward internally calls ray.get(), which blocks the execution thread until the remote task completes. Since this is called within the FastAPI event loop (via the background task in SkyRLTrainingClient), it will block the entire event loop, making the API server unresponsive to other concurrent requests during the training operation. This is a significant Denial of Service (DoS) vulnerability.

Suggested change
metrics = self.worker_dispatch.forward_backward(self.model_name, training_batch)
metrics = await asyncio.to_thread(self.worker_dispatch.forward_backward, self.model_name, training_batch)

# Note: SkyRL's optim_step doesn't take learning_rate as an arg;
# LR is controlled by the scheduler. Tinker's API accepts it for
# compatibility, but we ignore it here.
grad_norm = self.worker_dispatch.optim_step(self.model_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

Similar to the forward_backward method, optim_step is defined as async but calls the synchronous, blocking self.worker_dispatch.optim_step method. This will block the FastAPI event loop, leading to a Denial of Service. The blocking call should be offloaded to a separate thread.

Suggested change
grad_norm = self.worker_dispatch.optim_step(self.model_name)
grad_norm = await asyncio.to_thread(self.worker_dispatch.optim_step, self.model_name)

Comment on lines 146 to 149
training_batch.metadata = {
"loss_fn": self.LOSS_FN_MAP[loss_fn],
"loss_fn_config": loss_fn_config or {},
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The metadata from _convert_data_to_batch (which contains num_actions) is being overwritten here. You should update the existing metadata dictionary instead of replacing it to preserve all metadata.

Suggested change
training_batch.metadata = {
"loss_fn": self.LOSS_FN_MAP[loss_fn],
"loss_fn_config": loss_fn_config or {},
}
training_batch.metadata.update({
"loss_fn": self.LOSS_FN_MAP[loss_fn],
"loss_fn_config": loss_fn_config or {},
})

Comment on lines 207 to 210
max_seq_len = max(
len(d["model_input"].get("tokens", []))
for d in data
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation for calculating max_seq_len is not fully robust and duplicates logic from extract_tokens_from_model_input.

  1. It uses direct key access d["model_input"] which can raise a KeyError. It's safer to use .get().
  2. It only considers "tokens" and ignores the "chunks" format, which extract_tokens_from_model_input is designed to handle.

Using the extract_tokens_from_model_input static method here and on line 224 will improve consistency and robustness.

Suggested change
max_seq_len = max(
len(d["model_input"].get("tokens", []))
for d in data
)
max_seq_len = max(
len(TinkerTrainingAdapter.extract_tokens_from_model_input(d.get("model_input", {})))
for d in data
)

tyler-griggs and others added 4 commits January 24, 2026 23:05
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
…ions

This commit addresses all critical bugs found in PR NovaSky-AI#938:

1. **Add cross_entropy and importance_sampling loss functions**
   - Added PolicyLossType.CROSS_ENTROPY and PolicyLossType.IMPORTANCE_SAMPLING
   - Implemented cross_entropy_loss() for supervised learning (-log_probs)
   - Implemented importance_sampling_loss() for RL without clipping (ratio * advantages)
   - Both functions registered in PolicyLossRegistry

2. **Fix missing batch keys in TinkerTrainingAdapter._convert_data_to_batch**
   - Added required keys: base_action_log_probs, values, returns, response_mask
   - These are populated with zeros for supervised learning (not used when use_kl_loss=False)
   - Prevents KeyError crashes in BatchIterator.batch_to_experience

3. **Fix metadata handling**
   - Added response_length to batch.metadata (set to max_seq_len)
   - Changed metadata assignment from overwrite to update() to preserve num_actions
   - Prevents KeyError when batch_to_experience reads metadata["response_length"]

4. **Update LOSS_FN_MAP**
   - Already correct: cross_entropy→"cross_entropy", importance_sampling→"importance_sampling", ppo→"regular"
   - Now maps to actual loss functions that exist in PolicyLossRegistry

All 16 unit tests passing.

Addresses user feedback on PR NovaSky-AI#938 regarding:
- Missing required batch keys causing immediate crashes
- Missing metadata["response_length"]
- Metadata overwrite bug losing num_actions
- Wrong loss function names (cross_entropy/importance_sampling didn't exist)

References:
- tinker-backend loss functions: skyrl-tx/tx/tinker/loss_fns.py
- SkyRL loss semantics: ~/claude-docs/skyrl/loss-fn.md
- Tinker loss docs: ~/tinker-cookbook/docs/losses.mdx
This commit adds extensive test coverage for Stage 5:

**Unit Tests (tests/cpu/algorithms/test_losses.py):**
- test_cross_entropy_loss: Verifies cross-entropy ignores old_log_probs/advantages
- test_cross_entropy_loss_with_mask: Tests masking for variable-length sequences
- test_importance_sampling_loss: Verifies -(ratio * advantages) computation
- test_importance_sampling_vs_ppo: Confirms IS differs from PPO when clipping occurs
- test_importance_sampling_with_tis: Tests truncated importance sampling support

**GPU Integration Tests (tests/gpu/gpu_ci/test_tinker_training_adapter_integration.py):**
- test_tinker_adapter_cross_entropy_forward_backward: End-to-end cross-entropy through real workers
- test_tinker_adapter_importance_sampling_forward_backward: End-to-end importance sampling
- test_tinker_adapter_ppo_forward_backward: End-to-end PPO with clipping
- test_tinker_adapter_forward_backward_then_optim_step: Full training cycle test

**Test Coverage Summary:**
- ✅ New loss functions (cross_entropy, importance_sampling) tested in isolation
- ✅ Loss functions tested with masking and different reduction modes
- ✅ TinkerTrainingAdapter tested through real workers (not just mocks)
- ✅ All three Tinker loss types tested end-to-end
- ✅ Full training cycle (forward_backward + optim_step) verified

**Previous Test Status:**
- 16/16 unit tests for TinkerTrainingAdapter (with mocks)
- 5/5 new loss function unit tests
- 4/4 GPU integration tests (to be run on GPU CI)

Total: 25 tests covering Stage 5 functionality
This commit fixes two critical issues discovered in PR NovaSky-AI#938 review:

**Issue #1: loss_fn parameter was completely ignored**
- Problem: Workers always used `cfg.trainer.algorithm.policy_loss_type` (set at initialization)
  instead of checking `batch.metadata["loss_fn"]` (per-request from Tinker API)
- Impact: Every API request used the same loss function regardless of the loss_fn parameter!
  A client requesting cross_entropy would get PPO if that's what the config said.
- Fix: Modified `_forward_backward_micro()` to check `experience.metadata["loss_fn"]` first,
  fall back to config-based policy_loss_fn

**Issue #3: KL loss computed with all-zero inputs**
- Problem: `base_action_log_probs` are always zeros from Tinker adapter (not provided by API),
  but KL loss was computed anyway if `use_kl_loss=True` in config (meaningless KL)
- Impact: Could destabilize training with random KL values from zero inputs
- Fix: Check if `base_action_log_probs` is all zeros, skip KL computation if so

**Code Changes:**
- `worker.py` lines 714-734: Check metadata["loss_fn"] before using policy_loss_fn
- `worker.py` lines 760-770: Verify KL inputs are non-zero before computing KL loss

**Testing:**
- All 16 unit tests passing
- Fixes validated against expected behavior

**Documentation:**
- Known limitations documented in tinker-sampling-api-proposal.md
- target_tokens validation (Issue #2): Low priority, deferred
- Per-datum outputs (Issue #4): Requires worker changes, deferred to Stage 6+

References:
- PR NovaSky-AI#938 review feedback
- Issue #1: loss_fn effectively ignored (CRITICAL)
- Issue #3: KL/TIS plumbing missing
@tyler-griggs tyler-griggs force-pushed the tgriggs/tinker_sample_api_stage5 branch from 22203c9 to df1fb33 Compare January 24, 2026 23:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant