Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ NeMo Gym includes a curated collection of resource servers for training and eval
| math | Library Judge Math | <a href='resources_servers/library_judge_math/configs/library_judge_math.yaml'>resources_servers/library_judge_math/configs/library_judge_math.yaml</a> | Creative Commons Attribution 4.0 International | Train, Validation, Example |
| math | Library Judge Math | <a href='resources_servers/library_judge_math/configs/math_stack_overflow.yaml'>resources_servers/library_judge_math/configs/math_stack_overflow.yaml</a> | Creative Commons Attribution-ShareAlike 4.0 International | Train, Validation |
| math | Python Math Exec | <a href='resources_servers/python_math_exec/configs/python_math_exec.yaml'>resources_servers/python_math_exec/configs/python_math_exec.yaml</a> | Apache 2.0 | Train, Example |
| other | Noop | <a href='resources_servers/noop/configs/noop.yaml'>resources_servers/noop/configs/noop.yaml</a> | None | Example |
<!-- END_RESOURCE_TABLE -->

> [!TIP]
> Each resource server includes example data, configuration files, and tests. See each server's README for details.
> Each resource server includes example data, configuration files, and tests. See each server's README for details.
226 changes: 226 additions & 0 deletions nemo_gym/offline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
from collections import Counter
from contextlib import nullcontext
from itertools import count, product, repeat
from typing import Optional

from pydantic import BaseModel, Field
from tqdm import tqdm
from tqdm.asyncio import tqdm as tqdm_asyncio

from nemo_gym.config_types import BaseNeMoGymCLIConfig, BaseServerConfig
from nemo_gym.server_utils import (
GlobalAIOHTTPAsyncClientConfig,
ServerClient,
get_global_config_dict,
is_global_aiohttp_client_setup,
raise_for_status,
set_global_aiohttp_client,
)


class VerifyOfflineConfig(BaseNeMoGymCLIConfig):
"""
Perform a batch of offline verification.
"""

server_name: str = Field(description="The resources server to use for verifying samples.")
input_jsonl_fpath: str = Field(
description="The input data source of samples to verify, in the form of a file path to a jsonl file."
)
output_jsonl_fpath: str = Field(description="The output data jsonl file path.")
limit: Optional[int] = Field(
default=None, description="Maximum number of samples to load and take from the input dataset."
)
num_repeats: Optional[int] = Field(
default=None,
description="The number of times to repeat each sample to run. Useful if you want to calculate mean@k e.g. mean@4 or mean@16.",
)
num_samples_in_parallel: Optional[int] = Field(
default=None, description="Limit the number of concurrent samples running at once."
)
tqdm_miniters: Optional[int] = Field(
default=None,
description="tqdm miniters.",
)
use_rollout_cache: Optional[bool] = Field(
default=None,
description="Use rollout cache keys.",
)
enable_cache: Optional[bool] = Field(
default=None,
description="Enable caching for restartable offline verification.",
)


class VerifyOfflineHelper(BaseModel): # pragma: no cover
async def run_from_config(self, config: VerifyOfflineConfig):
range_iterator = count()
if config.limit and not config.use_rollout_cache:
range_iterator = range(config.limit)
print(f"Limiting the number of rows to {config.limit}!")

def _load_row(row: tuple) -> tuple:
(row_idx, row), rep_idx = row
item = json.loads(row)
if config.use_rollout_cache:
assert "_rollout_cache_key" in item
item_cache_key = item["_rollout_cache_key"]
row_idx = item_cache_key["row_idx"]
rep_idx = item_cache_key["rep_idx"]
return row_idx, rep_idx, item

def _postfilter_row(row: tuple) -> bool:
row_idx, rep_idx, item = row
row_cond = True
if config.limit:
row_cond = row_idx < config.limit
if not row_cond:
return False
rep_cond = True
if config.num_repeats:
rep_cond = rep_idx < config.num_repeats
else:
rep_cond = rep_idx == 0
if not rep_cond:
return False
return True

print("Reading input dataset rows...", flush=True)
with open(config.input_jsonl_fpath) as input_dataset:
if config.num_repeats and not config.use_rollout_cache:
repeat_iterator = range(config.num_repeats)
else:
repeat_iterator = repeat(0, 1)
rows = [
row
for row in map(
_load_row,
product(zip(range_iterator, input_dataset), repeat_iterator),
)
if _postfilter_row(row)
]
if config.num_repeats:
print(f"Found {len(rows) // config.num_repeats} rows!")
print(f"Including {config.num_repeats} repeats per original row, found {len(rows)} total repeated rows!")
print("(Repeating rows in an interleaved pattern from abc to aabbcc)")
else:
print(f"Found {len(rows)} rows!")

