Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions nemo_gym/ray_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from typing import Optional

from ray.util.scheduling_strategies import (
NodeAffinitySchedulingStrategy,
)

from nemo_gym.server_utils import (
get_global_config_dict,
)


def spinup_single_ray_gpu_node_worker(worker_cls, runtime_env: Optional[dict] = None):
cfg = get_global_config_dict()
cfg.setdefault("_ray_state", {})
cfg["_ray_state"].setdefault("spunup_node_ids", {})
nodes = cfg.get("ray_nodes", [])
num_gpus_per_node = cfg.get("ray_num_gpus_per_node", 0)
for node in nodes:
if node["node_id"] in cfg["_ray_state"]["spunup_node_ids"]:
continue
worker_options = {}
worker_options["num_gpus"] = num_gpus_per_node
worker_options["scheduling_strategy"] = NodeAffinitySchedulingStrategy(
node_id=node["node_id"],
soft=False,
)
py_exec = sys.executable
venv_path = os.environ.get("VIRTUAL_ENV", None)
uv_project_path = os.environ.get("UV_PROJECT_ENVIRONMENT", None)
print(f"DEBUG: spinup_single_ray_gpu_node_worker: py exec = {py_exec}", flush=True)
print(f"DEBUG: spinup_single_ray_gpu_node_worker: venv path = {venv_path}", flush=True)
print(f"DEBUG: spinup_single_ray_gpu_node_worker: uv project path = {uv_project_path}", flush=True)
worker_runtime_env = {
"py_executable": py_exec,
}
if venv_path is not None or uv_project_path is not None:
print(f"DEBUG: spinup_single_ray_gpu_node_worker: override env vars", flush=True)
worker_runtime_env["env_vars"] = {
**os.environ,
}
if runtime_env is not None:
worker_runtime_env |= runtime_env
# print(f"DEBUG: spinup_single_ray_gpu_node_worker: worker runtime env = {worker_runtime_env}", flush=True)
worker_options["runtime_env"] = worker_runtime_env
worker = worker_cls.options(**worker_options).remote()
cfg["_ray_state"]["spunup_node_ids"][node["node_id"]] = {
"worker_cls_name": f"{worker_cls}",
}
return worker
raise RuntimeError(
f"No available Ray GPU nodes for spinning up {worker_cls}"
)
101 changes: 78 additions & 23 deletions resources_servers/translation_metricx/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Optional

import datasets
import ray
import transformers
from fastapi import FastAPI
from metricx24.models import MT5ForRegression
Expand All @@ -27,6 +28,60 @@
BaseVerifyResponse,
SimpleResourcesServer,
)
from nemo_gym.ray_utils import (
spinup_single_ray_gpu_node_worker,
)


@ray.remote
class TranslationMetricxModelWorker:
def __init__(self):
self.model_name = None
self.device_map = None
self.output_dir = None
self.model = None

v = os.environ.get("UV_PROJECT_ENVIRONMENT", None)
print(f"DEBUG: TranslationMetricxModelWorker: UV_PROJECT_ENVIRONMENT = {v}", flush=True)
v = os.environ.get("VIRTUAL_ENV", None)
print(f"DEBUG: TranslationMetricxModelWorker: VIRTUAL_ENV = {v}", flush=True)
print(f"DEBUG: TranslationMetricxModelWorker: uv pip freeze...", flush=True)
os.system("uv pip freeze")
print(f"DEBUG: TranslationMetricxModelWorker: uv pip freeze: done", flush=True)

def _load_model(self, model_name, device_map, output_dir):
self.model_name = model_name
self.device_map = device_map
self.output_dir = output_dir

# Load model with device placement
model = MT5ForRegression.from_pretrained(
self.model_name, torch_dtype="auto", device_map=self.device_map
)
# Inputs should go to the device where the first layer is
# Get device from the first model parameter
self._inputs_device = next(model.parameters()).device

model.eval()
self.model = model

# Create trainer
training_args = transformers.TrainingArguments(
output_dir=output_dir,
per_device_eval_batch_size=1,
dataloader_pin_memory=False,
disable_tqdm=True,
)
trainer = transformers.Trainer(
model=model,
args=training_args,
)
self.trainer = trainer

return self._inputs_device

def predict(self, *args, **kwargs):
return self.trainer.predict(*args, **kwargs)


