Skip to content

Commit f04c673

Browse files
authored
feat: Improved metrics in ART (#609)
1 parent b6c4a38 commit f04c673

30 files changed

+4522
-289
lines changed

dev/yes-no-maybe-metrics.py

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
"""Yes-no-maybe metrics demo for the LocalBackend `model.train()` path.
2+
3+
This keeps the same prompt family, rollout structure, and reward ordering as
4+
`dev/yes-no-maybe.py` while adding explicit metrics taxonomy instrumentation for
5+
actor/eval timing and data metrics, while relying on LocalBackend for automatic
6+
step wall time and GPU cost logging.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import asyncio
12+
from itertools import permutations
13+
import os
14+
import time
15+
16+
from dotenv import load_dotenv
17+
import openai
18+
19+
try:
20+
import unsloth # noqa: F401
21+
except ImportError:
22+
pass
23+
24+
import art
25+
from art.local import LocalBackend
26+
27+
28+
async def create_chat_completion(
29+
client: openai.AsyncOpenAI,
30+
*,
31+
model_name: str,
32+
messages: art.Messages,
33+
max_tokens: int,
34+
timeout: float,
35+
) -> openai.types.chat.chat_completion.ChatCompletion:
36+
return await client.chat.completions.create(
37+
messages=messages,
38+
model=model_name,
39+
max_tokens=max_tokens,
40+
timeout=timeout,
41+
)
42+
43+
44+
def with_quotes(word: str) -> str:
45+
return f"'{word}'"
46+
47+
48+
def build_prompts() -> list[str]:
49+
return [
50+
f"{prefix} with {', '.join([with_quotes(word) if use_quotes else word for word in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}"
51+
for prefix in ["respond", "just respond"]
52+
for use_quotes in [True, False]
53+
for words in (
54+
list(permutation)
55+
for length in [3, 2]
56+
for permutation in permutations(["yes", "no", "maybe"], length)
57+
)
58+
]
59+
60+
61+
def reward_for_answer(content: str | None) -> float:
62+
if content == "yes":
63+
return 0.5
64+
if content == "no":
65+
return 0.75
66+
if content == "maybe":
67+
return 1.0
68+
return 0.0
69+
70+
71+
def scenario_id_for_prompt(prompt: str) -> str:
72+
return prompt.replace(" ", "_").replace("'", "")
73+
74+
75+
def response_total_tokens(
76+
response: openai.types.chat.chat_completion.ChatCompletion,
77+
) -> int:
78+
usage = response.usage
79+
if usage is None:
80+
return 0
81+
prompt_tokens = int(usage.prompt_tokens or 0)
82+
completion_tokens = int(usage.completion_tokens or 0)
83+
return prompt_tokens + completion_tokens
84+
85+
86+
def total_actor_tokens(groups: list[art.TrajectoryGroup]) -> int:
87+
return sum(
88+
int(trajectory.metadata.get("actor_total_tokens", 0) or 0)
89+
for group in groups
90+
for trajectory in group.trajectories
91+
)
92+
93+
94+
async def rollout(
95+
client: openai.AsyncOpenAI,
96+
model: art.TrainableModel,
97+
prompt: str,
98+
*,
99+
max_tokens: int,
100+
timeout: float,
101+
) -> art.Trajectory:
102+
messages: art.Messages = [{"role": "user", "content": prompt}]
103+
chat_completion = await create_chat_completion(
104+
client,
105+
model_name=model.get_inference_name(),
106+
messages=messages,
107+
max_tokens=max_tokens,
108+
timeout=timeout,
109+
)
110+
choice = chat_completion.choices[0]
111+
content = choice.message.content
112+
return art.Trajectory(
113+
messages_and_choices=[*messages, choice],
114+
reward=reward_for_answer(content),
115+
metadata={
116+
"scenario_id": scenario_id_for_prompt(prompt),
117+
"actor_total_tokens": response_total_tokens(chat_completion),
118+
},
119+
metrics={
120+
"valid_answer": reward_for_answer(content) > 0.0,
121+
},
122+
)
123+
124+
125+
async def evaluate(
126+
client: openai.AsyncOpenAI,
127+
model: art.TrainableModel,
128+
prompts: list[str],
129+
*,
130+
max_tokens: int,
131+
timeout: float,
132+
) -> list[art.TrajectoryGroup]:
133+
groups = await art.gather_trajectory_groups(
134+
art.TrajectoryGroup(
135+
[
136+
rollout(
137+
client,
138+
model,
139+
prompt,
140+
max_tokens=max_tokens,
141+
timeout=timeout,
142+
)
143+
],
144+
metadata={"scenario_id": scenario_id_for_prompt(prompt)},
145+
)
146+
for prompt in prompts
147+
)
148+
return groups
149+
150+
151+
def print_history_summary(model: art.TrainableModel) -> None:
152+
history_path = (
153+
model.base_path + f"/{model.project}/models/{model.name}/history.jsonl"
154+
)
155+
print(f"History: {history_path}")
156+
157+
158+
def build_internal_config() -> art.dev.InternalModelConfig:
159+
return art.dev.InternalModelConfig(
160+
engine_args=art.dev.EngineArgs(
161+
gpu_memory_utilization=float(
162+
os.environ.get("GPU_MEMORY_UTILIZATION", "0.85")
163+
),
164+
max_model_len=int(os.environ.get("MAX_MODEL_LEN", "4096")),
165+
)
166+
)
167+
168+
169+
async def main() -> None:
170+
load_dotenv()
171+
172+
backend = LocalBackend()
173+
base_model = os.environ.get("BASE_MODEL", "Qwen/Qwen3-30B-A3B-Instruct-2507")
174+
project = os.environ.get("PROJECT", "yes-no-maybe-metrics")
175+
model = art.TrainableModel(
176+
name=os.environ.get("MODEL_NAME", f"yes-no-maybe-metrics-{int(time.time())}"),
177+
project=project,
178+
base_model=base_model,
179+
report_metrics=["wandb"],
180+
_internal_config=build_internal_config(),
181+
)
182+
try:
183+
await model.register(backend)
184+
185+
prompts = build_prompts()
186+
eval_prompts = prompts[: int(os.environ.get("EVAL_PROMPTS", "12"))]
187+
openai_client = model.openai_client()
188+
max_steps = int(os.environ.get("NUM_STEPS", "20"))
189+
rollouts_per_prompt = int(os.environ.get("ROLLOUTS_PER_PROMPT", "32"))
190+
max_tokens = int(os.environ.get("MAX_TOKENS", "100"))
191+
timeout = float(os.environ.get("TIMEOUT", "100"))
192+
eval_every_n_steps = int(os.environ.get("EVAL_EVERY_N_STEPS", "1"))
193+
learning_rate = float(os.environ.get("LEARNING_RATE", "1e-4"))
194+
195+
start_step = await model.get_step()
196+
for offset in range(max_steps):
197+
current_step = start_step + offset
198+
199+
if (
200+
eval_every_n_steps > 0
201+
and (current_step - start_step) % eval_every_n_steps == 0
202+
):
203+
eval_builder = model.metrics_builder("eval")
204+
with eval_builder.activate_context():
205+
with eval_builder.measure("time/step_eval_s"):
206+
val_groups = await evaluate(
207+
openai_client,
208+
model,
209+
eval_prompts,
210+
max_tokens=max_tokens,
211+
timeout=timeout,
212+
)
213+
eval_builder.add_data(
214+
step_actor_tokens=total_actor_tokens(val_groups)
215+
)
216+
await model.log(val_groups, split="val", step=current_step)
217+
218+
train_builder = model.metrics_builder("train")
219+
with train_builder.activate_context():
220+
with train_builder.measure("time/step_actor_s"):
221+
train_groups = await art.gather_trajectory_groups(
222+
(
223+
art.TrajectoryGroup(
224+
rollout(
225+
openai_client,
226+
model,
227+
prompt,
228+
max_tokens=max_tokens,
229+
timeout=timeout,
230+
)
231+
for _ in range(rollouts_per_prompt)
232+
)
233+
for prompt in prompts
234+
)
235+
)
236+
train_builder.add_data(
237+
step_actor_tokens=total_actor_tokens(train_groups)
238+
)
239+
result = await backend.train(
240+
model,
241+
train_groups,
242+
learning_rate=learning_rate,
243+
)
244+
245+
await model.log(
246+
split="train",
247+
step=result.step,
248+
trajectories=train_groups,
249+
metrics=result.metrics,
250+
)
251+
print(f"step {result.step} complete")
252+
253+
print_history_summary(model)
254+
finally:
255+
await backend.close()
256+
257+
258+
if __name__ == "__main__":
259+
asyncio.run(main())

docs/docs.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
"features/checkpoint-forking",
6868
"features/checkpoint-deletion",
6969
"features/additional-histories",
70+
"features/tracking-metrics",
7071
"features/mcp-rl"
7172
]
7273
},
@@ -106,4 +107,4 @@
106107
"bluesky": "https://bsky.app/profile/openpipe.bsky.social",
107108
"github": "https://github.com/openpipe/ART"
108109
}
109-
}
110+
}

0 commit comments

Comments
 (0)