Skip to content

Conversation

@erenup
Copy link
Contributor

@erenup erenup commented Aug 11, 2019

Pytorch-transformers! Nice work!
Refactoring old run_swag.py.

Motivation:

I have seen the swag PR1 #951 and related issues #931
According to @thomwolf 's comments on PR1, I think it's necessary to adopt code styles of run_squad.py in run_swag.py so that we can easily take advantage of the new powerful pytorch_transformers.

Changes:

I refactored the old run_swag.py following run_squad.py and tested it on bert_base_uncased pretrained model, on Tesla P100.

Tests:

export SWAG_DIR=/path/to/SWAG
 python -m torch.distributed.launch --nproc_per_node 1 run_swag.py \
--train_file SWAG_DIR/train.csv \
--predict_file SWAG_DIR/val.csv \
--model_type bert \
--model_name_or_path bert-base-uncased \
--max_seq_length 80 \
--do_train \
--do_eval \
--do_lower_case \
--output_dir ../models/swag_output \
 --per_gpu_train_batch_size 32 \
--per_gpu_eval_batch_size 32 \
--learning_rate 2e-5 \
--gradient_accumulation_steps 2 \
--num_train_epochs 3.0 \
--logging_steps 200 \
--save_steps 200

Results:

eval_accuracy = 0.8016595021493552
eval_loss = 0.5581122178810473

