Skip to content

Commit c81584a

Browse files
committed
Fix regression in regression (#11785)
* Fix regression in regression * Add test
1 parent 265c26e commit c81584a

File tree

15 files changed

+65
-15
lines changed

15 files changed

+65
-15
lines changed

src/transformers/models/albert/modeling_albert.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,10 @@ def forward(
10371037

10381038
if self.config.problem_type == "regression":
10391039
loss_fct = MSELoss()
1040-
loss = loss_fct(logits.view(-1, self.num_labels), labels)
1040+
if self.num_labels == 1:
1041+
loss = loss_fct(logits.squeeze(), labels.squeeze())
1042+
else:
1043+
loss = loss_fct(logits, labels)
10411044
elif self.config.problem_type == "single_label_classification":
10421045
loss_fct = CrossEntropyLoss()
10431046
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

src/transformers/models/bert/modeling_bert.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1528,7 +1528,10 @@ def forward(
15281528

15291529
if self.config.problem_type == "regression":
15301530
loss_fct = MSELoss()
1531-
loss = loss_fct(logits.view(-1, self.num_labels), labels)
1531+
if self.num_labels == 1:
1532+
loss = loss_fct(logits.squeeze(), labels.squeeze())
1533+
else:
1534+
loss = loss_fct(logits, labels)
15321535
elif self.config.problem_type == "single_label_classification":
15331536
loss_fct = CrossEntropyLoss()
15341537
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

src/transformers/models/big_bird/modeling_big_bird.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2671,7 +2671,10 @@ def forward(
26712671

26722672
if self.config.problem_type == "regression":
26732673
loss_fct = MSELoss()
2674-
loss = loss_fct(logits.view(-1, self.num_labels), labels)
2674+
if self.num_labels == 1:
2675+
loss = loss_fct(logits.squeeze(), labels.squeeze())
2676+
else:
2677+
loss = loss_fct(logits, labels)
26752678
elif self.config.problem_type == "single_label_classification":
26762679
loss_fct = CrossEntropyLoss()
26772680
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

src/transformers/models/convbert/modeling_convbert.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,10 @@ def forward(
10231023

10241024
if self.config.problem_type == "regression":
10251025
loss_fct = MSELoss()
1026-
loss = loss_fct(logits.view(-1, self.num_labels), labels)
1026+
if self.num_labels == 1:
1027+
loss = loss_fct(logits.squeeze(), labels.squeeze())
1028+
else:
1029+
loss = loss_fct(logits, labels)
10271030
elif self.config.problem_type == "single_label_classification":
10281031
loss_fct = CrossEntropyLoss()
10291032
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

src/transformers/models/distilbert/modeling_distilbert.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,10 @@ def forward(
642642

643643
if self.config.problem_type == "regression":
644644
loss_fct = MSELoss()
645-
loss = loss_fct(logits.view(-1, self.num_labels), labels)
645+
if self.num_labels == 1:
646+
loss = loss_fct(logits.squeeze(), labels.squeeze())
647+
else:
648+
loss = loss_fct(logits, labels)
646649
elif self.config.problem_type == "single_label_classification":
647650
loss_fct = CrossEntropyLoss()
648651
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

src/transformers/models/electra/modeling_electra.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,10 @@ def forward(
964964

965965
if self.config.problem_type == "regression":
966966
loss_fct = MSELoss()
967-
loss = loss_fct(logits.view(-1, self.num_labels), labels)
967+
if self.num_labels == 1:
968+
loss = loss_fct(logits.squeeze(), labels.squeeze())
969+
else:
970+
loss = loss_fct(logits, labels)
968971
elif self.config.problem_type == "single_label_classification":
969972
loss_fct = CrossEntropyLoss()
970973
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

src/transformers/models/funnel/modeling_funnel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1298,7 +1298,10 @@ def forward(
12981298

12991299
if self.config.problem_type == "regression":
13001300
loss_fct = MSELoss()
1301-
loss = loss_fct(logits.view(-1, self.num_labels), labels)
1301+
if self.num_labels == 1:
1302+
loss = loss_fct(logits.squeeze(), labels.squeeze())
1303+
else:
1304+
loss = loss_fct(logits, labels)
13021305
elif self.config.problem_type == "single_label_classification":
13031306
loss_fct = CrossEntropyLoss()
13041307
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

src/transformers/models/longformer/modeling_longformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1872,7 +1872,10 @@ def forward(
18721872

18731873
if self.config.problem_type == "regression":
18741874
loss_fct = MSELoss()
1875-
loss = loss_fct(logits.view(-1, self.num_labels), labels)
1875+
if self.num_labels == 1:
1876+
loss = loss_fct(logits.squeeze(), labels.squeeze())
1877+
else:
1878+
loss = loss_fct(logits, labels)
18761879
elif self.config.problem_type == "single_label_classification":
18771880
loss_fct = CrossEntropyLoss()
18781881
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

src/transformers/models/mobilebert/modeling_mobilebert.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1279,7 +1279,10 @@ def forward(
12791279

12801280
if self.config.problem_type == "regression":
12811281
loss_fct = MSELoss()
1282-
loss = loss_fct(logits.view(-1, self.num_labels), labels)
1282+
if self.num_labels == 1:
1283+
loss = loss_fct(logits.squeeze(), labels.squeeze())
1284+
else:
1285+
loss = loss_fct(logits, labels)
12831286
elif self.config.problem_type == "single_label_classification":
12841287
loss_fct = CrossEntropyLoss()
12851288
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

src/transformers/models/reformer/modeling_reformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2445,7 +2445,10 @@ def forward(
24452445

24462446
if self.config.problem_type == "regression":
24472447
loss_fct = MSELoss()
2448-
loss = loss_fct(logits.view(-1, self.num_labels), labels)
2448+
if self.num_labels == 1:
2449+
loss = loss_fct(logits.squeeze(), labels.squeeze())
2450+
else:
2451+
loss = loss_fct(logits, labels)
24492452
elif self.config.problem_type == "single_label_classification":
24502453
loss_fct = CrossEntropyLoss()
24512454
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

0 commit comments

Comments
 (0)