[Example] Add STAFNet Model for Air Quality Prediction#1070
[Example] Add STAFNet Model for Air Quality Prediction#1070dylan-yin wants to merge 57 commits intoPaddlePaddle:developfrom
Conversation
HydrogenSulfate
left a comment
There was a problem hiding this comment.
整体项目请使用pre-commit格式化一边
| @@ -0,0 +1,136 @@ | |||
| hydra: | |||
There was a problem hiding this comment.
配置文件开头请加上以下字段:
PaddleScience/examples/ldc/conf/ldc_2d_Re3200_piratenet.yaml
Lines 1 to 9 in fad6927
| config: | ||
| override_dirname: | ||
| exclude_keys: | ||
| - TRAIN.checkpoint_path | ||
| - TRAIN.pretrained_model_path | ||
| - EVAL.pretrained_model_path | ||
| - mode | ||
| - output_dir | ||
| - log_freq |
| STAFNet_DATA_PATH: "/data6/home/yinhang2021/workspace/SATFNet/data/2020-2023_new/train_data.pkl" # | ||
| DATASET: | ||
| label_keys: ["label"] | ||
| data_dir: "/data6/home/yinhang2021/workspace/SATFNet/data/2020-2023_new/train_data.pkl" | ||
| STAFNet_DATA_args: { | ||
| "data_dir": "/data6/home/yinhang2021/workspace/SATFNet/data/2020-2023_new/train_data.pkl", | ||
| "batch_size": 1, | ||
| "shuffle": True, | ||
| "num_workers": 0, | ||
| "training": True | ||
| } | ||
|
|
||
|
|
||
|
|
||
| # "data_dir": "data/2020-2023_new/train_data.pkl", | ||
| # "batch_size": 32, | ||
| # "shuffle": True, | ||
| # "num_workers": 0, | ||
| # "training": True | ||
| # model settings | ||
| # MODEL: # | ||
|
|
There was a problem hiding this comment.
建议改为相对路径,以./data/...开头即可
| # "data_dir": "data/2020-2023_new/train_data.pkl", | ||
| # "batch_size": 32, | ||
| # "shuffle": True, | ||
| # "num_workers": 0, | ||
| # "training": True | ||
| # model settings | ||
| # MODEL: # |
| # configs: { | ||
| # "task_name": "forecast", | ||
| # "output_attention": False, | ||
| # "seq_len": 72, | ||
| # "label_len": 24, | ||
| # "pred_len": 48, | ||
|
|
||
| # "aq_gat_node_features" : 7, | ||
| # "aq_gat_node_num": 35, | ||
|
|
||
| # "mete_gat_node_features" : 7, | ||
| # "mete_gat_node_num": 18, | ||
|
|
||
| # "gat_hidden_dim": 32, | ||
| # "gat_edge_dim": 3, | ||
| # "gat_embed_dim": 32, | ||
|
|
||
| # "e_layers": 1, | ||
| # "enc_in": 7, | ||
| # "dec_in": 7, | ||
| # "c_out": 7, | ||
| # "d_model": 16 , | ||
| # "embed": "fixed", | ||
| # "freq": "t", | ||
| # "dropout": 0.05, | ||
| # "factor": 3, | ||
| # "n_heads": 4, | ||
|
|
||
| # "d_ff": 32 , | ||
| # "num_kernels": 6, | ||
| # "top_k": 4 | ||
| # } |
|
|
||
|
|
||
|
|
||
|
|
| # set random seed for reproducibility | ||
| ppsci.utils.misc.set_random_seed(42) | ||
| # set output directory | ||
| OUTPUT_DIR = "./output_example" | ||
| # initialize logger | ||
| logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") |
There was a problem hiding this comment.
这个可以删除,output_dir会由ppsci.utils.callbacks.InitCallback自动创建:
PaddleScience/ppsci/utils/callbacks.py
Lines 90 to 96 in fad6927
| from typing import Tuple | ||
|
|
||
| class Inception_Block_V1(paddle.nn.Layer): | ||
|
|
| output_dir: ${hydra:run.dir} | ||
| log_freq: 20 | ||
| # dataset setting | ||
| STAFNet_DATA_PATH: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl" # |
There was a problem hiding this comment.
- 这里的路径是否能改成相对路径?比如
./dataset/train_data.pkl,其余的路径字段也是,建议改为相对路径,并去掉用户名 - STAFNet_DATA_PATH是否应该放到DATASET字段下?
|
|
||
|
|
||
| MODEL: | ||
| input_keys: ["aq_train_data","mete_train_data",] |
There was a problem hiding this comment.
| input_keys: ["aq_train_data","mete_train_data",] | |
| input_keys: [aq_train_data, mete_train_data] |
|
|
||
| MODEL: | ||
| input_keys: ["aq_train_data","mete_train_data",] | ||
| output_keys: ["label"] |
There was a problem hiding this comment.
| output_keys: ["label"] | |
| output_keys: [label] |
| checkpoint_path: null | ||
|
|
||
| EVAL: | ||
| eval_data_path: "/data6/home/yinhang2021/dataset/chongqing_1921/val_data.pkl" |
There was a problem hiding this comment.
| eval_data_path: "/data6/home/yinhang2021/dataset/chongqing_1921/val_data.pkl" | |
| eval_data_path: ./dataset/val_data.pkl |
| STAFNet_DATA_PATH: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl" # | ||
| DATASET: | ||
| label_keys: ["label"] | ||
| data_dir: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl" |
There was a problem hiding this comment.
- data_dir为什么是具体文件路径而不是某个文件夹路径?
- 此处的路径是否跟STAFNet_DATA_PATH重复了?
| cfg.TRAIN.epochs, | ||
| ITERS_PER_EPOCH, | ||
| eval_during_train=cfg.TRAIN.eval_during_train, | ||
| seed=cfg.seed, |
There was a problem hiding this comment.
| seed=cfg.seed, |
| """ | ||
| Validate after training an epoch | ||
|
|
||
| :param epoch: Integer, current training epoch. | ||
| :return: A log that contains information about validation | ||
| """ |
There was a problem hiding this comment.
| """ | |
| Validate after training an epoch | |
| :param epoch: Integer, current training epoch. | |
| :return: A log that contains information about validation | |
| """ |
| "sampler": { | ||
| "name": "BatchSampler", | ||
| "drop_last": False, | ||
| "shuffle": True, | ||
| }, |
There was a problem hiding this comment.
| "sampler": { | |
| "name": "BatchSampler", | |
| "drop_last": False, | |
| "shuffle": True, | |
| }, |
| # set random seed for reproducibility | ||
| ppsci.utils.misc.set_random_seed(42) | ||
| # set output directory | ||
| OUTPUT_DIR = "./output_example" | ||
| # initialize logger | ||
| logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") |
There was a problem hiding this comment.
| # set random seed for reproducibility | |
| ppsci.utils.misc.set_random_seed(42) | |
| # set output directory | |
| OUTPUT_DIR = "./output_example" | |
| # initialize logger | |
| logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") |
| OUTPUT_DIR = "./output_example" | ||
| # initialize logger | ||
| logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") | ||
| multiprocessing.set_start_method("spawn") |
There was a problem hiding this comment.
这句代码是什么作用?paddle的多卡训练不需要这样吧?
There was a problem hiding this comment.
我这边如果不加 multiprocessing.set_start_method("spawn"),会出现cuda error(3)
|
|
||
| ```` | ||
| ``` sh | ||
| python stafnet.py TRAIN_DIR="Your train dataset path" eval_data_path="Your evaluate dataset path" |
There was a problem hiding this comment.
python stafnet.py DATASET.data_dir="Your train dataset path" EVAL.eval_data_path="Your evaluate dataset path"
|
|
||
| ```` | ||
| ``` sh | ||
| python stafnet.py mode=eval EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/models/stafnet/stafnet.pdparams" EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/datasets/stafnet/val_data.pkl" |
There was a problem hiding this comment.
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/stafnet/val_data.pkl -P ./dataset/
python stafnet.py mode=eval EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/models/stafnet/stafnet.pdparams"
|
|
||
| 针对空气质量预测提出了许多研究。早期的方法侧重于学习单个观测站观测数据的时间模式,而放弃了观测站之间的空间关系。最近,由于图神经网络(GNN)在处理非欧几里得图结构方面的有效性,越来越多的方法采用 GNN 来模拟空间依赖关系。这些方法将车站位置作为上下文特征,隐含地建立空间依赖关系模型,没有充分利用车站位置和车站之间关系所包含的宝贵空间信息。此外,现有的时空 GNN 缺乏在错位图中融合多个特征的能力。因此,大多数方法都需要额外的插值算法,以便在早期阶段将气象特征与 AQ 特征进行对齐和连接。这种方法消除了空气质量站和气象站之间的空间和结构信息,还可能引入噪声导致误差累积。此外,在空气质量预测中利用多周期性的问题仍未得到探索。 | ||
|
|
||
| 该案例研究时空图网络网络在空气质量预测方向上的应用。 |
| ``` | ||
| --8<-- | ||
| examples/stafnet/stafnet.py:11 | ||
| --8<-- | ||
| ``` |
There was a problem hiding this comment.
代码不显示,参考其他案例文档,在第一行添加py linenums="11" title="examples/stafnet/stafnet.py",下面的代码块均修改一下
|
|
||
| ### 3.7 评估器构建 | ||
|
|
||
| 在训练过程中通常会按一定轮数间隔,用验证集(测试集)评估当前模型的训练情况,因此使用 `ppsci.validate.SupervisedValidator` 构建评估器,构建过程与 [约束构建](https://github.com/PaddlePaddle/PaddleScience/blob/develop/docs/zh/examples/cfdgcn.md#34) 类似,只需把数据目录改为测试集的目录,并在配置文件中设置 `EVAL.batch_size=1` 即可。 |
There was a problem hiding this comment.
构建过程与 [3.6 约束构建](#36) 类似
|
|
||
| ## 5. 参考资料 | ||
|
|
||
| - [STAFNet: Spatiotemporal-Aware Fusion Network for Air Quality Prediction]([STAFNet: Spatiotemporal-Aware Fusion Network for Air Quality Prediction | SpringerLink](https://link.springer.com/chapter/10.1007/978-3-031-78186-5_22)) |
There was a problem hiding this comment.
[STAFNet: Spatiotemporal-Aware Fusion Network for Air Quality Prediction](https://link.springer.com/chapter/10.1007/978-3-031-78186-5_22)
| output_attention: True | ||
| seq_len: 72 | ||
| pred_len: 48 | ||
| aq_gat_node_features: 7 | ||
| aq_gat_node_num: 35 | ||
| mete_gat_node_features: 7 | ||
| mete_gat_node_num: 18 | ||
| gat_hidden_dim: 32 | ||
| gat_edge_dim: 3 | ||
| e_layers: 2 | ||
| enc_in: 7 | ||
| dec_in: 7 | ||
| c_out: 7 | ||
| d_model: 32 | ||
| embed: "fixed" | ||
| freq: "t" | ||
| dropout: 0.05 | ||
| factor: 3 | ||
| n_heads: 4 | ||
| d_ff: 64 | ||
| num_kernels: 6 | ||
| top_k: 4 |
There was a problem hiding this comment.
output_attention: true
seq_len: 72
pred_len: 48
aq_gat_node_features: 7
aq_gat_node_num: 35
mete_gat_node_features: 7
mete_gat_node_num: 18
gat_hidden_dim: 32
gat_edge_dim: 3
e_layers: 1
enc_in: 7
dec_in: 7
c_out: 7
d_model: 16
embed: fixed
freq: t
dropout: 0.05
factor: 3
n_heads: 4
d_ff: 32
num_kernels: 6
top_k: 4
| checkpoint_path: null | ||
|
|
||
| EVAL: | ||
| eval_data_path: ./dataset/new_val_data.pkl |
| from omegaconf import DictConfig | ||
| import hydra | ||
| import paddle | ||
| from ppsci.data.dataset.stafnet_dataset import gat_lstmcollate_fn |
There was a problem hiding this comment.
from ppsci.data.dataset.stafnet_dataset import gat_lstmcollate_fn去掉吧,并没有用到
| DATASET: | ||
| label_keys: [label] | ||
| data_dir: ./dataset/train_data.pkl | ||
|
|
| pretrained_model_path: null | ||
| compute_metric_by_batch: false | ||
| eval_with_no_grad: true | ||
| batch_size: 32 No newline at end of file |
| ``` sh | ||
| wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/stafnet/val_data.pkl -P ./dataset/ | ||
| python stafnet.py mode=eval EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/models/stafnet/stafnet.pdparams" | ||
| python stafnet.py mode=eval EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/models/stafnet/stafnet.pdparams" EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/datasets/stafnet/val_data.pkl" |
| --8<-- | ||
| examples/stafnet/stafnet.py:11 | ||
| --8<-- |
There was a problem hiding this comment.
插入的代码还需再检查一下,插入的代码和文字说明不太对应
|
|
||
| ### 3.7 评估器构建 | ||
|
|
||
| 在训练过程中通常会按一定轮数间隔,用验证集(测试集)评估当前模型的训练情况,因此使用 `ppsci.validate.SupervisedValidator` 构建评估器,构建过程与 [约束构建](https://github.com/PaddlePaddle/PaddleScience/blob/develop/docs/zh/examples/stafnet.md#36) 类似,只需把数据目录改为测试集的目录,并在配置文件中设置 `EVAL.batch_size=1` 即可。 |
|
|
||
| ```py linenums="10" title="examples/stafnet/stafnet.py" | ||
| --8<-- | ||
| examples/stafnet/stafnet.py:10 |
There was a problem hiding this comment.
examples/stafnet/stafnet.py:10:10
|
|
||
| ## 4. 完整代码 | ||
|
|
||
| ```python py linenums="1" title="examples/stafnet/stafnet.py" |
|
|
||
| ``` py linenums="62" title="examples/stafnet/stafnet.py" | ||
| --8<-- | ||
| examples/stafnet/stafnet.py:62 |
There was a problem hiding this comment.
examples/stafnet/stafnet.py:62:62
|
|
||
| ``` py linenums="55" title="examples/stafnet/stafnet.py" | ||
| --8<-- | ||
| examples/stafnet/stafnet.py:55 |
There was a problem hiding this comment.
examples/stafnet/stafnet.py:55:55
|
|
||
| import numpy as np | ||
| import paddle | ||
| from pgl.nn.conv import GATv2Conv |
There was a problem hiding this comment.
还麻烦在docs/zh/api/data/dataset.md加上这个类




PR types
PR changes
Describe