I have also tested the --fp16 and the acc is 0.801.
Other args have been tested: --evaluate_during_training, --eval_all_checkpoints, --overwrite_output_dir, `--overwrite_cache``.
Things have not been tested: multi-gpu, distributed trianing. since I only have one gpu and one computer.

Questions:

It seems the performance is worse than the pytorch-pretrain-bert results. Is this gap of result normal (0.82 and 0.86)?

Future work:

I think it's good to add multiple choice model in XLnet since there are many multiple choice datasets such as RACE.
Thank you all!

@codecov-io
Copy link

codecov-io commented Aug 11, 2019

Codecov Report

Merging #1004 into master will decrease coverage by 0.39%.
The diff coverage is 20.75%.

Impacted file tree graph

@@            Coverage Diff            @@
##           master    #1004     +/-   ##
=========================================
- Coverage   81.16%   80.77%   -0.4%     
=========================================
  Files          57       57             
  Lines        8039     8092     +53     
=========================================
+ Hits         6525     6536     +11     
- Misses       1514     1556     +42
Impacted Files Coverage Δ
pytorch_transformers/modeling_xlnet.py 74.52% <16%> (-2.9%) ⬇️
pytorch_transformers/modeling_roberta.py 64.96% <25%> (-10.27%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e768f23...8960988. Read the comment docs.

@erenup
Copy link
Contributor Author

erenup commented Aug 30, 2019

run_multiple_choice.py and utils_multiple_choice.py with roberta and xlnet have been tested on RACE, SWAG, ARC Challenge.

  1. roebrta large: RACE dev 0.84, SWAG dev 0.88, ARC Challenge 0.65
  2. xlnet large: RACE dev 0.81, ARC challenge 0.63

@thomwolf
Copy link
Member

thomwolf commented Aug 30, 2019

This looks really great. Thanks for updating and testing this script @erenup

A few questions and remarks:

  • do we still need to keep run_swag now that there is a run_multiple_choice?
  • there should be docstrings for the new classes, can you add them, taking inspiration from the other model's docstring?
  • do you want to add an example on how to use the script in the doc, for instance you can add a section here with the commands you used to run the script and indicate the results you got with this commands for each models (good for later reference)

@erenup
Copy link
Contributor Author

erenup commented Aug 30, 2019

@thomwolf Thank you!

  • SWAG dataset has been considered as one of the multiple-choice setting datasets and has a corresponding data processor in utils_multiple_choice.py. So I think run_swag will not be needed. It's also easy to add a new data processor for other multiple-choice datasets in utils_multiple_choice.py.
  • Docstrings will be added soon.
  • Sure, I'd like to add an example on how to use run_multiple_choice.

@erenup
Copy link
Contributor Author

erenup commented Sep 16, 2019

Hi @thomwolf, Docstrings of the multiple-choice models have been added. An example of run_multiple_choice.py has been added in the README of examples. Thank you.


tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
scheduler.step() # Update learning rate schedule
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch scheduler.step() should be called after optimizer.step() (see pytorch/pytorch#20124)

@thomwolf
Copy link
Member

Ok this looks clean and almost ready to merge, just added a quick comment to fix in the code (order of calls to step).

A few things for the merge as we have re-organized the examples folder, can you:

  • move run_swag to examples/contrib
  • move your run_multiple_choice scripts to the main examples folder?

erenup and others added 6 commits September 18, 2019 21:13
…hoice_merge

# Conflicts:
#	examples/contrib/run_swag.py
# Please enter a commit message to explain why this merge is necessary,
# especially if it merges an updated upstream into a topic branch.
#
# Lines starting with '#' will be ignored, and an empty message aborts
# the commit.
@erenup
Copy link
Contributor Author

erenup commented Sep 18, 2019

Hi @thomwolf. I have moved run_multiple_choice.py and utils_multiple_choice.py to examples, run_swag.py to example/contrib and scheduler.step after optimizer.step. I have also done a test of the example/contrib/run_swag.py on current pytorch-transformers. run_swag.py can get a normal result of dev 0.809 of bert-base-uncased model. Thank you.

@thomwolf
Copy link
Member

Awesome, thanks a lot for this contribution @erenup 🔥
Merging now

@thomwolf thomwolf merged commit 0d1dad6 into huggingface:master Sep 18, 2019
@PantherYan
Copy link

run_multiple_choice.py and utils_multiple_choice.py with roberta and xlnet have been tested on RACE, SWAG, ARC Challenge.

  1. roebrta large: RACE dev 0.84, SWAG dev 0.88, ARC Challenge 0.65
  2. xlnet large: RACE dev 0.81, ARC challenge 0.63

Could you share your run -configuration on RACE and ARC dataset?
On SWAG, I could got 0.82 folllowing the suggested setting.
To the RACE,the best performance is 0.62. (maxLength 256, lr 1e-6, cal_gradient 8 etc). The loss is easy over-fittting.
But to the ARC. In the process of data. It show an error like this.

line 638, in create_examples
contexts=[options[0]["para"].replace("
", ""), options[1]["para"].replace("_", ""),

KeyError: 'para'
(I have check the raw_data. the options item has no 'para' .
Could you give me a hit how to convert the dataset of ARC?
Thank you!

@erenup
Copy link
Contributor Author

erenup commented Oct 28, 2019

Hi, @PantherYan
For RACE, I checked my parameters. I run RACE with 4 P40 GPUs with roberta large:
Namespace(adam_epsilon=1e-08, cache_dir='', config_name='', data_dir='data/RACE/', device=device(type='cuda'), do_eval=True, do_lower_case=True, do_test=False, do_train=True, eval_all_checkpoints=False, evaluate_during_training=False, fp16=False, fp16_opt_level='O1', gradient_accumulation_steps=3, learning_rate=1e-05, local_rank=-1, logging_steps=50, max_grad_norm=1.0, max_seq_length=384, max_steps=-1, model_name_or_path='roberta-large', model_type='roberta', n_gpu=4, no_cuda=False, num_train_epochs=5.0, output_dir='models_bert/race_large', overwrite_cache=False, overwrite_output_dir=False, per_gpu_eval_batch_size=2, per_gpu_train_batch_size=2, save_steps=2000, seed=42, server_ip='', server_port='', task_name='race', tokenizer_name='', train_batch_size=8, warmup_steps=0, weight_decay=0.0), you can have a try.

For ARC, you need to ask ai2 for the retrieved text named para for the corresponding task of ARC Challenge, ARC Easy, OpenBookqa. you can find more details in this page

@PantherYan
Copy link

Hi, @PantherYan
For RACE, I checked my parameters. I run RACE with 4 P40 GPUs with roberta large:
Namespace(adam_epsilon=1e-08, cache_dir='', config_name='', data_dir='data/RACE/', device=device(type='cuda'), do_eval=True, do_lower_case=True, do_test=False, do_train=True, eval_all_checkpoints=False, evaluate_during_training=False, fp16=False, fp16_opt_level='O1', gradient_accumulation_steps=3, learning_rate=1e-05, local_rank=-1, logging_steps=50, max_grad_norm=1.0, max_seq_length=384, max_steps=-1, model_name_or_path='roberta-large', model_type='roberta', n_gpu=4, no_cuda=False, num_train_epochs=5.0, output_dir='models_bert/race_large', overwrite_cache=False, overwrite_output_dir=False, per_gpu_eval_batch_size=2, per_gpu_train_batch_size=2, save_steps=2000, seed=42, server_ip='', server_port='', task_name='race', tokenizer_name='', train_batch_size=8, warmup_steps=0, weight_decay=0.0), you can have a try.

For ARC, you need to ask ai2 for the retrieved text named para for the corresponding task of ARC Challenge, ARC Easy, OpenBookqa. you can find more details in this page

Thanks a lot for your prompt reply! Appreciate!
It seems is a TensorFlow-version setting. I will try on the PyTorch. I only have 4 2080Ti (11GB), is the max-lenght batch-size or model size(like roberta-base) influence the performance significantly? I will run a comparison and post it out.

For the ARC. Thanks, I have write a email to AI2 for the help.

Thank you!

@PantherYan
Copy link

Hi, @PantherYan
For RACE, I checked my parameters. I run RACE with 4 P40 GPUs with roberta large:
Namespace(adam_epsilon=1e-08, cache_dir='', config_name='', data_dir='data/RACE/', device=device(type='cuda'), do_eval=True, do_lower_case=True, do_test=False, do_train=True, eval_all_checkpoints=False, evaluate_during_training=False, fp16=False, fp16_opt_level='O1', gradient_accumulation_steps=3, learning_rate=1e-05, local_rank=-1, logging_steps=50, max_grad_norm=1.0, max_seq_length=384, max_steps=-1, model_name_or_path='roberta-large', model_type='roberta', n_gpu=4, no_cuda=False, num_train_epochs=5.0, output_dir='models_bert/race_large', overwrite_cache=False, overwrite_output_dir=False, per_gpu_eval_batch_size=2, per_gpu_train_batch_size=2, save_steps=2000, seed=42, server_ip='', server_port='', task_name='race', tokenizer_name='', train_batch_size=8, warmup_steps=0, weight_decay=0.0), you can have a try.

Thank you for your sharing your training configuration to guid us.

I used the pytorch backend, and strictly following your configure setting, except roberta-base and the batch_size= 2(per_gpu_train_batch_size)*4(gpu_num) , which you set [ train_batch_size=8]. In other words, you setting batch_size = 8, and my setting batch_size =2.

-------- Here is my acc on test dataset: 69.36, loss 0.8339.
Is the batch_size inflenced my test perfermance? or the loss or convergence enough?

data/nlp/MCQA/RACE/cached_test_roberta-base_384_race
11/01/2019 01:49:55 - INFO - main - ***** Running evaluation *****
11/01/2019 01:49:55 - INFO - main - Num examples = 4934
11/01/2019 01:49:55 - INFO - main - Batch size = 8
11/01/2019 01:53:38 - INFO - main - ***** Eval results is test:True *****
11/01/2019 01:53:38 - INFO - main - eval_acc = 0.6945683015808675
11/01/2019 01:53:38 - INFO - main - eval_loss = 0.8386425418383782
11/01/2019 01:53:38 - INFO - main - best steps of eval acc is the following checkpoints: 13000

I give up my training logs

11/01/2019 00:31:22 - INFO - transformers.configuration_utils - Configuration saved in models_race/roberta-base/checkpoint-12000/config.json
11/01/2019 00:31:23 - INFO - transformers.modeling_utils - Model weights saved in models_race/roberta-base/checkpoint-12000/pytorch_model.bin
11/01/2019 00:31:23 - INFO - main - Saving model checkpoint to models_race/roberta-base/checkpoint-12000
11/01/2019 01:12:20 - INFO - main - Loading features from cached file /workspace/data/nlp/MCQA/RACE/cached_dev_roberta-base_384_race
11/01/2019 01:12:22 - INFO - main - ***** Running evaluation *****
11/01/2019 01:12:22 - INFO - main - Num examples = 4887
11/01/2019 01:12:22 - INFO - main - Batch size = 8
11/01/2019 01:16:00 - INFO - main - ***** Eval results is test:False *****
11/01/2019 01:16:00 - INFO - main - eval_acc = 0.7086146920401064
11/01/2019 01:16:00 - INFO - main - eval_loss = 0.8062708838591306
11/01/2019 01:16:00 - INFO - main - Loading features from cached file /workspace/data/nlp/MCQA/RACE/cached_test_roberta-base_384_race
11/01/2019 01:16:02 - INFO - main - ***** Running evaluation *****
11/01/2019 01:16:02 - INFO - main - Num examples = 4934
11/01/2019 01:16:02 - INFO - main - Batch size = 8
11/01/2019 01:19:42 - INFO - main - ***** Eval results is test:True *****
11/01/2019 01:19:42 - INFO - main - eval_acc = 0.6935549250101337
11/01/2019 01:19:42 - INFO - main - eval_loss = 0.8339384843925892
11/01/2019 01:19:42 - INFO - main - test acc: 0.6935549250101337, loss: 0.8339384843925892, global steps: 13000
11/01/2019 01:19:42 - INFO - main - Average loss: 0.6908835964873433 at global step: 13000
11/01/2019 01:19:42 - INFO - transformers.configuration_utils - Configuration saved in models_race/roberta-base/checkpoint-13000/config.json
11/01/2019 01:19:43 - INFO - transformers.modeling_utils - Model weights saved in models_race/roberta-base/checkpoint-13000/pytorch_model.bin
11/01/2019 01:19:43 - INFO - main - Saving model checkpoint to models_race/roberta-base/checkpoint-13000
11/01/2019 01:49:44 - INFO - main - global_step = 13730, average loss = 0.8482715931345925

@erenup Could I learn your training loss and test loss after 5 epochs?
I have runed several times, the accuray still around 70%s. Is it influencd by the roberta-large model or batch_size ?
Looking forward your reply.
Thank you!

@erenup
Copy link
Contributor Author

erenup commented Nov 1, 2019

Hi @PantherYan I did not run race dataset with roberta base. In my experience, I thought the results of RACE with roberta base make sense, Since Bert large can only reach about 71~72. You can check the leaderboard for reference.

@PantherYan
Copy link

Hi @PantherYan I did not run race dataset with roberta base. In my experience, I thought the results of RACE with roberta base make sense, Since Bert large can only reach about 71~72. You can check the leaderboard for reference.

@erenup
I appreciate for your quick reply.
Thank you!

@PantherYan
Copy link

@erenup
You are nice!

@zhangSchnee
Copy link

run_multiple_choice.py and utils_multiple_choice.py with roberta and xlnet have been tested on RACE, SWAG, ARC Challenge.

  1. roebrta large: RACE dev 0.84, SWAG dev 0.88, ARC Challenge 0.65
  2. xlnet large: RACE dev 0.81, ARC challenge 0.63

Could you share your run -configuration on RACE and ARC dataset?
On SWAG, I could got 0.82 folllowing the suggested setting.
To the RACE,the best performance is 0.62. (maxLength 256, lr 1e-6, cal_gradient 8 etc). The loss is easy over-fittting.
But to the ARC. In the process of data. It show an error like this.

line 638, in create_examples contexts=[options[0]["para"].replace("", ""), options[1]["para"].replace("_", ""),

KeyError: 'para'
(I have check the raw_data. the options item has no 'para' .
Could you give me a hit how to convert the dataset of ARC?
Thank you!

I also met the problem of missing item "para", have you got some methods for converting raw corpus?
Thank you!

@erenup
Copy link
Contributor Author

erenup commented Nov 16, 2019

Please see PatherYan's comments and mine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants