Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 50 additions & 14 deletions src/tabpfn_common_utils/telemetry/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,10 @@ class BaseTelemetryEvent:
"""

# Python version that the SDK is running on
python_version: str = field(default_factory=_get_py_version, init=False)
python_version: Optional[str] = field(default=None, init=False)

# TabPFN version that the SDK is running on
tabpfn_version: str = field(default_factory=_get_sdk_version, init=False)
tabpfn_version: Optional[str] = field(default=None, init=False)

# Timestamp of the event
timestamp: datetime = field(default_factory=_utc_now, init=False)
Expand All @@ -257,18 +257,32 @@ class BaseTelemetryEvent:
extension: Optional[str] = field(default=None, init=False)

# Runtime environment of the platform
runtime_kernel: Optional[str] = field(
default_factory=_get_runtime_kernel, init=False
)
runtime_kernel: Optional[str] = field(default=None, init=False)

# Runtime environment of the platform
runtime_environment: Optional[str] = field(
default_factory=_get_runtime_environment, init=False
)
runtime_environment: Optional[str] = field(default=None, init=False)

# Operating system of the platform
platform_os: str = field(default_factory=_get_platform_os, init=False)

def enrich(self):
"""Enrich the event with additional properties.

We enrich the events with additional properties so we do not
have to call the default factories at initialization time.
"""
if self.runtime_kernel is None:
self.runtime_kernel = _get_runtime_kernel()

if self.runtime_environment is None:
self.runtime_environment = _get_runtime_environment()

if self.python_version is None:
self.python_version = _get_py_version()

if self.tabpfn_version is None:
self.tabpfn_version = _get_sdk_version()

@property
def name(self) -> str:
raise NotImplementedError
Expand Down Expand Up @@ -378,22 +392,22 @@ class ModelCallEvent(BaseTelemetryEvent):
task: Literal["classification", "regression"]

# Version of the PyTorch
torch_version: str = field(default_factory=_get_torch_version, init=False)
torch_version: Optional[str] = field(default=None, init=False)

# Version of the scikit-learn
sklearn_version: str = field(default_factory=_get_sklearn_version, init=False)
sklearn_version: Optional[str] = field(default=None, init=False)

# Version of the NumPy
numpy_version: str = field(default_factory=_get_numpy_version, init=False)
numpy_version: Optional[str] = field(default=None, init=False)

# Version of the Pandas
pandas_version: str = field(default_factory=_get_pandas_version, init=False)
pandas_version: Optional[str] = field(default=None, init=False)

# Version of the AutoGluon
autogluon_version: str = field(default_factory=_get_autogluon_version, init=False)
autogluon_version: Optional[str] = field(default=None, init=False)

# Type of GPU if available
gpu_type: Optional[str] = field(default_factory=_get_gpu_type, init=False)
gpu_type: Optional[str] = field(default=None, init=False)

# Version of the model
model_version: Optional[str] = field(default=None, init=False)
Expand All @@ -410,6 +424,28 @@ class ModelCallEvent(BaseTelemetryEvent):
# Duration of the model call in milliseconds
duration_ms: int = -1

def enrich(self):
"""Enrich the event with additional properties."""
super().enrich()

if self.gpu_type is None:
self.gpu_type = _get_gpu_type()

if self.torch_version is None:
self.torch_version = _get_torch_version()

if self.sklearn_version is None:
self.sklearn_version = _get_sklearn_version()

if self.numpy_version is None:
self.numpy_version = _get_numpy_version()

if self.pandas_version is None:
self.pandas_version = _get_pandas_version()

if self.autogluon_version is None:
self.autogluon_version = _get_autogluon_version()


@dataclass
class FitEvent(ModelCallEvent):
Expand Down
Loading