Skip to content

Conversation

jysohn23
Copy link

@jysohn23 jysohn23 commented Apr 1, 2020

Also, instead of calling eval_loss.item() every time do summation with
tensors on device.

Also, instead of calling `eval_loss.item()` every time do summation with
tensors on device.
@jysohn23 jysohn23 requested a review from taylanbil April 1, 2020 23:34
Copy link

@taylanbil taylanbil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you test this e2e?

results.update(result)
results['eval_loss'] = eval_loss.item()

# Average all metrics from each shard

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are some of the metrics? does it make sense to avg all of them? some metrics may be additive.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f1, accuracy, eval_loss, acc_and_f1 (avg of the two)

I checked they all make sense averaged.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep ignore that :D Not additive as discussed. Will update PR.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be surprised if global f1 == np.mean(local f1s). That's probably not true, let's verify on paper what the true formula to get global f1 is.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to sync the pred/label tensors directly instead (these shouldn't be that big for finetuning tasks; single integer 0/1 per example). This way we don't have to have some custom aggregation per metric and don't touch upstream core code.

logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
tb_writer.add_scalar(key, result[key])
if xm.is_master_ordinal():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is everything being logged here already on cpu?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe let's add a comment? It's a subtle point that can be missed by code readers.

Copy link
Author

@jysohn23 jysohn23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, e2e tested.

results.update(result)
results['eval_loss'] = eval_loss.item()

# Average all metrics from each shard
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f1, accuracy, eval_loss, acc_and_f1 (avg of the two)

I checked they all make sense averaged.

logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
tb_writer.add_scalar(key, result[key])
if xm.is_master_ordinal():
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

@jysohn23 jysohn23 requested a review from taylanbil April 1, 2020 23:46
As brought up during review some metrics like f1 cannot be aggregated
via averaging. GLUE task metrics depends largely on the dataset, so
instead we sync the prediction and label tensors so that the metrics can
be computed accurately on those instead.
@jysohn23 jysohn23 merged commit 6e20572 into pytorch-tpu:tpu Apr 2, 2020
@jysohn23 jysohn23 deleted the tpu branch April 2, 2020 16:44
jysohn23 added a commit to jysohn23/transformers that referenced this pull request Apr 10, 2020
* Initial commit to get BERT + run_glue.py on TPU

* Add README section for TPU and address comments.

* Cleanup TPU bits from run_glue.py (pytorch-tpu#3)

TPU runner is currently implemented in:
https://github.com/pytorch-tpu/transformers/blob/tpu/examples/run_glue_tpu.py.

We plan to upstream this directly into `huggingface/transformers`
(either `master` or `tpu`) branch once it's been more thoroughly tested.

* Cleanup TPU bits from run_glue.py

TPU runner is currently implemented in:
https://github.com/pytorch-tpu/transformers/blob/tpu/examples/run_glue_tpu.py.

We plan to upstream this directly into `huggingface/transformers`
(either `master` or `tpu`) branch once it's been more thoroughly tested.

* No need to call `xm.mark_step()` explicitly (pytorch-tpu#4)

Since for gradient accumulation we're accumulating on batches from
`ParallelLoader` instance which on next() marks the step itself.

* Resolve R/W conflicts from multiprocessing (pytorch-tpu#5)

* Add XLNet in list of models for `run_glue_tpu.py` (pytorch-tpu#6)

* Add RoBERTa to list of models in TPU GLUE (pytorch-tpu#7)

* Add RoBERTa and DistilBert to list of models in TPU GLUE (pytorch-tpu#8)

* Use barriers to reduce duplicate work/resources (pytorch-tpu#9)

* Shard eval dataset and aggregate eval metrics (pytorch-tpu#10)

* Shard eval dataset and aggregate eval metrics

Also, instead of calling `eval_loss.item()` every time do summation with
tensors on device.

* Change defaultdict to float

* Reduce the pred, label tensors instead of metrics

As brought up during review some metrics like f1 cannot be aggregated
via averaging. GLUE task metrics depends largely on the dataset, so
instead we sync the prediction and label tensors so that the metrics can
be computed accurately on those instead.

* Only use tb_writer from master (pytorch-tpu#11)

* Apply huggingface black code formatting

* Style

* Remove `--do_lower_case` as example uses cased

* Add option to specify tensorboard logdir

This is needed for our testing framework which checks regressions
against key metrics writtern by the summary writer.

* Using configuration for `xla_device`

* Prefix TPU specific comments.

* num_cores clarification and namespace eval metrics

* Cache features file under `args.cache_dir`

Instead of under `args.data_dir`. This is needed as our test infra uses
data_dir with a read-only filesystem.

* Rename `run_glue_tpu` to `run_tpu_glue`

Co-authored-by: LysandreJik <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants