Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions benchtools/response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pydantic import BaseModel


class StringAnswer(BaseModel):
answer: str

class IntAnswer(BaseModel):
answer: int

class FloatAnswer(BaseModel):
answer: float

class StringJustification(BaseModel):
answer: str
justification: str

class IntJustification(BaseModel):
answer: int
justification: str
36 changes: 23 additions & 13 deletions benchtools/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import yaml # requires pyyaml
import pandas as pd
from ollama import chat, ChatResponse, Client
from benchtools.logger import init_log_folder, log_interaction
from .logger import init_log_folder, log_interaction
from pathlib import PurePath
from datasets import load_dataset
from benchtools.runner import BenchRunner
from .runner import BenchRunner
import sys
from .response import StringAnswer, StringJustification, IntAnswer, IntJustification

from benchtools.scorers import scoring_fx_list, contains, exact_match
from .scorers import scoring_fx_list, contains, exact_match

from .utils import concatenator_id_generator, selector_id_generator

Expand All @@ -23,7 +25,8 @@ class Task:

def __init__(self, task_name, template, reference=None, scoring_function=None,
variant_values = None, storage_type = 'yaml', description = None,
prompt_id_generator_fx = concatenator_id_generator):
prompt_id_generator_fx = concatenator_id_generator,
format='StringAnswer'):
"""
init a task object from a prompt and reference, and a scoring function. If no scoring function is provided, defaults to exact match.

Expand All @@ -47,11 +50,15 @@ def __init__(self, task_name, template, reference=None, scoring_function=None,
self.template = template
self.variant_values = variant_values
self.reference = reference

# set up to name individual prompts
if not callable(prompt_id_generator_fx):
prompt_id_generator_fx = prompt_id_fx[prompt_id_generator_fx]

self.prompt_id_generator = prompt_id_generator_fx

# setup for response format
mod = sys.modules[__name__]
self.FormatClass = getattr(mod,format)

self.storage_type = storage_type
if scoring_function:
Expand Down Expand Up @@ -310,17 +317,20 @@ def run(self, runner=BenchRunner(), log_dir='logs', benchmark=None, bench_path=N
print(f"Couldn't create log directory in {log_dir}...\n{e}")


for prompt_name, sub_task in self.generate_prompts():

for prompt_name, prompt in self.generate_prompts():

error = None
response = ''
try:
match runner.runner_type:
case "ollama":
completion: ChatResponse = chat(model=runner.model, messages=[
completion: ChatResponse = chat(model=runner.model,
format = self.FormatClass.model_json_schema(),
messages=[
{
'role': 'user',
'content':sub_task,
'content':prompt,
},
])
# print("response: " + response.message.content)
Expand All @@ -331,16 +341,16 @@ def run(self, runner=BenchRunner(), log_dir='logs', benchmark=None, bench_path=N
client = Client(
host=runner.api_url if runner.api_url else "http://localhost:11434",
)
completeion = client.chat(
completion = client.chat(
runner.model,
messages=[
{
"role": "user",
"content": sub_task,
"content": prompt,
},
],
)
response = completeion["message"]["content"]
response = completion["message"]["content"]
responses.append(response)

case "openai":
Expand All @@ -352,7 +362,7 @@ def run(self, runner=BenchRunner(), log_dir='logs', benchmark=None, bench_path=N
messages=[
{
"role": "user",
"content": sub_task,
"content": prompt,
}
],
)
Expand All @@ -363,7 +373,7 @@ def run(self, runner=BenchRunner(), log_dir='logs', benchmark=None, bench_path=N
return None
except Exception as e:
error = e
log_interaction(run_log, prompt_name, sub_task, response, str(error))
log_interaction(run_log, prompt_name, prompt, response, str(error))



Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ dependencies = [
"pandas",
"datasets",
"openai",
"ollama"
"ollama",
"pydantic"
]
requires-python = ">=3.10"
authors = [
Expand Down