Skip to content

Commit fe301ec

Browse files
authored
Merge branch 'develop' into bubble_net
2 parents f940c53 + bee3a10 commit fe301ec

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

ppsci/utils/misc.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ class Timer(ContextDecorator):
210210
... w = sum(range(0, 10))
211211
>>> func() # doctest: +SKIP
212212
213+
>>> timer = misc.Timer("cost_of_func", auto_print=False)
214+
>>> timer.start()
215+
>>> func()
216+
>>> timer.end()
217+
>>> print(f"time cost of 'cost_of_func' is {timer.interval:.2f}")
213218
"""
214219

215220
interval: float # Time cost for code within Timer context
@@ -220,10 +225,31 @@ def __init__(self, name: str = "Timer", auto_print: bool = True):
220225
self.auto_print = auto_print
221226

222227
def __enter__(self):
228+
paddle.device.synchronize()
223229
self.start_time = time.perf_counter()
224230
return self
225231

226232
def __exit__(self, type, value, traceback):
233+
paddle.device.synchronize()
234+
self.end_time = time.perf_counter()
235+
self.interval = self.end_time - self.start_time
236+
if self.auto_print:
237+
logger.message(f"{self.name}.time_cost = {self.interval:.2f} s")
238+
239+
def start(self, name: str = "Timer"):
240+
"""Push a new timer context.
241+
242+
Args:
243+
name (str, optional): Name of code block to be clocked. Defaults to "Timer".
244+
"""
245+
paddle.device.synchronize()
246+
self.start_time = time.perf_counter()
247+
248+
def end(self):
249+
"""
250+
End current timer context and print time cost.
251+
"""
252+
paddle.device.synchronize()
227253
self.end_time = time.perf_counter()
228254
self.interval = self.end_time - self.start_time
229255
if self.auto_print:

ppsci/utils/save_load.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,18 @@ def load_checkpoint(
146146
equation_dict = paddle.load(f"{path}.pdeqn")
147147

148148
# set state dict
149-
model.set_state_dict(param_dict)
149+
missing_keys, unexpected_keys = model.set_state_dict(param_dict)
150+
if missing_keys:
151+
logger.warning(
152+
f"There are missing keys when loading checkpoint: {missing_keys}, "
153+
"and corresponding parameters will be initialized by default."
154+
)
155+
if unexpected_keys:
156+
logger.warning(
157+
f"There are redundant keys: {unexpected_keys}, "
158+
"and corresponding weights will be ignored."
159+
)
160+
150161
optimizer.set_state_dict(optim_dict)
151162
if grad_scaler is not None:
152163
grad_scaler.load_state_dict(scaler_dict)

ppsci/utils/symbolic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
sp.sign: paddle.sign,
102102
sp.ceiling: paddle.ceil,
103103
sp.floor: paddle.floor,
104-
# NOTE: sp.Add and sp.Mul is not included here for un-alignment with sympy
104+
# NOTE: sp.Add and sp.Mul is not included here for un-alignment with paddle
105105
# and are implemented manually in 'OperatorNode._add_operator_func' and
106106
# 'OperatorNode._mul_operator_func'
107107
}
@@ -711,15 +711,15 @@ def lambdify(
711711
such as 'momentum_x'. Defaults to None.
712712
create_graph (bool, optional): Whether to create the gradient graphs of
713713
the computing process. When it is True, higher order derivatives are
714-
supported to compute; when it is False, the gradient graphs of the
714+
supported to compute. When it is False, the gradient graphs of the
715715
computing process would be discarded. Defaults to True.
716716
retain_graph (Optional[bool]): Whether to retain the forward graph which
717717
is used to calculate the gradient. When it is True, the graph would
718718
be retained, in which way users can calculate backward twice for the
719719
same graph. When it is False, the graph would be freed. Defaults to None,
720720
which means it is equal to `create_graph`.
721721
fuse_derivative (bool, optional): Whether to fuse the derivative nodes.
722-
for example, if `expr` is 'Derivative(u, x) + Derivative(u, y)'
722+
For example, if `expr` is 'Derivative(u, x) + Derivative(u, y)'
723723
It will compute grad(u, x) + grad(u, y) if fuse_derivative=False,
724724
else will compute sum(grad(u, [x, y])) if fuse_derivative=True as is more
725725
efficient in backward-graph. Defaults to False, as it is experimental so not

0 commit comments

Comments
 (0)