From 5779256b1b49f2f661def24dcaf22db16f74dd8c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 29 Mar 2020 01:13:35 +0100 Subject: [PATCH] Update usage doc regarding generate fn --- docs/source/usage.rst | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 8fb7a447279b..6e53af18491a 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -420,7 +420,7 @@ to generate the tokens following the initial sequence in PyTorch, and creating a sequence = f"Hugging Face is based in DUMBO, New York City, and is" input = tokenizer.encode(sequence, return_tensors="pt") - generated = model.generate(input, max_length=50) + generated = model.generate(input, max_length=50, do_sample=True) resulting_string = tokenizer.decode(generated.tolist()[0]) print(resulting_string) @@ -432,14 +432,10 @@ to generate the tokens following the initial sequence in PyTorch, and creating a model = TFAutoModelWithLMHead.from_pretrained("gpt2") sequence = f"Hugging Face is based in DUMBO, New York City, and is" - generated = tokenizer.encode(sequence) - - for i in range(50): - predictions = model(tf.constant([generated]))[0] - token = tf.argmax(predictions[0], axis=1)[-1].numpy() - generated += [token] + input = tokenizer.encode(sequence, return_tensors="tf") + generated = model.generate(input, max_length=50, do_sample=True) - resulting_string = tokenizer.decode(generated) + resulting_string = tokenizer.decode(generated.tolist()[0]) print(resulting_string)