Skip to content

Commit 6c41483

Browse files
Rocketknight1amyerobertssgugger
authored andcommitted
TF port of the Segment Anything Model (SAM) (huggingface#22970)
* First commit * Add auto-translation with GPT-4 * make fixup * Add a functional layernorm for TF * Add all the auxiliary imports etc. * Add the extra processor and tests * rebase to main * Add all the needed fixes to the GPT code * make fixup * Make convolutions channels-last so they run on CPU * make fixup * Fix final issues * Fix other models affected by test change * Clarify comment on the sparse_prompt_embeddings check * Refactor functional_layernorm, use shape_list in place of .shape in some places * Remove deprecated torch-alike code * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <[email protected]> * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <[email protected]> * Refactor processor with common methods and separated private methods * make fixup * Quietly delete the file that didn't do anything (sorry Sylvain) * Refactor the processor tests into one file * make fixup * Clean up some unnecessary indirection * Fix TF mask postprocessing * Add more processor equivalence tests * Refactor generate_crop_boxes to use framework-neutral np code * Make the serving output correctly conditional * Fix error message line length * Use dict keys rather than indices internally in both TF and PT SAM call/forward * Return dicts internally in the call/forward methods * Revert changes to common tests and just override check_pt_tf_outputs * Revert changes to other model tests * Clarify comments for functional layernorm * Add missing transpose from PT code * Removed unused copied from in PT code * Remove overrides for tests that don't exist in TF * Fix transpose and update tests for PT and TF to check pred_masks * Add training flag * Update tests to use TF checkpoints * Update index.mdx * Add missing cross-test decorator * Remove optional extra asterisks * Revert return_dict changes in PT code * Update src/transformers/models/sam/modeling_tf_sam.py Co-authored-by: Sylvain Gugger <[email protected]> * Remove None return annotations on init methods * Update tests/models/sam/test_processor_sam.py Co-authored-by: amyeroberts <[email protected]> * Fix input_boxes shapes * make fixup --------- Co-authored-by: amyeroberts <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
1 parent d5f6a80 commit 6c41483

File tree

14 files changed

+2940
-44
lines changed

14 files changed

+2940
-44
lines changed

docs/source/en/index.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow.
399399
| RoCBert | | | | | |
400400
| RoFormer | | | | | |
401401
| RWKV | | | | | |
402-
| SAM | | | | | |
402+
| SAM | | | | | |
403403
| SegFormer | | | | | |
404404
| SEW | | | | | |
405405
| SEW-D | | | | | |

docs/source/en/model_doc/sam.mdx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,9 @@ Resources:
9999

100100
[[autodoc]] SamModel
101101
- forward
102+
103+
104+
## TFSamModel
105+
106+
[[autodoc]] TFSamModel
107+
- call

src/transformers/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3406,6 +3406,13 @@
34063406
"TFRoFormerPreTrainedModel",
34073407
]
34083408
)
3409+
_import_structure["models.sam"].extend(
3410+
[
3411+
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
3412+
"TFSamModel",
3413+
"TFSamPreTrainedModel",
3414+
]
3415+
)
34093416
_import_structure["models.segformer"].extend(
34103417
[
34113418
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -6660,6 +6667,11 @@
66606667
TFRoFormerModel,
66616668
TFRoFormerPreTrainedModel,
66626669
)
6670+
from .models.sam import (
6671+
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST,
6672+
TFSamModel,
6673+
TFSamPreTrainedModel,
6674+
)
66636675
from .models.segformer import (
66646676
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
66656677
TFSegformerDecodeHead,

src/transformers/models/auto/modeling_tf_auto.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
("roberta", "TFRobertaModel"),
7777
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
7878
("roformer", "TFRoFormerModel"),
79+
("sam", "TFSamModel"),
7980
("segformer", "TFSegformerModel"),
8081
("speech_to_text", "TFSpeech2TextModel"),
8182
("swin", "TFSwinModel"),
@@ -426,6 +427,11 @@
426427
("mobilebert", "TFMobileBertForNextSentencePrediction"),
427428
]
428429
)
430+
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
431+
[
432+
("sam", "TFSamModel"),
433+
]
434+
)
429435

430436
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
431437
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
@@ -476,6 +482,14 @@
476482
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
477483
)
478484

485+
TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
486+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
487+
)
488+
489+
490+
class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
491+
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
492+
479493

480494
class TFAutoModel(_BaseAutoModelClass):
481495
_model_mapping = TF_MODEL_MAPPING

src/transformers/models/sam/__init__.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
# limitations under the License.
1414
from typing import TYPE_CHECKING
1515

16-
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
16+
from ...utils import (
17+
OptionalDependencyNotAvailable,
18+
_LazyModule,
19+
is_tf_available,
20+
is_torch_available,
21+
is_vision_available,
22+
)
1723

1824

1925
_import_structure = {
@@ -39,6 +45,17 @@
3945
"SamModel",
4046
"SamPreTrainedModel",
4147
]
48+
try:
49+
if not is_tf_available():
50+
raise OptionalDependencyNotAvailable()
51+
except OptionalDependencyNotAvailable:
52+
pass
53+
else:
54+
_import_structure["modeling_tf_sam"] = [
55+
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
56+
"TFSamModel",
57+
"TFSamPreTrainedModel",
58+
]
4259
try:
4360
if not is_vision_available():
4461
raise OptionalDependencyNotAvailable()
@@ -66,6 +83,14 @@
6683
else:
6784
from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel
6885

86+
try:
87+
if not is_tf_available():
88+
raise OptionalDependencyNotAvailable()
89+
except OptionalDependencyNotAvailable:
90+
pass
91+
else:
92+
from .modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, TFSamModel, TFSamPreTrainedModel
93+
6994
try:
7095
if not is_vision_available():
7196
raise OptionalDependencyNotAvailable()

0 commit comments

Comments
 (0)