Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ vllm serve [[model_name]] --seed 42 --tensor-parallel-size 4
```bash
# Example with vllm model
python main.py --tasks mistake_location.yaml --provider completion_api --model_args base_url=http://localhost:8000/v1,model=meta-llama/Llama-3.2-3B-Instruct
# Example with an OpenAI-compatible chat/reasoning model served by vLLM
python main.py --tasks scaffolding_generation.yaml --provider completion_api --model_args base_url=http://localhost:8000/v1,model=Qwen/Qwen3-8B,is_chat=true,max_tokens=4096 --store_traces
# Example with OpenAI API
python main.py --tasks mistake_correction.yaml --provider completion_api --model_args model=gpt-4o-mini-2024-07-18,api_key=<API_KEY>
# Example with LearnLM Gemini API
Expand Down Expand Up @@ -56,6 +58,9 @@ python main.py --tasks student_solution_correctness.yaml --provider gemini --mod
- `temperature`: Temperature for sampling. Default is 0.0.
- `max_tokens`: Maximum tokens to generate. Default is 2048.
- `max_retries`: Maximum retries for the API. Default is 3.
- `--store_traces`: For chat models that emit `<think>...</think>` blocks, save extracted thinking traces to `traces-*.json`.

For chat/reasoning models, MathTutorBench strips `<think>...</think>` blocks before parsing and scoring the visible answer. In chat mode, task stop tokens are applied after this stripping step instead of being sent to the API request, because completion-style stop tokens such as blank lines can otherwise truncate the hidden reasoning before the final answer is produced. Use a larger `max_tokens` budget for reasoning models so the model can finish both the thinking trace and visible answer.

The performance of different benchmarked models averaged across tasks for Qwen2.5 family is as follows (using vllm version 0.8.0 on one node with 4x GH200 GPUs):

Expand Down
43 changes: 37 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def main():
help='Output directory for results')
parser.add_argument('--debug', action='store_true',
help='Enable debug logging')
parser.add_argument('--store_traces', action='store_true',
help='Store thinking traces extracted from <think> blocks')
args = parser.parse_args()

# Parse model arguments
Expand Down Expand Up @@ -76,18 +78,29 @@ def main():
predictions = []
targets = []
all_generations = []
all_traces = []

# Process examples
for example in tqdm(task.get_test_examples(), desc=f"Evaluating {task_config.name}"):
# Prepare messages with few-shot examples if provided
messages = []
example["shots"] = task_config.few_shot_samples
system_prompt = task.get_system_prompt(example)
# Get model response
response = model.generate(
messages=messages,
system_prompt=task.get_system_prompt(example),
stop=task_config.stop
)
if args.store_traces:
response, thinking_trace = model.generate(
messages=messages,
system_prompt=system_prompt,
stop=task_config.stop,
return_trace=True
)
else:
response = model.generate(
messages=messages,
system_prompt=system_prompt,
stop=task_config.stop
)
thinking_trace = ""
# Parse and store prediction
prediction = task.parse_response(response)
predictions.append(prediction)
Expand All @@ -107,6 +120,18 @@ def main():
}
all_generations.append(generation)

if args.store_traces and thinking_trace:
all_traces.append({
"task": task_config.name,
"problem": example.get("question", ""),
"student_solution": example.get("student_solution", ""),
"dialog_history": example.get("dialog_history", ""),
"system_prompt": system_prompt,
"thinking_trace": thinking_trace,
"visible_response": response,
"ground_truth": str(formatted_ground_truth),
})

# Compute metrics
metrics = task.compute_metrics(predictions, targets)
results[task_config.name] = metrics
Expand All @@ -125,6 +150,12 @@ def main():
with open(output_file, 'w') as f:
json.dump(all_generations, f, indent=2)

if len(all_traces) > 0:
traces_file = output_dir / f"traces-{config.model.split('/')[-1]}-{task_config.name}.json"
with open(traces_file, 'w') as f:
json.dump(all_traces, f, indent=2)
print(f"Thinking traces saved to {traces_file}")


if __name__ == "__main__":
main()
main()
83 changes: 69 additions & 14 deletions models/completion_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import jinja2
import re
import requests
from tenacity import retry, stop_after_attempt, wait_exponential
import google.generativeai as genai
Expand Down Expand Up @@ -67,6 +68,18 @@ class BaseLLMAPI(ABC):
def __init__(self, config: LLMConfig):
self.config = config

def _extract_thinking(self, response: str) -> str:
"""Extract hidden reasoning from <think> blocks, if present."""
match = re.search(r'<think>(.*?)</think>', response, flags=re.DOTALL)
if match:
return match.group(1).strip()

# Preserve partial traces when generation is cut off before </think>.
match = re.search(r'<think>(.*?)$', response, flags=re.DOTALL)
if match:
return match.group(1).strip()
return ""

def _format_conversation(self, messages: List[Dict]) -> str:
"""Format messages into a conversation string"""
# print("Formatting conversation from messages: " + "|EOM|".join(messages))
Expand Down Expand Up @@ -94,6 +107,30 @@ def _format_prompt(self, system_prompt: str, messages: List[Dict]) -> str:
# print("Final formatted prompt: " + prompt)
return prompt

def _strip_thinking(self, response: str, stop: Optional[List[str]] = None) -> str:
"""Remove <think> blocks before applying task stop tokens.

Chat reasoning models can emit a reasoning block before the visible answer.
Completion-style stop tokens such as "\\n" or "\\n\\n" can appear inside
that reasoning block, so passing them directly to chat completion APIs may
truncate the output before the answer is generated. For chat mode, callers
should request the full response, strip reasoning locally, then apply stop
tokens to the visible answer only.
"""
text = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
text = re.sub(r'<think>.*$', '', text, flags=re.DOTALL).strip()

# Some chat models echo completion-style role prefixes from prompts.
text = re.sub(r'^Teacher\s*\([^)]*\)\s*:\s*', '', text)
text = re.sub(r'^Teacher\s*:\s*', '', text).strip()

if stop:
for token in stop:
idx = text.find(token)
if idx != -1:
text = text[:idx]
return text.strip()

def _format_chat_messages(self, system_prompt: str, messages: List[Dict]) -> List[Dict]:
"""Format messages for chat completions"""
# print("Formatting chat messages with system_prompt: " + system_prompt)
Expand All @@ -110,27 +147,41 @@ def _make_completion_request(self, prompt: str, stop: Optional[List[str]] = None
pass

@abstractmethod
def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> str:
"""Make a chat request - to be implemented by specific providers"""
def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> tuple:
"""Make a chat request - to be implemented by specific providers.

Returns (visible_response, thinking_trace).
"""
pass

def generate(
self,
messages: List[Dict],
system_prompt: str,
stop: Optional[List[str]] = None,
) -> str:
"""Generate completion using either chat or completion API"""
return_trace: bool = False,
) -> Union[str, tuple]:
"""Generate completion using either chat or completion API.

When return_trace is true, return (visible_response, thinking_trace).
Non-chat providers return an empty trace.
"""
print("==============================================================")
print("Generating completion with model: " + self.config.model)

try:
if self.config.is_chat:
formatted_messages = self._format_chat_messages(system_prompt, messages)
return self._make_chat_request(formatted_messages, stop)
response, thinking_trace = self._make_chat_request(formatted_messages, stop)
if return_trace:
return response, thinking_trace
return response
else:
prompt = self._format_prompt(system_prompt, messages)
return self._make_completion_request(prompt, stop)
response = self._make_completion_request(prompt, stop)
if return_trace:
return response, ""
return response
except Exception as e:
print("Failed to generate completion after retries: " + str(e))
raise
Expand Down Expand Up @@ -185,26 +236,30 @@ def _make_completion_request(self, prompt: str, stop: Optional[List[str]] = None
wait=wait_exponential(multiplier=1, min=4, max=10),
reraise=True
)
def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> str:
def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> tuple:
try:
# print("Making chat request with messages: " + str(messages))
print("========================(Prompt-chat-start)======================================")
# print(messages)
print(messages[0]["content"])
print("========================(Prompt-chat-end)======================================")
# Do not pass task stop tokens to chat APIs. Reasoning models can emit
# <think> blocks before their answer, and stop tokens may appear inside
# the reasoning. Strip thinking and apply stops to visible text below.
response = self.client.chat.completions.create(
model=self.config.model,
messages=messages,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
stop=stop
)
completion = response.choices[0].message.content
raw_response = response.choices[0].message.content or ""
thinking_trace = self._extract_thinking(raw_response)
completion = self._strip_thinking(raw_response, stop)
print("===========================(Response-chat-start)===================================")
print(completion)
print("===========================(Response-chat-end)===================================")
# print("Received chat response: " + completion)
return completion
return completion, thinking_trace
except Exception as e:
print("Error in chat request: " + str(e))
raise
Expand Down Expand Up @@ -249,14 +304,14 @@ def _make_completion_request(self, prompt: str, stop: Optional[List[str]] = None
print("Error in completion request: " + str(e))
raise

def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> str:
def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> tuple:
"""Make a chat request to Ollama"""
# Ollama doesn't have a separate chat endpoint, so we'll format messages into a prompt
formatted_prompt = ""
for message in messages:
formatted_prompt += f"{message['role']}: {message['content']}\nassistant: "

return self._make_completion_request(formatted_prompt, stop)
return self._make_completion_request(formatted_prompt, stop), ""


class GeminiAPI(BaseLLMAPI):
Expand Down Expand Up @@ -292,7 +347,7 @@ def _make_completion_request(self, prompt: str, stop: Optional[List[str]] = None
print("Error in completion request: " + str(e))
raise

def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> str:
def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> tuple:
print("========================(Prompt-chat-start)======================================")
# print(messages)
print(messages[0]["content"])
Expand All @@ -315,4 +370,4 @@ def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = N
print("===========================(Response-chat-start)===================================")
print(completion)
print("===========================(Response-chat-end)===================================")
return completion
return completion, ""
106 changes: 106 additions & 0 deletions tests/test_thinking_model_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import unittest
import sys
import types
from types import SimpleNamespace

if "tenacity" not in sys.modules:
tenacity = types.ModuleType("tenacity")
tenacity.retry = lambda *args, **kwargs: (lambda fn: fn)
tenacity.stop_after_attempt = lambda *args, **kwargs: None
tenacity.wait_exponential = lambda *args, **kwargs: None
sys.modules["tenacity"] = tenacity

if "openai" not in sys.modules:
openai = types.ModuleType("openai")
openai.OpenAI = object
openai.api_key = None
sys.modules["openai"] = openai

if "google.generativeai" not in sys.modules:
google = types.ModuleType("google")
generativeai = types.ModuleType("google.generativeai")
generativeai.configure = lambda *args, **kwargs: None
generativeai.GenerativeModel = object
generativeai.types = SimpleNamespace(GenerationConfig=object)
sys.modules.setdefault("google", google)
sys.modules["google.generativeai"] = generativeai

from models.completion_api import BaseLLMAPI, CompletionAPI, LLMConfig


class DummyLLM(BaseLLMAPI):
def _make_completion_request(self, prompt, stop=None):
return prompt

def _make_chat_request(self, messages, stop=None):
return messages[0]["content"], ""


class FakeChatCompletions:
def __init__(self, content):
self.content = content
self.kwargs = None

def create(self, **kwargs):
self.kwargs = kwargs
return SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content=self.content))]
)


class TestThinkingModelSupport(unittest.TestCase):
def make_dummy(self):
return DummyLLM(LLMConfig(provider="completion_api", model="dummy"))

def test_extract_complete_thinking_trace(self):
model = self.make_dummy()

trace = model._extract_thinking("<think>reason\nstep</think>\nAnswer")

self.assertEqual(trace, "reason\nstep")

def test_extract_partial_thinking_trace(self):
model = self.make_dummy()

trace = model._extract_thinking("<think>unfinished reasoning")

self.assertEqual(trace, "unfinished reasoning")

def test_strip_thinking_before_applying_stop_tokens(self):
model = self.make_dummy()
raw = "<think>line one\nline two</think>\nTeacher: visible answer\nextra"

visible = model._strip_thinking(raw, stop=["\n"])

self.assertEqual(visible, "visible answer")

def test_completion_chat_does_not_send_stop_tokens_to_api(self):
api = CompletionAPI.__new__(CompletionAPI)
api.config = SimpleNamespace(model="dummy", temperature=0.0, max_tokens=128)
fake_chat = FakeChatCompletions("<think>reason\nwith stop</think>\nFinal answer\nignored")
api.client = SimpleNamespace(chat=SimpleNamespace(completions=fake_chat))

visible, trace = api._make_chat_request(
[{"role": "system", "content": "prompt"}],
stop=["\n"],
)

self.assertNotIn("stop", fake_chat.kwargs)
self.assertEqual(trace, "reason\nwith stop")
self.assertEqual(visible, "Final answer")

def test_generate_can_return_trace(self):
class TraceLLM(DummyLLM):
def _make_chat_request(self, messages, stop=None):
return "answer", "reason"

model = TraceLLM(LLMConfig(provider="completion_api", model="dummy", is_chat=True))

response, trace = model.generate([], "prompt", return_trace=True)

self.assertEqual(response, "answer")
self.assertEqual(trace, "reason")


if __name__ == "__main__":
unittest.main()