Skip to content

Commit d6eb578

Browse files
committed
abc trainer
1 parent 4fa4162 commit d6eb578

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

eval_protocol/training/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from gepa_adapter import GEPATrainer
1+
from gepa_trainer import GEPATrainer
22

33
__all__ = ["GEPATrainer"]
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
from eval_protocol.models import EPParameters, EvaluationRow
1111
from eval_protocol.pytest.types import TestFunction
12-
from eval_protocol.training.gepa_utils import REFLECTION_LM_CONFIGS
12+
from eval_protocol.training.trainer import Trainer
1313
from eval_protocol.training.utils import build_ep_parameters_from_test
1414

1515

16-
class GEPATrainer:
16+
class GEPATrainer(Trainer):
1717
"""
1818
High-level entrypoint for running GEPA-style training against an existing
1919
`@evaluation_test`-decorated function.
@@ -30,7 +30,7 @@ def __init__(self, test_fn: TestFunction) -> None:
3030
Args:
3131
test_fn: The `@evaluation_test`-decorated function defining the eval.
3232
"""
33-
self.test_fn = test_fn
33+
super().__init__(test_fn)
3434
self.ep_params: EPParameters = build_ep_parameters_from_test(test_fn)
3535

3636
self.metric = (

eval_protocol/training/trainer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from abc import ABC, abstractmethod
2+
3+
from eval_protocol.pytest.types import TestFunction
4+
5+
6+
class Trainer(ABC):
7+
def __init__(self, test_fn: TestFunction):
8+
self.test_fn = test_fn
9+
10+
@abstractmethod
11+
def train(self, *args, **kwargs): ...
12+
13+
@abstractmethod
14+
def evaluate(self, *args, **kwargs):
15+
# evaluation logic possibly can be shared since it's EP. TBD
16+
...

0 commit comments

Comments
 (0)