Skip to content

Commit f8b2574

Browse files
Better TF docstring types (#23477)
* Rework TF type hints to use | None instead of Optional[] for tf.Tensor * Rework TF type hints to use | None instead of Optional[] for tf.Tensor * Don't forget the imports * Add the imports to tests too * make fixup * Refactor tests that depended on get_type_hints * Better test refactor * Fix an old hidden bug in the test_keras_fit input creation code * Fix for the Deit tests
1 parent 767e6b5 commit f8b2574

File tree

139 files changed

+2907
-2621
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

139 files changed

+2907
-2621
lines changed

src/transformers/modeling_tf_outputs.py

Lines changed: 96 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import warnings
1618
from dataclasses import dataclass
1719
from typing import List, Optional, Tuple
@@ -43,8 +45,8 @@ class TFBaseModelOutput(ModelOutput):
4345
"""
4446

4547
last_hidden_state: tf.Tensor = None
46-
hidden_states: Optional[Tuple[tf.Tensor]] = None
47-
attentions: Optional[Tuple[tf.Tensor]] = None
48+
hidden_states: Tuple[tf.Tensor] | None = None
49+
attentions: Tuple[tf.Tensor] | None = None
4850

4951

5052
@dataclass
@@ -96,8 +98,8 @@ class TFBaseModelOutputWithPooling(ModelOutput):
9698

9799
last_hidden_state: tf.Tensor = None
98100
pooler_output: tf.Tensor = None
99-
hidden_states: Optional[Tuple[tf.Tensor]] = None
100-
attentions: Optional[Tuple[tf.Tensor]] = None
101+
hidden_states: Tuple[tf.Tensor] | None = None
102+
attentions: Tuple[tf.Tensor] | None = None
101103

102104

103105
@dataclass
@@ -164,10 +166,10 @@ class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
164166

165167
last_hidden_state: tf.Tensor = None
166168
pooler_output: tf.Tensor = None
167-
past_key_values: Optional[List[tf.Tensor]] = None
168-
hidden_states: Optional[Tuple[tf.Tensor]] = None
169-
attentions: Optional[Tuple[tf.Tensor]] = None
170-
cross_attentions: Optional[Tuple[tf.Tensor]] = None
169+
past_key_values: List[tf.Tensor] | None = None
170+
hidden_states: Tuple[tf.Tensor] | None = None
171+
attentions: Tuple[tf.Tensor] | None = None
172+
cross_attentions: Tuple[tf.Tensor] | None = None
171173

172174

173175
@dataclass
@@ -201,9 +203,9 @@ class TFBaseModelOutputWithPast(ModelOutput):
201203
"""
202204

203205
last_hidden_state: tf.Tensor = None
204-
past_key_values: Optional[List[tf.Tensor]] = None
205-
hidden_states: Optional[Tuple[tf.Tensor]] = None
206-
attentions: Optional[Tuple[tf.Tensor]] = None
206+
past_key_values: List[tf.Tensor] | None = None
207+
hidden_states: Tuple[tf.Tensor] | None = None
208+
attentions: Tuple[tf.Tensor] | None = None
207209

208210

209211
@dataclass
@@ -234,9 +236,9 @@ class TFBaseModelOutputWithCrossAttentions(ModelOutput):
234236
"""
235237

236238
last_hidden_state: tf.Tensor = None
237-
hidden_states: Optional[Tuple[tf.Tensor]] = None
238-
attentions: Optional[Tuple[tf.Tensor]] = None
239-
cross_attentions: Optional[Tuple[tf.Tensor]] = None
239+
hidden_states: Tuple[tf.Tensor] | None = None
240+
attentions: Tuple[tf.Tensor] | None = None
241+
cross_attentions: Tuple[tf.Tensor] | None = None
240242

241243

242244
@dataclass
@@ -276,10 +278,10 @@ class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
276278
"""
277279

278280
last_hidden_state: tf.Tensor = None
279-
past_key_values: Optional[List[tf.Tensor]] = None
280-
hidden_states: Optional[Tuple[tf.Tensor]] = None
281-
attentions: Optional[Tuple[tf.Tensor]] = None
282-
cross_attentions: Optional[Tuple[tf.Tensor]] = None
281+
past_key_values: List[tf.Tensor] | None = None
282+
hidden_states: Tuple[tf.Tensor] | None = None
283+
attentions: Tuple[tf.Tensor] | None = None
284+
cross_attentions: Tuple[tf.Tensor] | None = None
283285

284286

285287
@dataclass
@@ -333,13 +335,13 @@ class TFSeq2SeqModelOutput(ModelOutput):
333335
"""
334336

335337
last_hidden_state: tf.Tensor = None
336-
past_key_values: Optional[List[tf.Tensor]] = None
337-
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
338-
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
339-
cross_attentions: Optional[Tuple[tf.Tensor]] = None
340-
encoder_last_hidden_state: Optional[tf.Tensor] = None
341-
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
342-
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
338+
past_key_values: List[tf.Tensor] | None = None
339+
decoder_hidden_states: Tuple[tf.Tensor] | None = None
340+
decoder_attentions: Tuple[tf.Tensor] | None = None
341+
cross_attentions: Tuple[tf.Tensor] | None = None
342+
encoder_last_hidden_state: tf.Tensor | None = None
343+
encoder_hidden_states: Tuple[tf.Tensor] | None = None
344+
encoder_attentions: Tuple[tf.Tensor] | None = None
343345

344346

345347
@dataclass
@@ -365,10 +367,10 @@ class TFCausalLMOutput(ModelOutput):
365367
heads.
366368
"""
367369

368-
loss: Optional[tf.Tensor] = None
370+
loss: tf.Tensor | None = None
369371
logits: tf.Tensor = None
370-
hidden_states: Optional[Tuple[tf.Tensor]] = None
371-
attentions: Optional[Tuple[tf.Tensor]] = None
372+
hidden_states: Tuple[tf.Tensor] | None = None
373+
attentions: Tuple[tf.Tensor] | None = None
372374

373375

374376
@dataclass
@@ -400,11 +402,11 @@ class TFCausalLMOutputWithPast(ModelOutput):
400402
heads.
401403
"""
402404

403-
loss: Optional[tf.Tensor] = None
405+
loss: tf.Tensor | None = None
404406
logits: tf.Tensor = None
405-
past_key_values: Optional[List[tf.Tensor]] = None
406-
hidden_states: Optional[Tuple[tf.Tensor]] = None
407-
attentions: Optional[Tuple[tf.Tensor]] = None
407+
past_key_values: List[tf.Tensor] | None = None
408+
hidden_states: Tuple[tf.Tensor] | None = None
409+
attentions: Tuple[tf.Tensor] | None = None
408410

409411

410412
@dataclass
@@ -442,12 +444,12 @@ class TFCausalLMOutputWithCrossAttentions(ModelOutput):
442444
`past_key_values` input) to speed up sequential decoding.
443445
"""
444446

445-
loss: Optional[tf.Tensor] = None
447+
loss: tf.Tensor | None = None
446448
logits: tf.Tensor = None
447-
past_key_values: Optional[List[tf.Tensor]] = None
448-
hidden_states: Optional[Tuple[tf.Tensor]] = None
449-
attentions: Optional[Tuple[tf.Tensor]] = None
450-
cross_attentions: Optional[Tuple[tf.Tensor]] = None
449+
past_key_values: List[tf.Tensor] | None = None
450+
hidden_states: Tuple[tf.Tensor] | None = None
451+
attentions: Tuple[tf.Tensor] | None = None
452+
cross_attentions: Tuple[tf.Tensor] | None = None
451453

452454

453455
@dataclass
@@ -473,10 +475,10 @@ class TFMaskedLMOutput(ModelOutput):
473475
heads.
474476
"""
475477

476-
loss: Optional[tf.Tensor] = None
478+
loss: tf.Tensor | None = None
477479
logits: tf.Tensor = None
478-
hidden_states: Optional[Tuple[tf.Tensor]] = None
479-
attentions: Optional[Tuple[tf.Tensor]] = None
480+
hidden_states: Tuple[tf.Tensor] | None = None
481+
attentions: Tuple[tf.Tensor] | None = None
480482

481483

482484
@dataclass
@@ -527,15 +529,15 @@ class TFSeq2SeqLMOutput(ModelOutput):
527529
self-attention heads.
528530
"""
529531

530-
loss: Optional[tf.Tensor] = None
532+
loss: tf.Tensor | None = None
531533
logits: tf.Tensor = None
532-
past_key_values: Optional[List[tf.Tensor]] = None
533-
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
534-
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
535-
cross_attentions: Optional[Tuple[tf.Tensor]] = None
536-
encoder_last_hidden_state: Optional[tf.Tensor] = None
537-
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
538-
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
534+
past_key_values: List[tf.Tensor] | None = None
535+
decoder_hidden_states: Tuple[tf.Tensor] | None = None
536+
decoder_attentions: Tuple[tf.Tensor] | None = None
537+
cross_attentions: Tuple[tf.Tensor] | None = None
538+
encoder_last_hidden_state: tf.Tensor | None = None
539+
encoder_hidden_states: Tuple[tf.Tensor] | None = None
540+
encoder_attentions: Tuple[tf.Tensor] | None = None
539541

540542

541543
@dataclass
@@ -562,10 +564,10 @@ class TFNextSentencePredictorOutput(ModelOutput):
562564
heads.
563565
"""
564566

565-
loss: Optional[tf.Tensor] = None
567+
loss: tf.Tensor | None = None
566568
logits: tf.Tensor = None
567-
hidden_states: Optional[Tuple[tf.Tensor]] = None
568-
attentions: Optional[Tuple[tf.Tensor]] = None
569+
hidden_states: Tuple[tf.Tensor] | None = None
570+
attentions: Tuple[tf.Tensor] | None = None
569571

570572

571573
@dataclass
@@ -591,10 +593,10 @@ class TFSequenceClassifierOutput(ModelOutput):
591593
heads.
592594
"""
593595

594-
loss: Optional[tf.Tensor] = None
596+
loss: tf.Tensor | None = None
595597
logits: tf.Tensor = None
596-
hidden_states: Optional[Tuple[tf.Tensor]] = None
597-
attentions: Optional[Tuple[tf.Tensor]] = None
598+
hidden_states: Tuple[tf.Tensor] | None = None
599+
attentions: Tuple[tf.Tensor] | None = None
598600

599601

600602
@dataclass
@@ -642,15 +644,15 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
642644
self-attention heads.
643645
"""
644646

645-
loss: Optional[tf.Tensor] = None
647+
loss: tf.Tensor | None = None
646648
logits: tf.Tensor = None
647-
past_key_values: Optional[List[tf.Tensor]] = None
648-
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
649-
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
650-
cross_attentions: Optional[Tuple[tf.Tensor]] = None
651-
encoder_last_hidden_state: Optional[tf.Tensor] = None
652-
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
653-
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
649+
past_key_values: List[tf.Tensor] | None = None
650+
decoder_hidden_states: Tuple[tf.Tensor] | None = None
651+
decoder_attentions: Tuple[tf.Tensor] | None = None
652+
cross_attentions: Tuple[tf.Tensor] | None = None
653+
encoder_last_hidden_state: tf.Tensor | None = None
654+
encoder_hidden_states: Tuple[tf.Tensor] | None = None
655+
encoder_attentions: Tuple[tf.Tensor] | None = None
654656

655657

656658
@dataclass
@@ -684,10 +686,10 @@ class TFSemanticSegmenterOutput(ModelOutput):
684686
heads.
685687
"""
686688

687-
loss: Optional[tf.Tensor] = None
689+
loss: tf.Tensor | None = None
688690
logits: tf.Tensor = None
689-
hidden_states: Optional[Tuple[tf.Tensor]] = None
690-
attentions: Optional[Tuple[tf.Tensor]] = None
691+
hidden_states: Tuple[tf.Tensor] | None = None
692+
attentions: Tuple[tf.Tensor] | None = None
691693

692694

693695
@dataclass
@@ -716,9 +718,9 @@ class TFSemanticSegmenterOutputWithNoAttention(ModelOutput):
716718
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
717719
"""
718720

719-
loss: Optional[tf.Tensor] = None
721+
loss: tf.Tensor | None = None
720722
logits: tf.Tensor = None
721-
hidden_states: Optional[Tuple[tf.Tensor]] = None
723+
hidden_states: Tuple[tf.Tensor] | None = None
722724

723725

724726
@dataclass
@@ -742,10 +744,10 @@ class TFImageClassifierOutput(ModelOutput):
742744
heads.
743745
"""
744746

745-
loss: Optional[tf.Tensor] = None
747+
loss: tf.Tensor | None = None
746748
logits: tf.Tensor = None
747-
hidden_states: Optional[Tuple[tf.Tensor]] = None
748-
attentions: Optional[Tuple[tf.Tensor]] = None
749+
hidden_states: Tuple[tf.Tensor] | None = None
750+
attentions: Tuple[tf.Tensor] | None = None
749751

750752

751753
@dataclass
@@ -773,10 +775,10 @@ class TFMultipleChoiceModelOutput(ModelOutput):
773775
heads.
774776
"""
775777

776-
loss: Optional[tf.Tensor] = None
778+
loss: tf.Tensor | None = None
777779
logits: tf.Tensor = None
778-
hidden_states: Optional[Tuple[tf.Tensor]] = None
779-
attentions: Optional[Tuple[tf.Tensor]] = None
780+
hidden_states: Tuple[tf.Tensor] | None = None
781+
attentions: Tuple[tf.Tensor] | None = None
780782

781783

782784
@dataclass
@@ -802,10 +804,10 @@ class TFTokenClassifierOutput(ModelOutput):
802804
heads.
803805
"""
804806

805-
loss: Optional[tf.Tensor] = None
807+
loss: tf.Tensor | None = None
806808
logits: tf.Tensor = None
807-
hidden_states: Optional[Tuple[tf.Tensor]] = None
808-
attentions: Optional[Tuple[tf.Tensor]] = None
809+
hidden_states: Tuple[tf.Tensor] | None = None
810+
attentions: Tuple[tf.Tensor] | None = None
809811

810812

811813
@dataclass
@@ -833,11 +835,11 @@ class TFQuestionAnsweringModelOutput(ModelOutput):
833835
heads.
834836
"""
835837

836-
loss: Optional[tf.Tensor] = None
838+
loss: tf.Tensor | None = None
837839
start_logits: tf.Tensor = None
838840
end_logits: tf.Tensor = None
839-
hidden_states: Optional[Tuple[tf.Tensor]] = None
840-
attentions: Optional[Tuple[tf.Tensor]] = None
841+
hidden_states: Tuple[tf.Tensor] | None = None
842+
attentions: Tuple[tf.Tensor] | None = None
841843

842844

843845
@dataclass
@@ -884,15 +886,15 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
884886
self-attention heads.
885887
"""
886888

887-
loss: Optional[tf.Tensor] = None
889+
loss: tf.Tensor | None = None
888890
start_logits: tf.Tensor = None
889891
end_logits: tf.Tensor = None
890-
past_key_values: Optional[List[tf.Tensor]] = None
891-
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
892-
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
893-
encoder_last_hidden_state: Optional[tf.Tensor] = None
894-
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
895-
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
892+
past_key_values: List[tf.Tensor] | None = None
893+
decoder_hidden_states: Tuple[tf.Tensor] | None = None
894+
decoder_attentions: Tuple[tf.Tensor] | None = None
895+
encoder_last_hidden_state: tf.Tensor | None = None
896+
encoder_hidden_states: Tuple[tf.Tensor] | None = None
897+
encoder_attentions: Tuple[tf.Tensor] | None = None
896898

897899

898900
@dataclass
@@ -924,11 +926,11 @@ class TFSequenceClassifierOutputWithPast(ModelOutput):
924926
heads.
925927
"""
926928

927-
loss: Optional[tf.Tensor] = None
929+
loss: tf.Tensor | None = None
928930
logits: tf.Tensor = None
929-
past_key_values: Optional[List[tf.Tensor]] = None
930-
hidden_states: Optional[Tuple[tf.Tensor]] = None
931-
attentions: Optional[Tuple[tf.Tensor]] = None
931+
past_key_values: List[tf.Tensor] | None = None
932+
hidden_states: Tuple[tf.Tensor] | None = None
933+
attentions: Tuple[tf.Tensor] | None = None
932934

933935

934936
@dataclass
@@ -947,7 +949,7 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput):
947949
feature maps) of the model at the output of each stage.
948950
"""
949951

950-
loss: Optional[tf.Tensor] = None
952+
loss: tf.Tensor | None = None
951953
logits: tf.Tensor = None
952954
hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
953955

@@ -974,10 +976,10 @@ class TFMaskedImageModelingOutput(ModelOutput):
974976
heads.
975977
"""
976978

977-
loss: Optional[tf.Tensor] = None
979+
loss: tf.Tensor | None = None
978980
reconstruction: tf.Tensor = None
979-
hidden_states: Optional[Tuple[tf.Tensor]] = None
980-
attentions: Optional[Tuple[tf.Tensor]] = None
981+
hidden_states: Tuple[tf.Tensor] | None = None
982+
attentions: Tuple[tf.Tensor] | None = None
981983

982984
@property
983985
def logits(self):

0 commit comments

Comments
 (0)