Skip to content

Commit 88f50a1

Browse files
authored
Add TensorFlow implementation of EfficientFormer (#22620)
* Add tf code for efficientformer * Fix return dict bug - return last hidden state after last stage * Fix corresponding return dict bug * Override test tol * Change default values of training to False * Set training to default False X3 * Rm axis from ln * Set init in dense projection * Rm debug stuff * Make style; all tests pass. * Modify year to 2023 * Fix attention biases codes * Update the shape list logic * Add a batch norm eps config * Remove extract comments in test files * Add conditional attn and hidden states return for serving output * Change channel dim checking logic * Add exception for withteacher model in training mode * Revert layer count for now * Add layer count for conditional layer naming * Transpose for conv happens only in main layer * Make tests smaller * Make style * Update doc * Rm from_pt * Change to actual expect image class label * Remove stray print in tests * Update image processor test * Remove the old serving output logic * Make style * Make style * Complete test
1 parent 9fea71b commit 88f50a1

File tree

12 files changed

+1537
-23
lines changed

12 files changed

+1537
-23
lines changed

docs/source/en/index.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ Flax), PyTorch, and/or TensorFlow.
313313
| DonutSwin | | | | | |
314314
| DPR | | | | | |
315315
| DPT | | | | | |
316-
| EfficientFormer | | | | | |
316+
| EfficientFormer | | | | | |
317317
| EfficientNet | | | | | |
318318
| ELECTRA | | | | | |
319319
| Encoder decoder | | | | | |

docs/source/en/model_doc/efficientformer.mdx

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ EfficientFormer-L7, obtains 83.3% accuracy with only 7.0 ms latency. Our work pr
3737
reach extremely low latency on mobile devices while maintaining high performance.*
3838

