-
Notifications
You must be signed in to change notification settings - Fork 235
Add SkyRLInferenceClient adapter and Tinker API tests (Stage 2) #929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SkyRLInferenceClient adapter and Tinker API tests (Stage 2) #929
Conversation
Stage 2 of Tinker Sampling API implementation: skyrl-tx changes: - Add SkyRLInferenceClient adapter that converts Tinker types to/from skyrl-train's InferenceEngineClient.sample() - Add attach_skyrl_inference() helper for runtime integration skyrl-train changes: - Add test_tinker_api_integration.py with type conversion and sample tests - Add test_tinker_api_e2e.py with end-to-end Tinker flow tests - All 5 tests pass on 8xH200 GPUs This enables skyrl-tx API server to route /api/v1/asample requests to skyrl-train's inference engines via direct Python calls. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
be2121f to
3afe85f
Compare
- Fix import ordering and formatting - Improve test docstrings and comments Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the SkyRLInferenceClient adapter for direct Python integration between skyrl-tx and skyrl-train, along with corresponding integration and end-to-end tests. While the adapter implementation is well-structured, there are security concerns regarding information exposure in error handling and potential Denial of Service (DoS) due to a lack of validation on resource-intensive sampling parameters like num_samples and max_tokens. Additionally, the integration and e2e tests duplicate the adapter's conversion logic instead of importing and using the actual SkyRLInferenceClient, which means they may not accurately verify the production code. I've also provided suggestions to improve the maintainability of the SkyRLInferenceClient implementation by adding type hints and simplifying some code.
| class MockSkyRLTxApp: | ||
| """A mock skyrl-tx app that tests the SkyRLInferenceClient directly. | ||
|
|
||
| This simulates what the real skyrl-tx /api/v1/asample endpoint does, | ||
| but without needing the full FastAPI app and database. | ||
| """ | ||
|
|
||
| inference_client: InferenceEngineClient | ||
|
|
||
| async def asample(self, request: dict) -> dict: | ||
| """Simulate the /api/v1/asample endpoint behavior. | ||
|
|
||
| Takes a Tinker-style request, converts it, calls sample(), converts response. | ||
| """ | ||
| # Import the conversion functions from our adapter | ||
| # (In real deployment, these would be in skyrl-tx) | ||
| from tests.gpu.gpu_ci.test_tinker_api_integration import ( | ||
| ModelInput, | ||
| ModelInputChunk, | ||
| TinkerSamplingParams, | ||
| extract_prompt_tokens, | ||
| convert_sampling_params, | ||
| convert_to_sample_output, | ||
| ) | ||
|
|
||
| # Parse request (simulating SampleRequest from API) | ||
| prompt_chunks = [ModelInputChunk(tokens=chunk["tokens"]) for chunk in request["prompt"]["chunks"]] | ||
| tinker_input = ModelInput(chunks=prompt_chunks) | ||
|
|
||
| tinker_params = TinkerSamplingParams( | ||
| temperature=request["sampling_params"]["temperature"], | ||
| max_tokens=request["sampling_params"]["max_tokens"], | ||
| seed=request["sampling_params"].get("seed"), | ||
| top_k=request["sampling_params"].get("top_k", -1), | ||
| top_p=request["sampling_params"].get("top_p", 1.0), | ||
| ) | ||
|
|
||
| num_samples = request.get("num_samples", 1) | ||
|
|
||
| # Convert to skyrl-train format (same as SkyRLInferenceClient._sample) | ||
| converted_tokens = extract_prompt_tokens(tinker_input) | ||
| converted_params = convert_sampling_params(tinker_params) | ||
|
|
||
| # Call skyrl-train's sample() | ||
| output = await self.inference_client.sample( | ||
| prompt_token_ids=converted_tokens, | ||
| num_samples=num_samples, | ||
| sampling_params=converted_params, | ||
| ) | ||
|
|
||
| # Convert to Tinker format | ||
| tinker_output = convert_to_sample_output(output) | ||
|
|
||
| # Return as dict (simulating JSON response) | ||
| return { | ||
| "sequences": [ | ||
| { | ||
| "tokens": seq.tokens, | ||
| "logprobs": seq.logprobs, | ||
| "stop_reason": seq.stop_reason, | ||
| } | ||
| for seq in tinker_output.sequences | ||
| ], | ||
| "prompt_logprobs": tinker_output.prompt_logprobs, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The MockSkyRLTxApp re-implements the conversion logic from the SkyRLInferenceClient adapter by importing duplicated logic from test_tinker_api_integration.py. This defeats the purpose of an end-to-end test, which should be testing the actual adapter code, not a copy of it. This can lead to tests passing even if the production adapter code is broken.
The test should be refactored to use the actual SkyRLInferenceClient from skyrl-tx. This will ensure you are testing the real integration between skyrl-train and the skyrl-tx adapter. You can achieve this by having MockSkyRLTxApp hold an instance of the real SkyRLInferenceClient and calling its methods, using mocks for dependencies like the database engine.
| # Tinker-compatible types (mirrors skyrl-tx/tx/tinker/types.py) | ||
| @dataclass | ||
| class ModelInputChunk: | ||
| tokens: list[int] | ||
|
|
||
|
|
||
| @dataclass | ||
| class ModelInput: | ||
| chunks: list[ModelInputChunk] | ||
|
|
||
| @classmethod | ||
| def from_tokens(cls, tokens: list[int]) -> "ModelInput": | ||
| return cls(chunks=[ModelInputChunk(tokens=tokens)]) | ||
|
|
||
|
|
||
| @dataclass | ||
| class TinkerSamplingParams: | ||
| temperature: float | ||
| max_tokens: int | ||
| seed: int | None = None | ||
| stop_tokens: list[int] | None = None | ||
| stop_strings: list[str] | None = None | ||
| top_k: int = -1 | ||
| top_p: float = 1.0 | ||
|
|
||
|
|
||
| @dataclass | ||
| class GeneratedSequence: | ||
| stop_reason: Literal["length", "stop"] | ||
| tokens: list[int] | ||
| logprobs: list[float] | ||
|
|
||
|
|
||
| @dataclass | ||
| class SampleOutput: | ||
| sequences: list[GeneratedSequence] | ||
| prompt_logprobs: list[float] | None = None | ||
|
|
||
|
|
||
| # Conversion functions (mirrors SkyRLInferenceClient logic) | ||
| def extract_prompt_tokens(model_input: ModelInput) -> list[int]: | ||
| """Extract flat token list from ModelInput.""" | ||
| tokens = [] | ||
| for chunk in model_input.chunks: | ||
| tokens.extend(chunk.tokens) | ||
| return tokens | ||
|
|
||
|
|
||
| def convert_sampling_params(params: TinkerSamplingParams) -> dict: | ||
| """Convert Tinker SamplingParams to skyrl-train format.""" | ||
| result = { | ||
| "temperature": params.temperature, | ||
| "max_tokens": params.max_tokens, | ||
| "top_k": params.top_k, | ||
| "top_p": params.top_p, | ||
| } | ||
| if params.seed is not None: | ||
| result["seed"] = params.seed | ||
| if params.stop_tokens: | ||
| result["stop_token_ids"] = params.stop_tokens | ||
| if params.stop_strings: | ||
| result["stop"] = params.stop_strings | ||
| return result | ||
|
|
||
|
|
||
| def convert_to_sample_output(output: dict) -> SampleOutput: | ||
| """Convert skyrl-train's InferenceEngineOutput to Tinker SampleOutput.""" | ||
| sequences = [] | ||
| num_samples = len(output["response_ids"]) | ||
|
|
||
| for i in range(num_samples): | ||
| stop_reason = output["stop_reasons"][i] | ||
| if stop_reason in ("stop", "eos"): | ||
| tinker_stop_reason = "stop" | ||
| else: | ||
| tinker_stop_reason = "length" | ||
|
|
||
| logprobs = [] | ||
| if output.get("response_logprobs") and output["response_logprobs"][i]: | ||
| logprobs = output["response_logprobs"][i] | ||
|
|
||
| sequences.append( | ||
| GeneratedSequence( | ||
| tokens=output["response_ids"][i], | ||
| logprobs=logprobs, | ||
| stop_reason=tinker_stop_reason, | ||
| ) | ||
| ) | ||
|
|
||
| return SampleOutput(sequences=sequences, prompt_logprobs=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file duplicates a significant amount of logic from skyrl-tx (types and conversion functions) instead of importing it. This is problematic for several reasons:
- Maintainability: If the production code in
skyrl-txchanges, these tests will not automatically reflect that, leading to a maintenance burden to keep them in sync. - Incorrect Tests: The tests are verifying a copy of the integration logic, not the actual production code. This means the tests could pass even if the real
SkyRLInferenceClientadapter is broken.
To fix this, you should refactor the tests to import the necessary types and conversion functions directly from the skyrl-tx package. This will ensure you are testing against the actual implementation and that your tests will fail if the contract breaks.
| status = RequestStatus.COMPLETED | ||
| except Exception as e: | ||
| logger.exception("SkyRL inference error") | ||
| result_data = {"error": str(e), "status": "failed"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call_and_store_result method catches all exceptions during the inference process and stores the string representation of the exception (str(e)) directly in the result_data field of the FutureDB record. This field is intended to hold the output of the inference and is likely exposed to users through API endpoints that query the status of their requests. Storing raw exception messages can leak sensitive internal information, such as system paths, database details, or fragments of the input data that caused the error. It is recommended to store a generic error message for the user and log the detailed exception internally.
| # Call skyrl-train's sample() method | ||
| output = await self.inference_client.sample( | ||
| prompt_token_ids=prompt_token_ids, | ||
| num_samples=request.num_samples, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SkyRLInferenceClient adapter passes sampling parameters such as num_samples directly from the user-supplied request to the underlying inference engine without any validation or boundary checks. An attacker could provide an extremely large value for num_samples, leading to excessive resource consumption (CPU, GPU memory) and potentially causing a Denial of Service (DoS) for the inference service. Consider implementing strict validation and upper bounds for this parameter.
| """ | ||
| result = { | ||
| "temperature": params.temperature, | ||
| "max_tokens": params.max_tokens, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _convert_sampling_params method passes the max_tokens parameter directly from the user-supplied params to the inference engine without validation. A very large max_tokens value can lead to resource exhaustion and Denial of Service (DoS). It is recommended to enforce a maximum limit on the number of tokens that can be requested.
| def _extract_prompt_tokens(self, model_input) -> list[int]: | ||
| """Extract flat token list from ModelInput. | ||
|
|
||
| Args: | ||
| model_input: ModelInput with chunks of tokens. | ||
|
|
||
| Returns: | ||
| Flat list of token IDs. | ||
| """ | ||
| tokens = [] | ||
| for chunk in model_input.chunks: | ||
| tokens.extend(chunk.tokens) | ||
| return tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model_input parameter lacks a type hint. It appears to be of type tx.tinker.types.ModelInput. Adding this type hint will improve readability and enable static type checking.
Additionally, the loop to flatten the list of token chunks can be expressed more concisely using a list comprehension.
| def _extract_prompt_tokens(self, model_input) -> list[int]: | |
| """Extract flat token list from ModelInput. | |
| Args: | |
| model_input: ModelInput with chunks of tokens. | |
| Returns: | |
| Flat list of token IDs. | |
| """ | |
| tokens = [] | |
| for chunk in model_input.chunks: | |
| tokens.extend(chunk.tokens) | |
| return tokens | |
| def _extract_prompt_tokens(self, model_input: types.ModelInput) -> list[int]: | |
| """Extract flat token list from ModelInput. | |
| Args: | |
| model_input: ModelInput with chunks of tokens. | |
| Returns: | |
| Flat list of token IDs. | |
| """ | |
| return [token for chunk in model_input.chunks for token in chunk.tokens] |
|
|
||
| return result | ||
|
|
||
| def _convert_to_sample_output(self, output) -> types.SampleOutput: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output parameter lacks a type hint. Based on the call site in _sample, it receives the return value of self.inference_client.sample, which is an InferenceEngineOutput. Adding this type hint will improve code clarity and allow static analysis tools to catch potential errors.
You'll need to add from skyrl_train.inference_engines.base import InferenceEngineOutput inside the TYPE_CHECKING block at the top of the file.
| def _convert_to_sample_output(self, output) -> types.SampleOutput: | |
| def _convert_to_sample_output(self, output: "InferenceEngineOutput") -> types.SampleOutput: |
…Sky-AI#929) ## Summary Stage 2 of Tinker Sampling API implementation: - Add `SkyRLInferenceClient` adapter in skyrl-tx for direct Python integration - Add integration and e2e tests in skyrl-train ## Changes **skyrl-tx:** - `tx/tinker/extra/skyrl_inference.py` - SkyRLInferenceClient adapter - Converts Tinker `ModelInput` → flat token list - Converts Tinker `SamplingParams` → dict for skyrl-train - Converts `InferenceEngineOutput` → Tinker `SampleOutput` - `attach_skyrl_inference(app, inference_client)` helper for runtime integration **skyrl-train:** - `test_tinker_api_integration.py` - Type conversion and sample() tests - `test_tinker_api_e2e.py` - End-to-end Tinker flow tests ## Test plan - [x] test_tinker_type_conversion - PASSED - [x] test_tinker_sample_integration - PASSED - [x] test_tinker_stop_tokens - PASSED - [x] test_e2e_tinker_sample_flow - PASSED - [x] test_e2e_multiple_requests - PASSED --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Summary
Stage 2 of Tinker Sampling API implementation:
SkyRLInferenceClientadapter in skyrl-tx for direct Python integrationChanges
skyrl-tx:
tx/tinker/extra/skyrl_inference.py- SkyRLInferenceClient adapterModelInput→ flat token listSamplingParams→ dict for skyrl-trainInferenceEngineOutput→ TinkerSampleOutputattach_skyrl_inference(app, inference_client)helper for runtime integrationskyrl-train:
test_tinker_api_integration.py- Type conversion and sample() teststest_tinker_api_e2e.py- End-to-end Tinker flow testsTest plan