Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ generator:
# Disable CUDA graphs by default for stability. Set to false for higher performance, but this may affect convergence for long-running and/or long context training jobs.
enforce_eager: true
fully_sharded_loras: false
# Enable Ray Prometheus stats logger for vLLM inference engine metrics (vLLM v1 only)
# When enabled, uses vllm.v1.metrics.ray_wrappers.RayPrometheusStatLogger
enable_ray_prometheus_stats: false
gpu_memory_utilization: 0.8
max_num_seqs: 1024
remote_inference_engine_urls: ["127.0.0.1:8001"]
Expand Down
1 change: 1 addition & 0 deletions skyrl-train/skyrl_train/entrypoints/main_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def create_ray_wrapped_inference_engines_from_config(cfg: DictConfig, colocate_p
"tokenizer": tokenizer,
"backend": cfg.generator.backend,
"engine_init_kwargs": cfg.generator.engine_init_kwargs,
"enable_ray_prometheus_stats": cfg.generator.enable_ray_prometheus_stats,
}

# Conditionally add LoRA parameters if LoRA is enabled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def create_ray_wrapped_inference_engines(
engine_init_kwargs: Dict[str, Any] = {},
rope_scaling: Dict[str, Any] = {},
rope_theta: float | None = None,
enable_ray_prometheus_stats: bool = False,
) -> List[InferenceEngineInterface]:
"""
Create a list of RayWrappedInferenceEngine instances wrapping Ray actor handles to InferenceEngineInterface
Expand Down Expand Up @@ -221,6 +222,7 @@ def create_ray_wrapped_inference_engines(
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
max_logprobs=1, # only need chosen-token logprobs
enable_ray_prometheus_stats=enable_ray_prometheus_stats,
**dp_kwargs,
**engine_init_kwargs,
**lora_kwargs,
Expand Down
41 changes: 40 additions & 1 deletion skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,13 @@ def _create_engine(self, *args, **kwargs):
"Pipeline parallelism is only supported with AsyncVLLMInferenceEngine. "
"Please set `generator.async_engine=true` in your config."
)
# Pop enable_ray_prometheus_stats - only supported for async engine
enable_ray_prometheus_stats = kwargs.pop("enable_ray_prometheus_stats", False)
if enable_ray_prometheus_stats:
logger.warning(
"enable_ray_prometheus_stats is only supported with AsyncVLLMInferenceEngine. "
"Set `generator.async_engine=true` to enable Ray Prometheus stats logging."
)
return vllm.LLM(*args, **kwargs)

async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
Expand Down Expand Up @@ -350,12 +357,20 @@ def __init__(self, *args, **kwargs):

def _create_engine(self, *args, **kwargs):
openai_kwargs = pop_openai_kwargs(kwargs)
enable_ray_prometheus_stats = kwargs.pop("enable_ray_prometheus_stats", False)

# TODO (erictang000): potentially enable log requests for a debugging mode
if version.parse(vllm.__version__) >= version.parse("0.10.0"):
engine_args = vllm.AsyncEngineArgs(enable_log_requests=False, **kwargs)
else:
engine_args = vllm.AsyncEngineArgs(disable_log_requests=True, **kwargs)
engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)

# Setup stat loggers for vLLM v1 if Ray Prometheus stats are enabled
stat_loggers = None
if enable_ray_prometheus_stats:
stat_loggers = self._create_ray_prometheus_stat_loggers()

engine = vllm.AsyncLLMEngine.from_engine_args(engine_args, stat_loggers=stat_loggers)

# Adapted from https://github.com/volcengine/verl/blob/e90f18c40aa639cd25092b78a5ff7e2d2508c088/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L327
model_config = engine.model_config
Expand Down Expand Up @@ -388,6 +403,30 @@ def _create_engine(self, *args, **kwargs):
)
return engine

def _create_ray_prometheus_stat_loggers(self):
"""Create Ray Prometheus stat loggers for vLLM metrics.

Returns stat_loggers in the format expected by vLLM's from_engine_args().
For vLLM v1 (0.9.0+), this returns a list of StatLoggerFactory callables.
For older versions where the v1 API is not available, this returns `None`.

See: https://docs.vllm.ai/en/latest/api/vllm/v1/metrics/ray_wrappers/
"""
try:
# Try vLLM v1 API first (0.9.0+)
from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger

logger.info("Enabling RayPrometheusStatLogger for vLLM inference engine metrics")
# For v1, stat_loggers is a list of factory callables
return [RayPrometheusStatLogger]
except ImportError:
logger.warning(
"RayPrometheusStatLogger not available in this vLLM version. "
"For Ray-integrated metrics, upgrade to vLLM >= 0.9.0. "
"Stat logging will be disabled."
)
return None

async def _load_lora_from_disk(self, lora_path: str):
"""Load LoRA adapters from disk using vLLM's native add_lora method."""
lora_id = int(time.time_ns() % 0x7FFFFFFF)
Expand Down
Empty file.
Empty file.
103 changes: 103 additions & 0 deletions skyrl-train/tests/cpu/inf_engines/vllm/test_ray_prometheus_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
Test for RayPrometheusStatLogger integration in the vLLM engine.

