8
8
#
9
9
# http://www.apache.org/licenses/LICENSE-2.0
10
10
#
11
- # Unless required by applicable law or agreed to in writing, softwareFuchs.Mader.Luchs
12
-
11
+ # Unless required by applicable law or agreed to in writing, software
13
12
# distributed under the License is distributed on an "AS IS" BASIS,
14
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
14
# See the License for the specific language governing permissions and
@@ -304,17 +303,19 @@ def main():
304
303
'weight_decay' : args .weight_decay },
305
304
{'params' : [p for n , p in model .named_parameters () if any (nd in n for nd in no_decay )], 'weight_decay' : 0.0 }
306
305
]
306
+ t_total = len (train_dataloader ) // args .gradient_accumulation_steps * args .num_train_epochs
307
307
optimizer = AdamW (optimizer_grouped_parameters , lr = args .learning_rate , eps = args .adam_epsilon )
308
- scheduler = get_linear_schedule_with_warmup (optimizer ,
309
- num_warmup_steps = int (len ( train_dataloader ) * args . num_train_epochs * 0.1 ),
310
- num_training_steps = len ( train_dataloader ) * args . num_train_epochs )
308
+ scheduler = get_linear_schedule_with_warmup (optimizer ,
309
+ num_warmup_steps = int (t_total * 0.1 ),
310
+ num_training_steps = t_total )
311
311
312
312
#Start training
313
313
logger .info ("***** Running training *****" )
314
314
logger .info (" Num examples = %d" , len (train_examples ))
315
315
logger .info (" Batch size = %d" , args .train_batch_size )
316
316
logger .info (" Num epoch = %d" , args .num_train_epochs )
317
317
318
+
318
319
model .train ()
319
320
dev_dataset = {}
320
321
nb_tr_examples , nb_tr_steps ,tr_loss ,global_step ,best_bleu ,best_loss = 0 , 0 ,0 ,0 ,0 ,1e6
@@ -515,6 +516,4 @@ def main():
515
516
516
517
517
518
if __name__ == "__main__" :
518
- main ()
519
-
520
-
519
+ main ()
0 commit comments