Skip to content

Commit ebc0f67

Browse files
authored
【PPSCI Export&Infer No.4】laplace2d (#797)
1 parent 2e8deb3 commit ebc0f67

File tree

3 files changed

+88
-1
lines changed

3 files changed

+88
-1
lines changed

docs/zh/examples/laplace2d.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@
1414
python laplace2d.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/laplace2d/laplace2d_pretrained.pdparams
1515
```
1616

17+
=== "模型导出命令"
18+
19+
``` sh
20+
python laplace2d.py mode=export
21+
```
22+
23+
=== "模型推理命令"
24+
25+
``` sh
26+
python laplace2d.py mode=infer
27+
```
28+
1729
| 预训练模型 | 指标 |
1830
|:--| :--|
1931
| [laplace2d_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/laplace2d/laplace2d_pretrained.pdparams) | loss(MSE_Metric): 0.00002<br>MSE.u(MSE_Metric): 0.00002 |

examples/laplace/conf/laplace2d.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ hydra:
2222
# general settings
2323
mode: train # running mode: train/eval
2424
seed: 42
25+
log_freq: 20
2526
output_dir: ${hydra:run.dir}
2627
NPOINT_INTERIOR: 9801
2728
NPOINT_BC: 400
@@ -50,3 +51,20 @@ TRAIN:
5051
EVAL:
5152
pretrained_model_path: null
5253
eval_with_no_grad: true
54+
55+
INFER:
56+
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/laplace2d/laplace2d_pretrained.pdparams
57+
export_path: ./inference/laplace2d
58+
pdmodel_path: ${INFER.export_path}.pdmodel
59+
pdpiparams_path: ${INFER.export_path}.pdiparams
60+
device: gpu
61+
engine: native
62+
precision: fp32
63+
onnx_path: ${INFER.export_path}.onnx
64+
ir_optim: true
65+
min_subgraph_size: 10
66+
gpu_mem: 4000
67+
gpu_id: 0
68+
max_batch_size: 64
69+
num_cpu_threads: 4
70+
batch_size: 64

examples/laplace/laplace2d.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,71 @@ def u_solution_func(out):
205205
solver.visualize()
206206

207207

208+
def export(cfg: DictConfig):
209+
# set model
210+
model = ppsci.arch.MLP(**cfg.MODEL)
211+
212+
# initialize solver
213+
solver = ppsci.solver.Solver(
214+
model,
215+
pretrained_model_path=cfg.INFER.pretrained_model_path,
216+
)
217+
# export model
218+
from paddle.static import InputSpec
219+
220+
input_spec = [
221+
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
222+
]
223+
solver.export(input_spec, cfg.INFER.export_path)
224+
225+
226+
def inference(cfg: DictConfig):
227+
from deploy.python_infer import pinn_predictor
228+
229+
predictor = pinn_predictor.PINNPredictor(cfg)
230+
231+
# set geometry
232+
geom = {
233+
"rect": ppsci.geometry.Rectangle(
234+
cfg.DIAGONAL_COORD.xmin, cfg.DIAGONAL_COORD.xmax
235+
)
236+
}
237+
NPOINT_TOTAL = cfg.NPOINT_INTERIOR + cfg.NPOINT_BC
238+
input_dict = geom["rect"].sample_interior(NPOINT_TOTAL, evenly=True)
239+
240+
output_dict = predictor.predict(
241+
{key: input_dict[key] for key in cfg.MODEL.input_keys}, cfg.INFER.batch_size
242+
)
243+
244+
# mapping data to cfg.INFER.output_keys
245+
output_dict = {
246+
store_key: output_dict[infer_key]
247+
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
248+
}
249+
250+
# save result
251+
ppsci.visualize.save_vtu_from_dict(
252+
"./laplace2d_pred.vtu",
253+
{**input_dict, **output_dict},
254+
input_dict.keys(),
255+
cfg.MODEL.output_keys,
256+
)
257+
258+
208259
@hydra.main(version_base=None, config_path="./conf", config_name="laplace2d.yaml")
209260
def main(cfg: DictConfig):
210261
if cfg.mode == "train":
211262
train(cfg)
212263
elif cfg.mode == "eval":
213264
evaluate(cfg)
265+
elif cfg.mode == "export":
266+
export(cfg)
267+
elif cfg.mode == "infer":
268+
inference(cfg)
214269
else:
215-
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
270+
raise ValueError(
271+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
272+
)
216273

217274

218275
if __name__ == "__main__":

0 commit comments

Comments
 (0)