Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/flax/_tests_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ conllu
nltk
rouge-score
seqeval
tensorboard
tensorboard
evaluate >= 0.2.0
5 changes: 3 additions & 2 deletions examples/flax/image-captioning/run_image_captioning_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
import datasets
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import Dataset, load_dataset, load_metric
from datasets import Dataset, load_dataset
from PIL import Image
from tqdm import tqdm

import evaluate
import jax
import jax.numpy as jnp
import optax
Expand Down Expand Up @@ -811,7 +812,7 @@ def blockwise_data_loader(
yield batch

# Metric
metric = load_metric("rouge")
metric = evaluate.load("rouge")

def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
Expand Down
5 changes: 3 additions & 2 deletions examples/flax/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@

import datasets
import numpy as np
from datasets import load_dataset, load_metric
from datasets import load_dataset
from tqdm import tqdm

import evaluate
import jax
import jax.numpy as jnp
import optax
Expand Down Expand Up @@ -776,7 +777,7 @@ def post_processing_function(examples, features, predictions, stage="eval"):
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)

metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")

def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
Expand Down
5 changes: 3 additions & 2 deletions examples/flax/summarization/run_summarization_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@
import datasets
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import Dataset, load_dataset, load_metric
from datasets import Dataset, load_dataset
from tqdm import tqdm

import evaluate
import jax
import jax.numpy as jnp
import optax
Expand Down Expand Up @@ -656,7 +657,7 @@ def preprocess_function(examples):
)

# Metric
metric = load_metric("rouge")
metric = evaluate.load("rouge")

def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
Expand Down
7 changes: 4 additions & 3 deletions examples/flax/text-classification/run_flax_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@

import datasets
import numpy as np
from datasets import load_dataset, load_metric
from datasets import load_dataset
from tqdm import tqdm

import evaluate
import jax
import jax.numpy as jnp
import optax
Expand Down Expand Up @@ -570,9 +571,9 @@ def eval_step(state, batch):
p_eval_step = jax.pmap(eval_step, axis_name="batch")

if data_args.task_name is not None:
metric = load_metric("glue", data_args.task_name)
metric = evaluate.load("glue", data_args.task_name)
else:
metric = load_metric("accuracy")
metric = evaluate.load("accuracy")

logger.info(f"===== Starting training ({num_epochs} epochs) =====")
train_time = 0
Expand Down
5 changes: 3 additions & 2 deletions examples/flax/token-classification/run_flax_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@

import datasets
import numpy as np
from datasets import ClassLabel, load_dataset, load_metric
from datasets import ClassLabel, load_dataset
from tqdm import tqdm

import evaluate
import jax
import jax.numpy as jnp
import optax
Expand Down Expand Up @@ -646,7 +647,7 @@ def eval_step(state, batch):

p_eval_step = jax.pmap(eval_step, axis_name="batch")

metric = load_metric("seqeval")
metric = evaluate.load("seqeval")

def get_labels(y_pred, y_true):
# Transform predictions and references tensos to numpy arrays
Expand Down