Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions ppsci/arch/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,9 @@ def forward_tensor(self, x):
u = self.embed_u(x)
v = self.embed_v(x)

x = self.linears[0](x)
x = self.acts[0](x)

y = x
skip = None
for i, linear in enumerate(self.linears[1:], 1):
for i, linear in enumerate(self.linears):
y = linear(y)
y = self.acts[i](y)
y = (1 - y) * u + y * v
Expand Down
2 changes: 1 addition & 1 deletion ppsci/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def init_logger(
# add file_handler, output to log_file(if specified), only for rank 0 device
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.dirname(log_file)
if os.path.isdir(log_file_folder):
if len(log_file_folder):
os.makedirs(log_file_folder, exist_ok=True)
file_formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s",
Expand Down
2 changes: 1 addition & 1 deletion ppsci/utils/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def add_edge(u: str, v: str, u_color: str = C_DATA, v_color: str = C_DATA):
graph.layout()
image_path = f"{graph_filename}.png"
dot_path = f"{graph_filename}.dot"
if os.path.isdir(os.path.dirname(image_path)):
if len(os.path.dirname(image_path)):
os.makedirs(os.path.dirname(image_path), exist_ok=True)
graph.draw(image_path, prog="dot")
graph.write(dot_path)
Expand Down
2 changes: 1 addition & 1 deletion ppsci/utils/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def save_tecplot_file(
nx, ny = num_x, num_y
assert nx * ny == nxy, f"nx({nx}) * ny({ny}) != nxy({nxy})"

if os.path.isdir(os.path.dirname(filename)):
if len(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename), exist_ok=True)

if filename.endswith(".dat"):
Expand Down
8 changes: 4 additions & 4 deletions ppsci/visualize/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamps=
value_keys (Tuple[str, ...]): Value keys.
num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1.
"""
if os.path.isdir(os.path.dirname(filename)):
if len(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename), exist_ok=True)
fig, a = plt.subplots(len(value_keys), num_timestamps, squeeze=False)
fig.subplots_adjust(hspace=0.8)
Expand Down Expand Up @@ -188,7 +188,7 @@ def _save_plot_from_2d_array(
xticks (Optional[Tuple[float, ...]]): Tuple of xtick locations. Defaults to None.
yticks (Optional[Tuple[float, ...]]): Tuple of ytick locations. Defaults to None.
"""
if os.path.isdir(os.path.dirname(filename)):
if len(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename), exist_ok=True)

plt.close("all")
Expand Down Expand Up @@ -336,7 +336,7 @@ def _save_plot_from_3d_array(
visu_keys (Tuple[str, ...]): Keys for visualizing data. such as ("u", "v").
num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1.
"""
if os.path.isdir(os.path.dirname(filename)):
if len(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename), exist_ok=True)

fig = plt.figure(figsize=(10, 10))
Expand Down Expand Up @@ -482,7 +482,7 @@ def plot_weather(
)
plt.colorbar(mappable=map_, cax=None, ax=None, shrink=0.5, label=colorbar_label)

if os.path.isdir(os.path.dirname(filename)):
if len(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename), exist_ok=True)
fig = plt.figure(facecolor="w", figsize=(7, 7))
ax = fig.add_subplot(2, 1, 1)
Expand Down
4 changes: 2 additions & 2 deletions ppsci/visualize/vtu.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _save_vtu_from_array(filename, coord, value, value_keys, num_timestamps=1):
if coord.shape[1] not in [2, 3]:
raise ValueError(f"ndim of coord({coord.shape[1]}) should be 2 or 3.")

if os.path.isdir(os.path.dirname(filename)):
if len(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename), exist_ok=True)

# discard extension name
Expand Down Expand Up @@ -184,6 +184,6 @@ def save_vtu_to_mesh(
points=points.T, cells=[("vertex", np.arange(npoint).reshape(npoint, 1))]
)
mesh.point_data = {key: data_dict[key] for key in value_keys}
if os.path.isdir(os.path.dirname(filename)):
if len(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename), exist_ok=True)
mesh.write(filename)