Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ccflow/evaluators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .common import *
from .reporting import *
from .retry import *
128 changes: 24 additions & 104 deletions ccflow/evaluators/common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import itertools
import logging
import time
from contextlib import nullcontext
from datetime import timedelta
from pprint import pformat
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Callable, Dict, List, Optional, Set, Union

from pydantic import Field, PrivateAttr, field_validator
from pydantic import Field, PrivateAttr
from typing_extensions import override

from ..base import BaseModel, make_lazy_result
Expand All @@ -19,11 +14,14 @@
ResultType,
TransparentModelEvaluationContext,
)
from ..utils.reporting import FormatConfig, LoggingPolicy
from .reporting import ReportingEvaluator, _descriptor

__all__ = [
"cache_key",
"combine_evaluators",
"FallbackEvaluator",
"FormatConfig",
"LazyEvaluator",
"LoggingEvaluator",
"MemoryCacheEvaluator",
Expand Down Expand Up @@ -124,106 +122,23 @@ def make_result():
return make_lazy_result(context.model.result_type, make_result)


class FormatConfig(BaseModel):
"""Configuration for formatting the result of the evaluation.

This is used by the LoggingEvaluator to control how the result is formatted.
"""

arrow_as_polars: bool = Field(
False,
description="Whether to convert pyarrow tables to polars tables for formatting, as arrow formatting does not work well with large tables or provide control over options",
)
pformat_config: Dict[str, Any] = Field({}, description="pformat config to use for formatting data")
polars_config: Dict[str, Any] = Field({}, description="polars config to use for formatting polars frames")
pandas_config: Dict[str, Any] = Field({}, description="pandas config to use for formatting pandas objects")


class LoggingEvaluator(EvaluatorBase):
class LoggingEvaluator(ReportingEvaluator, LoggingPolicy):
"""Evaluator that logs information about evaluating the callable.

It logs start and end times, the model name, and the context."""

log_level: int = Field(logging.DEBUG, description="The log level for start/end of evaluation")
verbose: bool = Field(True, description="Whether to output the model definition as part of logging")
log_result: bool = Field(False, description="Whether to log the result of the evaluation")
format_config: FormatConfig = Field(FormatConfig(), description="Configuration for formatting the result of the evaluation if log_result=True")

def is_transparent(self, context: ModelEvaluationContext) -> bool:
return True

@field_validator("log_level", mode="before")
@classmethod
def _validate_log_level(cls, v: Union[int, str]) -> int:
"""Validate that the log level is a valid logging level."""
if isinstance(v, str):
return getattr(logging, v.upper(), "")
return v
It logs start and end times, the model name, and the context. This is the *default* evaluator
when no other is configured. It is now a thin combination of :class:`ReportingEvaluator` (span /
contextvar correlation, optional structured events when a ``reporter`` is set) and
:class:`~ccflow.utils.reporting.LoggingPolicy` (the actual log output, preserved exactly), so it
also participates in the reporting span tree.
"""

@override
def __call__(self, context: ModelEvaluationContext) -> ResultType:
model_name = context.model.meta.name or context.model.__class__.__name__
log_level = context.options.get("log_level", self.log_level)
verbose = context.options.get("verbose", self.verbose)
log.log(log_level, "[%s]: Start evaluation of %s on %s.", model_name, context.fn, context.context)
if verbose:
log.log(log_level, "[%s]: %s", model_name, context.model)
start = time.time()
result = None
try:
result = context()
return result
finally:
end = time.time()
if self.log_result and result is not None:
log.log(
log_level,
self._format_result(result),
model_name,
context.fn,
context.context,
)
log.log(
log_level,
"[%s]: End evaluation of %s on %s (time elapsed: %s).",
model_name,
context.fn,
context.context,
timedelta(seconds=end - start),
)

def _format_result(self, result: ResultType) -> str:
"""Handle formatting of the result"""
# Add special formatting for eager table/data frame types embedded in the results
import pyarrow as pa

result_dict = result.model_dump(by_alias=True)
for k, v in result_dict.items():
try:
if self.format_config.arrow_as_polars and isinstance(v, pa.Table):
import polars as pl # Only import polars if needed

result_dict[k] = pl.from_arrow(v)
except TypeError:
pass

if self.format_config.polars_config: # Control formatting of polars tables if set
import polars as pl # Only import polars if needed

polars_context = pl.Config(**self.format_config.polars_config)
else:
polars_context = nullcontext()

if self.format_config.pandas_config: # Control formatting of pandas tables if set
import pandas as pd

pandas_context = pd.option_context(*itertools.chain.from_iterable(self.format_config.pandas_config.items()))
else:
pandas_context = nullcontext()

with polars_context, pandas_context:
msg_str = "[%s]: Result of %s on %s:\n"
return f"{msg_str}{pformat(result_dict, **self.format_config.pformat_config)}"
return self._run_with_reporting(
context,
extra={"model": context.model, "raw_context": context.context, "options": dict(context.options)},
**_descriptor(context),
)


def cache_key(flow_obj: Union[ModelEvaluationContext, ContextBase, CallableModel]) -> bytes:
Expand Down Expand Up @@ -304,7 +219,12 @@ def __deepcopy__(self, memo):


class CallableModelGraph(BaseModel):
"""Class to hold a "graph" """
"""Dependency graph of callable-model evaluation contexts.

``graph`` maps each node's cache key to the set of its dependency cache keys, ``ids`` maps cache
keys back to their :class:`ModelEvaluationContext`, and ``root_id`` is the cache key of the node
the graph was built from.
"""

graph: Dict[bytes, Set[bytes]]
ids: Dict[bytes, ModelEvaluationContext]
Expand Down Expand Up @@ -357,7 +277,7 @@ def __call__(self, context: ModelEvaluationContext) -> ResultType:
import graphlib

# If we are evaluating deps, or if we have already started using the graph evaluator further up the call tree,
# no not apply it any further
# do not apply it any further
if self._is_evaluating:
return context()
self._is_evaluating = True
Expand Down
Loading
Loading