Skip to content

Commit e3d82b2

Browse files
fix deepcfd
1 parent 6b6637e commit e3d82b2

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

examples/deepcfd/deepcfd.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def split_tensors(
3535
3636
Args:
3737
tensors (List[np.array]): Non-empty tensor list.
38-
ratio (float): Split ratio. For example, tensor list A is split to A1 and A2. len(A1) / len(A) = ratio.
38+
ratio (float): Split ratio. For example, tensor list A is split to A1 and A2.
39+
len(A1) / len(A) = ratio.
40+
3941
Returns:
4042
Tuple[List[np.array], List[np.array]]: Split tensors.
4143
"""
@@ -77,8 +79,8 @@ def predict_and_save_plot(
7779
min_p = np.min(y[index, 2, :, :])
7880
max_p = np.max(y[index, 2, :, :])
7981

80-
output = solver.predict({"input": x}, return_numpy=True)
81-
pred_y = output["output"]
82+
pred_y = solver.predict({"input": x}, return_numpy=True)
83+
pred_y = pred_y["output"]
8284
error = np.abs(y - pred_y)
8385

8486
min_error_u = np.min(error[index, 0, :, :])
@@ -193,10 +195,7 @@ def predict_and_save_plot(
193195
plt.colorbar(orientation="horizontal")
194196
plt.tight_layout()
195197
plt.show()
196-
plt.savefig(
197-
os.path.join(plot_dir, f"cfd_{index}.png"),
198-
bbox_inches="tight",
199-
)
198+
plt.savefig(os.path.join(plot_dir, f"cfd_{index}.png"), bbox_inches="tight")
200199

201200

202201
def train(cfg: DictConfig):

0 commit comments

Comments
 (0)