semaphore = nullcontext()
if config.num_samples_in_parallel:
semaphore = asyncio.Semaphore(config.num_samples_in_parallel)

server_client = self.setup_server_client()

tqdm_miniters = config.tqdm_miniters
if tqdm_miniters is None:
tqdm_miniters = 10
if tqdm_miniters >= len(rows):
tqdm_miniters = 1
if tqdm_miniters > 1:
print(
f"The tqdm progress bar will only update every {tqdm_miniters} samples that finish to ensure that you are not being spammed."
)

cache_key_set = set()

if config.enable_cache:
print("Reading cached verifications...", flush=True)
try:
with open(config.output_jsonl_fpath, "r") as f:
for line in tqdm(f, total=len(rows), miniters=tqdm_miniters):
item = json.loads(line)
assert "_verify_cache_key" in item
item_cache_key = item["_verify_cache_key"]
row_idx = item_cache_key["row_idx"]
rep_idx = item_cache_key["rep_idx"]
cache_key_set.add((row_idx, rep_idx))
except OSError:
pass
print(f"Found {len(cache_key_set)} cached verifications.", flush=True)

print("Starting offline verification...", flush=True)

metrics = Counter()
write_lock = asyncio.Lock()
write_file = open(config.output_jsonl_fpath, "a")

def _filter_row(row: tuple) -> bool:
row_idx, rep_idx, row = row
if config.enable_cache:
if (row_idx, rep_idx) in cache_key_set:
return False
return True

async def _post_coroutine(row: tuple) -> None:
row_idx, rep_idx, request = row
async with semaphore:
response = await server_client.post(server_name=config.server_name, url_path="/verify", json=request)
if config.enable_cache:
try:
await raise_for_status(response)
except Exception as e:
print(f"HTTP error during rollout (row={row_idx} rep={rep_idx}): {e}", flush=True)
return
else:
await raise_for_status(response)
result = await response.json()
result.pop("reward", None)
if config.enable_cache:
assert "_verify_cache_key" not in result
result["_verify_cache_key"] = {
"row_idx": row_idx,
"rep_idx": rep_idx,
}
async with write_lock:
print(json.dumps(result), file=write_file, flush=True)
metrics.update({k: v for k, v in result.items() if isinstance(v, (int, float))})

await tqdm_asyncio.gather(
*map(_post_coroutine, filter(_filter_row, rows)),
desc="Verifying",
miniters=tqdm_miniters,
)

write_file.flush()
write_file.close()

print("Done offline verification.", flush=True)

avg_metrics = {k: v / len(rows) for k, v in metrics.items()}

if avg_metrics:
print(f"Metrics (sample mean): {json.dumps(avg_metrics, indent=4)}", flush=True)

def setup_server_client(self, head_server_config: Optional[BaseServerConfig] = None) -> ServerClient:
server_client = ServerClient.load_from_global_config(head_server_config)

# We set this rollout global aiohttp client to use the same max connections as the underlying head server global config.
if not is_global_aiohttp_client_setup():
set_global_aiohttp_client(
cfg=GlobalAIOHTTPAsyncClientConfig.model_validate(server_client.global_config_dict)
)

return server_client


def verify_offline(): # pragma: no cover
config = VerifyOfflineConfig.model_validate(get_global_config_dict())
helper = VerifyOfflineHelper()

asyncio.run(helper.run_from_config(config))
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ ng_init_resources_server = "nemo_gym.cli:init_resources_server"
# Rollout collection
nemo_gym_collect_rollouts = "nemo_gym.rollout_collection:collect_rollouts"
ng_collect_rollouts = "nemo_gym.rollout_collection:collect_rollouts"
ng_verify_offline = "nemo_gym.offline:verify_offline"

# Dataset management
nemo_gym_upload_dataset_to_gitlab = "nemo_gym.gitlab_utils:upload_jsonl_dataset_cli"
Expand Down
28 changes: 28 additions & 0 deletions resources_servers/noop/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Description
This is a no-op resource server, which may be a helpful utility for collecting rollouts while deferring verification.

# Example usage

