From 6f1200a85c94ceb5ddd887bd405e22d7cfb24356 Mon Sep 17 00:00:00 2001 From: huangqi27 <975674032@qq.com> Date: Fri, 2 Jan 2026 14:43:19 +0800 Subject: [PATCH] =?UTF-8?q?Fix=20M4=20dataset=20sequence=20length=20mismat?= =?UTF-8?q?ch=EF=BC=88=E9=81=BF=E5=85=8Dm4=E6=95=B0=E6=8D=AE=E9=95=BF?= =?UTF-8?q?=E5=BA=A6=E4=B8=8D=E8=83=BD=E6=AD=A3=E5=B8=B8=E9=80=82=E9=85=8D?= =?UTF-8?q?=E9=A1=B9=E7=9B=AE=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_provider/data_loader.py | 21 ++++++++++++++++----- data_provider/m4.py | 27 +++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/data_provider/data_loader.py b/data_provider/data_loader.py index ea28fdccd..d2ef39eda 100644 --- a/data_provider/data_loader.py +++ b/data_provider/data_loader.py @@ -361,11 +361,22 @@ def __read_data__(self): dataset = M4Dataset.load(training=True, dataset_file=self.root_path) else: dataset = M4Dataset.load(training=False, dataset_file=self.root_path) - training_values = np.array( - [v[~np.isnan(v)] for v in - dataset.values[dataset.groups == self.seasonal_patterns]]) # split different frequencies - self.ids = np.array([i for i in dataset.ids[dataset.groups == self.seasonal_patterns]]) - self.timeseries = [ts for ts in training_values] + # training_values = np.array( + # [v[~np.isnan(v)] for v in + # dataset.values[dataset.groups == self.seasonal_patterns]]) # split different frequencies + # self.ids = np.array([i for i in dataset.ids[dataset.groups == self.seasonal_patterns]]) + # self.timeseries = [ts for ts in training_values] + mask = dataset.groups == self.seasonal_patterns + raw_sequences = [v[~np.isnan(v)] for v in dataset.values[mask]] # 移除NaN值 + self.ids = np.array([i for i in dataset.ids[mask]]) + # 确定目标长度(训练集用seq_len,测试集用M4定义的预测长度) + from data_provider.m4 import M4Meta, pad_sequences # 导入必要的类和函数 + if self.flag == 'train': + target_len = self.seq_len # 训练集使用模型输入序列长度 + else: + target_len = M4Meta.horizons_map[self.seasonal_patterns] # 测试集使用M4标准长度 + # 统一序列长度(填充/截断) + self.timeseries = pad_sequences(raw_sequences, target_len, mode='edge') def __getitem__(self, index): insample = np.zeros((self.seq_len, 1)) diff --git a/data_provider/m4.py b/data_provider/m4.py index eb2842a25..23f1b7027 100644 --- a/data_provider/m4.py +++ b/data_provider/m4.py @@ -159,3 +159,30 @@ def load_m4_info() -> pd.DataFrame: :return: Pandas DataFrame of M4Info. """ # return pd.read_csv(INFO_FILE_PATH) + + +def pad_sequences(sequences, target_len, mode='edge'): + """ + 统一序列长度(填充或截断) + :param sequences: 原始序列列表(长度可能不同) + :param target_len: 目标长度 + :param mode: 填充模式,'edge' 用最后一个值填充,'zero' 用0填充 + :return: 长度统一的序列数组 + """ + padded = [] + for seq in sequences: + seq_len = len(seq) + if seq_len < target_len: + # 填充到目标长度 + pad_length = target_len - seq_len + if mode == 'edge': + # 用最后一个值填充(更适合时间序列) + padded_seq = np.pad(seq, (0, pad_length), mode='edge') + else: + # 用0填充 + padded_seq = np.pad(seq, (0, pad_length), mode='constant', constant_values=0) + else: + # 截断到目标长度 + padded_seq = seq[:target_len] + padded.append(padded_seq) + return np.array(padded)