Skip to content

Conversation

@tyler-griggs
Copy link
Member

@tyler-griggs tyler-griggs commented Jan 23, 2026

Summary

Stage 2 of Tinker Sampling API implementation:

  • Add SkyRLInferenceClient adapter in skyrl-tx for direct Python integration
  • Add integration and e2e tests in skyrl-train

Changes

skyrl-tx:

  • tx/tinker/extra/skyrl_inference.py - SkyRLInferenceClient adapter
    • Converts Tinker ModelInput → flat token list
    • Converts Tinker SamplingParams → dict for skyrl-train
    • Converts InferenceEngineOutput → Tinker SampleOutput
  • attach_skyrl_inference(app, inference_client) helper for runtime integration

skyrl-train:

  • test_tinker_api_integration.py - Type conversion and sample() tests
  • test_tinker_api_e2e.py - End-to-end Tinker flow tests

Test plan

  • test_tinker_type_conversion - PASSED
  • test_tinker_sample_integration - PASSED
  • test_tinker_stop_tokens - PASSED
  • test_e2e_tinker_sample_flow - PASSED
  • test_e2e_multiple_requests - PASSED

Stage 2 of Tinker Sampling API implementation:

skyrl-tx changes:
- Add SkyRLInferenceClient adapter that converts Tinker types to/from
  skyrl-train's InferenceEngineClient.sample()
- Add attach_skyrl_inference() helper for runtime integration

skyrl-train changes:
- Add test_tinker_api_integration.py with type conversion and sample tests
- Add test_tinker_api_e2e.py with end-to-end Tinker flow tests
- All 5 tests pass on 8xH200 GPUs

This enables skyrl-tx API server to route /api/v1/asample requests
to skyrl-train's inference engines via direct Python calls.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@tyler-griggs tyler-griggs force-pushed the tgriggs/tinker_sample_api_stage2 branch from be2121f to 3afe85f Compare January 24, 2026 17:33
- Fix import ordering and formatting
- Improve test docstrings and comments

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
@tyler-griggs tyler-griggs marked this pull request as ready for review January 24, 2026 17:40
@tyler-griggs tyler-griggs merged commit 53b53c9 into NovaSky-AI:main Jan 24, 2026
4 checks passed
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the SkyRLInferenceClient adapter for direct Python integration between skyrl-tx and skyrl-train, along with corresponding integration and end-to-end tests. While the adapter implementation is well-structured, there are security concerns regarding information exposure in error handling and potential Denial of Service (DoS) due to a lack of validation on resource-intensive sampling parameters like num_samples and max_tokens. Additionally, the integration and e2e tests duplicate the adapter's conversion logic instead of importing and using the actual SkyRLInferenceClient, which means they may not accurately verify the production code. I've also provided suggestions to improve the maintainability of the SkyRLInferenceClient implementation by adding type hints and simplifying some code.

Comment on lines +66 to +130
class MockSkyRLTxApp:
"""A mock skyrl-tx app that tests the SkyRLInferenceClient directly.

This simulates what the real skyrl-tx /api/v1/asample endpoint does,
but without needing the full FastAPI app and database.
"""

inference_client: InferenceEngineClient

async def asample(self, request: dict) -> dict:
"""Simulate the /api/v1/asample endpoint behavior.

Takes a Tinker-style request, converts it, calls sample(), converts response.
"""
# Import the conversion functions from our adapter
# (In real deployment, these would be in skyrl-tx)
from tests.gpu.gpu_ci.test_tinker_api_integration import (
ModelInput,
ModelInputChunk,
TinkerSamplingParams,
extract_prompt_tokens,
convert_sampling_params,
convert_to_sample_output,
)

# Parse request (simulating SampleRequest from API)
prompt_chunks = [ModelInputChunk(tokens=chunk["tokens"]) for chunk in request["prompt"]["chunks"]]
tinker_input = ModelInput(chunks=prompt_chunks)

tinker_params = TinkerSamplingParams(
temperature=request["sampling_params"]["temperature"],
max_tokens=request["sampling_params"]["max_tokens"],
seed=request["sampling_params"].get("seed"),
top_k=request["sampling_params"].get("top_k", -1),
top_p=request["sampling_params"].get("top_p", 1.0),
)

num_samples = request.get("num_samples", 1)

# Convert to skyrl-train format (same as SkyRLInferenceClient._sample)
converted_tokens = extract_prompt_tokens(tinker_input)
converted_params = convert_sampling_params(tinker_params)

