From 476f3ce8a5b31a6a2ed3eb650033dd406d5f69f7 Mon Sep 17 00:00:00 2001 From: lx Date: Sun, 15 Jun 2025 16:41:09 +0800 Subject: [PATCH 1/2] 'add-san' --- baselines/SAN/ETTm2.py | 139 +++++++++++++++++++++++++++++++++ baselines/SAN/Electricity.py | 139 +++++++++++++++++++++++++++++++++ baselines/SAN/Weather.py | 139 +++++++++++++++++++++++++++++++++ baselines/SAN/arch/__init__.py | 1 + baselines/SAN/arch/dlinear.py | 98 +++++++++++++++++++++++ baselines/SAN/arch/san_arch.py | 104 ++++++++++++++++++++++++ baselines/SAN/loss/__init__.py | 1 + baselines/SAN/loss/loss.py | 25 ++++++ 8 files changed, 646 insertions(+) create mode 100644 baselines/SAN/ETTm2.py create mode 100644 baselines/SAN/Electricity.py create mode 100644 baselines/SAN/Weather.py create mode 100644 baselines/SAN/arch/__init__.py create mode 100644 baselines/SAN/arch/dlinear.py create mode 100644 baselines/SAN/arch/san_arch.py create mode 100644 baselines/SAN/loss/__init__.py create mode 100644 baselines/SAN/loss/loss.py diff --git a/baselines/SAN/ETTm2.py b/baselines/SAN/ETTm2.py new file mode 100644 index 00000000..687f9d46 --- /dev/null +++ b/baselines/SAN/ETTm2.py @@ -0,0 +1,139 @@ +import os +import sys +from easydict import EasyDict +sys.path.append(os.path.abspath(__file__ + '/../../..')) + +from basicts.metrics import masked_mae, masked_mse +from basicts.data import TimeSeriesForecastingDataset +from basicts.runners import SimpleTimeSeriesForecastingRunner +from basicts.scaler import ZScoreScaler +from basicts.utils import get_regular_settings + +from .arch import SAN +from .loss import san_loss + +############################## Hot Parameters ############################## +# Dataset & Metrics configuration +DATA_NAME = 'ETTm2' # Dataset name +regular_settings = get_regular_settings(DATA_NAME) +INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence +OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence +TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios +NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data +RESCALE = regular_settings['RESCALE'] # Whether to rescale the data +NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data +# Model architecture and parameters +MODEL_ARCH = SAN +MODEL_PARAM = { + "seq_len": INPUT_LEN, + "pred_len": OUTPUT_LEN, + "individual": False, + "enc_in": 7, + "period_len": 24, + "station_pretrain_epoch": 5, +} +NUM_EPOCHS = 50 + +############################## General Configuration ############################## +CFG = EasyDict() +# General settings +CFG.DESCRIPTION = 'An Example Config' +CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode) +# Runner +CFG.RUNNER = SimpleTimeSeriesForecastingRunner + +############################## Dataset Configuration ############################## +CFG.DATASET = EasyDict() +# Dataset settings +CFG.DATASET.NAME = DATA_NAME +CFG.DATASET.TYPE = TimeSeriesForecastingDataset +CFG.DATASET.PARAM = EasyDict({ + 'dataset_name': DATA_NAME, + 'train_val_test_ratio': TRAIN_VAL_TEST_RATIO, + 'input_len': INPUT_LEN, + 'output_len': OUTPUT_LEN, + # 'mode' is automatically set by the runner +}) + +############################## Scaler Configuration ############################## +CFG.SCALER = EasyDict() +# Scaler settings +CFG.SCALER.TYPE = ZScoreScaler # Scaler class +CFG.SCALER.PARAM = EasyDict({ + 'dataset_name': DATA_NAME, + 'train_ratio': TRAIN_VAL_TEST_RATIO[0], + 'norm_each_channel': NORM_EACH_CHANNEL, + 'rescale': RESCALE, +}) + +############################## Model Configuration ############################## +CFG.MODEL = EasyDict() +# Model settings +CFG.MODEL.NAME = MODEL_ARCH.__name__ +CFG.MODEL.ARCH = MODEL_ARCH +CFG.MODEL.PARAM = MODEL_PARAM +CFG.MODEL.FORWARD_FEATURES = [0] +CFG.MODEL.TARGET_FEATURES = [0] + +############################## Metrics Configuration ############################## + +CFG.METRICS = EasyDict() +# Metrics settings +CFG.METRICS.FUNCS = EasyDict({ + 'MAE': masked_mae, + 'MSE': masked_mse, + }) +CFG.METRICS.TARGET = 'MSE' +CFG.METRICS.NULL_VAL = NULL_VAL + +############################## Training Configuration ############################## +CFG.TRAIN = EasyDict() +CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS +CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( + 'checkpoints', + MODEL_ARCH.__name__, + '_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)]) +) +CFG.TRAIN.LOSS = san_loss +# Optimizer settings +CFG.TRAIN.OPTIM = EasyDict() +CFG.TRAIN.OPTIM.TYPE = "Adam" +CFG.TRAIN.OPTIM.PARAM = { + "lr": 0.001 +} + +CFG.TRAIN.LR_SCHEDULER = EasyDict() +CFG.TRAIN.LR_SCHEDULER.TYPE = "SANWarmupMultiStepLR" +CFG.TRAIN.LR_SCHEDULER.PARAM = { + "warmup_lr": 0.0001, + "warmup_epochs": 5, + "milestones": [6, 30], +} + +CFG.TRAIN.CLIP_GRAD_PARAM = { + 'max_norm': 5.0 +} +# Train data loader settings +CFG.TRAIN.DATA = EasyDict() +CFG.TRAIN.DATA.BATCH_SIZE = 64 +CFG.TRAIN.DATA.SHUFFLE = True + +############################## Validation Configuration ############################## +CFG.VAL = EasyDict() +CFG.VAL.INTERVAL = 1 +CFG.VAL.DATA = EasyDict() +CFG.VAL.DATA.BATCH_SIZE = 64 + +############################## Test Configuration ############################## +CFG.TEST = EasyDict() +CFG.TEST.INTERVAL = 1 +CFG.TEST.DATA = EasyDict() +CFG.TEST.DATA.BATCH_SIZE = 64 + +############################## Evaluation Configuration ############################## + +CFG.EVAL = EasyDict() + +# Evaluation parameters +CFG.EVAL.HORIZONS = [12, 24, 48, 96] +CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True diff --git a/baselines/SAN/Electricity.py b/baselines/SAN/Electricity.py new file mode 100644 index 00000000..4f4aade2 --- /dev/null +++ b/baselines/SAN/Electricity.py @@ -0,0 +1,139 @@ +import os +import sys +from easydict import EasyDict +sys.path.append(os.path.abspath(__file__ + '/../../..')) + +from basicts.metrics import masked_mae, masked_mse +from basicts.data import TimeSeriesForecastingDataset +from basicts.runners import SimpleTimeSeriesForecastingRunner +from basicts.scaler import ZScoreScaler +from basicts.utils import get_regular_settings + +from .arch import SAN +from .loss import san_loss + +############################## Hot Parameters ############################## +# Dataset & Metrics configuration +DATA_NAME = 'Electricity' # Dataset name +regular_settings = get_regular_settings(DATA_NAME) +INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence +OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence +TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios +NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data +RESCALE = regular_settings['RESCALE'] # Whether to rescale the data +NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data +# Model architecture and parameters +MODEL_ARCH = SAN +MODEL_PARAM = { + "seq_len": INPUT_LEN, + "pred_len": OUTPUT_LEN, + "individual": False, + "enc_in": 321, + "period_len": 24, + "station_pretrain_epoch": 5, +} +NUM_EPOCHS = 50 + +############################## General Configuration ############################## +CFG = EasyDict() +# General settings +CFG.DESCRIPTION = 'An Example Config' +CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode) +# Runner +CFG.RUNNER = SimpleTimeSeriesForecastingRunner + +############################## Dataset Configuration ############################## +CFG.DATASET = EasyDict() +# Dataset settings +CFG.DATASET.NAME = DATA_NAME +CFG.DATASET.TYPE = TimeSeriesForecastingDataset +CFG.DATASET.PARAM = EasyDict({ + 'dataset_name': DATA_NAME, + 'train_val_test_ratio': TRAIN_VAL_TEST_RATIO, + 'input_len': INPUT_LEN, + 'output_len': OUTPUT_LEN, + # 'mode' is automatically set by the runner +}) + +############################## Scaler Configuration ############################## +CFG.SCALER = EasyDict() +# Scaler settings +CFG.SCALER.TYPE = ZScoreScaler # Scaler class +CFG.SCALER.PARAM = EasyDict({ + 'dataset_name': DATA_NAME, + 'train_ratio': TRAIN_VAL_TEST_RATIO[0], + 'norm_each_channel': NORM_EACH_CHANNEL, + 'rescale': RESCALE, +}) + +############################## Model Configuration ############################## +CFG.MODEL = EasyDict() +# Model settings +CFG.MODEL.NAME = MODEL_ARCH.__name__ +CFG.MODEL.ARCH = MODEL_ARCH +CFG.MODEL.PARAM = MODEL_PARAM +CFG.MODEL.FORWARD_FEATURES = [0] +CFG.MODEL.TARGET_FEATURES = [0] + +############################## Metrics Configuration ############################## + +CFG.METRICS = EasyDict() +# Metrics settings +CFG.METRICS.FUNCS = EasyDict({ + 'MAE': masked_mae, + 'MSE': masked_mse, + }) +CFG.METRICS.TARGET = 'MSE' +CFG.METRICS.NULL_VAL = NULL_VAL + +############################## Training Configuration ############################## +CFG.TRAIN = EasyDict() +CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS +CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( + 'checkpoints', + MODEL_ARCH.__name__, + '_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)]) +) +CFG.TRAIN.LOSS = san_loss +# Optimizer settings +CFG.TRAIN.OPTIM = EasyDict() +CFG.TRAIN.OPTIM.TYPE = "Adam" +CFG.TRAIN.OPTIM.PARAM = { + "lr": 0.001 +} + +CFG.TRAIN.LR_SCHEDULER = EasyDict() +CFG.TRAIN.LR_SCHEDULER.TYPE = "SANWarmupMultiStepLR" +CFG.TRAIN.LR_SCHEDULER.PARAM = { + "warmup_lr": 0.0001, + "warmup_epochs": 5, + "milestones": [6, 30], +} + +CFG.TRAIN.CLIP_GRAD_PARAM = { + 'max_norm': 5.0 +} +# Train data loader settings +CFG.TRAIN.DATA = EasyDict() +CFG.TRAIN.DATA.BATCH_SIZE = 64 +CFG.TRAIN.DATA.SHUFFLE = True + +############################## Validation Configuration ############################## +CFG.VAL = EasyDict() +CFG.VAL.INTERVAL = 1 +CFG.VAL.DATA = EasyDict() +CFG.VAL.DATA.BATCH_SIZE = 64 + +############################## Test Configuration ############################## +CFG.TEST = EasyDict() +CFG.TEST.INTERVAL = 1 +CFG.TEST.DATA = EasyDict() +CFG.TEST.DATA.BATCH_SIZE = 64 + +############################## Evaluation Configuration ############################## + +CFG.EVAL = EasyDict() + +# Evaluation parameters +CFG.EVAL.HORIZONS = [12, 24, 48, 96] +CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True diff --git a/baselines/SAN/Weather.py b/baselines/SAN/Weather.py new file mode 100644 index 00000000..2634c77d --- /dev/null +++ b/baselines/SAN/Weather.py @@ -0,0 +1,139 @@ +import os +import sys +from easydict import EasyDict +sys.path.append(os.path.abspath(__file__ + '/../../..')) + +from basicts.metrics import masked_mae, masked_mse +from basicts.data import TimeSeriesForecastingDataset +from basicts.runners import SimpleTimeSeriesForecastingRunner +from basicts.scaler import ZScoreScaler +from basicts.utils import get_regular_settings + +from .arch import SAN +from .loss import san_loss + +############################## Hot Parameters ############################## +# Dataset & Metrics configuration +DATA_NAME = 'Weather' # Dataset name +regular_settings = get_regular_settings(DATA_NAME) +INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence +OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence +TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios +NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data +RESCALE = regular_settings['RESCALE'] # Whether to rescale the data +NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data +# Model architecture and parameters +MODEL_ARCH = SAN +MODEL_PARAM = { + "seq_len": INPUT_LEN, + "pred_len": OUTPUT_LEN, + "individual": False, + "enc_in": 21, + "period_len": 24, + "station_pretrain_epoch": 5, +} +NUM_EPOCHS = 50 + +############################## General Configuration ############################## +CFG = EasyDict() +# General settings +CFG.DESCRIPTION = 'An Example Config' +CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode) +# Runner +CFG.RUNNER = SimpleTimeSeriesForecastingRunner + +############################## Dataset Configuration ############################## +CFG.DATASET = EasyDict() +# Dataset settings +CFG.DATASET.NAME = DATA_NAME +CFG.DATASET.TYPE = TimeSeriesForecastingDataset +CFG.DATASET.PARAM = EasyDict({ + 'dataset_name': DATA_NAME, + 'train_val_test_ratio': TRAIN_VAL_TEST_RATIO, + 'input_len': INPUT_LEN, + 'output_len': OUTPUT_LEN, + # 'mode' is automatically set by the runner +}) + +############################## Scaler Configuration ############################## +CFG.SCALER = EasyDict() +# Scaler settings +CFG.SCALER.TYPE = ZScoreScaler # Scaler class +CFG.SCALER.PARAM = EasyDict({ + 'dataset_name': DATA_NAME, + 'train_ratio': TRAIN_VAL_TEST_RATIO[0], + 'norm_each_channel': NORM_EACH_CHANNEL, + 'rescale': RESCALE, +}) + +############################## Model Configuration ############################## +CFG.MODEL = EasyDict() +# Model settings +CFG.MODEL.NAME = MODEL_ARCH.__name__ +CFG.MODEL.ARCH = MODEL_ARCH +CFG.MODEL.PARAM = MODEL_PARAM +CFG.MODEL.FORWARD_FEATURES = [0] +CFG.MODEL.TARGET_FEATURES = [0] + +############################## Metrics Configuration ############################## + +CFG.METRICS = EasyDict() +# Metrics settings +CFG.METRICS.FUNCS = EasyDict({ + 'MAE': masked_mae, + 'MSE': masked_mse, + }) +CFG.METRICS.TARGET = 'MSE' +CFG.METRICS.NULL_VAL = NULL_VAL + +############################## Training Configuration ############################## +CFG.TRAIN = EasyDict() +CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS +CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( + 'checkpoints', + MODEL_ARCH.__name__, + '_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)]) +) +CFG.TRAIN.LOSS = san_loss +# Optimizer settings +CFG.TRAIN.OPTIM = EasyDict() +CFG.TRAIN.OPTIM.TYPE = "Adam" +CFG.TRAIN.OPTIM.PARAM = { + "lr": 0.001 +} + +CFG.TRAIN.LR_SCHEDULER = EasyDict() +CFG.TRAIN.LR_SCHEDULER.TYPE = "SANWarmupMultiStepLR" +CFG.TRAIN.LR_SCHEDULER.PARAM = { + "warmup_lr": 0.0001, + "warmup_epochs": 5, + "milestones": [6, 30], +} + +CFG.TRAIN.CLIP_GRAD_PARAM = { + 'max_norm': 5.0 +} +# Train data loader settings +CFG.TRAIN.DATA = EasyDict() +CFG.TRAIN.DATA.BATCH_SIZE = 64 +CFG.TRAIN.DATA.SHUFFLE = True + +############################## Validation Configuration ############################## +CFG.VAL = EasyDict() +CFG.VAL.INTERVAL = 1 +CFG.VAL.DATA = EasyDict() +CFG.VAL.DATA.BATCH_SIZE = 64 + +############################## Test Configuration ############################## +CFG.TEST = EasyDict() +CFG.TEST.INTERVAL = 1 +CFG.TEST.DATA = EasyDict() +CFG.TEST.DATA.BATCH_SIZE = 64 + +############################## Evaluation Configuration ############################## + +CFG.EVAL = EasyDict() + +# Evaluation parameters +CFG.EVAL.HORIZONS = [12, 24, 48, 96] +CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True diff --git a/baselines/SAN/arch/__init__.py b/baselines/SAN/arch/__init__.py new file mode 100644 index 00000000..569f8e64 --- /dev/null +++ b/baselines/SAN/arch/__init__.py @@ -0,0 +1 @@ +from .san_arch import SAN \ No newline at end of file diff --git a/baselines/SAN/arch/dlinear.py b/baselines/SAN/arch/dlinear.py new file mode 100644 index 00000000..a8a2c582 --- /dev/null +++ b/baselines/SAN/arch/dlinear.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn + + +class moving_avg(nn.Module): + """Moving average block to highlight the trend of time series""" + + def __init__(self, kernel_size, stride): + super(moving_avg, self).__init__() + self.kernel_size = kernel_size + self.avg = nn.AvgPool1d(kernel_size=kernel_size, + stride=stride, padding=0) + + def forward(self, x): + # padding on the both ends of time series + front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) + end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) + x = torch.cat([front, x, end], dim=1) + x = self.avg(x.permute(0, 2, 1)) + x = x.permute(0, 2, 1) + return x + + +class series_decomp(nn.Module): + """Series decomposition block""" + + def __init__(self, kernel_size): + super(series_decomp, self).__init__() + self.moving_avg = moving_avg(kernel_size, stride=1) + + def forward(self, x): + moving_mean = self.moving_avg(x) + res = x - moving_mean + return res, moving_mean + + +class DLinear(nn.Module): + """ + Paper: Are Transformers Effective for Time Series Forecasting? + Link: https://arxiv.org/abs/2205.13504 + Official Code: https://github.com/cure-lab/DLinear + Venue: AAAI 2023 + Task: Long-term Time Series Forecasting + """ + def __init__(self, **model_args): + super(DLinear, self).__init__() + self.seq_len = model_args["seq_len"] + self.pred_len = model_args["pred_len"] + + # Decompsition Kernel Size + kernel_size = 25 + self.decompsition = series_decomp(kernel_size) + self.individual = model_args["individual"] + self.channels = model_args["enc_in"] + + if self.individual: + self.Linear_Seasonal = nn.ModuleList() + self.Linear_Trend = nn.ModuleList() + + for i in range(self.channels): + self.Linear_Seasonal.append( + nn.Linear(self.seq_len, self.pred_len)) + self.Linear_Trend.append( + nn.Linear(self.seq_len, self.pred_len)) + + else: + self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len) + self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Feed forward of DLinear. + + Args: + x (torch.Tensor): history data with shape [B, L, N] + + Returns: + torch.Tensor: prediction with shape [B, L, N, C] + """ + + seasonal_init, trend_init = self.decompsition(x) + seasonal_init, trend_init = seasonal_init.permute( + 0, 2, 1), trend_init.permute(0, 2, 1) + if self.individual: + seasonal_output = torch.zeros([seasonal_init.size(0), seasonal_init.size( + 1), self.pred_len], dtype=seasonal_init.dtype).to(seasonal_init.device) + trend_output = torch.zeros([trend_init.size(0), trend_init.size( + 1), self.pred_len], dtype=trend_init.dtype).to(trend_init.device) + for i in range(self.channels): + seasonal_output[:, i, :] = self.Linear_Seasonal[i]( + seasonal_init[:, i, :]) + trend_output[:, i, :] = self.Linear_Trend[i]( + trend_init[:, i, :]) + else: + seasonal_output = self.Linear_Seasonal(seasonal_init) + trend_output = self.Linear_Trend(trend_init) + + prediction = seasonal_output + trend_output + return prediction.permute(0, 2, 1) # [B, L, N] diff --git a/baselines/SAN/arch/san_arch.py b/baselines/SAN/arch/san_arch.py new file mode 100644 index 00000000..5b49f01d --- /dev/null +++ b/baselines/SAN/arch/san_arch.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn +from .dlinear import DLinear +from argparse import Namespace +import pdb + +class MLP(nn.Module): + def __init__(self, configs, mode): + super(MLP, self).__init__() + configs = Namespace(**configs) + self.seq_len = configs.seq_len // configs.period_len + self.pred_len = int(configs.pred_len / configs.period_len) + self.channels = configs.enc_in + self.period_len = configs.period_len + self.mode = mode + if mode == 'std': + self.final_activation = nn.ReLU() + else: + self.final_activation = nn.Identity() + self.input = nn.Linear(self.seq_len, 512) + self.input_raw = nn.Linear(self.seq_len * self.period_len, 512) + self.activation = nn.ReLU() if mode == 'std' else nn.Tanh() + self.output = nn.Linear(1024, self.pred_len) + + def forward(self, x, x_raw): + x, x_raw = x.permute(0, 2, 1), x_raw.permute(0, 2, 1) + x = self.input(x) + x_raw = self.input_raw(x_raw) + x = torch.cat([x, x_raw], dim=-1) + x = self.output(self.activation(x)) + x = self.final_activation(x) + return x.permute(0, 2, 1) + + +class SAN(nn.Module): + """ + Paper: Adaptive Normalization for Non-stationary Time Series Forecasting: A Temporal Slice Perspective + Link: https://openreview.net/forum?id=5BqDSw8r5j + Official Code: https://github.com/icantnamemyself/SAN + Venue: NIPS 2023 + Task: Long-term Time Series Forecasting + """ + def __init__(self, **model_args): + super(SAN, self).__init__() + self.seq_len = model_args["seq_len"] + self.pred_len = model_args["pred_len"] + self.period_len = model_args["period_len"] + self.station_pretrain_epoch = model_args["station_pretrain_epoch"] + + self.channels = model_args["enc_in"] + self.seq_len_new = int(self.seq_len / self.period_len) + self.pred_len_new = int(self.pred_len / self.period_len) + self.epsilon = 1e-5 + self.weight = nn.Parameter(torch.ones(2, self.channels)) + + self.backbone = DLinear(**model_args) + self.model = MLP(model_args, mode='mean') + self.model_std = MLP(model_args, mode='std') + + def normalize(self, x): + bs, length, dim = x.shape # (B, L, N) + x = x.reshape(bs, -1, self.period_len, dim) + mean = torch.mean(x, dim=-2, keepdim=True) + std = torch.std(x, dim=-2, keepdim=True) + norm_x = (x - mean) / (std + self.epsilon) + x = x.reshape(bs, length, dim) + mean_all = torch.mean(x, dim=1, keepdim=True) + + outputs_mean = self.model(mean.squeeze(2) - mean_all, x - mean_all) * self.weight[0] + mean_all * self.weight[1] + outputs_std = self.model_std(std.squeeze(2), x) + outputs = torch.cat([outputs_mean, outputs_std], dim=-1) + return norm_x.reshape(bs, length, dim), outputs[:, -self.pred_len_new:, :] + + def de_normalize(self, y, station_pred): + bs, length, dim = y.shape + y = y.reshape(bs, -1, self.period_len, dim) + mean = station_pred[:, :, :self.channels].unsqueeze(2) + std = station_pred[:, :, self.channels:].unsqueeze(2) + output = y * (std + self.epsilon) + mean + return output.reshape(bs, length, dim) + + def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_seen: int, epoch: int, train: bool, **kwargs) -> torch.Tensor: + """Feed forward of DLinear. + + Args: + history_data (torch.Tensor): history data with shape [B, L, N, C] + + Returns: + torch.Tensor: prediction with shape [B, L, N, C] + """ + + assert history_data.shape[-1] == 1 # only use the target feature + target = future_data[..., 0] + x = history_data[..., 0] # B, L, N + x, statistics_pred = self.normalize(x) + y = self.backbone(x) + y = self.de_normalize(y, statistics_pred) + + return {"prediction": y.unsqueeze(-1), + "statistics_pred":statistics_pred, + "period_len":self.period_len, "epoch": epoch, + "station_pretrain_epoch": self.station_pretrain_epoch, + "train": train} + diff --git a/baselines/SAN/loss/__init__.py b/baselines/SAN/loss/__init__.py new file mode 100644 index 00000000..c367817b --- /dev/null +++ b/baselines/SAN/loss/__init__.py @@ -0,0 +1 @@ +from .loss import san_loss \ No newline at end of file diff --git a/baselines/SAN/loss/loss.py b/baselines/SAN/loss/loss.py new file mode 100644 index 00000000..6d9bd9d2 --- /dev/null +++ b/baselines/SAN/loss/loss.py @@ -0,0 +1,25 @@ +import torch +import numpy as np +from basicts.metrics import masked_mse +import pdb + + +def station_loss(y, statistics_pred, period_len): + bs, length, dim = y.shape + y = y.reshape(bs, -1, period_len, dim) + mean = torch.mean(y, dim=2) + std = torch.std(y, dim=2) + station_ture = torch.cat([mean, std], dim=-1) + loss = masked_mse(statistics_pred, station_ture) + return loss + +def san_loss(prediction, target, statistics_pred, period_len, epoch, station_pretrain_epoch, train): + if train: + if epoch + 1 <= station_pretrain_epoch: + return station_loss(target.squeeze(-1), statistics_pred, period_len) + + else: + return masked_mse(prediction, target) + else: + return masked_mse(prediction, target) + From 9d54307c7214e7ad35f6ca749c9d64aa1dd50f5d Mon Sep 17 00:00:00 2001 From: lx Date: Sun, 15 Jun 2025 16:45:35 +0800 Subject: [PATCH 2/2] 'update-san-lr_sch' --- basicts/runners/optim/lr_schedulers.py | 49 +++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/basicts/runners/optim/lr_schedulers.py b/basicts/runners/optim/lr_schedulers.py index 6bafe92c..abed1fd6 100644 --- a/basicts/runners/optim/lr_schedulers.py +++ b/basicts/runners/optim/lr_schedulers.py @@ -6,7 +6,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR -__all__ = ['CosineWarmup', 'CosineWarmupRestarts'] +__all__ = ['CosineWarmup', 'CosineWarmupRestarts', 'SANWarmupMultiStepLR'] class CosineWarmup(LambdaLR): @@ -92,3 +92,50 @@ def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda( if progress >= 1.0: return 0.0 return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) + + +class SANWarmupMultiStepLR(LambdaLR): + """ + A learning rate scheduler that uses a fixed learning rate during a warmup phase, + then switches to the original learning rate and follows a MultiStepLR schedule. + + Args: + optimizer (`torch.optim.Optimizer`): + The optimizer for which to schedule the learning rate. + warmup_epochs (`int`): + The number of epochs for the warmup phase. + warmup_lr (`float`): + The fixed learning rate during the warmup phase. + milestones (`list[int]`): + List of epoch indices for MultiStepLR. Must be increasing. + gamma (`float`, optional, default=0.1): + Multiplicative factor of learning rate decay. + last_epoch (`int`, optional, default=-1): + The index of the last epoch when resuming training. + """ + def __init__(self, optimizer: Optimizer, warmup_epochs: int, warmup_lr: float, milestones: list, gamma: float = 0.1, last_epoch: int = -1): + self.milestones = milestones + self.gamma = gamma + base_lr = optimizer.defaults['lr'] # 原始学习率 + # lr_lambda = lambda epoch: self._get_lr_lambda(epoch, warmup_epochs, warmup_lr, base_lr, milestones, gamma) + lr_lambda = partial( + self._get_lr_lambda, + warmup_epochs=warmup_epochs, + warmup_lr=warmup_lr, + base_lr=base_lr, + milestones=milestones, + gamma=gamma + ) + super().__init__(optimizer, lr_lambda, last_epoch) + + @staticmethod + def _get_lr_lambda(epoch: int, warmup_epochs: int, warmup_lr: float, base_lr: float, milestones: list, gamma: float): + if epoch +1 <= warmup_epochs: + # Warmup phase + return warmup_lr / base_lr # From SAN: * (0.5 ** ((epoch - 1) // 1)) + # MultiStepLR phase: decay learning rate at milestones + adjusted_lr = 1.0 + for milestone in milestones: + if epoch >= milestone: + adjusted_lr *= gamma + return adjusted_lr