Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ generator:

step_wise_trajectories: false

# Experimental: Use Tinker-compatible sample() API instead of generate()
# This enables token-in/token-out semantics matching the Tinker API
use_tinker_sampling_api: false

environment:
env_class: "gsm8k"
# NOTE: environment specific defaults for environment.skyrl_gym are set at the following path:
Expand Down
19 changes: 15 additions & 4 deletions skyrl-train/skyrl_train/generators/skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def __init__(
else:
self.env_executor = None

# Experimental: Use Tinker-compatible sample() API instead of generate()
self.use_tinker_sampling_api = generator_cfg.get("use_tinker_sampling_api", False)

self._validate_cfg(generator_cfg)

# base_conversation is used when `use_conversation_multi_turn==True and custom_chat_template==None` to
Expand Down Expand Up @@ -288,10 +291,18 @@ async def agent_loop(
agent_loop_state.loss_mask = []
agent_loop_state.rollout_logprobs = None

engine_input = InferenceEngineInput(
prompt_token_ids=[agent_loop_state.input_ids], session_ids=[session_id], sampling_params=sampling_params
)
engine_output = await self.inference_engine_client.generate(engine_input)
if self.use_tinker_sampling_api:
# Use Tinker-compatible sample() API for token-in/token-out semantics
engine_output = await self.inference_engine_client.sample(
prompt_token_ids=agent_loop_state.input_ids,
num_samples=1,
sampling_params=current_sampling_params,
)
else:
engine_input = InferenceEngineInput(
prompt_token_ids=[agent_loop_state.input_ids], session_ids=[session_id], sampling_params=sampling_params
)
engine_output = await self.inference_engine_client.generate(engine_input)
output = engine_output["responses"][0]
output_ids = engine_output["response_ids"][0]
stop_reason = engine_output["stop_reasons"][0]
Expand Down
56 changes: 56 additions & 0 deletions skyrl-train/skyrl_train/inference_engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,62 @@ class InferenceEngineInterface(ABC):
async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
raise NotImplementedError

async def sample(
self,
prompt_token_ids: List[int],
num_samples: int,
sampling_params: Dict[str, Any],
) -> InferenceEngineOutput:
"""Generate multiple independent samples from a single prompt.

This method provides Tinker-compatible token-in/token-out sampling semantics.
Unlike generate() which processes a batch of different prompts, sample() generates
num_samples independent completions from the same prompt.

Args:
prompt_token_ids: Token IDs for a single prompt (not batched).
num_samples: Number of independent samples to generate.
sampling_params: Sampling parameters (temperature, max_tokens, etc.).

Returns:
InferenceEngineOutput containing num_samples results:
- response_ids: List of num_samples token ID lists
- responses: List of num_samples decoded strings
- stop_reasons: List of num_samples stop reasons
- response_logprobs: Optional list of num_samples logprob lists

Note:
Default implementation calls generate() sequentially num_samples times.
Subclasses may override for more efficient batched sampling.
"""
all_response_ids = []
all_responses = []
all_stop_reasons = []
all_response_logprobs = []

for _ in range(num_samples):
input_batch: InferenceEngineInput = {
"prompts": None,
"prompt_token_ids": [prompt_token_ids], # Wrap in list for batch of 1
"sampling_params": sampling_params,
"session_ids": None,
}
output = await self.generate(input_batch)

# Extract single result from batch of 1
all_response_ids.append(output["response_ids"][0])
all_responses.append(output["responses"][0])
all_stop_reasons.append(output["stop_reasons"][0])
if output.get("response_logprobs") is not None:
all_response_logprobs.append(output["response_logprobs"][0])

return {
"response_ids": all_response_ids,
"responses": all_responses,
"stop_reasons": all_stop_reasons,
"response_logprobs": all_response_logprobs if all_response_logprobs else None,
}

@abstractmethod
async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Handles OpenAI-compatible HTTP endpoint.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,36 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu
response_logprobs=response_logprobs if add_resp_logprobs else None,
)

async def sample(
self,
prompt_token_ids: List[int],
num_samples: int,
sampling_params: Dict[str, Any],
) -> InferenceEngineOutput:
"""Generate multiple independent samples from a single prompt.

This method provides Tinker-compatible token-in/token-out sampling semantics.
Generates num_samples independent completions from the same prompt.

Args:
prompt_token_ids: Token IDs for a single prompt (not batched).
num_samples: Number of independent samples to generate.
sampling_params: Sampling parameters (temperature, max_tokens, etc.).

Returns:
InferenceEngineOutput containing num_samples results.
"""
# TODO(Stage 4 - Tinker API): Add multi-engine load balancing and retry logic for sample().
# Currently routes to first engine only, which bottlenecks multi-engine deployments.
# Should mirror the load-balancing/retry/pause logic used in generate() for production use.
# See: https://github.com/NovaSky-AI/SkyRL/issues/XXX
engine = self.engines[0]
return await engine.sample(
prompt_token_ids=prompt_token_ids,
num_samples=num_samples,
sampling_params=sampling_params,
)

async def _generate_single_with_retry(
self, engine_idx: int, original_prompt_ids: List[int], sampling_params: Optional[Dict[str, Any]]
) -> InferenceEngineOutput:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ def dp_size(self):
async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
return await self.inference_engine_actor.generate.remote(input_batch=input_batch)

async def sample(
self,
prompt_token_ids: List[int],
num_samples: int,
sampling_params: Dict[str, Any],
) -> InferenceEngineOutput:
"""Delegate sample() to the remote actor for Tinker-compatible sampling."""
return await self.inference_engine_actor.sample.remote(
prompt_token_ids=prompt_token_ids,
num_samples=num_samples,
sampling_params=sampling_params,
)

async def wake_up(self, *args: Any, **kwargs: Any):
return await self.inference_engine_actor.wake_up.remote(*args, **kwargs)

Expand Down
48 changes: 48 additions & 0 deletions skyrl-train/tests/gpu/gpu_ci/test_engine_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,51 @@ def test_token_based_generation(ray_init_fixture, backend: str, tp_size: int, dp
for i in range(len(prompts)):
if not are_responses_similar([token_batch_responses[i]], [prompt_responses[i]], tolerance=0.01):
print(f"Token and prompt responses differ: token={token_batch_responses[i]}, prompt={prompt_responses[i]}")


@pytest.mark.parametrize(
"backend,tp_size,dp_size",
[
pytest.param("vllm", 2, 1, marks=pytest.mark.vllm),
],
ids=["vllm_tp2"],
)
def test_sample_api(ray_init_fixture, backend: str, tp_size: int, dp_size: int):
"""Test the Tinker-compatible sample() API for generating multiple independent samples."""
cfg = get_test_actor_config()
cfg.generator.backend = backend
cfg.generator.sampling_params.temperature = 0.7

prompts = get_test_prompts(MODEL, 1)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
prompt_token_ids = tokenizer.apply_chat_template(
prompts, add_generation_prompt=True, tokenize=True, return_dict=True
)["input_ids"][0]

llm_client = init_ray_inference_engines(backend, tp_size=tp_size, pp_size=1, dp_size=dp_size, config=cfg)
sampling_params = get_sampling_params_for_backend(cfg.generator.backend, cfg.generator.sampling_params)

num_samples = 3

async def run_sample():
return await llm_client.sample(
prompt_token_ids=prompt_token_ids,
num_samples=num_samples,
sampling_params=sampling_params,
)

output = asyncio.run(run_sample())

assert len(output["response_ids"]) == num_samples
assert len(output["responses"]) == num_samples
assert len(output["stop_reasons"]) == num_samples

for i, response_ids in enumerate(output["response_ids"]):
assert isinstance(response_ids, list)
assert len(response_ids) > 0
assert all(isinstance(t, int) for t in response_ids)

unique_responses = set(output["responses"])
print(f"Generated {len(unique_responses)} unique responses from {num_samples} samples")
for i, resp in enumerate(output["responses"]):
print(f"Sample {i}: {resp[:100]}..." if len(resp) > 100 else f"Sample {i}: {resp}")
123 changes: 122 additions & 1 deletion skyrl-train/tests/gpu/gpu_ci/test_skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient
from skyrl_train.inference_engines.utils import get_sampling_params_for_backend
from skyrl_train.generators.skyrl_gym_generator import SkyRLGymGenerator
from skyrl_train.generators.base import GeneratorInput, GeneratorOutput
from skyrl_train.generators.base import GeneratorInput, GeneratorOutput, TrajectoryID
from tests.gpu.utils import Timer, get_test_generator_input
from omegaconf import DictConfig, OmegaConf
from skyrl_train.utils.utils import initialize_ray
Expand Down Expand Up @@ -508,3 +508,124 @@ async def test_generator_multi_turn_gsm8k_step_wise():
assert sum(generator_output["is_last_step"]) != len(generator_output["is_last_step"])
finally:
ray.shutdown()


@pytest.mark.vllm
@pytest.mark.asyncio
async def test_generator_with_tinker_sampling_api():
"""
Test the generator with the use_tinker_sampling_api flag enabled.

This verifies that the Tinker-compatible sample() API path works correctly
through the generator's agent_loop.
"""
initialize_ray(get_default_config())
try:
model = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model)

inference_engines = create_ray_wrapped_inference_engines(
num_inference_engines=1,
tensor_parallel_size=2,
model_dtype="bfloat16",
pretrain=model,
seed=42,
vllm_v1_disable_multiproc=True,
enable_prefix_caching=True,
enforce_eager=True,
shared_pg=None,
gpu_memory_utilization=0.8,
inference_engine_enable_sleep=True,
async_engine=True,
max_num_batched_tokens=32768,
max_num_seqs=1024,
tokenizer=tokenizer,
sleep_level=1,
)

cfg = get_test_config(
max_generate_length=256,
max_input_length=512,
batched=False,
max_turns=1,
use_conversation_multi_turn=True,
max_env_workers=0,
model=model,
is_step_wise=False,
temperature=0.7,
get_logprobs=False,
)

# Enable Tinker sampling API
OmegaConf.update(cfg.generator, "use_tinker_sampling_api", True)

inference_engine_client = InferenceEngineClient(
inference_engines,
tokenizer,
cfg,
)

await inference_engine_client.wake_up()

generator = SkyRLGymGenerator(
generator_cfg=cfg.generator,
skyrl_gym_cfg=cfg.environment.skyrl_gym,
inference_engine_client=inference_engine_client,
tokenizer=tokenizer,
model_name=model,
)

# Verify the flag is set
assert generator.use_tinker_sampling_api is True, "use_tinker_sampling_api flag should be True"

# Use test_env which doesn't require reward_spec (simpler for testing)
num_prompts = 2
n_samples_per_prompt = 2
prompts = []
env_extras = []
for i in range(num_prompts):
prompt = [{"role": "user", "content": f"What is {i+1} + {i+1}?"}]
prompts.extend([prompt] * n_samples_per_prompt)
env_extras.extend([{}] * n_samples_per_prompt)

input_batch: GeneratorInput = {
"prompts": prompts,
"env_classes": ["test_env"] * len(prompts),
"env_extras": env_extras,
"trajectory_ids": [TrajectoryID(instance_id=f"{i}", repetition_id=0) for i in range(len(prompts))],
}
input_batch["sampling_params"] = get_sampling_params_for_backend(
"vllm",
DictConfig({
"temperature": 0.7,
"top_p": 1.0,
"top_k": -1,
"max_generate_length": 256,
"min_p": 0.0,
"logprobs": None,
}),
)

generator_output = await generator.generate(input_batch)

# Verify output structure
assert "response_ids" in generator_output
assert "prompt_token_ids" in generator_output
assert "rewards" in generator_output
assert "loss_masks" in generator_output
assert "stop_reasons" in generator_output

# Verify we got the expected number of outputs
expected_outputs = num_prompts * n_samples_per_prompt
assert len(generator_output["response_ids"]) == expected_outputs

# Verify each output has valid tokens
for i, response_ids in enumerate(generator_output["response_ids"]):
assert isinstance(response_ids, list), f"Response {i} should be a list"
assert len(response_ids) > 0, f"Response {i} should have tokens"
assert all(isinstance(t, int) for t in response_ids), f"All tokens should be integers"

logger.info(f"Tinker sampling API test passed with {len(generator_output['response_ids'])} outputs")

finally:
ray.shutdown()
Loading
Loading