Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions docs/zh/examples/tempoGAN.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,25 +110,25 @@ examples/tempoGAN/tempoGAN.py:57:76

Generator 的输入为低密度流体数据的插值,而数据集中保存的为原始的低密度流体数据,因此需要进行一个插值的 transform。

``` py linenums="269"
``` py linenums="270"
--8<--
examples/tempoGAN/functions.py:269:274
examples/tempoGAN/functions.py:270:275
--8<--
```

Discriminator 和 Discriminator_tempo 对输入的 transform 更为复杂,分别为:

``` py linenums="359"
``` py linenums="360"
--8<--
examples/tempoGAN/functions.py:359:393
examples/tempoGAN/functions.py:360:394
--8<--
```

其中:

``` py linenums="368"
``` py linenums="369"
--8<--
examples/tempoGAN/functions.py:368:368
examples/tempoGAN/functions.py:369:369
--8<--
```

Expand Down Expand Up @@ -239,9 +239,9 @@ examples/tempoGAN/tempoGAN.py:205:244

因为 GAN 网络训练的特性,本问题不使用 PaddleScience 中内置的可视化器,而是自定义了一个用于实现推理的函数,该函数读取验证集数据,得到推理结果并将结果以图片形式保存下来,在训练过程中按照一定间隔调用该函数即可在训练过程中监控训练效果。

``` py linenums="153"
``` py linenums="154"
--8<--
examples/tempoGAN/functions.py:153:229
examples/tempoGAN/functions.py:154:230
--8<--
```

Expand All @@ -253,39 +253,39 @@ examples/tempoGAN/functions.py:153:229

Generator 的 loss 提供了 l1 loss、l2 loss、输出经过 Discriminator 判断的 loss 和 输出经过 Discriminator_tempo 判断的 loss。这些 loss 是否存在根据权重参数控制,若某一项 loss 的权重参数为 0,则表示训练中不添加该 loss 项。

``` py linenums="276"
``` py linenums="277"
--8<--
examples/tempoGAN/functions.py:276:345
examples/tempoGAN/functions.py:277:346
--8<--
```

#### 3.8.2 Discriminator 的 loss

Discriminator 为判别器,它的作用是判断数据为真数据还是假数据,因此它的 loss 为 Generator 产生的数据应当判断为假而产生的 loss 和 目标值数据应当判断为真而产生的 loss。

``` py linenums="395"
``` py linenums="396"
--8<--
examples/tempoGAN/functions.py:395:409
examples/tempoGAN/functions.py:396:410
--8<--
```

#### 3.8.3 Discriminator_tempo 的 loss

Discriminator_tempo 的 loss 构成 与 Discriminator 相同,只是所需数据不同。

``` py linenums="411"
``` py linenums="412"
--8<--
examples/tempoGAN/functions.py:411:427
examples/tempoGAN/functions.py:412:428
--8<--
```

#### 3.8.4 自定义 data transform

本问题提供了一种输入数据处理方法,将输入的流体密度数据随机裁剪一块,然后进行密度值判断,若裁剪下来的块密度值低于阈值则重新裁剪,直到密度满足条件或裁剪次数达到阈值。这样做主要是为了减少训练所需的显存,同时对裁剪下来的块密度值的判断保证了块中信息的丰富程度。[参数和超参数设定](#34)中 `tile_ratio` 表示原始尺寸是块的尺寸的几倍,即若`tile_ratio` 为 2,裁剪下来的块的大小为整张原始图片的四分之一。

``` py linenums="430"
``` py linenums="431"
--8<--
examples/tempoGAN/functions.py:430:488
examples/tempoGAN/functions.py:431:489
--8<--
```

Expand Down
1 change: 1 addition & 0 deletions examples/fsi/viv.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(self, model, func):
self.func = func

def forward(self, x):
x = {**x}
model_out = self.model(x)
func_out = self.func(x)
return {**model_out, "f": func_out}
Expand Down
5 changes: 3 additions & 2 deletions examples/tempoGAN/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,12 @@ def reshape_input(input_dict: Dict[str, paddle.Tensor]) -> Dict[str, paddle.Tens
Returns:
Dict[str, paddle.Tensor]: reshaped data dict.
"""
out_dict = {}
for key in input_dict:
input = input_dict[key]
N, C, H, W = input.shape
input_dict[key] = paddle.reshape(input, [N * C, 1, H, W])
return input_dict
out_dict[key] = paddle.reshape(input, [N * C, 1, H, W])
return out_dict


def dereshape_input(
Expand Down
6 changes: 3 additions & 3 deletions ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ def predict(
@misc.run_on_eval_mode
def export(
self,
input_spec: List["InputSpec"],
input_spec: List[Dict[str, InputSpec]],
export_path: str,
with_onnx: bool = False,
skip_prune_program: bool = False,
Expand All @@ -870,8 +870,8 @@ def export(
Convert model to static graph model and export to files.

Args:
input_spec (List[InputSpec]): InputSpec describes the signature information
of the model input.
input_spec (List[Dict[str, InputSpec]]): InputSpec describes the signature
information of the model input.
export_path (str): The path prefix to save model.
with_onnx (bool, optional): Whether to export model into onnx after
paddle inference models are exported. Defaults to False.
Expand Down