Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions examples/phe/phe_ts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Prior Labs GmbH 2025.
# Licensed under the Apache License, Version 2.0

"""WARNING: This example may run slowly on CPU-only systems.
For better performance, we recommend running with GPU acceleration.
This example trains multiple TabPFN models, which is computationally intensive.
"""

import pandas as pd
from sklearn.datasets import fetch_openml
from sklearn.metrics import (
mean_absolute_error,
)
from sklearn.model_selection import TimeSeriesSplit

from tabpfn import TabPFNRegressor
from tabpfn_extensions.post_hoc_ensembles.sklearn_interface import (
AutoTabPFNRegressor,
)

btc = fetch_openml(
data_id=43563, as_frame=True
) # ↔ name="Digital-currency---Time-series", version=1
df = btc.frame.copy()

df = (
df.rename(columns={"Unnamed:_0": "date"})
.assign(date=lambda x: pd.to_datetime(x["date"]))
.set_index("date")
.sort_index()
)

print("Head of raw data")
print(df.head(), "\n")

# ------------------------------------------------------------------
# 2️⃣ Features & target
target_col = "close_SAR"
y_raw = df[target_col].to_numpy()
y = y_raw / y_raw.max()
X = df.drop(columns=["close_SAR", "close_USD"]).to_numpy()

# ------------------------------------------------------------------
# 3️⃣ Chronological train/test split (last fold = test)
ts = TimeSeriesSplit(n_splits=5)
train_idx, test_idx = list(ts.split(X))[-1]
X_tr, X_te = X[train_idx], X[test_idx]
y_tr, y_te = y[train_idx], y[test_idx]

# ------------------------------------------------------------------
# 4️⃣ Baseline model: TabPFNRegressor
print("🔹 TabPFNRegressor (baseline)")
baseline = TabPFNRegressor()
baseline.fit(X_tr, y_tr)
pred_base = baseline.predict(X_te)

mae_base_rel = mean_absolute_error(y_te, pred_base)
mae_base_raw = mae_base_rel * y_raw.max()
print(f" MAE (relative): {mae_base_rel:.4f}")
print(f" MAE (SAR): {mae_base_raw:,.2f}")


# ------------------------------------------------------------------
# no CV respect AutoTabPFNRegressor

print("\n🔹 AutoTabPFNRegressor (holdout)")
auto_holdout = AutoTabPFNRegressor(max_time=60 * 3)
auto_holdout.fit(X_tr, y_tr)
pred_auto_holdout = auto_holdout.predict(X_te)

mae_auto_holdout_rel = mean_absolute_error(y_te, pred_auto_holdout)
mae_auto_holdout_raw = mae_auto_holdout_rel * y_raw.max()
print(f" MAE (relative): {mae_auto_holdout_rel:.4f}")
print(f" MAE (SAR): {mae_auto_holdout_raw:,.2f}")


# ------------------------------------------------------------------
# 5️⃣ AutoTabPFNRegressor with TimeSeriesSplit
print("\n🔹 AutoTabPFNRegressor (time-series aware CV)")
auto = AutoTabPFNRegressor(
max_time=60 * 3, # quick run
random_state=42,
phe_init_args={
"cv_splitter": ts,
"validation_method": "cv",
"n_folds": 5,
"max_models": 10,
"n_repeats": 1,
},
)
auto.fit(X_tr, y_tr)
pred_auto = auto.predict(X_te)

mae_auto_rel = mean_absolute_error(y_te, pred_auto)
mae_auto_raw = mae_auto_rel * y_raw.max()
print(f" MAE (relative): {mae_auto_rel:.4f}")
print(f" MAE (SAR): {mae_auto_raw:,.2f}")
4 changes: 2 additions & 2 deletions scripts/get_max_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def main() -> None:
deps_match = re.search(r"dependencies\s*=\s*\[(.*?)\]", content, re.DOTALL)
if deps_match:
deps = [
d.strip(' "\'')
d.strip(" \"'")
for d in deps_match.group(1).strip().split("\n")
if d.strip()
]
Expand All @@ -37,4 +37,4 @@ def main() -> None:


if __name__ == "__main__":
main()
main()
4 changes: 2 additions & 2 deletions scripts/get_min_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def main() -> None:
deps_match = re.search(r"dependencies\s*=\s*\[(.*?)\]", content, re.DOTALL)
if deps_match:
deps = [
d.strip(' "\'')
d.strip(" \"'")
for d in deps_match.group(1).strip().split("\n")
if d.strip()
]
Expand All @@ -31,4 +31,4 @@ def main() -> None:


if __name__ == "__main__":
main()
main()
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit
from sklearn.model_selection import (
BaseCrossValidator,
ShuffleSplit,
StratifiedShuffleSplit,
)

from tabpfn_extensions.scoring.scoring_utils import (
CLF_LABEL_METRICS,
Expand Down Expand Up @@ -46,6 +50,7 @@ def __init__(
n_folds: int | None = None,
time_limit: int | None = None,
score_metric: str | None = None,
cv_splitter: BaseCrossValidator | None = None,
):
"""Abstract validation utilities for sklearn-like models.

Expand Down Expand Up @@ -77,6 +82,11 @@ def __init__(
self._estimators = estimators.copy() # internal copy for early stopping.
self.validation_method = validation_method
self.holdout_fraction = holdout_fraction
self.cv_splitter = cv_splitter

if self.cv_splitter:
self.n_repeats = 1
self.n_folds = self.cv_splitter.get_n_splits()

@property
@abstractmethod
Expand Down Expand Up @@ -179,47 +189,86 @@ def _get_split(

def yield_cv_data_iterator(
self,
X,
y,
X: np.ndarray,
y: np.ndarray,
) -> tuple[int, int, int, BaseEstimator, list[int], list[int], bool, bool]:
"""Yields data for cross-validation, handling both default and custom splitters.

This iterator yields tuples containing the necessary information for a single
cross-validation split, such as model index, split index, train/test indices, etc.

When a custom `cv_splitter` is provided, this method assumes that the user
wants to apply the same fold structure to each estimator. Therefore, the splits
are generated once and reused for every model in `self.estimators`.
"""
n_models = len(self.estimators)
holdout_validation = self._is_holdout
_folds = self.n_folds if not holdout_validation else 1

for repeat_i in range(self.n_repeats):
for model_i in range(n_models):
for fold_i in range(_folds):
split_i = fold_i + repeat_i * _folds
train_index, test_index = _get_split = self._get_split(
X=X,
y=y,
fold_i=fold_i,
repeat_i=repeat_i,
holdout_validation=holdout_validation,
)
check_for_model_early_stopping = False
check_for_repeat_early_stopping = False
# TODO: Update documentation to make user aware that self.n_foldds and self.n_repeats are not used
if self.cv_splitter:
logger.info(
f"Using provided CV splitter: {self.cv_splitter.__class__.__name__}"
)
logger.info(f"Ignoring n_folds parameter: {self.n_folds}")
logger.info(
f"Ignoring n_repeats parameter: {self.n_repeats}, hardcoded to 1 in splitter"
)

if (fold_i + 1) == _folds: # check after last fold of model
check_for_model_early_stopping = True
if (
model_i + 1
) == n_models: # check after last fold of last model
check_for_repeat_early_stopping = True
# Generate splits once and reuse for all models
splits = list(self.cv_splitter.split(X, y))
n_folds_per_rep = self.cv_splitter.get_n_splits()

logger.info(
f"Yield data for model {self.estimators[model_i][0]} and split {split_i} (repeat={repeat_i + 1}).",
)
for model_i in range(n_models):
for fold_i, (train_index, test_index) in enumerate(splits):
is_last_fold = (fold_i + 1) == n_folds_per_rep
is_last_model = (model_i + 1) == n_models
yield (
model_i,
split_i,
repeat_i + 1,
fold_i,
1, # Custom splitters don't have a native concept of repeats
self.estimators[model_i][1],
train_index,
test_index,
check_for_model_early_stopping,
check_for_repeat_early_stopping,
is_last_fold,
is_last_fold and is_last_model,
)
else:
holdout_validation = self._is_holdout
_folds = self.n_folds if not holdout_validation else 1

for repeat_i in range(self.n_repeats):
for model_i in range(n_models):
for fold_i in range(_folds):
split_i = fold_i + repeat_i * _folds
train_index, test_index = _get_split = self._get_split(
X=X,
y=y,
fold_i=fold_i,
repeat_i=repeat_i,
holdout_validation=holdout_validation,
)
check_for_model_early_stopping = False
check_for_repeat_early_stopping = False

if (fold_i + 1) == _folds: # check after last fold of model
check_for_model_early_stopping = True
if (
model_i + 1
) == n_models: # check after last fold of last model
check_for_repeat_early_stopping = True

logger.info(
f"Yield data for model {self.estimators[model_i][0]} and split {split_i} (repeat={repeat_i + 1}).",
)
yield (
model_i,
split_i,
repeat_i + 1,
self.estimators[model_i][1],
train_index,
test_index,
check_for_model_early_stopping,
check_for_repeat_early_stopping,
)

def time_limit_reached(self) -> bool:
"""Check if the time limit for execution has been reached.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

if TYPE_CHECKING:
from sklearn.base import BaseEstimator
from sklearn.model_selection import BaseCrossValidator

logging.basicConfig()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -277,6 +278,7 @@ def __init__(
n_iterations: int = 50,
silo_top_n: int = 25,
model_family_per_estimator: list[str] | None = None,
cv_splitter: BaseCrossValidator | None = None,
):
"""Post Hoc Ensemble Classifier Specialized for TabPFN models.

Expand Down Expand Up @@ -310,6 +312,7 @@ def __init__(
time_limit=time_limit,
holdout_fraction=holdout_fraction,
validation_method=validation_method,
cv_splitter=cv_splitter,
)

def _build_ensemble(self, base_models, weights):
Expand All @@ -334,6 +337,7 @@ def __init__(
n_iterations: int = 50,
silo_top_n: int = 25,
model_family_per_estimator: list[str] | None = None,
cv_splitter: BaseCrossValidator | None = None,
):
"""Post Hoc Ensemble Regressor Specialized for TabPFN models.

Expand Down Expand Up @@ -367,6 +371,7 @@ def __init__(
time_limit=time_limit,
holdout_fraction=holdout_fraction,
validation_method=validation_method,
cv_splitter=cv_splitter,
)

def _build_ensemble(self, base_models, weights):
Expand Down
8 changes: 7 additions & 1 deletion src/tabpfn_extensions/post_hoc_ensembles/pfn_phe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import warnings
from enum import Enum
from typing import Literal
from typing import TYPE_CHECKING, Literal

import numpy as np
from sklearn.base import BaseEstimator
Expand All @@ -26,6 +26,9 @@
GreedyWeightedEnsembleRegressor,
)

if TYPE_CHECKING:
from sklearn.model_selection import BaseCrossValidator

logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
Expand Down Expand Up @@ -114,6 +117,7 @@ def __init__(
holdout_fraction: float = 0.33,
ges_n_iterations: int = 25,
ignore_pretraining_limits: bool = False,
cv_splitter: BaseCrossValidator | None = None,
) -> None:
"""Builds a PostHocEnsembleConfig with default values for the given parameters.

Expand Down Expand Up @@ -154,6 +158,7 @@ def __init__(
self.bm_random_state = bm_random_state
self.ges_random_state = ges_random_state
self.ignore_pretraining_limits = ignore_pretraining_limits
self.cv_splitter = cv_splitter

# Model Source
self.tabpfn_base_model_source = tabpfn_base_model_source
Expand Down Expand Up @@ -331,6 +336,7 @@ def fit(
validation_method=self.validation_method,
holdout_fraction=self.holdout_fraction,
model_family_per_estimator=model_family_per_estimator,
cv_splitter=self.cv_splitter,
)

self._ens_model.fit(X, y)
Expand Down