class TranslationMetricxResourcesServerConfig(BaseResourcesServerConfig):
Expand Down Expand Up @@ -70,35 +125,24 @@ class TranslationMetricxResourcesServer(SimpleResourcesServer):
def model_post_init(self, context: Any) -> None:
super().model_post_init(context)

v = os.environ.get("UV_PROJECT_ENVIRONMENT", None)
print(f"DEBUG: TranslationMetricxResourcesServer.model_post_init: UV_PROJECT_ENVIRONMENT = {v}", flush=True)
v = os.environ.get("VIRTUAL_ENV", None)
print(f"DEBUG: TranslationMetricxResourcesServer.model_post_init: VIRTUAL_ENV = {v}", flush=True)
print(f"DEBUG: TranslationMetricxResourcesServer.model_post_init: uv pip freeze...", flush=True)
os.system("uv pip freeze")
print(f"DEBUG: TranslationMetricxResourcesServer.model_post_init: uv pip freeze: done", flush=True)

# Load tokenizer (MetricX models use MT5 tokenizers, separate from the model name)
tokenizer = transformers.AutoTokenizer.from_pretrained(self.config.tokenizer_name)
self._tokenizer = tokenizer

# Load model with device placement
model = MT5ForRegression.from_pretrained(
self.config.metricx_model_name, torch_dtype="auto", device_map=self.config.device_map
)
# Inputs should go to the device where the first layer is
# Get device from the first model parameter
self._inputs_device = next(model.parameters()).device

model.eval()
self._metricx_model = model

# Ensure output directory exists (following predict.py lines 167-169)
os.makedirs(self.config.output_dir, exist_ok=True)

# Create trainer
training_args = transformers.TrainingArguments(
output_dir=self.config.output_dir,
per_device_eval_batch_size=1,
dataloader_pin_memory=False,
)
trainer = transformers.Trainer(
model=model,
args=training_args,
)
self._metricx_trainer = trainer
model_workers = [spinup_single_ray_gpu_node_worker(TranslationMetricxModelWorker)]
self._model_workers = model_workers
self._inputs_device = None

def setup_webserver(self) -> FastAPI:
app = super().setup_webserver()
Expand Down Expand Up @@ -133,7 +177,18 @@ def _verify_answer(
) -> tuple[float, str]:
extracted_answer = self._extract_answer(model_response)
ds = self._create_dataset_from_example(extracted_answer, source_text, target_text)
predictions, _, _ = self._metricx_trainer.predict(test_dataset=ds)
if self._inputs_device is None:
print("DEBUG: TranslationMetricxResourcesServer._verify_answer: initial load model...", flush=True)
for model_worker in self._model_workers:
# Load model with device placement
inputs_device = ray.get(model_worker._load_model.remote(
self.config.metricx_model_name,
self.config.device_map,
self.config.output_dir,
))
self._inputs_device = inputs_device
print("DEBUG: TranslationMetricxResourcesServer._verify_answer: initial load model: done", flush=True)
predictions, _, _ = ray.get(self._model_workers[0].predict.remote(test_dataset=ds))
score = float(predictions[0])

# MetricX scores are between 0 and 25, where 25 is worst, so we normalize to 0 to 1 where 0 is worst
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ translation_metricx:
resources_servers:
translation_metricx:
entrypoint: app.py
domain: translation
domain: other
# domain: translation
use_reference: true
metricx_model_name: google/metricx-24-hybrid-xl-v2p6
tokenizer_name: google/mt5-large
Expand All @@ -26,7 +27,8 @@ translation_metricx_simple_agent:
dataset_name: riva_mt_v3_nothinkInSys_train
version: 0.0.1
artifact_fpath: riva_mt_v3_nothinkInSys_train.jsonl
license: NVIDIA Internal Use Only, Do Not Distribute
license: TBD
# license: NVIDIA Internal Use Only, Do Not Distribute
- name: validation
type: validation
jsonl_fpath: resources_servers/translation_metricx/data/riva_mt_v3_nothinkInSys_validation.jsonl
Expand All @@ -35,7 +37,8 @@ translation_metricx_simple_agent:
dataset_name: riva_mt_v3_nothinkInSys_validation
version: 0.0.1
artifact_fpath: riva_mt_v3_nothinkInSys_validation.jsonl
license: NVIDIA Internal Use Only, Do Not Distribute
license: TBD
# license: NVIDIA Internal Use Only, Do Not Distribute
- name: example
type: example
jsonl_fpath: resources_servers/translation_metricx/data/example.jsonl
Expand Down
Loading