@@ -35,7 +35,9 @@ def split_tensors(
3535
3636 Args:
3737 tensors (List[np.array]): Non-empty tensor list.
38- ratio (float): Split ratio. For example, tensor list A is split to A1 and A2. len(A1) / len(A) = ratio.
38+ ratio (float): Split ratio. For example, tensor list A is split to A1 and A2.
39+ len(A1) / len(A) = ratio.
40+
3941 Returns:
4042 Tuple[List[np.array], List[np.array]]: Split tensors.
4143 """
@@ -77,8 +79,8 @@ def predict_and_save_plot(
7779 min_p = np .min (y [index , 2 , :, :])
7880 max_p = np .max (y [index , 2 , :, :])
7981
80- output = solver .predict ({"input" : x }, return_numpy = True )
81- pred_y = output ["output" ]
82+ pred_y = solver .predict ({"input" : x }, return_numpy = True )
83+ pred_y = pred_y ["output" ]
8284 error = np .abs (y - pred_y )
8385
8486 min_error_u = np .min (error [index , 0 , :, :])
@@ -193,10 +195,7 @@ def predict_and_save_plot(
193195 plt .colorbar (orientation = "horizontal" )
194196 plt .tight_layout ()
195197 plt .show ()
196- plt .savefig (
197- os .path .join (plot_dir , f"cfd_{ index } .png" ),
198- bbox_inches = "tight" ,
199- )
198+ plt .savefig (os .path .join (plot_dir , f"cfd_{ index } .png" ), bbox_inches = "tight" )
200199
201200
202201def train (cfg : DictConfig ):
0 commit comments