# Call skyrl-train's sample()
output = await self.inference_client.sample(
prompt_token_ids=converted_tokens,
num_samples=num_samples,
sampling_params=converted_params,
)

# Convert to Tinker format
tinker_output = convert_to_sample_output(output)

# Return as dict (simulating JSON response)
return {
"sequences": [
{
"tokens": seq.tokens,
"logprobs": seq.logprobs,
"stop_reason": seq.stop_reason,
}
for seq in tinker_output.sequences
],
"prompt_logprobs": tinker_output.prompt_logprobs,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The MockSkyRLTxApp re-implements the conversion logic from the SkyRLInferenceClient adapter by importing duplicated logic from test_tinker_api_integration.py. This defeats the purpose of an end-to-end test, which should be testing the actual adapter code, not a copy of it. This can lead to tests passing even if the production adapter code is broken.

The test should be refactored to use the actual SkyRLInferenceClient from skyrl-tx. This will ensure you are testing the real integration between skyrl-train and the skyrl-tx adapter. You can achieve this by having MockSkyRLTxApp hold an instance of the real SkyRLInferenceClient and calling its methods, using mocks for dependencies like the database engine.

Comment on lines +27 to +116
# Tinker-compatible types (mirrors skyrl-tx/tx/tinker/types.py)
@dataclass
class ModelInputChunk:
tokens: list[int]


@dataclass
class ModelInput:
chunks: list[ModelInputChunk]

@classmethod
def from_tokens(cls, tokens: list[int]) -> "ModelInput":
return cls(chunks=[ModelInputChunk(tokens=tokens)])


@dataclass
class TinkerSamplingParams:
temperature: float
max_tokens: int
seed: int | None = None
stop_tokens: list[int] | None = None
stop_strings: list[str] | None = None
top_k: int = -1
top_p: float = 1.0


@dataclass
class GeneratedSequence:
stop_reason: Literal["length", "stop"]
tokens: list[int]
logprobs: list[float]


@dataclass
class SampleOutput:
sequences: list[GeneratedSequence]
prompt_logprobs: list[float] | None = None


# Conversion functions (mirrors SkyRLInferenceClient logic)
def extract_prompt_tokens(model_input: ModelInput) -> list[int]:
"""Extract flat token list from ModelInput."""
tokens = []
for chunk in model_input.chunks:
tokens.extend(chunk.tokens)
return tokens


def convert_sampling_params(params: TinkerSamplingParams) -> dict:
"""Convert Tinker SamplingParams to skyrl-train format."""
result = {
"temperature": params.temperature,
"max_tokens": params.max_tokens,
"top_k": params.top_k,
"top_p": params.top_p,
}
if params.seed is not None:
result["seed"] = params.seed
if params.stop_tokens:
result["stop_token_ids"] = params.stop_tokens
if params.stop_strings:
result["stop"] = params.stop_strings
return result


def convert_to_sample_output(output: dict) -> SampleOutput:
"""Convert skyrl-train's InferenceEngineOutput to Tinker SampleOutput."""
sequences = []
num_samples = len(output["response_ids"])

for i in range(num_samples):
stop_reason = output["stop_reasons"][i]
if stop_reason in ("stop", "eos"):
tinker_stop_reason = "stop"
else:
tinker_stop_reason = "length"

logprobs = []
if output.get("response_logprobs") and output["response_logprobs"][i]:
logprobs = output["response_logprobs"][i]

sequences.append(
GeneratedSequence(
tokens=output["response_ids"][i],
logprobs=logprobs,
stop_reason=tinker_stop_reason,
)
)

return SampleOutput(sequences=sequences, prompt_logprobs=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This file duplicates a significant amount of logic from skyrl-tx (types and conversion functions) instead of importing it. This is problematic for several reasons:

  1. Maintainability: If the production code in skyrl-tx changes, these tests will not automatically reflect that, leading to a maintenance burden to keep them in sync.
  2. Incorrect Tests: The tests are verifying a copy of the integration logic, not the actual production code. This means the tests could pass even if the real SkyRLInferenceClient adapter is broken.

To fix this, you should refactor the tests to import the necessary types and conversion functions directly from the skyrl-tx package. This will ensure you are testing against the actual implementation and that your tests will fail if the contract breaks.

status = RequestStatus.COMPLETED
except Exception as e:
logger.exception("SkyRL inference error")
result_data = {"error": str(e), "status": "failed"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The call_and_store_result method catches all exceptions during the inference process and stores the string representation of the exception (str(e)) directly in the result_data field of the FutureDB record. This field is intended to hold the output of the inference and is likely exposed to users through API endpoints that query the status of their requests. Storing raw exception messages can leak sensitive internal information, such as system paths, database details, or fragments of the input data that caused the error. It is recommended to store a generic error message for the user and log the detailed exception internally.

# Call skyrl-train's sample() method
output = await self.inference_client.sample(
prompt_token_ids=prompt_token_ids,
num_samples=request.num_samples,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The SkyRLInferenceClient adapter passes sampling parameters such as num_samples directly from the user-supplied request to the underlying inference engine without any validation or boundary checks. An attacker could provide an extremely large value for num_samples, leading to excessive resource consumption (CPU, GPU memory) and potentially causing a Denial of Service (DoS) for the inference service. Consider implementing strict validation and upper bounds for this parameter.

"""
result = {
"temperature": params.temperature,
"max_tokens": params.max_tokens,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The _convert_sampling_params method passes the max_tokens parameter directly from the user-supplied params to the inference engine without validation. A very large max_tokens value can lead to resource exhaustion and Denial of Service (DoS). It is recommended to enforce a maximum limit on the number of tokens that can be requested.

Comment on lines +117 to +129
def _extract_prompt_tokens(self, model_input) -> list[int]:
"""Extract flat token list from ModelInput.

Args:
model_input: ModelInput with chunks of tokens.

Returns:
Flat list of token IDs.
"""
tokens = []
for chunk in model_input.chunks:
tokens.extend(chunk.tokens)
return tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The model_input parameter lacks a type hint. It appears to be of type tx.tinker.types.ModelInput. Adding this type hint will improve readability and enable static type checking.

Additionally, the loop to flatten the list of token chunks can be expressed more concisely using a list comprehension.

Suggested change
def _extract_prompt_tokens(self, model_input) -> list[int]:
"""Extract flat token list from ModelInput.
Args:
model_input: ModelInput with chunks of tokens.
Returns:
Flat list of token IDs.
"""
tokens = []
for chunk in model_input.chunks:
tokens.extend(chunk.tokens)
return tokens
def _extract_prompt_tokens(self, model_input: types.ModelInput) -> list[int]:
"""Extract flat token list from ModelInput.
Args:
model_input: ModelInput with chunks of tokens.
Returns:
Flat list of token IDs.
"""
return [token for chunk in model_input.chunks for token in chunk.tokens]


return result

def _convert_to_sample_output(self, output) -> types.SampleOutput:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The output parameter lacks a type hint. Based on the call site in _sample, it receives the return value of self.inference_client.sample, which is an InferenceEngineOutput. Adding this type hint will improve code clarity and allow static analysis tools to catch potential errors.

You'll need to add from skyrl_train.inference_engines.base import InferenceEngineOutput inside the TYPE_CHECKING block at the top of the file.

Suggested change
def _convert_to_sample_output(self, output) -> types.SampleOutput:
def _convert_to_sample_output(self, output: "InferenceEngineOutput") -> types.SampleOutput:

tanmaysachan pushed a commit to tanmaysachan/SkyRL that referenced this pull request Jan 25, 2026
…Sky-AI#929)

## Summary
Stage 2 of Tinker Sampling API implementation:
- Add `SkyRLInferenceClient` adapter in skyrl-tx for direct Python
integration
- Add integration and e2e tests in skyrl-train

## Changes

**skyrl-tx:**
- `tx/tinker/extra/skyrl_inference.py` - SkyRLInferenceClient adapter
  - Converts Tinker `ModelInput` → flat token list
  - Converts Tinker `SamplingParams` → dict for skyrl-train
  - Converts `InferenceEngineOutput` → Tinker `SampleOutput`
- `attach_skyrl_inference(app, inference_client)` helper for runtime
integration

**skyrl-train:**
- `test_tinker_api_integration.py` - Type conversion and sample() tests
- `test_tinker_api_e2e.py` - End-to-end Tinker flow tests

## Test plan
- [x] test_tinker_type_conversion - PASSED
- [x] test_tinker_sample_integration - PASSED  
- [x] test_tinker_stop_tokens - PASSED
- [x] test_e2e_tinker_sample_flow - PASSED
- [x] test_e2e_multiple_requests - PASSED

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant