diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index 3ebcfe023..c056ab6aa 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -308,6 +308,10 @@ generator: step_wise_trajectories: false + # Experimental: Use Tinker-compatible sample() API instead of generate() + # This enables token-in/token-out semantics matching the Tinker API + use_tinker_sampling_api: false + environment: env_class: "gsm8k" # NOTE: environment specific defaults for environment.skyrl_gym are set at the following path: diff --git a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py index 619979a5d..1d03265af 100644 --- a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py +++ b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py @@ -130,6 +130,9 @@ def __init__( else: self.env_executor = None + # Experimental: Use Tinker-compatible sample() API instead of generate() + self.use_tinker_sampling_api = generator_cfg.get("use_tinker_sampling_api", False) + self._validate_cfg(generator_cfg) # base_conversation is used when `use_conversation_multi_turn==True and custom_chat_template==None` to @@ -288,10 +291,18 @@ async def agent_loop( agent_loop_state.loss_mask = [] agent_loop_state.rollout_logprobs = None - engine_input = InferenceEngineInput( - prompt_token_ids=[agent_loop_state.input_ids], session_ids=[session_id], sampling_params=sampling_params - ) - engine_output = await self.inference_engine_client.generate(engine_input) + if self.use_tinker_sampling_api: + # Use Tinker-compatible sample() API for token-in/token-out semantics + engine_output = await self.inference_engine_client.sample( + prompt_token_ids=agent_loop_state.input_ids, + num_samples=1, + sampling_params=current_sampling_params, + ) + else: + engine_input = InferenceEngineInput( + prompt_token_ids=[agent_loop_state.input_ids], session_ids=[session_id], sampling_params=sampling_params + ) + engine_output = await self.inference_engine_client.generate(engine_input) output = engine_output["responses"][0] output_ids = engine_output["response_ids"][0] stop_reason = engine_output["stop_reasons"][0] diff --git a/skyrl-train/skyrl_train/inference_engines/base.py b/skyrl-train/skyrl_train/inference_engines/base.py index 392e2100e..d808c58a0 100644 --- a/skyrl-train/skyrl_train/inference_engines/base.py +++ b/skyrl-train/skyrl_train/inference_engines/base.py @@ -37,6 +37,62 @@ class InferenceEngineInterface(ABC): async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: raise NotImplementedError + async def sample( + self, + prompt_token_ids: List[int], + num_samples: int, + sampling_params: Dict[str, Any], + ) -> InferenceEngineOutput: + """Generate multiple independent samples from a single prompt. + + This method provides Tinker-compatible token-in/token-out sampling semantics. + Unlike generate() which processes a batch of different prompts, sample() generates + num_samples independent completions from the same prompt. + + Args: + prompt_token_ids: Token IDs for a single prompt (not batched). + num_samples: Number of independent samples to generate. + sampling_params: Sampling parameters (temperature, max_tokens, etc.). + + Returns: + InferenceEngineOutput containing num_samples results: + - response_ids: List of num_samples token ID lists + - responses: List of num_samples decoded strings + - stop_reasons: List of num_samples stop reasons + - response_logprobs: Optional list of num_samples logprob lists + + Note: + Default implementation calls generate() sequentially num_samples times. + Subclasses may override for more efficient batched sampling. + """ + all_response_ids = [] + all_responses = [] + all_stop_reasons = [] + all_response_logprobs = [] + + for _ in range(num_samples): + input_batch: InferenceEngineInput = { + "prompts": None, + "prompt_token_ids": [prompt_token_ids], # Wrap in list for batch of 1 + "sampling_params": sampling_params, + "session_ids": None, + } + output = await self.generate(input_batch) + + # Extract single result from batch of 1 + all_response_ids.append(output["response_ids"][0]) + all_responses.append(output["responses"][0]) + all_stop_reasons.append(output["stop_reasons"][0]) + if output.get("response_logprobs") is not None: + all_response_logprobs.append(output["response_logprobs"][0]) + + return { + "response_ids": all_response_ids, + "responses": all_responses, + "stop_reasons": all_stop_reasons, + "response_logprobs": all_response_logprobs if all_response_logprobs else None, + } + @abstractmethod async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: """Handles OpenAI-compatible HTTP endpoint. diff --git a/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py b/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py index ae4ffc251..3943e564e 100644 --- a/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py @@ -153,6 +153,36 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu response_logprobs=response_logprobs if add_resp_logprobs else None, ) + async def sample( + self, + prompt_token_ids: List[int], + num_samples: int, + sampling_params: Dict[str, Any], + ) -> InferenceEngineOutput: + """Generate multiple independent samples from a single prompt. + + This method provides Tinker-compatible token-in/token-out sampling semantics. + Generates num_samples independent completions from the same prompt. + + Args: + prompt_token_ids: Token IDs for a single prompt (not batched). + num_samples: Number of independent samples to generate. + sampling_params: Sampling parameters (temperature, max_tokens, etc.). + + Returns: + InferenceEngineOutput containing num_samples results. + """ + # TODO(Stage 4 - Tinker API): Add multi-engine load balancing and retry logic for sample(). + # Currently routes to first engine only, which bottlenecks multi-engine deployments. + # Should mirror the load-balancing/retry/pause logic used in generate() for production use. + # See: https://github.com/NovaSky-AI/SkyRL/issues/XXX + engine = self.engines[0] + return await engine.sample( + prompt_token_ids=prompt_token_ids, + num_samples=num_samples, + sampling_params=sampling_params, + ) + async def _generate_single_with_retry( self, engine_idx: int, original_prompt_ids: List[int], sampling_params: Optional[Dict[str, Any]] ) -> InferenceEngineOutput: diff --git a/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py b/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py index 7422003cd..b89977eda 100644 --- a/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py @@ -37,6 +37,19 @@ def dp_size(self): async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: return await self.inference_engine_actor.generate.remote(input_batch=input_batch) + async def sample( + self, + prompt_token_ids: List[int], + num_samples: int, + sampling_params: Dict[str, Any], + ) -> InferenceEngineOutput: + """Delegate sample() to the remote actor for Tinker-compatible sampling.""" + return await self.inference_engine_actor.sample.remote( + prompt_token_ids=prompt_token_ids, + num_samples=num_samples, + sampling_params=sampling_params, + ) + async def wake_up(self, *args: Any, **kwargs: Any): return await self.inference_engine_actor.wake_up.remote(*args, **kwargs) diff --git a/skyrl-train/tests/gpu/gpu_ci/test_engine_generation.py b/skyrl-train/tests/gpu/gpu_ci/test_engine_generation.py index d59e0eea8..73ac16ff3 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_engine_generation.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_engine_generation.py @@ -257,3 +257,51 @@ def test_token_based_generation(ray_init_fixture, backend: str, tp_size: int, dp for i in range(len(prompts)): if not are_responses_similar([token_batch_responses[i]], [prompt_responses[i]], tolerance=0.01): print(f"Token and prompt responses differ: token={token_batch_responses[i]}, prompt={prompt_responses[i]}") + + +@pytest.mark.parametrize( + "backend,tp_size,dp_size", + [ + pytest.param("vllm", 2, 1, marks=pytest.mark.vllm), + ], + ids=["vllm_tp2"], +) +def test_sample_api(ray_init_fixture, backend: str, tp_size: int, dp_size: int): + """Test the Tinker-compatible sample() API for generating multiple independent samples.""" + cfg = get_test_actor_config() + cfg.generator.backend = backend + cfg.generator.sampling_params.temperature = 0.7 + + prompts = get_test_prompts(MODEL, 1) + tokenizer = AutoTokenizer.from_pretrained(MODEL) + prompt_token_ids = tokenizer.apply_chat_template( + prompts, add_generation_prompt=True, tokenize=True, return_dict=True + )["input_ids"][0] + + llm_client = init_ray_inference_engines(backend, tp_size=tp_size, pp_size=1, dp_size=dp_size, config=cfg) + sampling_params = get_sampling_params_for_backend(cfg.generator.backend, cfg.generator.sampling_params) + + num_samples = 3 + + async def run_sample(): + return await llm_client.sample( + prompt_token_ids=prompt_token_ids, + num_samples=num_samples, + sampling_params=sampling_params, + ) + + output = asyncio.run(run_sample()) + + assert len(output["response_ids"]) == num_samples + assert len(output["responses"]) == num_samples + assert len(output["stop_reasons"]) == num_samples + + for i, response_ids in enumerate(output["response_ids"]): + assert isinstance(response_ids, list) + assert len(response_ids) > 0 + assert all(isinstance(t, int) for t in response_ids) + + unique_responses = set(output["responses"]) + print(f"Generated {len(unique_responses)} unique responses from {num_samples} samples") + for i, resp in enumerate(output["responses"]): + print(f"Sample {i}: {resp[:100]}..." if len(resp) > 100 else f"Sample {i}: {resp}") diff --git a/skyrl-train/tests/gpu/gpu_ci/test_skyrl_gym_generator.py b/skyrl-train/tests/gpu/gpu_ci/test_skyrl_gym_generator.py index 74166d1c3..7e8a0b407 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_skyrl_gym_generator.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_skyrl_gym_generator.py @@ -10,7 +10,7 @@ from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient from skyrl_train.inference_engines.utils import get_sampling_params_for_backend from skyrl_train.generators.skyrl_gym_generator import SkyRLGymGenerator -from skyrl_train.generators.base import GeneratorInput, GeneratorOutput +from skyrl_train.generators.base import GeneratorInput, GeneratorOutput, TrajectoryID from tests.gpu.utils import Timer, get_test_generator_input from omegaconf import DictConfig, OmegaConf from skyrl_train.utils.utils import initialize_ray @@ -508,3 +508,124 @@ async def test_generator_multi_turn_gsm8k_step_wise(): assert sum(generator_output["is_last_step"]) != len(generator_output["is_last_step"]) finally: ray.shutdown() + + +@pytest.mark.vllm +@pytest.mark.asyncio +async def test_generator_with_tinker_sampling_api(): + """ + Test the generator with the use_tinker_sampling_api flag enabled. + + This verifies that the Tinker-compatible sample() API path works correctly + through the generator's agent_loop. + """ + initialize_ray(get_default_config()) + try: + model = "Qwen/Qwen2.5-1.5B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model) + + inference_engines = create_ray_wrapped_inference_engines( + num_inference_engines=1, + tensor_parallel_size=2, + model_dtype="bfloat16", + pretrain=model, + seed=42, + vllm_v1_disable_multiproc=True, + enable_prefix_caching=True, + enforce_eager=True, + shared_pg=None, + gpu_memory_utilization=0.8, + inference_engine_enable_sleep=True, + async_engine=True, + max_num_batched_tokens=32768, + max_num_seqs=1024, + tokenizer=tokenizer, + sleep_level=1, + ) + + cfg = get_test_config( + max_generate_length=256, + max_input_length=512, + batched=False, + max_turns=1, + use_conversation_multi_turn=True, + max_env_workers=0, + model=model, + is_step_wise=False, + temperature=0.7, + get_logprobs=False, + ) + + # Enable Tinker sampling API + OmegaConf.update(cfg.generator, "use_tinker_sampling_api", True) + + inference_engine_client = InferenceEngineClient( + inference_engines, + tokenizer, + cfg, + ) + + await inference_engine_client.wake_up() + + generator = SkyRLGymGenerator( + generator_cfg=cfg.generator, + skyrl_gym_cfg=cfg.environment.skyrl_gym, + inference_engine_client=inference_engine_client, + tokenizer=tokenizer, + model_name=model, + ) + + # Verify the flag is set + assert generator.use_tinker_sampling_api is True, "use_tinker_sampling_api flag should be True" + + # Use test_env which doesn't require reward_spec (simpler for testing) + num_prompts = 2 + n_samples_per_prompt = 2 + prompts = [] + env_extras = [] + for i in range(num_prompts): + prompt = [{"role": "user", "content": f"What is {i+1} + {i+1}?"}] + prompts.extend([prompt] * n_samples_per_prompt) + env_extras.extend([{}] * n_samples_per_prompt) + + input_batch: GeneratorInput = { + "prompts": prompts, + "env_classes": ["test_env"] * len(prompts), + "env_extras": env_extras, + "trajectory_ids": [TrajectoryID(instance_id=f"{i}", repetition_id=0) for i in range(len(prompts))], + } + input_batch["sampling_params"] = get_sampling_params_for_backend( + "vllm", + DictConfig({ + "temperature": 0.7, + "top_p": 1.0, + "top_k": -1, + "max_generate_length": 256, + "min_p": 0.0, + "logprobs": None, + }), + ) + + generator_output = await generator.generate(input_batch) + + # Verify output structure + assert "response_ids" in generator_output + assert "prompt_token_ids" in generator_output + assert "rewards" in generator_output + assert "loss_masks" in generator_output + assert "stop_reasons" in generator_output + + # Verify we got the expected number of outputs + expected_outputs = num_prompts * n_samples_per_prompt + assert len(generator_output["response_ids"]) == expected_outputs + + # Verify each output has valid tokens + for i, response_ids in enumerate(generator_output["response_ids"]): + assert isinstance(response_ids, list), f"Response {i} should be a list" + assert len(response_ids) > 0, f"Response {i} should have tokens" + assert all(isinstance(t, int) for t in response_ids), f"All tokens should be integers" + + logger.info(f"Tinker sampling API test passed with {len(generator_output['response_ids'])} outputs") + + finally: + ray.shutdown() diff --git a/skyrl-train/tests/gpu/gpu_ci/test_tinker_api_e2e.py b/skyrl-train/tests/gpu/gpu_ci/test_tinker_api_e2e.py new file mode 100644 index 000000000..8bcdeaa0f --- /dev/null +++ b/skyrl-train/tests/gpu/gpu_ci/test_tinker_api_e2e.py @@ -0,0 +1,362 @@ +""" +End-to-end test for Tinker API integration. + +Tests the full flow: HTTP client -> skyrl-tx API -> adapter -> skyrl-train sample() + +# Run tests: +uv run --extra dev --extra vllm pytest tests/gpu/gpu_ci/test_tinker_api_e2e.py -m "vllm" -v +""" + +import pytest +import asyncio +from dataclasses import dataclass +from typing import Literal +from transformers import AutoTokenizer +import hydra + +from skyrl_train.inference_engines.ray_wrapped_inference_engine import create_ray_wrapped_inference_engines +from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient +from skyrl_train.entrypoints.main_base import config_dir +from omegaconf import DictConfig + + +MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + + +# Lightweight duplicates of Tinker types for testing +# These mirror tx.tinker.types without requiring skyrl-tx dependencies +@dataclass +class ModelInputChunk: + tokens: list[int] + + +@dataclass +class ModelInput: + chunks: list[ModelInputChunk] + + +@dataclass +class TinkerSamplingParams: + temperature: float + max_tokens: int + seed: int = 42 + stop_tokens: list[int] | None = None + stop_strings: list[str] | None = None + top_k: int = -1 + top_p: float = 1.0 + + +@dataclass +class GeneratedSequence: + stop_reason: Literal["length", "stop"] + tokens: list[int] + logprobs: list[float] + + +@dataclass +class SampleOutput: + sequences: list[GeneratedSequence] + prompt_logprobs: list[float] | None = None + + +@dataclass +class MockSampleRequest: + """Mock SampleRequest that mimics the API request structure.""" + prompt: ModelInput + sampling_params: TinkerSamplingParams + num_samples: int + + +class TinkerAdapter: + """Test adapter that mirrors SkyRLInferenceClient conversion logic. + + This duplicates the conversion logic from tx.tinker.extra.skyrl_inference + to test the integration contract without requiring skyrl-tx dependencies. + """ + + def __init__(self, inference_client: InferenceEngineClient): + self.inference_client = inference_client + + def _extract_prompt_tokens(self, model_input: ModelInput) -> list[int]: + """Extract flat token list from ModelInput.""" + tokens = [] + for chunk in model_input.chunks: + tokens.extend(chunk.tokens) + return tokens + + def _convert_sampling_params(self, params: TinkerSamplingParams) -> dict: + """Convert Tinker SamplingParams to skyrl-train format.""" + result = { + "temperature": params.temperature, + "max_tokens": params.max_tokens, + "top_k": params.top_k, + "top_p": params.top_p, + } + + if params.seed is not None: + result["seed"] = params.seed + + if params.stop_tokens: + result["stop_token_ids"] = params.stop_tokens + if params.stop_strings: + result["stop"] = params.stop_strings + + return result + + def _convert_to_sample_output(self, output: dict) -> SampleOutput: + """Convert skyrl-train output to Tinker SampleOutput.""" + sequences = [] + num_samples = len(output["response_ids"]) + + for i in range(num_samples): + stop_reason = output["stop_reasons"][i] + if stop_reason in ("stop", "eos"): + tinker_stop_reason = "stop" + else: + tinker_stop_reason = "length" + + logprobs = [] + if output.get("response_logprobs") and output["response_logprobs"][i]: + logprobs = output["response_logprobs"][i] + + sequences.append( + GeneratedSequence( + tokens=output["response_ids"][i], + logprobs=logprobs, + stop_reason=tinker_stop_reason, + ) + ) + + return SampleOutput(sequences=sequences, prompt_logprobs=None) + + async def _sample(self, request: MockSampleRequest) -> SampleOutput: + """Execute sample and convert response - mirrors SkyRLInferenceClient._sample.""" + prompt_token_ids = self._extract_prompt_tokens(request.prompt) + sampling_params = self._convert_sampling_params(request.sampling_params) + + output = await self.inference_client.sample( + prompt_token_ids=prompt_token_ids, + num_samples=request.num_samples, + sampling_params=sampling_params, + ) + + return self._convert_to_sample_output(output) + + +@dataclass +class MockSkyRLTxApp: + """A mock skyrl-tx app that tests the TinkerAdapter directly. + + This simulates what the real skyrl-tx /api/v1/asample endpoint does, + but without needing the full FastAPI app and database. + """ + adapter: TinkerAdapter + + async def asample(self, request: dict) -> dict: + """Simulate the /api/v1/asample endpoint behavior. + + Takes a Tinker-style request, converts it, calls sample(), converts response. + """ + # Parse request into MockSampleRequest (simulating SampleRequest from API) + prompt_chunks = [ModelInputChunk(tokens=chunk["tokens"]) for chunk in request["prompt"]["chunks"]] + sample_request = MockSampleRequest( + prompt=ModelInput(chunks=prompt_chunks), + sampling_params=TinkerSamplingParams( + temperature=request["sampling_params"]["temperature"], + max_tokens=request["sampling_params"]["max_tokens"], + seed=request["sampling_params"].get("seed", 42), + top_k=request["sampling_params"].get("top_k", -1), + top_p=request["sampling_params"].get("top_p", 1.0), + ), + num_samples=request.get("num_samples", 1), + ) + + # Call the adapter's _sample method (mirrors SkyRLInferenceClient._sample) + tinker_output = await self.adapter._sample(sample_request) + + # Return as dict (simulating JSON response) + return { + "sequences": [ + { + "tokens": seq.tokens, + "logprobs": seq.logprobs, + "stop_reason": seq.stop_reason, + } + for seq in tinker_output.sequences + ], + "prompt_logprobs": tinker_output.prompt_logprobs, + } + + +def get_test_config() -> DictConfig: + """Get base config with test-specific overrides.""" + with hydra.initialize_config_dir(config_dir=config_dir): + cfg = hydra.compose(config_name="ppo_base_config") + cfg.trainer.policy.model.path = MODEL + cfg.generator.sampling_params.temperature = 0.7 + cfg.generator.sampling_params.top_p = 1 + cfg.generator.sampling_params.top_k = -1 + cfg.generator.sampling_params.max_generate_length = 64 + cfg.generator.sampling_params.min_p = 0.0 + cfg.generator.sampling_params.logprobs = None + return cfg + + +def init_inference_client(backend: str, tp_size: int, config: DictConfig) -> InferenceEngineClient: + """Initialize inference client for testing.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL) + engines = create_ray_wrapped_inference_engines( + num_inference_engines=1, + tensor_parallel_size=tp_size, + pipeline_parallel_size=1, + data_parallel_size=1, + model_dtype="bfloat16", + pretrain=MODEL, + seed=42, + vllm_v1_disable_multiproc=True, + enable_prefix_caching=True, + enforce_eager=True, + shared_pg=None, + gpu_memory_utilization=0.8, + inference_engine_enable_sleep=False, + async_engine=True, + max_num_batched_tokens=32768, + max_num_seqs=1024, + tokenizer=tokenizer, + backend=backend, + ) + return InferenceEngineClient(engines, tokenizer, config) + + +def create_tinker_adapter(inference_client: InferenceEngineClient) -> TinkerAdapter: + """Create TinkerAdapter for testing.""" + return TinkerAdapter(inference_client) + + +@pytest.mark.vllm +@pytest.mark.parametrize( + "backend,tp_size", + [ + pytest.param("vllm", 2, marks=pytest.mark.vllm), + ], + ids=["vllm_tp2"], +) +def test_e2e_tinker_sample_flow(ray_init_fixture, backend: str, tp_size: int): + """End-to-end test of Tinker sampling through skyrl-train. + + This test simulates the full flow: + 1. Client creates Tinker-style request + 2. Request goes through API (simulated) + 3. Adapter converts and calls sample() + 4. Response is converted back to Tinker format + 5. Client receives and validates response + """ + cfg = get_test_config() + cfg.generator.backend = backend + + tokenizer = AutoTokenizer.from_pretrained(MODEL) + + # Initialize skyrl-train inference client + llm_client = init_inference_client(backend, tp_size, cfg) + adapter = create_tinker_adapter(llm_client) + + # Create mock app (simulates skyrl-tx API server) + app = MockSkyRLTxApp(adapter=adapter) + + # Create Tinker-style request (as would come from tinker-cookbook client) + prompt_text = "What is the capital of France?" + messages = [{"role": "user", "content": prompt_text}] + prompt_tokens = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True + ) + + tinker_request = { + "prompt": { + "chunks": [{"tokens": prompt_tokens}] + }, + "sampling_params": { + "temperature": 0.7, + "max_tokens": 32, + "top_k": -1, + "top_p": 1.0, + }, + "num_samples": 2, + } + + # Call the API (simulated) + async def run_request(): + return await app.asample(tinker_request) + + response = asyncio.run(run_request()) + + # Validate response structure matches Tinker SampleOutput + assert "sequences" in response, "Response should have 'sequences'" + assert len(response["sequences"]) == 2, "Should have 2 samples" + + for i, seq in enumerate(response["sequences"]): + assert "tokens" in seq, f"Sequence {i} should have 'tokens'" + assert "stop_reason" in seq, f"Sequence {i} should have 'stop_reason'" + assert isinstance(seq["tokens"], list), f"Tokens should be a list" + assert len(seq["tokens"]) > 0, f"Should have generated tokens" + assert seq["stop_reason"] in ("length", "stop"), f"Invalid stop_reason" + + # Decode and print samples + print(f"\n=== E2E Test Results ===") + print(f"Prompt: {prompt_text}") + print(f"Generated {len(response['sequences'])} samples:") + for i, seq in enumerate(response["sequences"]): + decoded = tokenizer.decode(seq["tokens"], skip_special_tokens=True) + print(f" Sample {i}: {decoded[:100]}..." if len(decoded) > 100 else f" Sample {i}: {decoded}") + + +@pytest.mark.vllm +@pytest.mark.parametrize( + "backend,tp_size", + [ + pytest.param("vllm", 2, marks=pytest.mark.vllm), + ], + ids=["vllm_tp2"], +) +def test_e2e_multiple_requests(ray_init_fixture, backend: str, tp_size: int): + """Test multiple concurrent Tinker requests through skyrl-train.""" + cfg = get_test_config() + cfg.generator.backend = backend + + tokenizer = AutoTokenizer.from_pretrained(MODEL) + llm_client = init_inference_client(backend, tp_size, cfg) + adapter = create_tinker_adapter(llm_client) + app = MockSkyRLTxApp(adapter=adapter) + + prompts = [ + "What is 2 + 2?", + "Name the largest planet.", + "What color is the sky?", + ] + + requests = [] + for prompt_text in prompts: + messages = [{"role": "user", "content": prompt_text}] + prompt_tokens = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True + ) + requests.append({ + "prompt": {"chunks": [{"tokens": prompt_tokens}]}, + "sampling_params": {"temperature": 0.0, "max_tokens": 32, "top_k": -1, "top_p": 1.0}, + "num_samples": 1, + }) + + async def run_all_requests(): + tasks = [app.asample(req) for req in requests] + return await asyncio.gather(*tasks) + + responses = asyncio.run(run_all_requests()) + + assert len(responses) == len(prompts), "Should have response for each prompt" + + print(f"\n=== E2E Multiple Requests Test ===") + for i, (prompt, response) in enumerate(zip(prompts, responses)): + assert len(response["sequences"]) == 1 + decoded = tokenizer.decode(response["sequences"][0]["tokens"], skip_special_tokens=True) + print(f"Q: {prompt}") + print(f"A: {decoded[:100]}...") + print() diff --git a/skyrl-train/tests/gpu/gpu_ci/test_tinker_api_integration.py b/skyrl-train/tests/gpu/gpu_ci/test_tinker_api_integration.py new file mode 100644 index 000000000..62a15e3ba --- /dev/null +++ b/skyrl-train/tests/gpu/gpu_ci/test_tinker_api_integration.py @@ -0,0 +1,353 @@ +""" +Integration tests for Tinker API compatibility. + +Tests that skyrl-train's sample() method works with Tinker-style inputs/outputs, +verifying the integration contract between skyrl-tx API and skyrl-train inference. + +# Run tests: +uv run --isolated --extra dev --extra vllm pytest tests/gpu/gpu_ci/test_tinker_api_integration.py -m "vllm" -v +""" + +import pytest +import asyncio +from dataclasses import dataclass +from typing import Literal +from transformers import AutoTokenizer +import hydra + +from skyrl_train.inference_engines.ray_wrapped_inference_engine import create_ray_wrapped_inference_engines +from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient +from skyrl_train.inference_engines.utils import get_sampling_params_for_backend +from skyrl_train.entrypoints.main_base import config_dir +from omegaconf import DictConfig + + +# Lightweight duplicates of Tinker types for testing +# These mirror tx.tinker.types without requiring skyrl-tx dependencies +@dataclass +class ModelInputChunk: + tokens: list[int] + + +@dataclass +class ModelInput: + chunks: list[ModelInputChunk] + + +@dataclass +class TinkerSamplingParams: + temperature: float + max_tokens: int + seed: int + stop_tokens: list[int] | None = None + stop_strings: list[str] | None = None + top_k: int = -1 + top_p: float = 1.0 + + +@dataclass +class GeneratedSequence: + stop_reason: Literal["length", "stop"] + tokens: list[int] + logprobs: list[float] + + +@dataclass +class SampleOutput: + sequences: list[GeneratedSequence] + prompt_logprobs: list[float] | None = None + + +class TinkerAdapter: + """Test adapter that mirrors SkyRLInferenceClient conversion logic. + + This duplicates the conversion logic from tx.tinker.extra.skyrl_inference + to test the integration contract without requiring skyrl-tx dependencies. + """ + + def __init__(self, inference_client: InferenceEngineClient): + self.inference_client = inference_client + + def _extract_prompt_tokens(self, model_input: ModelInput) -> list[int]: + """Extract flat token list from ModelInput.""" + tokens = [] + for chunk in model_input.chunks: + tokens.extend(chunk.tokens) + return tokens + + def _convert_sampling_params(self, params: TinkerSamplingParams) -> dict: + """Convert Tinker SamplingParams to skyrl-train format.""" + result = { + "temperature": params.temperature, + "max_tokens": params.max_tokens, + "top_k": params.top_k, + "top_p": params.top_p, + } + + if params.seed is not None: + result["seed"] = params.seed + + if params.stop_tokens: + result["stop_token_ids"] = params.stop_tokens + if params.stop_strings: + result["stop"] = params.stop_strings + + return result + + def _convert_to_sample_output(self, output: dict) -> SampleOutput: + """Convert skyrl-train output to Tinker SampleOutput.""" + sequences = [] + num_samples = len(output["response_ids"]) + + for i in range(num_samples): + stop_reason = output["stop_reasons"][i] + if stop_reason in ("stop", "eos"): + tinker_stop_reason = "stop" + else: + tinker_stop_reason = "length" + + logprobs = [] + if output.get("response_logprobs") and output["response_logprobs"][i]: + logprobs = output["response_logprobs"][i] + + sequences.append( + GeneratedSequence( + tokens=output["response_ids"][i], + logprobs=logprobs, + stop_reason=tinker_stop_reason, + ) + ) + + return SampleOutput(sequences=sequences, prompt_logprobs=None) + + +MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + + +def get_test_config() -> DictConfig: + """Get base config with test-specific overrides.""" + with hydra.initialize_config_dir(config_dir=config_dir): + cfg = hydra.compose(config_name="ppo_base_config") + cfg.trainer.policy.model.path = MODEL + cfg.generator.sampling_params.temperature = 0.7 + cfg.generator.sampling_params.top_p = 1 + cfg.generator.sampling_params.top_k = -1 + cfg.generator.sampling_params.max_generate_length = 64 + cfg.generator.sampling_params.min_p = 0.0 + cfg.generator.sampling_params.logprobs = None + return cfg + + +def init_inference_client(backend: str, tp_size: int, config: DictConfig) -> InferenceEngineClient: + """Initialize inference client for testing.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL) + engines = create_ray_wrapped_inference_engines( + num_inference_engines=1, + tensor_parallel_size=tp_size, + pipeline_parallel_size=1, + data_parallel_size=1, + model_dtype="bfloat16", + pretrain=MODEL, + seed=42, + vllm_v1_disable_multiproc=True, + enable_prefix_caching=True, + enforce_eager=True, + shared_pg=None, + gpu_memory_utilization=0.8, + inference_engine_enable_sleep=False, + async_engine=True, + max_num_batched_tokens=32768, + max_num_seqs=1024, + tokenizer=tokenizer, + backend=backend, + ) + return InferenceEngineClient(engines, tokenizer, config) + + +def create_tinker_adapter(inference_client: InferenceEngineClient) -> TinkerAdapter: + """Create TinkerAdapter for testing Tinker API integration.""" + return TinkerAdapter(inference_client) + + +@pytest.mark.parametrize( + "backend,tp_size", + [ + pytest.param("vllm", 2, marks=pytest.mark.vllm), + ], + ids=["vllm_tp2"], +) +def test_tinker_type_conversion(ray_init_fixture, backend: str, tp_size: int): + """Test that Tinker-style types convert correctly to/from skyrl-train format.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL) + + # Create Tinker-style input using actual types + prompt_text = "What is 2 + 2?" + messages = [{"role": "user", "content": prompt_text}] + prompt_tokens = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True + ) + + tinker_input = ModelInput(chunks=[ModelInputChunk(tokens=prompt_tokens)]) + tinker_params = TinkerSamplingParams( + temperature=0.7, + max_tokens=32, + seed=42, + top_k=-1, + top_p=1.0, + ) + + # Create adapter to test conversion methods + cfg = get_test_config() + cfg.generator.backend = backend + llm_client = init_inference_client(backend, tp_size, cfg) + adapter = create_tinker_adapter(llm_client) + + # Test conversion methods that mirror SkyRLInferenceClient + converted_tokens = adapter._extract_prompt_tokens(tinker_input) + converted_params = adapter._convert_sampling_params(tinker_params) + + # Verify conversions + assert converted_tokens == prompt_tokens, "Token conversion should preserve all tokens" + assert converted_params["temperature"] == 0.7 + assert converted_params["max_tokens"] == 32 + assert converted_params["seed"] == 42 + assert converted_params["top_k"] == -1 + assert converted_params["top_p"] == 1.0 + + +@pytest.mark.parametrize( + "backend,tp_size", + [ + pytest.param("vllm", 2, marks=pytest.mark.vllm), + ], + ids=["vllm_tp2"], +) +def test_tinker_sample_integration(ray_init_fixture, backend: str, tp_size: int): + """Test end-to-end Tinker-style sampling through skyrl-train. + + This test verifies the integration contract: + 1. Accept Tinker-style ModelInput and SamplingParams + 2. Convert to skyrl-train format using adapter methods + 3. Call sample() + 4. Convert result to Tinker SampleOutput + """ + cfg = get_test_config() + cfg.generator.backend = backend + + tokenizer = AutoTokenizer.from_pretrained(MODEL) + llm_client = init_inference_client(backend, tp_size, cfg) + adapter = create_tinker_adapter(llm_client) + + # Create Tinker-style input + prompt_text = "What is 2 + 2?" + messages = [{"role": "user", "content": prompt_text}] + prompt_tokens = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True + ) + + tinker_input = ModelInput(chunks=[ModelInputChunk(tokens=prompt_tokens)]) + tinker_params = TinkerSamplingParams( + temperature=0.7, + max_tokens=32, + seed=42, # Use seed for reproducibility + top_k=-1, + top_p=1.0, + ) + num_samples = 3 + + # Convert to skyrl-train format using adapter methods + converted_tokens = adapter._extract_prompt_tokens(tinker_input) + converted_params = adapter._convert_sampling_params(tinker_params) + + # Call skyrl-train's sample() + async def run_sample(): + return await llm_client.sample( + prompt_token_ids=converted_tokens, + num_samples=num_samples, + sampling_params=converted_params, + ) + + output = asyncio.run(run_sample()) + + # Convert to Tinker format using adapter method + tinker_output = adapter._convert_to_sample_output(output) + + # Verify Tinker output structure + assert isinstance(tinker_output, SampleOutput), "Should return SampleOutput type" + assert len(tinker_output.sequences) == num_samples, f"Expected {num_samples} sequences" + + for i, seq in enumerate(tinker_output.sequences): + # Verify each sequence has tokens + assert isinstance(seq, GeneratedSequence), f"Sequence {i} should be GeneratedSequence" + assert isinstance(seq.tokens, list), f"Sequence {i} tokens should be a list" + assert len(seq.tokens) > 0, f"Sequence {i} should have generated tokens" + assert all(isinstance(t, int) for t in seq.tokens), f"All tokens should be integers" + + # Verify stop reason is valid Tinker format + assert seq.stop_reason in ("length", "stop"), f"Invalid stop reason: {seq.stop_reason}" + + # Verify logprobs is a list (may be empty if not requested) + assert isinstance(seq.logprobs, list), f"Logprobs should be a list" + + # Print samples for debugging + print(f"\nGenerated {len(tinker_output.sequences)} Tinker-format sequences:") + for i, seq in enumerate(tinker_output.sequences): + decoded = tokenizer.decode(seq.tokens, skip_special_tokens=True) + print(f" Sample {i}: {decoded[:80]}... (stop={seq.stop_reason}, {len(seq.tokens)} tokens)") + + +@pytest.mark.parametrize( + "backend,tp_size", + [ + pytest.param("vllm", 2, marks=pytest.mark.vllm), + ], + ids=["vllm_tp2"], +) +def test_tinker_stop_tokens(ray_init_fixture, backend: str, tp_size: int): + """Test that stop tokens are handled correctly in Tinker format.""" + cfg = get_test_config() + cfg.generator.backend = backend + + tokenizer = AutoTokenizer.from_pretrained(MODEL) + llm_client = init_inference_client(backend, tp_size, cfg) + adapter = create_tinker_adapter(llm_client) + + # Create input with stop strings + prompt_text = "Count from 1 to 10:" + messages = [{"role": "user", "content": prompt_text}] + prompt_tokens = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True + ) + + tinker_input = ModelInput(chunks=[ModelInputChunk(tokens=prompt_tokens)]) + tinker_params = TinkerSamplingParams( + temperature=0.0, # Deterministic + max_tokens=64, + seed=42, + stop_strings=["5"], # Stop at "5" + top_k=-1, + top_p=1.0, + ) + + # Convert and call using adapter methods + converted_tokens = adapter._extract_prompt_tokens(tinker_input) + converted_params = adapter._convert_sampling_params(tinker_params) + + async def run_sample(): + return await llm_client.sample( + prompt_token_ids=converted_tokens, + num_samples=1, + sampling_params=converted_params, + ) + + output = asyncio.run(run_sample()) + tinker_output = adapter._convert_to_sample_output(output) + + # Should have stopped at "5" + assert len(tinker_output.sequences) == 1 + decoded = tokenizer.decode(tinker_output.sequences[0].tokens, skip_special_tokens=True) + print(f"Output with stop string '5': {decoded}") + + # Verify we got a valid stop reason + assert tinker_output.sequences[0].stop_reason in ("length", "stop"), \ + f"Stop reason should be valid Tinker format, got: {tinker_output.sequences[0].stop_reason}" diff --git a/skyrl-tx/tx/tinker/extra/__init__.py b/skyrl-tx/tx/tinker/extra/__init__.py index 9d4472d73..e898667df 100644 --- a/skyrl-tx/tx/tinker/extra/__init__.py +++ b/skyrl-tx/tx/tinker/extra/__init__.py @@ -1,3 +1,4 @@ from tx.tinker.extra.external_inference import ExternalInferenceClient +from tx.tinker.extra.skyrl_inference import SkyRLInferenceClient, attach_skyrl_inference -__all__ = ["ExternalInferenceClient"] +__all__ = ["ExternalInferenceClient", "SkyRLInferenceClient", "attach_skyrl_inference"] diff --git a/skyrl-tx/tx/tinker/extra/skyrl_inference.py b/skyrl-tx/tx/tinker/extra/skyrl_inference.py new file mode 100644 index 000000000..497c37e31 --- /dev/null +++ b/skyrl-tx/tx/tinker/extra/skyrl_inference.py @@ -0,0 +1,221 @@ +"""SkyRL-Train inference client for direct Python integration. + +This module provides an adapter that allows skyrl-tx's API server to call +skyrl-train's InferenceEngineClient.sample() directly, without HTTP overhead. + +Architecture: + skyrl-tx API (/api/v1/asample) -> SkyRLInferenceClient -> InferenceEngineClient.sample() + +Usage: + # From skyrl-train, after initializing inference engines: + from tx.tinker.extra.skyrl_inference import attach_skyrl_inference + + # Attach to running API server + attach_skyrl_inference(app, inference_client) + + # Or start API server with skyrl-train inference: + from tx.tinker.extra.skyrl_inference import create_app_with_skyrl_inference + app = create_app_with_skyrl_inference(inference_client, engine_config) +""" + +from datetime import datetime, timezone +from typing import TYPE_CHECKING + +from sqlmodel.ext.asyncio.session import AsyncSession + +from tx.tinker import types +from tx.tinker.db_models import FutureDB, RequestStatus +from tx.utils.log import logger + +if TYPE_CHECKING: + from fastapi import FastAPI + from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient + + +class SkyRLInferenceClient: + """Client for calling skyrl-train's inference engines directly. + + This adapter converts between skyrl-tx's Tinker API types and skyrl-train's + InferenceEngineInterface, enabling direct Python calls without HTTP overhead. + + Usage: + # During app startup + inference_client = InferenceEngineClient(engines, tokenizer, config) + skyrl_client = SkyRLInferenceClient(inference_client, db_engine) + app.state.skyrl_inference_client = skyrl_client + + # In /api/v1/asample endpoint + asyncio.create_task(skyrl_client.call_and_store_result(request_id, sample_req)) + """ + + def __init__(self, inference_client: "InferenceEngineClient", db_engine): + """Initialize the SkyRL inference client. + + Args: + inference_client: skyrl-train's InferenceEngineClient with engines initialized. + db_engine: SQLModel async engine for storing results in FutureDB. + """ + self.inference_client = inference_client + self.db_engine = db_engine + + async def call_and_store_result( + self, + request_id: int, + sample_req, + model_id: str = "", + checkpoint_id: str = "", + ): + """Background task to call skyrl-train inference and store result in database. + + Args: + request_id: FutureDB request ID to update with results. + sample_req: SampleRequest from the API endpoint. + model_id: Model identifier (unused for now, skyrl-train uses pre-loaded model). + checkpoint_id: Checkpoint identifier (unused for now). + """ + try: + result = await self._sample(sample_req) + result_data = result.model_dump() + status = RequestStatus.COMPLETED + except Exception as e: + logger.exception("SkyRL inference error") + result_data = {"error": str(e), "status": "failed"} + status = RequestStatus.FAILED + + async with AsyncSession(self.db_engine) as session: + future = await session.get(FutureDB, request_id) + future.result_data = result_data + future.status = status + future.completed_at = datetime.now(timezone.utc) + await session.commit() + + async def _sample(self, request) -> types.SampleOutput: + """Call skyrl-train's sample() and convert response to Tinker format. + + Args: + request: SampleRequest from the API endpoint. + + Returns: + types.SampleOutput with generated sequences. + """ + # Convert ModelInput to flat token list + prompt_token_ids = self._extract_prompt_tokens(request.prompt) + + # Convert SamplingParams to dict for skyrl-train + sampling_params = self._convert_sampling_params(request.sampling_params) + + # Call skyrl-train's sample() method + output = await self.inference_client.sample( + prompt_token_ids=prompt_token_ids, + num_samples=request.num_samples, + sampling_params=sampling_params, + ) + + # Convert InferenceEngineOutput to SampleOutput + return self._convert_to_sample_output(output) + + def _extract_prompt_tokens(self, model_input) -> list[int]: + """Extract flat token list from ModelInput. + + Args: + model_input: ModelInput with chunks of tokens. + + Returns: + Flat list of token IDs. + """ + tokens = [] + for chunk in model_input.chunks: + tokens.extend(chunk.tokens) + return tokens + + def _convert_sampling_params(self, params) -> dict: + """Convert Tinker SamplingParams to skyrl-train format. + + Args: + params: SamplingParams from Tinker API. + + Returns: + Dict compatible with skyrl-train's sampling. + """ + result = { + "temperature": params.temperature, + "max_tokens": params.max_tokens, + "top_k": params.top_k, + "top_p": params.top_p, + } + + if params.seed is not None: + result["seed"] = params.seed + + # Handle stop tokens/strings + if params.stop_tokens: + result["stop_token_ids"] = params.stop_tokens + if params.stop_strings: + result["stop"] = params.stop_strings + + return result + + def _convert_to_sample_output(self, output) -> types.SampleOutput: + """Convert skyrl-train's InferenceEngineOutput to Tinker SampleOutput. + + Args: + output: InferenceEngineOutput from skyrl-train's sample(). + + Returns: + types.SampleOutput with GeneratedSequence list. + """ + sequences = [] + num_samples = len(output["response_ids"]) + + for i in range(num_samples): + # Map stop_reason to Tinker's expected values + stop_reason = output["stop_reasons"][i] + if stop_reason in ("stop", "eos"): + tinker_stop_reason = "stop" + else: + tinker_stop_reason = "length" + + # Extract logprobs if available + logprobs = [] + if output.get("response_logprobs") and output["response_logprobs"][i]: + logprobs = output["response_logprobs"][i] + + sequences.append( + types.GeneratedSequence( + tokens=output["response_ids"][i], + logprobs=logprobs, + stop_reason=tinker_stop_reason, + ) + ) + + # Note: prompt_logprobs not supported yet in skyrl-train's sample() + return types.SampleOutput(sequences=sequences, prompt_logprobs=None) + + +def attach_skyrl_inference(app: "FastAPI", inference_client: "InferenceEngineClient") -> None: + """Attach SkyRL inference client to an existing FastAPI app. + + This enables the /api/v1/asample endpoint to use skyrl-train's inference + engines directly instead of the internal JAX backend or external vLLM. + + Args: + app: The FastAPI app instance (must have db_engine in state). + inference_client: Initialized InferenceEngineClient from skyrl-train. + + Example: + # In skyrl-train after engines are initialized: + from tx.tinker.extra.skyrl_inference import attach_skyrl_inference + + app = get_running_api_app() # Get the FastAPI app + attach_skyrl_inference(app, llm_client) + """ + if not hasattr(app.state, "db_engine"): + raise RuntimeError("App must have db_engine initialized before attaching SkyRL inference") + + skyrl_client = SkyRLInferenceClient(inference_client, app.state.db_engine) + app.state.skyrl_inference_client = skyrl_client + + # Also set as external_inference_client so existing endpoint code routes to it + app.state.external_inference_client = skyrl_client + + logger.info("SkyRL-train inference client attached to API server")