Run with:
uv run --isolated --extra dev pytest tests/cpu/inf_engines/vllm/test_ray_prometheus_stats.py
"""

from unittest.mock import patch, MagicMock
import sys


class TestRayPrometheusStatLoggers:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's great that you've added tests for the success case. To make the tests more robust, consider adding a test case for the failure scenario where RayPrometheusStatLogger cannot be imported (e.g., on older vLLM versions). This would verify that the method correctly returns None and logs a warning.

You could add a test like this to TestRayPrometheusStatLoggers:

def test_create_ray_prometheus_stat_loggers_v1_unavailable(self):
    """Test that None is returned when vLLM v1 API is not available."""
    # By setting the module to None in sys.modules, the import will fail.
    with patch.dict(sys.modules, {"vllm.v1.metrics.ray_wrappers": None}):
        from skyrl_train.inference_engines.vllm.vllm_engine import AsyncVLLMInferenceEngine

        # Create a minimal instance without actually initializing the engine
        engine = object.__new__(AsyncVLLMInferenceEngine)

        with patch("skyrl_train.inference_engines.vllm.vllm_engine.logger") as mock_logger:
            result = engine._create_ray_prometheus_stat_loggers()

            assert result is None
            mock_logger.warning.assert_called_once()
            assert "not available in this vLLM version" in mock_logger.warning.call_args[0][0]

"""Test cases for _create_ray_prometheus_stat_loggers method."""

def test_create_ray_prometheus_stat_loggers_v1_available(self):
"""Test that RayPrometheusStatLogger is returned when vLLM v1 API is available."""
# Create a mock for the v1 RayPrometheusStatLogger
mock_stat_logger = MagicMock()
mock_stat_logger.__name__ = "RayPrometheusStatLogger"

mock_ray_wrappers = MagicMock()
mock_ray_wrappers.RayPrometheusStatLogger = mock_stat_logger

# Patch the import to return our mock
with patch.dict(sys.modules, {"vllm.v1.metrics.ray_wrappers": mock_ray_wrappers}):
from skyrl_train.inference_engines.vllm.vllm_engine import AsyncVLLMInferenceEngine

# Create a minimal instance without actually initializing the engine
engine = object.__new__(AsyncVLLMInferenceEngine)

result = engine._create_ray_prometheus_stat_loggers()

# Should return a list with the stat logger class
assert result is not None
assert isinstance(result, list)
assert len(result) == 1
assert result[0] == mock_stat_logger

def test_create_ray_prometheus_stat_loggers_v1_unavailable(self):
"""Test that None is returned when vLLM v1 API is not available."""
# By setting the module to None in sys.modules, the import will fail.
with patch.dict(sys.modules, {"vllm.v1.metrics.ray_wrappers": None}):
from skyrl_train.inference_engines.vllm.vllm_engine import AsyncVLLMInferenceEngine

# Create a minimal instance without actually initializing the engine
engine = object.__new__(AsyncVLLMInferenceEngine)

with patch("skyrl_train.inference_engines.vllm.vllm_engine.logger") as mock_logger:
result = engine._create_ray_prometheus_stat_loggers()

assert result is None
mock_logger.warning.assert_called_once()
assert "not available in this vLLM version" in mock_logger.warning.call_args[0][0]


class TestConfigIntegration:
"""Test that configuration flows correctly through the stack."""

def test_config_default_value(self):
"""Test that enable_ray_prometheus_stats defaults to False in config."""
from omegaconf import OmegaConf

# Load the base config
config_content = """
generator:
enable_ray_prometheus_stats: false
"""
cfg = OmegaConf.create(config_content)
assert cfg.generator.enable_ray_prometheus_stats is False

def test_config_can_be_enabled(self):
"""Test that enable_ray_prometheus_stats can be set to True."""
from omegaconf import OmegaConf

config_content = """
generator:
enable_ray_prometheus_stats: true
"""
cfg = OmegaConf.create(config_content)
assert cfg.generator.enable_ray_prometheus_stats is True


class TestKwargsHandling:
"""Test that enable_ray_prometheus_stats is properly handled in kwargs."""

def test_enable_ray_prometheus_stats_popped_from_kwargs(self):
"""Test that enable_ray_prometheus_stats is properly popped from kwargs."""
# This test verifies the configuration flows correctly
kwargs = {"enable_ray_prometheus_stats": True, "other_param": "value"}

# Pop should remove it from kwargs (same logic as in _create_engine)
enable_stats = kwargs.pop("enable_ray_prometheus_stats", False)
assert enable_stats is True
assert "enable_ray_prometheus_stats" not in kwargs
assert kwargs == {"other_param": "value"}

def test_enable_ray_prometheus_stats_defaults_to_false(self):
"""Test that enable_ray_prometheus_stats defaults to False when not present."""
kwargs = {"other_param": "value"}

enable_stats = kwargs.pop("enable_ray_prometheus_stats", False)
assert enable_stats is False
assert kwargs == {"other_param": "value"}
Loading