From 2224faae5a4533219486da09fe6d80cb01be215e Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 1 Apr 2024 15:29:25 +0000 Subject: [PATCH 1/2] replace os.path.isdir with len --- ppsci/arch/mlp.py | 5 +---- ppsci/utils/logger.py | 2 +- ppsci/utils/symbolic.py | 2 +- ppsci/utils/writer.py | 2 +- ppsci/visualize/plot.py | 8 ++++---- ppsci/visualize/vtu.py | 4 ++-- 6 files changed, 10 insertions(+), 13 deletions(-) diff --git a/ppsci/arch/mlp.py b/ppsci/arch/mlp.py index a532f458b0..2cf8a3887a 100644 --- a/ppsci/arch/mlp.py +++ b/ppsci/arch/mlp.py @@ -300,12 +300,9 @@ def forward_tensor(self, x): u = self.embed_u(x) v = self.embed_v(x) - x = self.linears[0](x) - x = self.acts[0](x) - y = x skip = None - for i, linear in enumerate(self.linears[1:], 1): + for i, linear in enumerate(self.linears, 1): y = linear(y) y = self.acts[i](y) y = (1 - y) * u + y * v diff --git a/ppsci/utils/logger.py b/ppsci/utils/logger.py index bdd7a0e034..f8c7ac4e86 100644 --- a/ppsci/utils/logger.py +++ b/ppsci/utils/logger.py @@ -104,7 +104,7 @@ def init_logger( # add file_handler, output to log_file(if specified), only for rank 0 device if log_file is not None and dist.get_rank() == 0: log_file_folder = os.path.dirname(log_file) - if os.path.isdir(log_file_folder): + if len(log_file_folder): os.makedirs(log_file_folder, exist_ok=True) file_formatter = logging.Formatter( "[%(asctime)s] %(name)s %(levelname)s: %(message)s", diff --git a/ppsci/utils/symbolic.py b/ppsci/utils/symbolic.py index af129caf81..5a2f1e99c7 100644 --- a/ppsci/utils/symbolic.py +++ b/ppsci/utils/symbolic.py @@ -615,7 +615,7 @@ def add_edge(u: str, v: str, u_color: str = C_DATA, v_color: str = C_DATA): graph.layout() image_path = f"{graph_filename}.png" dot_path = f"{graph_filename}.dot" - if os.path.isdir(os.path.dirname(image_path)): + if len(os.path.dirname(image_path)): os.makedirs(os.path.dirname(image_path), exist_ok=True) graph.draw(image_path, prog="dot") graph.write(dot_path) diff --git a/ppsci/utils/writer.py b/ppsci/utils/writer.py index 1c3650f51e..eb7a6d55f3 100644 --- a/ppsci/utils/writer.py +++ b/ppsci/utils/writer.py @@ -182,7 +182,7 @@ def save_tecplot_file( nx, ny = num_x, num_y assert nx * ny == nxy, f"nx({nx}) * ny({ny}) != nxy({nxy})" - if os.path.isdir(os.path.dirname(filename)): + if len(os.path.dirname(filename)): os.makedirs(os.path.dirname(filename), exist_ok=True) if filename.endswith(".dat"): diff --git a/ppsci/visualize/plot.py b/ppsci/visualize/plot.py index f5a3bb2f09..f221eeea28 100644 --- a/ppsci/visualize/plot.py +++ b/ppsci/visualize/plot.py @@ -80,7 +80,7 @@ def _save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamps= value_keys (Tuple[str, ...]): Value keys. num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1. """ - if os.path.isdir(os.path.dirname(filename)): + if len(os.path.dirname(filename)): os.makedirs(os.path.dirname(filename), exist_ok=True) fig, a = plt.subplots(len(value_keys), num_timestamps, squeeze=False) fig.subplots_adjust(hspace=0.8) @@ -188,7 +188,7 @@ def _save_plot_from_2d_array( xticks (Optional[Tuple[float, ...]]): Tuple of xtick locations. Defaults to None. yticks (Optional[Tuple[float, ...]]): Tuple of ytick locations. Defaults to None. """ - if os.path.isdir(os.path.dirname(filename)): + if len(os.path.dirname(filename)): os.makedirs(os.path.dirname(filename), exist_ok=True) plt.close("all") @@ -336,7 +336,7 @@ def _save_plot_from_3d_array( visu_keys (Tuple[str, ...]): Keys for visualizing data. such as ("u", "v"). num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1. """ - if os.path.isdir(os.path.dirname(filename)): + if len(os.path.dirname(filename)): os.makedirs(os.path.dirname(filename), exist_ok=True) fig = plt.figure(figsize=(10, 10)) @@ -482,7 +482,7 @@ def plot_weather( ) plt.colorbar(mappable=map_, cax=None, ax=None, shrink=0.5, label=colorbar_label) - if os.path.isdir(os.path.dirname(filename)): + if len(os.path.dirname(filename)): os.makedirs(os.path.dirname(filename), exist_ok=True) fig = plt.figure(facecolor="w", figsize=(7, 7)) ax = fig.add_subplot(2, 1, 1) diff --git a/ppsci/visualize/vtu.py b/ppsci/visualize/vtu.py index fb4dc5857a..2971f48293 100644 --- a/ppsci/visualize/vtu.py +++ b/ppsci/visualize/vtu.py @@ -53,7 +53,7 @@ def _save_vtu_from_array(filename, coord, value, value_keys, num_timestamps=1): if coord.shape[1] not in [2, 3]: raise ValueError(f"ndim of coord({coord.shape[1]}) should be 2 or 3.") - if os.path.isdir(os.path.dirname(filename)): + if len(os.path.dirname(filename)): os.makedirs(os.path.dirname(filename), exist_ok=True) # discard extension name @@ -184,6 +184,6 @@ def save_vtu_to_mesh( points=points.T, cells=[("vertex", np.arange(npoint).reshape(npoint, 1))] ) mesh.point_data = {key: data_dict[key] for key in value_keys} - if os.path.isdir(os.path.dirname(filename)): + if len(os.path.dirname(filename)): os.makedirs(os.path.dirname(filename), exist_ok=True) mesh.write(filename) From fc80f6536737dc08e7e1fd0a2c0dec1f577729cd Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 2 Apr 2024 02:39:35 +0000 Subject: [PATCH 2/2] enumerate from 0 --- ppsci/arch/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppsci/arch/mlp.py b/ppsci/arch/mlp.py index 2cf8a3887a..edd9513c3a 100644 --- a/ppsci/arch/mlp.py +++ b/ppsci/arch/mlp.py @@ -302,7 +302,7 @@ def forward_tensor(self, x): y = x skip = None - for i, linear in enumerate(self.linears, 1): + for i, linear in enumerate(self.linears): y = linear(y) y = self.acts[i](y) y = (1 - y) * u + y * v