Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/tabpfn_common_utils/telemetry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from .core.events import DatasetEvent, FitEvent, PingEvent, PredictEvent
from .core.events import DatasetEvent, FitEvent, ModelLoadEvent, PingEvent, PredictEvent
from .core.service import ProductTelemetry, capture_event
from .core.decorators import (
track_model_call,
Expand All @@ -14,6 +14,7 @@
# Public exports
__all__ = [
"DatasetEvent",
"ModelLoadEvent",
"FitEvent",
"PingEvent",
"PredictEvent",
Expand Down
2 changes: 2 additions & 0 deletions src/tabpfn_common_utils/telemetry/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .events import (
BaseTelemetryEvent,
ModelLoadEvent,
PingEvent,
DatasetEvent,
FitEvent,
Expand All @@ -25,6 +26,7 @@
"BaseTelemetryEvent",
"PingEvent",
"DatasetEvent",
"ModelLoadEvent",
"FitEvent",
"PredictEvent",
"ProductTelemetry",
Expand Down
28 changes: 28 additions & 0 deletions src/tabpfn_common_utils/telemetry/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,34 @@ def name(self) -> str:
return "session"


@dataclass
class ModelLoadEvent(BaseTelemetryEvent):
Comment thread
noahho marked this conversation as resolved.
"""
Event emitted when a model is loaded.
"""

# Status of the model download attempt
status: Literal["success", "failed"]

# Install ID of the user
install_id: str = field(default_factory=_get_install_id, init=False)
Comment thread
safaricd marked this conversation as resolved.

# Name of the model, may be a HuggingFace repo ID
model_name: Optional[str] = field(default=None)

# Failure reason if the model download failed
failure_reason: Optional[str] = field(default=None)

def __post_init__(self):
"""Post-init hook to ensure data integrity."""
if self.status == "success":
self.failure_reason = None

@property
def name(self) -> str:
return "model_load"

Comment thread
safaricd marked this conversation as resolved.

@dataclass
class PingEvent(BaseTelemetryEvent):
"""
Expand Down
60 changes: 60 additions & 0 deletions tests/telemetry/core/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
BaseTelemetryEvent,
DatasetEvent,
FitEvent,
ModelLoadEvent,
PingEvent,
PredictEvent,
_get_py_version,
Expand Down Expand Up @@ -362,6 +363,65 @@ def test_ping_event_minimal_structure(self):
assert property in expected_attrs


class TestModelLoadEvent:
"""Test ModelLoadEvent class"""

def test_model_load_event_initialization(self):
"""Test ModelLoadEvent initialization with required status"""
event = ModelLoadEvent(status="success")

assert event.status == "success"
assert event.name == "model_load"
assert event.failure_reason is None
assert event.model_name is None

def test_model_load_event_with_failed_status(self):
"""Test ModelLoadEvent with failed status"""
event = ModelLoadEvent(
status="failed", failure_reason="Network error", model_name="test-model"
)

assert event.status == "failed"
assert event.failure_reason == "Network error"
assert event.model_name == "test-model"

def test_model_load_event_post_init_clears_failure_reason_on_success(self):
"""Test that __post_init__ clears failure_reason when status is success"""
# This tests the __post_init__ behavior
event = ModelLoadEvent(status="success", failure_reason="should be cleared")

assert event.status == "success"
assert event.failure_reason is None

def test_model_load_event_inherits_base_properties(self):
"""Test that ModelLoadEvent inherits base telemetry properties"""
event = ModelLoadEvent(status="success")

assert isinstance(event.python_version, str)
assert isinstance(event.tabpfn_version, str)
assert isinstance(event.timestamp, datetime)
assert event.source == "sdk"
assert isinstance(event.install_id, str)

def test_model_load_event_properties_method(self):
"""Test ModelLoadEvent properties method"""
event = ModelLoadEvent(
status="failed",
model_name="test-model",
failure_reason="Download timeout",
)

props = event.properties

assert "name" not in props
assert props["status"] == "failed"
assert props["model_name"] == "test-model"
assert props["failure_reason"] == "Download timeout"
assert "python_version" in props
assert "tabpfn_version" in props
assert "install_id" in props


class TestEventIntegration:
"""Integration tests for all event types"""

Expand Down