diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index d315622fcc6..1dbbe0f83de 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -336,7 +336,9 @@ def long_short_backtest( shift=shift, ) - _pred_dates = pred.index.get_level_values(level="datetime") + # Resolve positionally to survive duplicate "datetime" level names (#1909). + _dt_level = pred.index.names.index("datetime") + _pred_dates = pred.index.get_level_values(_dt_level) predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max()) trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift)) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index a6cace3730f..0a8bae3ff0e 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -673,7 +673,10 @@ def config(self, **kwargs): def setup_data(self, **kwargs): super().setup_data(**kwargs) # make sure the calendar is updated to latest when loading data from new config - cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique() + _raw = self.handler.fetch(col_set=self.handler.CS_RAW) + # Resolve positionally for duplicate "datetime" level names (#1909). + _dt_level = _raw.index.names.index("datetime") + cal = _raw.index.get_level_values(_dt_level).unique() self.cal = sorted(cal) @staticmethod diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index 1670a6538ef..c09ec6f2698 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -80,7 +80,12 @@ class RollingEnsemble(Ensemble): def __call__(self, ensemble_dict: dict) -> pd.DataFrame: get_module_logger("RollingEnsemble").info(f"keys in group: {list(ensemble_dict.keys())}") artifact_list = list(ensemble_dict.values()) - artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min()) + + def _min_dt(x: pd.DataFrame) -> pd.Timestamp: + # Resolve positionally for duplicate "datetime" level names (#1909). + return x.index.get_level_values(x.index.names.index("datetime")).min() + + artifact_list.sort(key=_min_dt) artifact = pd.concat(artifact_list) # If there are duplicated predition, use the latest perdiction artifact = artifact[~artifact.index.duplicated(keep="last")] diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 2a94ebd555b..5fc68b5a10a 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -520,7 +520,10 @@ def split_pred(pred, number=None, split_date=None): """ if number is None and split_date is None: raise ValueError("`number` and `split date` cannot both be None") - dates = sorted(pred.index.get_level_values("datetime").unique()) + # Resolve the "datetime" level positionally so duplicate level names + # (which some handler chains produce, see #1909) don't crash this path. + _dt_level = pred.index.names.index("datetime") + dates = sorted(pred.index.get_level_values(_dt_level).unique()) dates = list(map(pd.Timestamp, dates)) if split_date is None: date_left_end = dates[number - 1] diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index 09e96d444f2..d2bd1e941d1 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -277,7 +277,9 @@ def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write signals = prepare_func(self.get_collector()()) old_signals = self.signals if old_signals is not None and not over_write: - old_max = old_signals.index.get_level_values("datetime").max() + # Resolve positionally for duplicate "datetime" level names (#1909). + _dt_level = old_signals.index.names.index("datetime") + old_max = old_signals.index.get_level_values(_dt_level).max() new_signals = signals.loc[old_max:] signals = pd.concat([old_signals, new_signals], axis=0) else: @@ -379,4 +381,6 @@ def delay_prepare(self, model_kwargs={}, signal_kwargs={}): f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models." ) need_prepare = False - signals_time = self.signals.index.get_level_values("datetime").max() + # Resolve positionally for duplicate "datetime" level names (#1909). + _dt_level = self.signals.index.names.index("datetime") + signals_time = self.signals.index.get_level_values(_dt_level).max() diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py index 5047a1bd25e..62bf70a1c56 100644 --- a/qlib/workflow/online/update.py +++ b/qlib/workflow/online/update.py @@ -173,7 +173,10 @@ def __init__( if from_date is None: # dropna is for being compatible to some data with future information(e.g. label) # The recent label data should be updated together - self.last_end = self.old_data.dropna().index.get_level_values("datetime").max() + _old_clean = self.old_data.dropna() + # Resolve the "datetime" level positionally for duplicate names (#1909). + _dt_level = _old_clean.index.names.index("datetime") + self.last_end = _old_clean.index.get_level_values(_dt_level).max() else: self.last_end = get_date_by_shift(from_date, -1, align="right") @@ -259,7 +262,9 @@ def get_update_data(self, dataset: Dataset) -> pd.DataFrame: def _replace_range(data, new_data): - dates = new_data.index.get_level_values("datetime") + # Resolve the "datetime" level positionally for duplicate names (#1909). + _dt_level = new_data.index.names.index("datetime") + dates = new_data.index.get_level_values(_dt_level) data = data.sort_index() data = data.drop(data.loc[dates.min() : dates.max()].index) cb_data = pd.concat([data, new_data], axis=0) diff --git a/tests/misc/test_datetime_level_lookup_helpers.py b/tests/misc/test_datetime_level_lookup_helpers.py new file mode 100644 index 00000000000..9bcf70f78bd --- /dev/null +++ b/tests/misc/test_datetime_level_lookup_helpers.py @@ -0,0 +1,49 @@ +"""Follow-up coverage for the #1909 duplicate-``datetime`` regression. + +After #1909, ``PortAnaRecord`` resolves the ``datetime`` level positionally. +Several other call sites still relied on the name-based lookup and would +crash with the same ``ValueError`` on the same MultiIndex shape. This test +covers the helpers that drive those sites so the broader fix doesn't +regress silently if the call sites are later refactored. +""" + +import unittest + +import numpy as np +import pandas as pd + +from qlib.utils import split_pred + + +def _dup_dt_index(n: int = 6) -> pd.MultiIndex: + dates = pd.date_range("2024-01-01", periods=n, freq="D") + instruments = [f"i{i % 2}" for i in range(n)] + return pd.MultiIndex.from_arrays( + [dates, instruments, dates], + names=["datetime", "instrument", "datetime"], + ) + + +class TestDatetimeLevelLookupHelpers(unittest.TestCase): + def test_split_pred_handles_duplicate_datetime_level(self) -> None: + # Sanity: the index name lookup that the old code path used does + # raise on this shape — split_pred must not rely on it. + idx = _dup_dt_index(6) + with self.assertRaises(ValueError): + idx.get_level_values("datetime") + + pred = pd.DataFrame({"score": np.arange(6, dtype=float)}, index=idx) + pred_left, pred_right = split_pred(pred, number=2) + + # Left half should contain the earliest two distinct dates, right + # half should contain the rest. The exact slicing semantics are + # the same as the unique-name case; we only assert sizes here so + # the assertion stays meaningful even if sort behavior on the + # duplicate index changes between pandas versions. + self.assertGreater(len(pred_left), 0) + self.assertGreater(len(pred_right), 0) + self.assertEqual(len(pred_left) + len(pred_right), len(pred)) + + +if __name__ == "__main__": + unittest.main()