Skip to content

Commit e3f71a9

Browse files
revert random permutation from numpy to paddle (#792)
1 parent 7e0cf12 commit e3f71a9

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

examples/epnn/conf/epnn.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,11 @@ TRAIN:
4141
iters_per_epoch: 1
4242
save_freq: 50
4343
eval_during_train: true
44-
eval_with_no_grad: true
4544
lr_scheduler:
4645
epochs: ${TRAIN.epochs}
4746
iters_per_epoch: ${TRAIN.iters_per_epoch}
4847
gamma: 0.97
49-
decay_steps: 1
48+
decay_steps: 10000000
5049
pretrained_model_path: null
5150
checkpoint_path: null
5251

examples/epnn/epnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _transform_in_stress(_in):
114114
save_freq=cfg.TRAIN.save_freq,
115115
eval_during_train=cfg.TRAIN.eval_during_train,
116116
validator=validator_pde,
117-
eval_with_no_grad=cfg.TRAIN.eval_with_no_grad,
117+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
118118
)
119119

120120
# train model

examples/epnn/functions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ def get(self, epochs=1):
248248
label_dict_train = {"dummy_loss": []}
249249
label_dict_val = {"dummy_loss": []}
250250
for i in range(epochs):
251-
shuffled_indices = np.random.permutation(self.data_state.x_train.shape[0])
251+
shuffled_indices = paddle.randperm(
252+
n=self.data_state.x_train.shape[0]
253+
).numpy()
252254
input_dict_train["state_x"].append(
253255
self.data_state.x_train[shuffled_indices[0 : self.itrain]]
254256
)
@@ -263,7 +265,7 @@ def get(self, epochs=1):
263265
)
264266
label_dict_train["dummy_loss"].append(0.0)
265267

266-
shuffled_indices = np.random.permutation(self.data_state.x_valid.shape[0])
268+
shuffled_indices = paddle.randperm(n=self.data_state.x_valid.shape[0]).numpy()
267269
input_dict_val["state_x"].append(
268270
self.data_state.x_valid[shuffled_indices[0 : self.itrain]]
269271
)
@@ -296,7 +298,7 @@ def __init__(self, dataset_path, train_p=0.6, cross_valid_p=0.2, test_p=0.2):
296298
def get_shuffled_data(self):
297299
# Need to set the seed, otherwise the loss will not match the precision
298300
ppsci.utils.misc.set_random_seed(seed=10)
299-
shuffled_indices = np.random.permutation(self.x.shape[0])
301+
shuffled_indices = paddle.randperm(n=self.x.shape[0]).numpy()
300302
n_train = math.floor(self.train_p * self.x.shape[0])
301303
n_cross_valid = math.floor(self.cross_valid_p * self.x.shape[0])
302304
n_test = math.floor(self.test_p * self.x.shape[0])

0 commit comments

Comments
 (0)