Skip to content

Commit 45b68c6

Browse files
committed
Merge remote-tracking branch 'origin/SAN'
2 parents 3fab5d3 + d7abdfc commit 45b68c6

9 files changed

Lines changed: 694 additions & 0 deletions

File tree

baselines/SAN/ETTm2.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import os
2+
import sys
3+
from easydict import EasyDict
4+
sys.path.append(os.path.abspath(__file__ + '/../../..'))
5+
6+
from basicts.metrics import masked_mae, masked_mse
7+
from basicts.data import TimeSeriesForecastingDataset
8+
from basicts.runners import SimpleTimeSeriesForecastingRunner
9+
from basicts.scaler import ZScoreScaler
10+
from basicts.utils import get_regular_settings
11+
12+
from .arch import SAN
13+
from .loss import san_loss
14+
15+
############################## Hot Parameters ##############################
16+
# Dataset & Metrics configuration
17+
DATA_NAME = 'ETTm2' # Dataset name
18+
regular_settings = get_regular_settings(DATA_NAME)
19+
INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence
20+
OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence
21+
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
22+
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data
23+
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data
24+
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
25+
# Model architecture and parameters
26+
MODEL_ARCH = SAN
27+
MODEL_PARAM = {
28+
"seq_len": INPUT_LEN,
29+
"pred_len": OUTPUT_LEN,
30+
"individual": False,
31+
"enc_in": 7,
32+
"period_len": 24,
33+
"station_pretrain_epoch": 5,
34+
}
35+
NUM_EPOCHS = 30
36+
37+
############################## General Configuration ##############################
38+
CFG = EasyDict()
39+
# General settings
40+
CFG.DESCRIPTION = 'An Example Config'
41+
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode)
42+
# Runner
43+
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
44+
45+
############################## Dataset Configuration ##############################
46+
CFG.DATASET = EasyDict()
47+
# Dataset settings
48+
CFG.DATASET.NAME = DATA_NAME
49+
CFG.DATASET.TYPE = TimeSeriesForecastingDataset
50+
CFG.DATASET.PARAM = EasyDict({
51+
'dataset_name': DATA_NAME,
52+
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,
53+
'input_len': INPUT_LEN,
54+
'output_len': OUTPUT_LEN,
55+
# 'mode' is automatically set by the runner
56+
})
57+
58+
############################## Scaler Configuration ##############################
59+
CFG.SCALER = EasyDict()
60+
# Scaler settings
61+
CFG.SCALER.TYPE = ZScoreScaler # Scaler class
62+
CFG.SCALER.PARAM = EasyDict({
63+
'dataset_name': DATA_NAME,
64+
'train_ratio': TRAIN_VAL_TEST_RATIO[0],
65+
'norm_each_channel': NORM_EACH_CHANNEL,
66+
'rescale': RESCALE,
67+
})
68+
69+
############################## Model Configuration ##############################
70+
CFG.MODEL = EasyDict()
71+
# Model settings
72+
CFG.MODEL.NAME = MODEL_ARCH.__name__
73+
CFG.MODEL.ARCH = MODEL_ARCH
74+
CFG.MODEL.PARAM = MODEL_PARAM
75+
CFG.MODEL.FORWARD_FEATURES = [0]
76+
CFG.MODEL.TARGET_FEATURES = [0]
77+
78+
############################## Metrics Configuration ##############################
79+
80+
CFG.METRICS = EasyDict()
81+
# Metrics settings
82+
CFG.METRICS.FUNCS = EasyDict({
83+
'MAE': masked_mae,
84+
'MSE': masked_mse
85+
86+
})
87+
CFG.METRICS.TARGET = 'MSE'
88+
CFG.METRICS.NULL_VAL = NULL_VAL
89+
90+
############################## Training Configuration ##############################
91+
CFG.TRAIN = EasyDict()
92+
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS
93+
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
94+
'checkpoints',
95+
MODEL_ARCH.__name__,
96+
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)])
97+
)
98+
CFG.TRAIN.LOSS = san_loss
99+
# Optimizer settings
100+
CFG.TRAIN.OPTIM = EasyDict()
101+
CFG.TRAIN.OPTIM.TYPE = "Adam"
102+
CFG.TRAIN.OPTIM.PARAM = {
103+
"lr": 0.001,
104+
"weight_decay": 0.0001,
105+
}
106+
107+
CFG.TRAIN.LR_SCHEDULER = EasyDict()
108+
CFG.TRAIN.LR_SCHEDULER.TYPE = "SANWarmupMultiStepLR"
109+
CFG.TRAIN.LR_SCHEDULER.PARAM = {
110+
"warmup_lr": 0.0001,
111+
"warmup_epochs": 5,
112+
"milestones": [10, 25],
113+
}
114+
115+
CFG.TRAIN.CLIP_GRAD_PARAM = {
116+
'max_norm': 5.0
117+
}
118+
# Train data loader settings
119+
CFG.TRAIN.DATA = EasyDict()
120+
CFG.TRAIN.DATA.BATCH_SIZE = 64
121+
CFG.TRAIN.DATA.SHUFFLE = True
122+
123+
############################## Validation Configuration ##############################
124+
CFG.VAL = EasyDict()
125+
CFG.VAL.INTERVAL = 1
126+
CFG.VAL.DATA = EasyDict()
127+
CFG.VAL.DATA.BATCH_SIZE = 64
128+
129+
############################## Test Configuration ##############################
130+
CFG.TEST = EasyDict()
131+
CFG.TEST.INTERVAL = 1
132+
CFG.TEST.DATA = EasyDict()
133+
CFG.TEST.DATA.BATCH_SIZE = 64
134+
135+
############################## Evaluation Configuration ##############################
136+
137+
CFG.EVAL = EasyDict()
138+
139+
# Evaluation parameters
140+
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True

