Skip to content

Commit c69f1f4

Browse files
committed
fixes
1 parent 36f0095 commit c69f1f4

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

examples/tinker_math_rl/debug_dataset.py

Whitespace-only changes.

examples/tinker_math_rl/test_gsm8k_eval.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def get_gsm8k_input_rows(limit: int = 10) -> List[EvaluationRow]:
2727

2828

2929
@evaluation_test(
30-
input_rows=[gsm8k_input_rows],
30+
input_rows=gsm8k_input_rows,
3131
completion_params=[
3232
{
3333
"max_tokens": 512,
@@ -48,7 +48,9 @@ def test_gsm8k_tinker(row: EvaluationRow) -> EvaluationRow:
4848
else:
4949
model_response = assistant_msgs[-1].content
5050
# The content might be a list of content parts, handle that
51-
if not isinstance(model_response, str):
51+
if model_response is None:
52+
model_response = ""
53+
elif not isinstance(model_response, str):
5254
# Simple join for now if it's a list
5355
model_response = "".join([p.text for p in model_response if hasattr(p, "text")])
5456

examples/tinker_math_rl/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import asyncio
2+
import copy
23
import logging
34
import os
45
import sys
56
from functools import partial
67
from typing import Literal, Any, Optional
78

89
import chz
10+
import datetime
911

1012
# Add tinker-cookbook to path if not installed
1113
# Assuming the directory structure:
@@ -153,7 +155,7 @@ async def cli_main(cli_config: CLIConfig):
153155
# Need to wrap in a factory as expected by Config.evaluator_builders
154156
def create_eval_protocol_evaluator():
155157
return EvalProtocolEvaluator(
156-
rows=eval_rows,
158+
rows=copy.deepcopy(eval_rows),
157159
eval_func=test_gsm8k_tinker,
158160
rollout_processor_cls=TinkerRolloutProcessor,
159161
model_name=cli_config.model_name,
@@ -203,7 +205,5 @@ def create_eval_protocol_evaluator():
203205

204206

205207
if __name__ == "__main__":
206-
from datetime import datetime
207-
208208
cli_config = chz.entrypoint(CLIConfig)
209209
asyncio.run(cli_main(cli_config))

0 commit comments

Comments
 (0)