Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions data_provider/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,22 @@ def __read_data__(self):
dataset = M4Dataset.load(training=True, dataset_file=self.root_path)
else:
dataset = M4Dataset.load(training=False, dataset_file=self.root_path)
training_values = np.array(
[v[~np.isnan(v)] for v in
dataset.values[dataset.groups == self.seasonal_patterns]]) # split different frequencies
self.ids = np.array([i for i in dataset.ids[dataset.groups == self.seasonal_patterns]])
self.timeseries = [ts for ts in training_values]
# training_values = np.array(
# [v[~np.isnan(v)] for v in
# dataset.values[dataset.groups == self.seasonal_patterns]]) # split different frequencies
# self.ids = np.array([i for i in dataset.ids[dataset.groups == self.seasonal_patterns]])
# self.timeseries = [ts for ts in training_values]
mask = dataset.groups == self.seasonal_patterns
raw_sequences = [v[~np.isnan(v)] for v in dataset.values[mask]] # 移除NaN值
self.ids = np.array([i for i in dataset.ids[mask]])
# 确定目标长度(训练集用seq_len,测试集用M4定义的预测长度)
from data_provider.m4 import M4Meta, pad_sequences # 导入必要的类和函数
if self.flag == 'train':
target_len = self.seq_len # 训练集使用模型输入序列长度
else:
target_len = M4Meta.horizons_map[self.seasonal_patterns] # 测试集使用M4标准长度
# 统一序列长度(填充/截断)
self.timeseries = pad_sequences(raw_sequences, target_len, mode='edge')

def __getitem__(self, index):
insample = np.zeros((self.seq_len, 1))
Expand Down
27 changes: 27 additions & 0 deletions data_provider/m4.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,30 @@ def load_m4_info() -> pd.DataFrame:
:return: Pandas DataFrame of M4Info.
"""
# return pd.read_csv(INFO_FILE_PATH)


def pad_sequences(sequences, target_len, mode='edge'):
"""
统一序列长度(填充或截断)
:param sequences: 原始序列列表(长度可能不同)
:param target_len: 目标长度
:param mode: 填充模式,'edge' 用最后一个值填充,'zero' 用0填充
:return: 长度统一的序列数组
"""
padded = []
for seq in sequences:
seq_len = len(seq)
if seq_len < target_len:
# 填充到目标长度
pad_length = target_len - seq_len
if mode == 'edge':
# 用最后一个值填充(更适合时间序列)
padded_seq = np.pad(seq, (0, pad_length), mode='edge')
else:
# 用0填充
padded_seq = np.pad(seq, (0, pad_length), mode='constant', constant_values=0)
else:
# 截断到目标长度
padded_seq = seq[:target_len]
padded.append(padded_seq)
return np.array(padded)