From 3de0f3c9e57c4009eb411fb21142af98068427ca Mon Sep 17 00:00:00 2001 From: Sarah M Brown Date: Thu, 19 Feb 2026 23:27:48 -0500 Subject: [PATCH] task format for ollama --- benchtools/response.py | 19 +++++++++++++++++++ benchtools/task.py | 36 +++++++++++++++++++++++------------- pyproject.toml | 3 ++- 3 files changed, 44 insertions(+), 14 deletions(-) create mode 100644 benchtools/response.py diff --git a/benchtools/response.py b/benchtools/response.py new file mode 100644 index 0000000..da857ca --- /dev/null +++ b/benchtools/response.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel + + +class StringAnswer(BaseModel): + answer: str + +class IntAnswer(BaseModel): + answer: int + +class FloatAnswer(BaseModel): + answer: float + +class StringJustification(BaseModel): + answer: str + justification: str + +class IntJustification(BaseModel): + answer: int + justification: str \ No newline at end of file diff --git a/benchtools/task.py b/benchtools/task.py index 4988e40..83d0f90 100644 --- a/benchtools/task.py +++ b/benchtools/task.py @@ -4,12 +4,14 @@ import yaml # requires pyyaml import pandas as pd from ollama import chat, ChatResponse, Client -from benchtools.logger import init_log_folder, log_interaction +from .logger import init_log_folder, log_interaction from pathlib import PurePath from datasets import load_dataset -from benchtools.runner import BenchRunner +from .runner import BenchRunner +import sys +from .response import StringAnswer, StringJustification, IntAnswer, IntJustification -from benchtools.scorers import scoring_fx_list, contains, exact_match +from .scorers import scoring_fx_list, contains, exact_match from .utils import concatenator_id_generator, selector_id_generator @@ -23,7 +25,8 @@ class Task: def __init__(self, task_name, template, reference=None, scoring_function=None, variant_values = None, storage_type = 'yaml', description = None, - prompt_id_generator_fx = concatenator_id_generator): + prompt_id_generator_fx = concatenator_id_generator, + format='StringAnswer'): """ init a task object from a prompt and reference, and a scoring function. If no scoring function is provided, defaults to exact match. @@ -47,11 +50,15 @@ def __init__(self, task_name, template, reference=None, scoring_function=None, self.template = template self.variant_values = variant_values self.reference = reference + + # set up to name individual prompts if not callable(prompt_id_generator_fx): prompt_id_generator_fx = prompt_id_fx[prompt_id_generator_fx] - self.prompt_id_generator = prompt_id_generator_fx + # setup for response format + mod = sys.modules[__name__] + self.FormatClass = getattr(mod,format) self.storage_type = storage_type if scoring_function: @@ -310,17 +317,20 @@ def run(self, runner=BenchRunner(), log_dir='logs', benchmark=None, bench_path=N print(f"Couldn't create log directory in {log_dir}...\n{e}") - for prompt_name, sub_task in self.generate_prompts(): + + for prompt_name, prompt in self.generate_prompts(): error = None response = '' try: match runner.runner_type: case "ollama": - completion: ChatResponse = chat(model=runner.model, messages=[ + completion: ChatResponse = chat(model=runner.model, + format = self.FormatClass.model_json_schema(), + messages=[ { 'role': 'user', - 'content':sub_task, + 'content':prompt, }, ]) # print("response: " + response.message.content) @@ -331,16 +341,16 @@ def run(self, runner=BenchRunner(), log_dir='logs', benchmark=None, bench_path=N client = Client( host=runner.api_url if runner.api_url else "http://localhost:11434", ) - completeion = client.chat( + completion = client.chat( runner.model, messages=[ { "role": "user", - "content": sub_task, + "content": prompt, }, ], ) - response = completeion["message"]["content"] + response = completion["message"]["content"] responses.append(response) case "openai": @@ -352,7 +362,7 @@ def run(self, runner=BenchRunner(), log_dir='logs', benchmark=None, bench_path=N messages=[ { "role": "user", - "content": sub_task, + "content": prompt, } ], ) @@ -363,7 +373,7 @@ def run(self, runner=BenchRunner(), log_dir='logs', benchmark=None, bench_path=N return None except Exception as e: error = e - log_interaction(run_log, prompt_name, sub_task, response, str(error)) + log_interaction(run_log, prompt_name, prompt, response, str(error)) diff --git a/pyproject.toml b/pyproject.toml index 6d1d33e..96cb7bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,8 @@ dependencies = [ "pandas", "datasets", "openai", - "ollama" + "ollama", + "pydantic" ] requires-python = ">=3.10" authors = [