diff --git a/README.md b/README.md index c8125b5..ddbfa29 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,8 @@ vllm serve [[model_name]] --seed 42 --tensor-parallel-size 4 ```bash # Example with vllm model python main.py --tasks mistake_location.yaml --provider completion_api --model_args base_url=http://localhost:8000/v1,model=meta-llama/Llama-3.2-3B-Instruct +# Example with an OpenAI-compatible chat/reasoning model served by vLLM +python main.py --tasks scaffolding_generation.yaml --provider completion_api --model_args base_url=http://localhost:8000/v1,model=Qwen/Qwen3-8B,is_chat=true,max_tokens=4096 --store_traces # Example with OpenAI API python main.py --tasks mistake_correction.yaml --provider completion_api --model_args model=gpt-4o-mini-2024-07-18,api_key= # Example with LearnLM Gemini API @@ -56,6 +58,9 @@ python main.py --tasks student_solution_correctness.yaml --provider gemini --mod - `temperature`: Temperature for sampling. Default is 0.0. - `max_tokens`: Maximum tokens to generate. Default is 2048. - `max_retries`: Maximum retries for the API. Default is 3. + - `--store_traces`: For chat models that emit `...` blocks, save extracted thinking traces to `traces-*.json`. + +For chat/reasoning models, MathTutorBench strips `...` blocks before parsing and scoring the visible answer. In chat mode, task stop tokens are applied after this stripping step instead of being sent to the API request, because completion-style stop tokens such as blank lines can otherwise truncate the hidden reasoning before the final answer is produced. Use a larger `max_tokens` budget for reasoning models so the model can finish both the thinking trace and visible answer. The performance of different benchmarked models averaged across tasks for Qwen2.5 family is as follows (using vllm version 0.8.0 on one node with 4x GH200 GPUs): diff --git a/main.py b/main.py index 3c9b380..110da7e 100644 --- a/main.py +++ b/main.py @@ -49,6 +49,8 @@ def main(): help='Output directory for results') parser.add_argument('--debug', action='store_true', help='Enable debug logging') + parser.add_argument('--store_traces', action='store_true', + help='Store thinking traces extracted from blocks') args = parser.parse_args() # Parse model arguments @@ -76,18 +78,29 @@ def main(): predictions = [] targets = [] all_generations = [] + all_traces = [] # Process examples for example in tqdm(task.get_test_examples(), desc=f"Evaluating {task_config.name}"): # Prepare messages with few-shot examples if provided messages = [] example["shots"] = task_config.few_shot_samples + system_prompt = task.get_system_prompt(example) # Get model response - response = model.generate( - messages=messages, - system_prompt=task.get_system_prompt(example), - stop=task_config.stop - ) + if args.store_traces: + response, thinking_trace = model.generate( + messages=messages, + system_prompt=system_prompt, + stop=task_config.stop, + return_trace=True + ) + else: + response = model.generate( + messages=messages, + system_prompt=system_prompt, + stop=task_config.stop + ) + thinking_trace = "" # Parse and store prediction prediction = task.parse_response(response) predictions.append(prediction) @@ -107,6 +120,18 @@ def main(): } all_generations.append(generation) + if args.store_traces and thinking_trace: + all_traces.append({ + "task": task_config.name, + "problem": example.get("question", ""), + "student_solution": example.get("student_solution", ""), + "dialog_history": example.get("dialog_history", ""), + "system_prompt": system_prompt, + "thinking_trace": thinking_trace, + "visible_response": response, + "ground_truth": str(formatted_ground_truth), + }) + # Compute metrics metrics = task.compute_metrics(predictions, targets) results[task_config.name] = metrics @@ -125,6 +150,12 @@ def main(): with open(output_file, 'w') as f: json.dump(all_generations, f, indent=2) + if len(all_traces) > 0: + traces_file = output_dir / f"traces-{config.model.split('/')[-1]}-{task_config.name}.json" + with open(traces_file, 'w') as f: + json.dump(all_traces, f, indent=2) + print(f"Thinking traces saved to {traces_file}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/models/completion_api.py b/models/completion_api.py index c7f4e21..62f9c42 100644 --- a/models/completion_api.py +++ b/models/completion_api.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Union import jinja2 +import re import requests from tenacity import retry, stop_after_attempt, wait_exponential import google.generativeai as genai @@ -67,6 +68,18 @@ class BaseLLMAPI(ABC): def __init__(self, config: LLMConfig): self.config = config + def _extract_thinking(self, response: str) -> str: + """Extract hidden reasoning from blocks, if present.""" + match = re.search(r'(.*?)', response, flags=re.DOTALL) + if match: + return match.group(1).strip() + + # Preserve partial traces when generation is cut off before . + match = re.search(r'(.*?)$', response, flags=re.DOTALL) + if match: + return match.group(1).strip() + return "" + def _format_conversation(self, messages: List[Dict]) -> str: """Format messages into a conversation string""" # print("Formatting conversation from messages: " + "|EOM|".join(messages)) @@ -94,6 +107,30 @@ def _format_prompt(self, system_prompt: str, messages: List[Dict]) -> str: # print("Final formatted prompt: " + prompt) return prompt + def _strip_thinking(self, response: str, stop: Optional[List[str]] = None) -> str: + """Remove blocks before applying task stop tokens. + + Chat reasoning models can emit a reasoning block before the visible answer. + Completion-style stop tokens such as "\\n" or "\\n\\n" can appear inside + that reasoning block, so passing them directly to chat completion APIs may + truncate the output before the answer is generated. For chat mode, callers + should request the full response, strip reasoning locally, then apply stop + tokens to the visible answer only. + """ + text = re.sub(r'.*?', '', response, flags=re.DOTALL) + text = re.sub(r'.*$', '', text, flags=re.DOTALL).strip() + + # Some chat models echo completion-style role prefixes from prompts. + text = re.sub(r'^Teacher\s*\([^)]*\)\s*:\s*', '', text) + text = re.sub(r'^Teacher\s*:\s*', '', text).strip() + + if stop: + for token in stop: + idx = text.find(token) + if idx != -1: + text = text[:idx] + return text.strip() + def _format_chat_messages(self, system_prompt: str, messages: List[Dict]) -> List[Dict]: """Format messages for chat completions""" # print("Formatting chat messages with system_prompt: " + system_prompt) @@ -110,8 +147,11 @@ def _make_completion_request(self, prompt: str, stop: Optional[List[str]] = None pass @abstractmethod - def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> str: - """Make a chat request - to be implemented by specific providers""" + def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> tuple: + """Make a chat request - to be implemented by specific providers. + + Returns (visible_response, thinking_trace). + """ pass def generate( @@ -119,18 +159,29 @@ def generate( messages: List[Dict], system_prompt: str, stop: Optional[List[str]] = None, - ) -> str: - """Generate completion using either chat or completion API""" + return_trace: bool = False, + ) -> Union[str, tuple]: + """Generate completion using either chat or completion API. + + When return_trace is true, return (visible_response, thinking_trace). + Non-chat providers return an empty trace. + """ print("==============================================================") print("Generating completion with model: " + self.config.model) try: if self.config.is_chat: formatted_messages = self._format_chat_messages(system_prompt, messages) - return self._make_chat_request(formatted_messages, stop) + response, thinking_trace = self._make_chat_request(formatted_messages, stop) + if return_trace: + return response, thinking_trace + return response else: prompt = self._format_prompt(system_prompt, messages) - return self._make_completion_request(prompt, stop) + response = self._make_completion_request(prompt, stop) + if return_trace: + return response, "" + return response except Exception as e: print("Failed to generate completion after retries: " + str(e)) raise @@ -185,26 +236,30 @@ def _make_completion_request(self, prompt: str, stop: Optional[List[str]] = None wait=wait_exponential(multiplier=1, min=4, max=10), reraise=True ) - def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> str: + def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> tuple: try: # print("Making chat request with messages: " + str(messages)) print("========================(Prompt-chat-start)======================================") # print(messages) print(messages[0]["content"]) print("========================(Prompt-chat-end)======================================") + # Do not pass task stop tokens to chat APIs. Reasoning models can emit + # blocks before their answer, and stop tokens may appear inside + # the reasoning. Strip thinking and apply stops to visible text below. response = self.client.chat.completions.create( model=self.config.model, messages=messages, temperature=self.config.temperature, max_tokens=self.config.max_tokens, - stop=stop ) - completion = response.choices[0].message.content + raw_response = response.choices[0].message.content or "" + thinking_trace = self._extract_thinking(raw_response) + completion = self._strip_thinking(raw_response, stop) print("===========================(Response-chat-start)===================================") print(completion) print("===========================(Response-chat-end)===================================") # print("Received chat response: " + completion) - return completion + return completion, thinking_trace except Exception as e: print("Error in chat request: " + str(e)) raise @@ -249,14 +304,14 @@ def _make_completion_request(self, prompt: str, stop: Optional[List[str]] = None print("Error in completion request: " + str(e)) raise - def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> str: + def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> tuple: """Make a chat request to Ollama""" # Ollama doesn't have a separate chat endpoint, so we'll format messages into a prompt formatted_prompt = "" for message in messages: formatted_prompt += f"{message['role']}: {message['content']}\nassistant: " - return self._make_completion_request(formatted_prompt, stop) + return self._make_completion_request(formatted_prompt, stop), "" class GeminiAPI(BaseLLMAPI): @@ -292,7 +347,7 @@ def _make_completion_request(self, prompt: str, stop: Optional[List[str]] = None print("Error in completion request: " + str(e)) raise - def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> str: + def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = None) -> tuple: print("========================(Prompt-chat-start)======================================") # print(messages) print(messages[0]["content"]) @@ -315,4 +370,4 @@ def _make_chat_request(self, messages: List[Dict], stop: Optional[List[str]] = N print("===========================(Response-chat-start)===================================") print(completion) print("===========================(Response-chat-end)===================================") - return completion + return completion, "" diff --git a/tests/test_thinking_model_support.py b/tests/test_thinking_model_support.py new file mode 100644 index 0000000..49f3ce2 --- /dev/null +++ b/tests/test_thinking_model_support.py @@ -0,0 +1,106 @@ +import unittest +import sys +import types +from types import SimpleNamespace + +if "tenacity" not in sys.modules: + tenacity = types.ModuleType("tenacity") + tenacity.retry = lambda *args, **kwargs: (lambda fn: fn) + tenacity.stop_after_attempt = lambda *args, **kwargs: None + tenacity.wait_exponential = lambda *args, **kwargs: None + sys.modules["tenacity"] = tenacity + +if "openai" not in sys.modules: + openai = types.ModuleType("openai") + openai.OpenAI = object + openai.api_key = None + sys.modules["openai"] = openai + +if "google.generativeai" not in sys.modules: + google = types.ModuleType("google") + generativeai = types.ModuleType("google.generativeai") + generativeai.configure = lambda *args, **kwargs: None + generativeai.GenerativeModel = object + generativeai.types = SimpleNamespace(GenerationConfig=object) + sys.modules.setdefault("google", google) + sys.modules["google.generativeai"] = generativeai + +from models.completion_api import BaseLLMAPI, CompletionAPI, LLMConfig + + +class DummyLLM(BaseLLMAPI): + def _make_completion_request(self, prompt, stop=None): + return prompt + + def _make_chat_request(self, messages, stop=None): + return messages[0]["content"], "" + + +class FakeChatCompletions: + def __init__(self, content): + self.content = content + self.kwargs = None + + def create(self, **kwargs): + self.kwargs = kwargs + return SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content=self.content))] + ) + + +class TestThinkingModelSupport(unittest.TestCase): + def make_dummy(self): + return DummyLLM(LLMConfig(provider="completion_api", model="dummy")) + + def test_extract_complete_thinking_trace(self): + model = self.make_dummy() + + trace = model._extract_thinking("reason\nstep\nAnswer") + + self.assertEqual(trace, "reason\nstep") + + def test_extract_partial_thinking_trace(self): + model = self.make_dummy() + + trace = model._extract_thinking("unfinished reasoning") + + self.assertEqual(trace, "unfinished reasoning") + + def test_strip_thinking_before_applying_stop_tokens(self): + model = self.make_dummy() + raw = "line one\nline two\nTeacher: visible answer\nextra" + + visible = model._strip_thinking(raw, stop=["\n"]) + + self.assertEqual(visible, "visible answer") + + def test_completion_chat_does_not_send_stop_tokens_to_api(self): + api = CompletionAPI.__new__(CompletionAPI) + api.config = SimpleNamespace(model="dummy", temperature=0.0, max_tokens=128) + fake_chat = FakeChatCompletions("reason\nwith stop\nFinal answer\nignored") + api.client = SimpleNamespace(chat=SimpleNamespace(completions=fake_chat)) + + visible, trace = api._make_chat_request( + [{"role": "system", "content": "prompt"}], + stop=["\n"], + ) + + self.assertNotIn("stop", fake_chat.kwargs) + self.assertEqual(trace, "reason\nwith stop") + self.assertEqual(visible, "Final answer") + + def test_generate_can_return_trace(self): + class TraceLLM(DummyLLM): + def _make_chat_request(self, messages, stop=None): + return "answer", "reason" + + model = TraceLLM(LLMConfig(provider="completion_api", model="dummy", is_chat=True)) + + response, trace = model.generate([], "prompt", return_trace=True) + + self.assertEqual(response, "answer") + self.assertEqual(trace, "reason") + + +if __name__ == "__main__": + unittest.main()