3939
This model was contributed by [novice03](https://huggingface.co/novice03) and [Bearnardd](https://huggingface.co/Bearnardd).
40-
The original code can be found [here](https://github.com/snap-research/EfficientFormer).
40+
The original code can be found [here](https://github.com/snap-research/EfficientFormer). The TensorFlow version of this model was added by [D-Roberts](https://huggingface.co/D-Roberts).
4141

4242
## Documentation resources
4343

@@ -66,3 +66,18 @@ The original code can be found [here](https://github.com/snap-research/Efficient
6666

6767
[[autodoc]] EfficientFormerForImageClassificationWithTeacher
6868
- forward
69+
70+
## TFEfficientFormerModel
71+
72+
[[autodoc]] TFEfficientFormerModel
73+
- call
74+
75+
## TFEfficientFormerForImageClassification
76+
77+
[[autodoc]] TFEfficientFormerForImageClassification
78+
- call
79+
80+
## TFEfficientFormerForImageClassificationWithTeacher
81+
82+
[[autodoc]] TFEfficientFormerForImageClassificationWithTeacher
83+
- call

src/transformers/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3142,6 +3142,15 @@
31423142
"TFDPRReader",
31433143
]
31443144
)
3145+
_import_structure["models.efficientformer"].extend(
3146+
[
3147+
"TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
3148+
"TFEfficientFormerForImageClassification",
3149+
"TFEfficientFormerForImageClassificationWithTeacher",
3150+
"TFEfficientFormerModel",
3151+
"TFEfficientFormerPreTrainedModel",
3152+
]
3153+
)
31453154
_import_structure["models.electra"].extend(
31463155
[
31473156
"TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -6471,6 +6480,13 @@
64716480
TFDPRQuestionEncoder,
64726481
TFDPRReader,
64736482
)
6483+
from .models.efficientformer import (
6484+
TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
6485+
TFEfficientFormerForImageClassification,
6486+
TFEfficientFormerForImageClassificationWithTeacher,
6487+
TFEfficientFormerModel,
6488+
TFEfficientFormerPreTrainedModel,
6489+
)
64746490
from .models.electra import (
64756491
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
64766492
TFElectraForMaskedLM,

src/transformers/models/auto/modeling_tf_auto.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
("deit", "TFDeiTModel"),
4848
("distilbert", "TFDistilBertModel"),
4949
("dpr", "TFDPRQuestionEncoder"),
50+
("efficientformer", "TFEfficientFormerModel"),
5051
("electra", "TFElectraModel"),
5152
("esm", "TFEsmModel"),
5253
("flaubert", "TFFlaubertModel"),
@@ -202,6 +203,10 @@
202203
("cvt", "TFCvtForImageClassification"),
203204
("data2vec-vision", "TFData2VecVisionForImageClassification"),
204205
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
206+
(
207+
"efficientformer",
208+
("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"),
209+
),
205210
("mobilevit", "TFMobileViTForImageClassification"),
206211
("regnet", "TFRegNetForImageClassification"),
207212
("resnet", "TFResNetForImageClassification"),

src/transformers/models/efficientformer/__init__.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
# limitations under the License.
1414
from typing import TYPE_CHECKING
1515

16-
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
16+
from ...utils import (
17+
OptionalDependencyNotAvailable,
18+
_LazyModule,
19+
is_tf_available,
20+
is_torch_available,
21+
is_vision_available,
22+
)
1723

1824

1925
_import_structure = {
@@ -45,6 +51,20 @@
4551
"EfficientFormerPreTrainedModel",
4652
]
4753

54+
try:
55+
if not is_tf_available():
56+
raise OptionalDependencyNotAvailable()
57+
except OptionalDependencyNotAvailable:
58+
pass
59+
else:
60+
_import_structure["modeling_tf_efficientformer"] = [
61+
"TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
62+
"TFEfficientFormerForImageClassification",
63+
"TFEfficientFormerForImageClassificationWithTeacher",
64+
"TFEfficientFormerModel",
65+
"TFEfficientFormerPreTrainedModel",
66+
]
67+
4868
if TYPE_CHECKING:
4969
from .configuration_efficientformer import EFFICIENTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, EfficientFormerConfig
5070

@@ -69,6 +89,19 @@
6989
EfficientFormerModel,
7090
EfficientFormerPreTrainedModel,
7191
)
92+
try:
93+
if not is_tf_available():
94+
raise OptionalDependencyNotAvailable()
95+
except OptionalDependencyNotAvailable:
96+
pass
97+
else:
98+
from .modeling_tf_efficientformer import (
99+
TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
100+
TFEfficientFormerForImageClassification,
101+
TFEfficientFormerForImageClassificationWithTeacher,
102+
TFEfficientFormerModel,
103+
TFEfficientFormerPreTrainedModel,
104+
)
72105

73106
else:
74107
import sys

src/transformers/models/efficientformer/configuration_efficientformer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class EfficientFormerConfig(PretrainedConfig):
5252
The size of the key in meta3D block.
5353
attention_ratio (`int`, *optional*, defaults to 4):
5454
Ratio of the dimension of the query and value to the dimension of the key in MSHA block
55-
resolution (`int`, *optional*, defaults to 5)
55+
resolution (`int`, *optional*, defaults to 7)
5656
Size of each patch
5757
num_hidden_layers (`int`, *optional*, defaults to 5):
5858
Number of hidden layers in the Transformer encoder.
@@ -91,6 +91,8 @@ class EfficientFormerConfig(PretrainedConfig):
9191
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
9292
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
9393
The epsilon used by the layer normalization layers.
94+
image_size (`int`, *optional*, defaults to `224`):
95+
The size (resolution) of each image.
9496
9597
Example:
9698
@@ -136,6 +138,8 @@ def __init__(
136138
hidden_act: str = "gelu",
137139
initializer_range: float = 0.02,
138140
layer_norm_eps: float = 1e-12,
141+
image_size: int = 224,
142+
batch_norm_eps: float = 1e-05,
139143
**kwargs,
140144
) -> None:
141145
super().__init__(**kwargs)
@@ -165,3 +169,5 @@ def __init__(
165169
self.distillation = distillation
166170
self.use_layer_scale = use_layer_scale
167171
self.layer_scale_init_value = layer_scale_init_value
172+
self.image_size = image_size
173+
self.batch_norm_eps = batch_norm_eps

src/transformers/models/efficientformer/modeling_efficientformer.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
# Base docstring
4545
_CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300"
46-
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
46+
_EXPECTED_OUTPUT_SHAPE = [1, 49, 448]
4747

4848
# Image classification docstring
4949
_IMAGE_CLASS_CHECKPOINT = "snap-research/efficientformer-l1-300"
@@ -73,7 +73,7 @@ def __init__(self, config: EfficientFormerConfig, num_channels: int, embed_dim:
7373
stride=config.downsample_stride,
7474
padding=config.downsample_pad,
7575
)
76-
self.norm = nn.BatchNorm2d(embed_dim) if apply_norm else nn.Identity()
76+
self.norm = nn.BatchNorm2d(embed_dim, eps=config.batch_norm_eps) if apply_norm else nn.Identity()
7777

7878
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
7979
batch_size, num_channels, height, width = pixel_values.shape
@@ -157,10 +157,10 @@ def __init__(self, config: EfficientFormerConfig, out_channels: int):
157157
super().__init__()
158158

159159
self.convolution1 = nn.Conv2d(config.num_channels, out_channels // 2, kernel_size=3, stride=2, padding=1)
160-
self.batchnorm_before = nn.BatchNorm2d(out_channels // 2)
160+
self.batchnorm_before = nn.BatchNorm2d(out_channels // 2, eps=config.batch_norm_eps)
161161

162162
self.convolution2 = nn.Conv2d(out_channels // 2, out_channels, kernel_size=3, stride=2, padding=1)
163-
self.batchnorm_after = nn.BatchNorm2d(out_channels)
163+
self.batchnorm_after = nn.BatchNorm2d(out_channels, eps=config.batch_norm_eps)
164164

165165
self.activation = nn.ReLU()
166166

@@ -224,24 +224,24 @@ def __init__(
224224
hidden_features = hidden_features or in_features
225225

226226
self.convolution1 = nn.Conv2d(in_features, hidden_features, 1)
227-
self.actvation = ACT2FN[config.hidden_act]
227+
self.activation = ACT2FN[config.hidden_act]
228228
self.convolution2 = nn.Conv2d(hidden_features, out_features, 1)
229229
self.dropout = nn.Dropout(drop)
230230

231-
self.batchnorm_before = nn.BatchNorm2d(hidden_features)
232-
self.batchnorm_after = nn.BatchNorm2d(out_features)
231+
self.batchnorm_before = nn.BatchNorm2d(hidden_features, eps=config.batch_norm_eps)
232+
self.batchnorm_after = nn.BatchNorm2d(out_features, eps=config.batch_norm_eps)
233233

234234
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
235235
hidden_state = self.convolution1(hidden_state)
236236
hidden_state = self.batchnorm_before(hidden_state)
237237

238-
hidden_state = self.actvation(hidden_state)
238+
hidden_state = self.activation(hidden_state)
239239
hidden_state = self.dropout(hidden_state)
240240
hidden_state = self.convolution2(hidden_state)
241241

242242
hidden_state = self.batchnorm_after(hidden_state)
243-
244243
hidden_state = self.dropout(hidden_state)
244+
245245
return hidden_state
246246

247247

@@ -266,7 +266,7 @@ def drop_path(input, drop_prob: float = 0.0, training: bool = False):
266266
return output
267267

268268

269-
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Bit
269+
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->EfficientFormer
270270
class EfficientFormerDropPath(nn.Module):
271271
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
272272

@@ -301,8 +301,10 @@ def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0
301301
attention_ratio=config.attention_ratio,
302302
resolution=config.resolution,
303303
)
304-
self.layernorm1 = nn.LayerNorm(dim)
305-
self.layernorm2 = nn.LayerNorm(dim)
304+
305+
self.layernorm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
306+
self.layernorm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
307+
306308
mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
307309
self.mlp = EfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim)
308310

@@ -346,15 +348,20 @@ def __init__(self, config: EfficientFormerConfig):
346348

347349
def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
348350
all_attention_outputs = () if output_attentions else None
351+
349352
for layer_module in self.blocks:
350353
if isinstance(hidden_states, tuple):
351354
hidden_states = hidden_states[0]
355+
352356
hidden_states = layer_module(hidden_states, output_attentions)
357+
353358
if output_attentions:
354359
all_attention_outputs = all_attention_outputs + (hidden_states[1],)
360+
355361
if output_attentions:
356362
outputs = (hidden_states[0],) + all_attention_outputs
357363
return outputs
364+
358365
return hidden_states
359366

360367

@@ -379,6 +386,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
379386

380387
if self.use_layer_scale:
381388
layer_output = hidden_states + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * outputs)
389+
382390
layer_output = layer_output + self.drop_path(
383391
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(layer_output)
384392
)
@@ -398,6 +406,7 @@ def __init__(self, config: EfficientFormerConfig, stage_idx: int):
398406
drop_paths = [
399407
config.drop_path_rate * (block_idx + sum(config.depths[:stage_idx])) for block_idx in range(num_layers)
400408
]
409+
401410
self.blocks = nn.ModuleList(
402411
[
403412
EfficientFormerMeta4D(config, config.hidden_sizes[stage_idx], drop_path=drop_path)
@@ -446,6 +455,7 @@ def __init__(self, config: EfficientFormerConfig):
446455
for i in range(num_intermediate_stages)
447456
]
448457
intermediate_stages = []
458+
449459
for i in range(num_intermediate_stages):
450460
intermediate_stages.append(EfficientFormerIntermediateStage(config, i))
451461
if downsamples[i]:
@@ -475,14 +485,15 @@ def forward(
475485
all_hidden_states = all_hidden_states + (hidden_states,)
476486

477487
layer_output = self.last_stage(hidden_states, output_attentions=output_attentions)
488+
478489
if output_attentions:
479490
all_self_attentions = all_self_attentions + layer_output[1:]
480491

481492
if output_hidden_states:
482493
all_hidden_states = all_hidden_states + (layer_output[0],)
483494

484495
if not return_dict:
485-
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
496+
return tuple(v for v in [layer_output[0], all_hidden_states, all_self_attentions] if v is not None)
486497

487498
return BaseModelOutput(
488499
last_hidden_state=layer_output[0],

0 commit comments

Comments
 (0)