Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion qlib/contrib/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ def long_short_backtest(
shift=shift,
)

_pred_dates = pred.index.get_level_values(level="datetime")
# Resolve positionally to survive duplicate "datetime" level names (#1909).
_dt_level = pred.index.names.index("datetime")
_pred_dates = pred.index.get_level_values(_dt_level)
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))

Expand Down
5 changes: 4 additions & 1 deletion qlib/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,10 @@ def config(self, **kwargs):
def setup_data(self, **kwargs):
super().setup_data(**kwargs)
# make sure the calendar is updated to latest when loading data from new config
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
_raw = self.handler.fetch(col_set=self.handler.CS_RAW)
# Resolve positionally for duplicate "datetime" level names (#1909).
_dt_level = _raw.index.names.index("datetime")
cal = _raw.index.get_level_values(_dt_level).unique()
self.cal = sorted(cal)

@staticmethod
Expand Down
7 changes: 6 additions & 1 deletion qlib/model/ens/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,12 @@ class RollingEnsemble(Ensemble):
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
get_module_logger("RollingEnsemble").info(f"keys in group: {list(ensemble_dict.keys())}")
artifact_list = list(ensemble_dict.values())
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())

def _min_dt(x: pd.DataFrame) -> pd.Timestamp:
# Resolve positionally for duplicate "datetime" level names (#1909).
return x.index.get_level_values(x.index.names.index("datetime")).min()

artifact_list.sort(key=_min_dt)
artifact = pd.concat(artifact_list)
# If there are duplicated predition, use the latest perdiction
artifact = artifact[~artifact.index.duplicated(keep="last")]
Expand Down
5 changes: 4 additions & 1 deletion qlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,10 @@ def split_pred(pred, number=None, split_date=None):
"""
if number is None and split_date is None:
raise ValueError("`number` and `split date` cannot both be None")
dates = sorted(pred.index.get_level_values("datetime").unique())
# Resolve the "datetime" level positionally so duplicate level names
# (which some handler chains produce, see #1909) don't crash this path.
_dt_level = pred.index.names.index("datetime")
dates = sorted(pred.index.get_level_values(_dt_level).unique())
dates = list(map(pd.Timestamp, dates))
if split_date is None:
date_left_end = dates[number - 1]
Expand Down
8 changes: 6 additions & 2 deletions qlib/workflow/online/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write
signals = prepare_func(self.get_collector()())
old_signals = self.signals
if old_signals is not None and not over_write:
old_max = old_signals.index.get_level_values("datetime").max()
# Resolve positionally for duplicate "datetime" level names (#1909).
_dt_level = old_signals.index.names.index("datetime")
old_max = old_signals.index.get_level_values(_dt_level).max()
new_signals = signals.loc[old_max:]
signals = pd.concat([old_signals, new_signals], axis=0)
else:
Expand Down Expand Up @@ -379,4 +381,6 @@ def delay_prepare(self, model_kwargs={}, signal_kwargs={}):
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
)
need_prepare = False
signals_time = self.signals.index.get_level_values("datetime").max()
# Resolve positionally for duplicate "datetime" level names (#1909).
_dt_level = self.signals.index.names.index("datetime")
signals_time = self.signals.index.get_level_values(_dt_level).max()
9 changes: 7 additions & 2 deletions qlib/workflow/online/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ def __init__(
if from_date is None:
# dropna is for being compatible to some data with future information(e.g. label)
# The recent label data should be updated together
self.last_end = self.old_data.dropna().index.get_level_values("datetime").max()
_old_clean = self.old_data.dropna()
# Resolve the "datetime" level positionally for duplicate names (#1909).
_dt_level = _old_clean.index.names.index("datetime")
self.last_end = _old_clean.index.get_level_values(_dt_level).max()
else:
self.last_end = get_date_by_shift(from_date, -1, align="right")

Expand Down Expand Up @@ -259,7 +262,9 @@ def get_update_data(self, dataset: Dataset) -> pd.DataFrame:


def _replace_range(data, new_data):
dates = new_data.index.get_level_values("datetime")
# Resolve the "datetime" level positionally for duplicate names (#1909).
_dt_level = new_data.index.names.index("datetime")
dates = new_data.index.get_level_values(_dt_level)
data = data.sort_index()
data = data.drop(data.loc[dates.min() : dates.max()].index)
cb_data = pd.concat([data, new_data], axis=0)
Expand Down
49 changes: 49 additions & 0 deletions tests/misc/test_datetime_level_lookup_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Follow-up coverage for the #1909 duplicate-``datetime`` regression.

After #1909, ``PortAnaRecord`` resolves the ``datetime`` level positionally.
Several other call sites still relied on the name-based lookup and would
crash with the same ``ValueError`` on the same MultiIndex shape. This test
covers the helpers that drive those sites so the broader fix doesn't
regress silently if the call sites are later refactored.
"""

import unittest

import numpy as np
import pandas as pd

from qlib.utils import split_pred


def _dup_dt_index(n: int = 6) -> pd.MultiIndex:
dates = pd.date_range("2024-01-01", periods=n, freq="D")
instruments = [f"i{i % 2}" for i in range(n)]
return pd.MultiIndex.from_arrays(
[dates, instruments, dates],
names=["datetime", "instrument", "datetime"],
)


class TestDatetimeLevelLookupHelpers(unittest.TestCase):
def test_split_pred_handles_duplicate_datetime_level(self) -> None:
# Sanity: the index name lookup that the old code path used does
# raise on this shape — split_pred must not rely on it.
idx = _dup_dt_index(6)
with self.assertRaises(ValueError):
idx.get_level_values("datetime")

pred = pd.DataFrame({"score": np.arange(6, dtype=float)}, index=idx)
pred_left, pred_right = split_pred(pred, number=2)

# Left half should contain the earliest two distinct dates, right
# half should contain the rest. The exact slicing semantics are
# the same as the unique-name case; we only assert sizes here so
# the assertion stays meaningful even if sort behavior on the
# duplicate index changes between pandas versions.
self.assertGreater(len(pred_left), 0)
self.assertGreater(len(pred_right), 0)
self.assertEqual(len(pred_left) + len(pred_right), len(pred))


if __name__ == "__main__":
unittest.main()