Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/tabpfn_common_utils/telemetry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
set_extension,
get_current_extension,
set_model_config,
set_init_params,
get_init_params,
)

# Public exports
Expand All @@ -21,4 +23,6 @@
"set_extension",
"get_current_extension",
"set_model_config",
"set_init_params",
"get_init_params",
]
4 changes: 4 additions & 0 deletions src/tabpfn_common_utils/telemetry/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
set_extension,
get_current_extension,
set_model_config,
set_init_params,
get_init_params,
)

# Public exports
Expand All @@ -32,4 +34,6 @@
"set_extension",
"get_current_extension",
"set_model_config",
"set_init_params",
"get_init_params",
]
38 changes: 37 additions & 1 deletion src/tabpfn_common_utils/telemetry/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dataclasses import dataclass
from pathlib import Path
from functools import wraps
from typing import Any, Callable, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union

from .events import FitEvent, PredictEvent
from .service import capture_event
Expand Down Expand Up @@ -85,6 +85,38 @@ def get_model_config() -> Optional[Tuple[str, str]]:
return None


def set_init_params(
params: Dict[str, Any],
) -> Optional[contextvars.Token[Optional[str]]]:
"""Set the initial parameters of the model.

Args:
params: The initial parameters of the model.
"""
try:
token = json.dumps(params)
tok = _get_context_var("tabpfn_model_init_params").set(token)
return tok
except Exception:
return None
Comment thread
safaricd marked this conversation as resolved.


def get_init_params() -> Optional[Dict[str, Any]]:
"""Get the initial parameters of the model.

Returns:
The initial parameters of the model.
"""
token = _get_context_var("tabpfn_model_init_params").get()
if token is None:
return None

try:
return json.loads(token)
except Exception:
return None
Comment thread
safaricd marked this conversation as resolved.


def get_current_extension() -> Optional[str]:
"""Get the current extension.

Expand Down Expand Up @@ -383,6 +415,10 @@ def _send_model_called_event(call_info: _ModelCallInfo, duration_ms: int) -> Non
event.model_path = model_path
event.model_version = model_version

# Set the model init params for fit
if isinstance(event, FitEvent):
event.init_params = get_init_params()

except TypeError as e:
logger.debug(f"Event creation failed: {e}")
return
Expand Down
5 changes: 4 additions & 1 deletion src/tabpfn_common_utils/telemetry/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, asdict, field
from datetime import datetime, timezone
from functools import lru_cache
from typing import Any, Literal, Optional
from typing import Any, Dict, Literal, Optional
from .runtime import get_execution_context
from .state import get_property, set_property

Expand Down Expand Up @@ -373,6 +373,9 @@ class FitEvent(ModelCallEvent):
Event emitted when a model is fit.
"""

# Initial parameters of the model
init_params: Optional[Dict[str, Any]] = field(default=None, init=False)

@property
def name(self) -> str:
return "fit_called"
Expand Down