diff --git a/docs/zh/examples/tempoGAN.md b/docs/zh/examples/tempoGAN.md index 15c3c1df26..d7628ac8e3 100644 --- a/docs/zh/examples/tempoGAN.md +++ b/docs/zh/examples/tempoGAN.md @@ -110,25 +110,25 @@ examples/tempoGAN/tempoGAN.py:57:76 Generator 的输入为低密度流体数据的插值,而数据集中保存的为原始的低密度流体数据,因此需要进行一个插值的 transform。 -``` py linenums="269" +``` py linenums="270" --8<-- -examples/tempoGAN/functions.py:269:274 +examples/tempoGAN/functions.py:270:275 --8<-- ``` Discriminator 和 Discriminator_tempo 对输入的 transform 更为复杂,分别为: -``` py linenums="359" +``` py linenums="360" --8<-- -examples/tempoGAN/functions.py:359:393 +examples/tempoGAN/functions.py:360:394 --8<-- ``` 其中: -``` py linenums="368" +``` py linenums="369" --8<-- -examples/tempoGAN/functions.py:368:368 +examples/tempoGAN/functions.py:369:369 --8<-- ``` @@ -239,9 +239,9 @@ examples/tempoGAN/tempoGAN.py:205:244 因为 GAN 网络训练的特性,本问题不使用 PaddleScience 中内置的可视化器,而是自定义了一个用于实现推理的函数,该函数读取验证集数据,得到推理结果并将结果以图片形式保存下来,在训练过程中按照一定间隔调用该函数即可在训练过程中监控训练效果。 -``` py linenums="153" +``` py linenums="154" --8<-- -examples/tempoGAN/functions.py:153:229 +examples/tempoGAN/functions.py:154:230 --8<-- ``` @@ -253,9 +253,9 @@ examples/tempoGAN/functions.py:153:229 Generator 的 loss 提供了 l1 loss、l2 loss、输出经过 Discriminator 判断的 loss 和 输出经过 Discriminator_tempo 判断的 loss。这些 loss 是否存在根据权重参数控制,若某一项 loss 的权重参数为 0,则表示训练中不添加该 loss 项。 -``` py linenums="276" +``` py linenums="277" --8<-- -examples/tempoGAN/functions.py:276:345 +examples/tempoGAN/functions.py:277:346 --8<-- ``` @@ -263,9 +263,9 @@ examples/tempoGAN/functions.py:276:345 Discriminator 为判别器,它的作用是判断数据为真数据还是假数据,因此它的 loss 为 Generator 产生的数据应当判断为假而产生的 loss 和 目标值数据应当判断为真而产生的 loss。 -``` py linenums="395" +``` py linenums="396" --8<-- -examples/tempoGAN/functions.py:395:409 +examples/tempoGAN/functions.py:396:410 --8<-- ``` @@ -273,9 +273,9 @@ examples/tempoGAN/functions.py:395:409 Discriminator_tempo 的 loss 构成 与 Discriminator 相同,只是所需数据不同。 -``` py linenums="411" +``` py linenums="412" --8<-- -examples/tempoGAN/functions.py:411:427 +examples/tempoGAN/functions.py:412:428 --8<-- ``` @@ -283,9 +283,9 @@ examples/tempoGAN/functions.py:411:427 本问题提供了一种输入数据处理方法,将输入的流体密度数据随机裁剪一块,然后进行密度值判断,若裁剪下来的块密度值低于阈值则重新裁剪,直到密度满足条件或裁剪次数达到阈值。这样做主要是为了减少训练所需的显存,同时对裁剪下来的块密度值的判断保证了块中信息的丰富程度。[参数和超参数设定](#34)中 `tile_ratio` 表示原始尺寸是块的尺寸的几倍,即若`tile_ratio` 为 2,裁剪下来的块的大小为整张原始图片的四分之一。 -``` py linenums="430" +``` py linenums="431" --8<-- -examples/tempoGAN/functions.py:430:488 +examples/tempoGAN/functions.py:431:489 --8<-- ``` diff --git a/examples/fsi/viv.py b/examples/fsi/viv.py index 648a3a91ab..2d94dc6443 100644 --- a/examples/fsi/viv.py +++ b/examples/fsi/viv.py @@ -201,6 +201,7 @@ def __init__(self, model, func): self.func = func def forward(self, x): + x = {**x} model_out = self.model(x) func_out = self.func(x) return {**model_out, "f": func_out} diff --git a/examples/tempoGAN/functions.py b/examples/tempoGAN/functions.py index 6c578ce686..7e8d8e1830 100644 --- a/examples/tempoGAN/functions.py +++ b/examples/tempoGAN/functions.py @@ -64,11 +64,12 @@ def reshape_input(input_dict: Dict[str, paddle.Tensor]) -> Dict[str, paddle.Tens Returns: Dict[str, paddle.Tensor]: reshaped data dict. """ + out_dict = {} for key in input_dict: input = input_dict[key] N, C, H, W = input.shape - input_dict[key] = paddle.reshape(input, [N * C, 1, H, W]) - return input_dict + out_dict[key] = paddle.reshape(input, [N * C, 1, H, W]) + return out_dict def dereshape_input( diff --git a/ppsci/solver/solver.py b/ppsci/solver/solver.py index f2cb0a2872..0959976dd0 100644 --- a/ppsci/solver/solver.py +++ b/ppsci/solver/solver.py @@ -858,7 +858,7 @@ def predict( @misc.run_on_eval_mode def export( self, - input_spec: List["InputSpec"], + input_spec: List[Dict[str, InputSpec]], export_path: str, with_onnx: bool = False, skip_prune_program: bool = False, @@ -870,8 +870,8 @@ def export( Convert model to static graph model and export to files. Args: - input_spec (List[InputSpec]): InputSpec describes the signature information - of the model input. + input_spec (List[Dict[str, InputSpec]]): InputSpec describes the signature + information of the model input. export_path (str): The path prefix to save model. with_onnx (bool, optional): Whether to export model into onnx after paddle inference models are exported. Defaults to False.