[Example] Add preformer for precipitation nowcasting#976
[Example] Add preformer for precipitation nowcasting#976EricKing19 wants to merge 9 commits intoPaddlePaddle:developfrom
Conversation
|
Thanks for your contribution! |
HydrogenSulfate
left a comment
There was a problem hiding this comment.
感谢提交PR,有几处小问题麻烦看一下
|
|
||
| ``` sh | ||
| # 模型训练 | ||
| python examples/preformer/train.py |
There was a problem hiding this comment.
| python examples/preformer/train.py | |
| python train.py |
|
|
||
| ``` sh | ||
| # 模型评估 | ||
| python examples/preformer/train.py mode=eval |
There was a problem hiding this comment.
| python examples/preformer/train.py mode=eval | |
| python train.py mode=eval |
| # set random seed for reproducibility | ||
| ppsci.utils.misc.set_random_seed(cfg.seed) | ||
| # initialize logger | ||
| logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info") | ||
|
|
| "num_replicas": NUM_GPUS_PER_NODE, | ||
| "rank": dist.get_rank() % NUM_GPUS_PER_NODE, |
There was a problem hiding this comment.
这两个参数应该不需要,并且paddlescience也没有对应的处理逻辑,默认会根据环境中设置的卡数自动设置
| mon = str("0") + mon | ||
| day = str(self.time_table[idxs].timetuple().tm_mday) | ||
| if len(day) == 1: | ||
| day = str("0") + day |
There was a problem hiding this comment.
str("0")是否可以直接写成"0"?,下同
| r_data = np.load( | ||
| os.path.join(self.file_path, year, "r_" + year + mon + day + hour + ".npy") | ||
| ) | ||
| t_data = np.load( | ||
| os.path.join(self.file_path, year, "t_" + year + mon + day + hour + ".npy") | ||
| ) | ||
| u_data = np.load( | ||
| os.path.join(self.file_path, year, "u_" + year + mon + day + hour + ".npy") | ||
| ) | ||
| v_data = np.load( | ||
| os.path.join(self.file_path, year, "v_" + year + mon + day + hour + ".npy") | ||
| ) |
There was a problem hiding this comment.
可以直接使用f-string化简字符串拼接的写法
| hydra: | ||
| run: | ||
| # dynamic output directory according to running time and override name | ||
| dir: outputs_preformer | ||
| job: | ||
| name: ${mode} # name of logfile | ||
| chdir: false # keep current working directory unchanged | ||
| config: | ||
| override_dirname: | ||
| exclude_keys: | ||
| - TRAIN.checkpoint_path | ||
| - TRAIN.trained_model_path | ||
| - EVAL.trained_model_path | ||
| - mode | ||
| - output_dir | ||
| - log_freq | ||
| sweep: | ||
| # output directory for multirun | ||
| dir: ${hydra.run.dir} | ||
| subdir: ./ | ||
|
|
There was a problem hiding this comment.
| hydra: | |
| run: | |
| # dynamic output directory according to running time and override name | |
| dir: outputs_preformer | |
| job: | |
| name: ${mode} # name of logfile | |
| chdir: false # keep current working directory unchanged | |
| config: | |
| override_dirname: | |
| exclude_keys: | |
| - TRAIN.checkpoint_path | |
| - TRAIN.trained_model_path | |
| - EVAL.trained_model_path | |
| - mode | |
| - output_dir | |
| - log_freq | |
| sweep: | |
| # output directory for multirun | |
| dir: ${hydra.run.dir} | |
| subdir: ./ | |
| defaults: | |
| - ppsci_default | |
| - TRAIN: train_default | |
| - TRAIN/ema: ema_default | |
| - TRAIN/swa: swa_default | |
| - EVAL: eval_default | |
| - INFER: infer_default | |
| - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default | |
| - _self_ | |
| hydra: | |
| run: | |
| # dynamic output directory according to running time and override name | |
| dir: outputs_preformer | |
| job: | |
| name: ${mode} # name of logfile | |
| chdir: false # keep current working directory unchanged | |
| sweep: | |
| # output directory for multirun | |
| dir: ${hydra.run.dir} | |
| subdir: ./ | |
|
|
||
| # model settings | ||
| MODEL: | ||
| afno: |
| afno: | ||
| input_keys: ["input"] | ||
| output_keys: ["output"] | ||
| shape_in: [6, 12, IMG_H, IMG_W] |
There was a problem hiding this comment.
| shape_in: [6, 12, IMG_H, IMG_W] | |
| shape_in: | |
| - 6 | |
| - 12 | |
| - ${IMG_H} | |
| - ${IMG_W} | |
|
@EricKing19 标题已经修改过了,原先的merge code of upstream不太合适 |
| 案例中使用了预处理的 PEMSD4 和 PEMSD8 数据集。PEMSD4 为旧金山湾区交通数据,选取 29 条道路上 307 个传感器记录的交通数据,时间为 2018 年 1 月至 2 月。PEMSD8 为圣贝纳迪诺 8 条道路上 170 个检测器收集的交通数据,时间为 2016 年 7 月至 8 月。 | ||
|
|
||
| 两个数据集均被保存为 N x T x 1 的矩阵,记录了相应交通节点与时间的流量数据,其中 N 为交通节点数量,T 为时间序列长度。两个数据集分别按照 7:2:1 划分为训练集、验证集,和测试集。案例中预先计算了流量数据的均值与标准差,用于后续的正则化操作。 |
There was a problem hiding this comment.
该案例是关于降水的,这个数据集好像是交通的,数据集与代码不一致
| 开始训练、评估前,请下载数据集文件 | ||
|
|
||
| 开始评估前,请下载或训练生成预训练模型 | ||
|
|
There was a problem hiding this comment.
可以稍微介绍一下数据集的准备过程吗?比如如何下载和解压后的文件组织形式?
| === "模型训练命令" | ||
|
|
||
| ``` sh | ||
| # 模型训练 |
There was a problem hiding this comment.
删除这个注释,上面这个标签已经说明了这是模型训练命令了
| === "模型评估命令" | ||
|
|
||
| ``` sh | ||
| # 模型评估 |
|
|
||
| ``` sh | ||
| # 模型评估 | ||
| python train.py mode=eval |
There was a problem hiding this comment.
这里麻烦提供一下您训练好的预训练模型文件(.pdparams文件即可),我们上传到bce上,这样就能通过在命令里直接指定预训练模型url直接下载并在评估前自动加载权重,不需要额外的手动下载了
| #### 3.2.6 模型导出 | ||
|
|
||
| 通过设置 `ppsci.solver.Solver` 中的 `eval_during_train` 和 `eval_freq` 参数,可以自动保存在验证集上效果最优的模型参数。 | ||
|
|
||
| ``` py linenums="100" title="examples/preformer/train.py" | ||
| --8<-- | ||
| examples/preformer/train.py:158:158 | ||
| --8<-- | ||
| ``` | ||
|
|
There was a problem hiding this comment.
- 模型导出章节可以不用出现在文章中,删除
- 请补充模型导出的函数
def export和def inference到examples\preformer\main.py中,参考:PaddleScience/examples/allen_cahn/allen_cahn_piratenet.py
Lines 235 to 269 in 83f6739
- 模型导出和模型推理执行命令请添加到文档开头处的"=== "模型评估命令""后面
| return latent | ||
|
|
||
|
|
||
| class Mid_Xnet(nn.Layer): |
There was a problem hiding this comment.
Mid_Xnet建议改为MidXNet,命名更规范
| def forward(self, hid, enc1=None): | ||
| for i in range(0, len(self.dec)): | ||
| hid = self.dec[i](hid) | ||
| # Y = self.dec[-1](torch.cat([hid, enc1], dim=1)) |
| for m in range(self.sq_length): | ||
| x.append(self.load_data(global_idx + m)) | ||
| for n in range(self.sq_length): | ||
| # y.append(self.load_data(global_idx+n)) |
| # y.append(self.load_data(global_idx+n)) | ||
| y.append(self.precipitation["tp"][global_idx + self.sq_length + n]) | ||
| # x = self.Normalize(x) | ||
| x, y = self.RandomCrop(x, y) |
There was a problem hiding this comment.
self.RandomCrop是否应该是self._random_crop?
| def _random_crop(self, x, y): | ||
| if isinstance(self.size, numbers.Number): | ||
| self.size = (int(self.size), int(self.size)) | ||
| th, tw = self.size | ||
| h, w = y[0].shape[-2], y[0].shape[-1] | ||
| x1 = random.randint(0, w - tw) | ||
| y1 = random.randint(0, h - th) | ||
|
|
||
| for i in range(len(x)): | ||
| x[i] = self.crop(x[i], y1, x1, y1 + th, x1 + tw) | ||
| for i in range(len(y)): | ||
| y[i] = self.crop(y[i], y1, x1, y1 + th, x1 + tw) | ||
|
|
||
| return x, y | ||
|
|
||
| def crop(self, im, x_start, y_start, x_end, y_end): | ||
| if len(im.shape) == 3: | ||
| return im[:, x_start:x_end, y_start:y_end] | ||
| else: | ||
| return im[x_start:x_end, y_start:y_end] |
There was a problem hiding this comment.
非公开方法前面建议加上下划线:
| def _random_crop(self, x, y): | |
| if isinstance(self.size, numbers.Number): | |
| self.size = (int(self.size), int(self.size)) | |
| th, tw = self.size | |
| h, w = y[0].shape[-2], y[0].shape[-1] | |
| x1 = random.randint(0, w - tw) | |
| y1 = random.randint(0, h - th) | |
| for i in range(len(x)): | |
| x[i] = self.crop(x[i], y1, x1, y1 + th, x1 + tw) | |
| for i in range(len(y)): | |
| y[i] = self.crop(y[i], y1, x1, y1 + th, x1 + tw) | |
| return x, y | |
| def crop(self, im, x_start, y_start, x_end, y_end): | |
| if len(im.shape) == 3: | |
| return im[:, x_start:x_end, y_start:y_end] | |
| else: | |
| return im[x_start:x_end, y_start:y_end] | |
| def _random_crop(self, x, y): | |
| if isinstance(self.size, numbers.Number): | |
| self.size = (int(self.size), int(self.size)) | |
| th, tw = self.size | |
| h, w = y[0].shape[-2], y[0].shape[-1] | |
| x1 = random.randint(0, w - tw) | |
| y1 = random.randint(0, h - th) | |
| for i in range(len(x)): | |
| x[i] = self._crop(x[i], y1, x1, y1 + th, x1 + tw) | |
| for i in range(len(y)): | |
| y[i] = self._crop(y[i], y1, x1, y1 + th, x1 + tw) | |
| return x, y | |
| def _crop(self, im, x_start, y_start, x_end, y_end): | |
| if len(im.shape) == 3: | |
| return im[:, x_start:x_end, y_start:y_end] | |
| else: | |
| return im[x_start:x_end, y_start:y_end] |
|
@EricKing19 顺带解决一下冲突问题 |

PR types
Others
PR changes
Others
Describe
add Preformer model for precipitation nowcasting
add docs for Preformer
add examples for Preformer