baselines/SAN/Electricity.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import os
2+
import sys
3+
from easydict import EasyDict
4+
sys.path.append(os.path.abspath(__file__ + '/../../..'))
5+
6+
from basicts.metrics import masked_mae, masked_mse
7+
from basicts.data import TimeSeriesForecastingDataset
8+
from basicts.runners import SimpleTimeSeriesForecastingRunner
9+
from basicts.scaler import ZScoreScaler
10+
from basicts.utils import get_regular_settings
11+
12+
from .arch import SAN
13+
from .loss import san_loss
14+
15+
############################## Hot Parameters ##############################
16+
# Dataset & Metrics configuration
17+
DATA_NAME = 'Electricity' # Dataset name
18+
regular_settings = get_regular_settings(DATA_NAME)
19+
INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence
20+
OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence
21+
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
22+
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data
23+
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data
24+
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
25+
# Model architecture and parameters
26+
MODEL_ARCH = SAN
27+
MODEL_PARAM = {
28+
"seq_len": INPUT_LEN,
29+
"pred_len": OUTPUT_LEN,
30+
"individual": False,
31+
"enc_in": 321,
32+
"period_len": 24,
33+
"station_pretrain_epoch": 5,
34+
}
35+
NUM_EPOCHS = 30
36+
37+
############################## General Configuration ##############################
38+
CFG = EasyDict()
39+
# General settings
40+
CFG.DESCRIPTION = 'An Example Config'
41+
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode)
42+
# Runner
43+
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
44+
45+
############################## Dataset Configuration ##############################
46+
CFG.DATASET = EasyDict()
47+
# Dataset settings
48+
CFG.DATASET.NAME = DATA_NAME
49+
CFG.DATASET.TYPE = TimeSeriesForecastingDataset
50+
CFG.DATASET.PARAM = EasyDict({
51+
'dataset_name': DATA_NAME,
52+
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,
53+
'input_len': INPUT_LEN,
54+
'output_len': OUTPUT_LEN,
55+
# 'mode' is automatically set by the runner
56+
})
57+
58+
############################## Scaler Configuration ##############################
59+
CFG.SCALER = EasyDict()
60+
# Scaler settings
61+
CFG.SCALER.TYPE = ZScoreScaler # Scaler class
62+
CFG.SCALER.PARAM = EasyDict({
63+
'dataset_name': DATA_NAME,
64+
'train_ratio': TRAIN_VAL_TEST_RATIO[0],
65+
'norm_each_channel': NORM_EACH_CHANNEL,
66+
'rescale': RESCALE,
67+
})
68+
69+
############################## Model Configuration ##############################
70+
CFG.MODEL = EasyDict()
71+
# Model settings
72+
CFG.MODEL.NAME = MODEL_ARCH.__name__
73+
CFG.MODEL.ARCH = MODEL_ARCH
74+
CFG.MODEL.PARAM = MODEL_PARAM
75+
CFG.MODEL.FORWARD_FEATURES = [0]
76+
CFG.MODEL.TARGET_FEATURES = [0]
77+
78+
############################## Metrics Configuration ##############################
79+
80+
CFG.METRICS = EasyDict()
81+
# Metrics settings
82+
CFG.METRICS.FUNCS = EasyDict({
83+
'MAE': masked_mae,
84+
'MSE': masked_mse
85+
})
86+
CFG.METRICS.TARGET = 'MSE'
87+
CFG.METRICS.NULL_VAL = NULL_VAL
88+
89+
############################## Training Configuration ##############################
90+
CFG.TRAIN = EasyDict()
91+
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS
92+
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
93+
'checkpoints',
94+
MODEL_ARCH.__name__,
95+
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)])
96+
)
97+
CFG.TRAIN.LOSS = san_loss
98+
# Optimizer settings
99+
CFG.TRAIN.OPTIM = EasyDict()
100+
CFG.TRAIN.OPTIM.TYPE = "Adam"
101+
CFG.TRAIN.OPTIM.PARAM = {
102+
"lr": 0.001,
103+
"weight_decay": 0.0001,
104+
}
105+
106+
CFG.TRAIN.LR_SCHEDULER = EasyDict()
107+
CFG.TRAIN.LR_SCHEDULER.TYPE = "SANWarmupMultiStepLR"
108+
CFG.TRAIN.LR_SCHEDULER.PARAM = {
109+
"warmup_lr": 0.0001,
110+
"warmup_epochs": 5,
111+
"milestones": [10, 25],
112+
}
113+
114+
CFG.TRAIN.CLIP_GRAD_PARAM = {
115+
'max_norm': 5.0
116+
}
117+
# Train data loader settings
118+
CFG.TRAIN.DATA = EasyDict()
119+
CFG.TRAIN.DATA.BATCH_SIZE = 64
120+
CFG.TRAIN.DATA.SHUFFLE = True
121+
122+
############################## Validation Configuration ##############################
123+
CFG.VAL = EasyDict()
124+
CFG.VAL.INTERVAL = 1
125+
CFG.VAL.DATA = EasyDict()
126+
CFG.VAL.DATA.BATCH_SIZE = 64
127+
128+
############################## Test Configuration ##############################
129+
CFG.TEST = EasyDict()
130+
CFG.TEST.INTERVAL = 1
131+
CFG.TEST.DATA = EasyDict()
132+
CFG.TEST.DATA.BATCH_SIZE = 64
133+
134+
############################## Evaluation Configuration ##############################
135+
136+
CFG.EVAL = EasyDict()
137+
138+
# Evaluation parameters
139+
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True

0 commit comments

Comments
 (0)