## Running servers
The following are example commands for running this resource server, along with the simple agent and a VLLM model:
```bash
config_paths="responses_api_models/vllm_model/configs/vllm_model.yaml, \
resources_servers/noop/configs/noop.yaml"
ng_run "+config_paths=[$config_paths]"
```

Then, rollouts can be collected using a command such as the following:
```bash
ng_collect_rollouts \
+agent_name=noop_simple_agent \
+input_jsonl_fpath=your_rollout_input_dataset.jsonl \
+output_jsonl_fpath=your_rollout_output_responses.jsonl \
+limit=10
```

# Licensing information
Code: Apache 2.0
Data: Apache 2.0

Dependencies
- nemo_gym: Apache 2.0
59 changes: 59 additions & 0 deletions resources_servers/noop/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastapi import FastAPI

from nemo_gym.base_resources_server import (
BaseResourcesServerConfig,
BaseRunRequest,
BaseVerifyRequest,
BaseVerifyResponse,
SimpleResourcesServer,
)


class NoopResourcesServerConfig(BaseResourcesServerConfig):
name: str = "noop"
reward: float = 1.0


class NoopRunRequest(BaseRunRequest): # pragma: no cover
pass


class NoopVerifyRequest(NoopRunRequest, BaseVerifyRequest): # pragma: no cover
pass


class NoopVerifyResponse(BaseVerifyResponse): # pragma: no cover
pass


class NoopResourcesServer(SimpleResourcesServer):
config: NoopResourcesServerConfig

def setup_webserver(self) -> FastAPI:
app = super().setup_webserver()
return app

async def verify(self, body: NoopVerifyRequest) -> NoopVerifyResponse:
payload = body.model_dump()
reward = self.config.reward
return NoopVerifyResponse(
**payload,
reward=reward,
)


if __name__ == "__main__":
NoopResourcesServer.run_webserver()
20 changes: 20 additions & 0 deletions resources_servers/noop/configs/noop.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
noop:
resources_servers:
noop:
entrypoint: app.py
domain: other
reward: 1.0
noop_simple_agent:
responses_api_agents:
simple_agent:
entrypoint: app.py
resources_server:
type: resources_servers
name: noop
model_server:
type: responses_api_models
name: policy_model
datasets:
- name: example
type: example
jsonl_fpath: resources_servers/noop/data/example.jsonl
5 changes: 5 additions & 0 deletions resources_servers/noop/data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
*train.jsonl
*validation.jsonl
*train_prepare.jsonl
*validation_prepare.jsonl
*example_prepare.jsonl
5 changes: 5 additions & 0 deletions resources_servers/noop/data/example.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{"responses_create_params": {"input": [{"content": "You are a helpful personal assistant that aims to be helpful and reduce any pain points the user has.", "role": "developer"}, {"content": "what's it like in sf?", "role": "user"}], "tools": [{"name": "get_weather", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": ""}}, "required": ["city"], "additionalProperties": false}, "strict": true, "type": "function", "description": ""}]}}
{"responses_create_params": {"input": [{"content": "You are a helpful personal assistant that aims to be helpful and reduce any pain points the user has.", "role": "developer"}, {"content": "going out in sf tn", "role": "user"}], "tools": [{"name": "get_weather", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": ""}}, "required": ["city"], "additionalProperties": false}, "strict": true, "type": "function", "description": ""}]}}
{"responses_create_params": {"input": [{"content": "You are a helpful personal assistant that aims to be helpful and reduce any pain points the user has.", "role": "developer"}, {"content": "humidity in sf", "role": "user"}], "tools": [{"name": "get_weather", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": ""}}, "required": ["city"], "additionalProperties": false}, "strict": true, "type": "function", "description": ""}]}}
{"responses_create_params": {"input": [{"content": "You are a helpful personal assistant that aims to be helpful and reduce any pain points the user has.", "role": "developer"}, {"content": "how's the outside?", "role": "user"}], "tools": [{"name": "get_weather", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": ""}}, "required": ["city"], "additionalProperties": false}, "strict": true, "type": "function", "description": ""}]}}
{"responses_create_params": {"input": [{"content": "You are a helpful personal assistant that aims to be helpful and reduce any pain points the user has.", "role": "developer"}, {"content": "get the weather for 3 cities in the us", "role": "user"}], "tools": [{"name": "get_weather", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": ""}}, "required": ["city"], "additionalProperties": false}, "strict": true, "type": "function", "description": ""}]}}
Loading