Skip to content

Conversation

@fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Aug 21, 2023

The ONNX export of SAM will fail at the next transformers release due to huggingface/transformers#25074, that unintentionally triggers a bug in PyTorch that is since fixed pytorch/pytorch#100429.

Note as well that there is an additional not-fixed bug (see pytorch/pytorch#107591) in PyTorch for repeat_interleave that make it impossible to export SAM on cuda device (and thus in fp16) following huggingface/transformers#25074 (so for transformers>=4.32). In order to export on CUDA device transformers<=4.31 is required, or alternatively we need to wait from a fix in PyTorch.

This PR also adds the variant for the ONNX export #1299. The idea is that for the same model, task, one may want to have different ONNX exported models. Currently, existing variants may be with-past, without-past, monolith for the encoder/decoder models. Here the motivation is that for SAM either we want to export the whole model standalone, or as separate vision encoder / prompt encoder, mask decoder.

A description of the variant is displayed during the export:

Using the export variant monolith. Available variants are:
    - monolith: All the SAM model components are exported as a single model.onnx.
    - split: The vision encoder is exported as a separate vision_encoder.onnx, and the prompt encoder and mask decoder are exported as a prompt_mask.onnx. This allows to encoder the image only once for multiple point queries.

If this is validated, we could move from -with-past and monolith configbehavior to variants.

Fixes #1078. By default exporting as split.

Potentially breaking (have to check the tests): the model patcher is used during the export validation, have to check that tests are fine.

@xenova
Copy link
Contributor

xenova commented Aug 21, 2023

Amazing! Testing now :)

@xenova
Copy link
Contributor

xenova commented Aug 21, 2023

Installed branch with:

!pip install git+https://github.com/fxmarty/optimum.git@sam-vision-encoder-onnx --upgrade

Ran into this issue when exporting:

Command:

!optimum-cli export onnx -m facebook/sam-vit-base out

Versions:
onnx: 1.14.0
torch: 2.0.0
onnxruntime: 1.15.1

Log:

Framework not specified. Using pt to export to ONNX.
Automatic task detection to feature-extraction (possible synonyms are: default, mask-generation, sentence-similarity).
Using framework PyTorch: 2.0.0+cu117
/usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:137: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if num_channels != self.num_channels:
/usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:141: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if height != self.image_size[0] or width != self.image_size[1]:
/usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:771: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  max_rel_dist = int(2 * max(q_size, k_size) - 1)
/usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:771: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  max_rel_dist = int(2 * max(q_size, k_size) - 1)
/usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:781: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
/usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:782: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
/usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:783: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
/usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:1384: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
/usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:652: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
/usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:502: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if sparse_prompt_embeddings.sum().item() != 0:
/usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:242: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  attn = attn / math.sqrt(c_per_head)
/usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:548: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if multimask_output:
============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Traceback (most recent call last):
  File "/usr/local/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/optimum/commands/optimum_cli.py", line 163, in main
    service.run()
  File "/usr/local/lib/python3.10/dist-packages/optimum/commands/export/onnx.py", line 219, in run
    main_export(
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/__main__.py", line 446, in main_export
    _, onnx_outputs = export_models(
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py", line 760, in export_models
    export(
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py", line 863, in export
    export_output = export_pytorch(
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py", line 580, in export_pytorch
    onnx_export(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 506, in export
    _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1117, in _model_to_graph
    graph = _optimize_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 665, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1891, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset9.py", line 4256, in repeat_interleave
    return symbolic_helper._onnx_opset_unsupported_detailed(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_helper.py", line 657, in _onnx_opset_unsupported_detailed
    raise errors.SymbolicValueError(
torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of repeat_interleave in opset 9. Unsupported along dimension with unknown input size. Please try opset version 13.  [Caused by the value '3703 defined in (%3703 : Float(*, *, *, *, strides=[1048576, 4096, 64, 1], requires_grad=0, device=cpu) = onnx::Add(%3376, %3672), scope: transformers.models.sam.modeling_sam.SamModel::/transformers.models.sam.modeling_sam.SamMaskDecoder::mask_decoder # /usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:509:0
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::Add'.] 
    (node defined in /usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py(509): forward
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py(1488): _slow_forward
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py(1501): _call_impl
/usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py(1400): forward
/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/model_patcher.py(102): patched_forward
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py(1488): _slow_forward
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py(1501): _call_impl
/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py(118): wrapper
/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py(127): forward
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py(1501): _call_impl
/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py(1268): _get_trace_graph
/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py(893): _trace_and_get_graph_from_model
/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py(989): _create_jit_graph
/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py(1113): _model_to_graph
/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py(1548): _export
/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py(506): export
/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py(580): export_pytorch
/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py(863): export
/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py(760): export_models
/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/__main__.py(446): main_export
/usr/local/lib/python3.10/dist-packages/optimum/commands/export/onnx.py(219): run
/usr/local/lib/python3.10/dist-packages/optimum/commands/optimum_cli.py(163): main
/usr/local/bin/optimum-cli(8): <module>
)

    Inputs:
        #0: 3376 defined in (%3376 : Float(*, 256, *, *, strides=[1048576, 4096, 64, 1], requires_grad=0, device=cpu) = onnx::Add(%3373, %3375), scope: transformers.models.sam.modeling_sam.SamModel::/transformers.models.sam.modeling_sam.SamVisionEncoder::vision_encoder/transformers.models.sam.modeling_sam.SamVisionNeck::neck/transformers.models.sam.modeling_sam.SamLayerNorm::layer_norm2 # /usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:190:0
    )  (type 'Tensor')
        #1: 3672 defined in (%3672 : Float(*, *, *, *, strides=[0, 1, 0, 0], requires_grad=1, device=cpu) = onnx::Expand(%3655, %3671), scope: transformers.models.sam.modeling_sam.SamModel::/transformers.models.sam.modeling_sam.SamPromptEncoder::prompt_encoder # /usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:717:0
    )  (type 'Tensor')
    Outputs:
        #0: 3703 defined in (%3703 : Float(*, *, *, *, strides=[1048576, 4096, 64, 1], requires_grad=0, device=cpu) = onnx::Add(%3376, %3672), scope: transformers.models.sam.modeling_sam.SamModel::/transformers.models.sam.modeling_sam.SamMaskDecoder::mask_decoder # /usr/local/lib/python3.10/dist-packages/transformers/models/sam/modeling_sam.py:509:0
    )  (type 'Tensor')

Am I missing something? 👀 I skimmed over the code and it should be using opset 13

https://github.com/huggingface/optimum/pull/1301/files#diff-cf7d52b6c0ee70bc45f40c2c40ac99b49dbf02f2fdb0a141e9483f764126e05eR1218

@fxmarty
Copy link
Contributor Author

fxmarty commented Aug 22, 2023

Which version of transformers are you using? If >4.31, could you update to PyTorch nightly (cf my comment above)? Are you sure you checkout out in the correct branch?

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variant approach seems nice and more generic.

Comment on lines +94 to +99
optional_group.add_argument(
"--variant",
type=str,
default="default",
help=("Select a variant of the model to export."),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a set of possible choices here.

Copy link
Contributor Author

@fxmarty fxmarty Aug 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be a bit tricky given that the choices are dynamic (dependent on the onnx config).

monolith: bool,
custom_onnx_configs: Dict,
custom_architecture: bool,
_variant: str,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why make it a protected parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking to keep it "private" for now, and support it correctly once we move to this API fully instead of -with-past, monolith, etc.

@xenova
Copy link
Contributor

xenova commented Aug 22, 2023

Which version of transformers are you using?

Sorry, for some reason I forgot to mention: installed from source

could you update to PyTorch nightly (cf my comment above)?

Will try that 👍 For the testing I just downgraded to PyTorch v2.0.0.

Are you sure you checkout out in the correct branch?

I believe so, as I ran

!pip install git+https://github.com/fxmarty/optimum.git@sam-vision-encoder-onnx --upgrade

Will do more testing today.

@fxmarty
Copy link
Contributor Author

fxmarty commented Aug 22, 2023

@xenova transformers from source + PyTorch 2.0 does not work for SAM ONNX export (you should get a meaningful error message actually) due to the issue detailed in my first post.

Edit: something like this:

2023-08-22 18:11:25.472260: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-22 18:11:25.499068: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Framework not specified. Using pt to export to ONNX.
Automatic task detection to feature-extraction (possible synonyms are: default, mask-generation, sentence-similarity).
Using the export variant split. Available variants are:
        - monolith: All the SAM model components are exported as a single model.onnx.
        - split: The vision encoder is exported as a separate vision_encoder.onnx, and the prompt encoder and mask decoder are exported as a prompt_mask.onnx. This allows to encoder the image only once for multiple point queries.
Traceback (most recent call last):
  File "/home/fxmarty/anaconda3/envs/hf-inf/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
  File "/home/fxmarty/hf_internship/optimum/optimum/commands/optimum_cli.py", line 163, in main
    service.run()
  File "/home/fxmarty/hf_internship/optimum/optimum/commands/export/onnx.py", line 225, in run
    main_export(
  File "/home/fxmarty/hf_internship/optimum/optimum/exporters/onnx/__main__.py", line 460, in main_export
    _, onnx_outputs = export_models(
  File "/home/fxmarty/hf_internship/optimum/optimum/exporters/onnx/convert.py", line 763, in export_models
    export(
  File "/home/fxmarty/hf_internship/optimum/optimum/exporters/onnx/convert.py", line 855, in export
    raise MinimumVersionError(
optimum.exporters.error_utils.MinimumVersionError: Unsupported PyTorch version for this model. Minimum required is 2.0.99, got: 2.0.1+cu117

@xenova
Copy link
Contributor

xenova commented Aug 23, 2023

After fixing some versions, I got it exported! 🥳

Validating ONNX model sam_onnx/vision_encoder.onnx...
	-[✓] ONNX model output names match reference model (image_positional_embeddings, image_embeddings)
	- Validating ONNX Model output "image_embeddings":
		-[✓] (2, 256, 64, 64) matches (2, 256, 64, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "image_positional_embeddings":
		-[✓] (2, 256, 64, 64) matches (2, 256, 64, 64)
		-[✓] all values close (atol: 1e-05)
Validating ONNX model sam_onnx/prompt_mask.onnx...
	-[✓] ONNX model output names match reference model (pred_masks, iou_scores)
	- Validating ONNX Model output "iou_scores":
		-[✓] (2, 3, 3) matches (2, 3, 3)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "pred_masks":
		-[✓] (2, 3, 3, 256, 256) matches (2, 3, 3, 256, 256)
		-[x] values not close enough, max diff: 0.003204345703125 (atol: 1e-05)
The ONNX export succeeded with the warning: The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance 1e-05:
- pred_masks: max diff = 0.003204345703125.

a noticeable difference in validation of pred_masks though (3e-3); not sure if it will cause any issues.

Let me try quantizing for use in transformers.js 🔥

Copy link
Contributor

@xenova xenova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-exported, but it looks like the config.json has changed quite a bit (see diff):

image

else:
# We use the model patcher to patch their forward method.
models_for_export["vision_encoder"] = model
models_for_export["prompt_mask"] = model
Copy link
Contributor

@xenova xenova Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason behind the naming btw?

What about "vision_encoder" and "mask_decoder"? I'd say this better aligns with the original naming and HF names:

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, is prompt_mask a combination of the prompt_encoder and mask_decoder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, maybe I should name it more explicitly prompt_encoder_mask_decoder?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could help I guess (albeit a bit verbose 😅).

Also, something you've probably thought about already: is there a situation where the user may want to use the prompt encoder and mask decoder separately? I haven't needed to worry about this (yet), but just asking in case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess when we want to encode a point for several different images?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave as is and just rename to prompt_encoder_mask_decoder.onnx.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess when we want to encode a point for several different images?

Indeed, quite a rare use case 😅

I'll leave as is and just rename to prompt_encoder_mask_decoder.onnx.

Sounds good!

@xenova
Copy link
Contributor

xenova commented Aug 24, 2023

Good news: got the export working in Transformers.js! 🥳

I replicated the demo code from here to test it. I'm still trying to figure out the input and output format, but it looks like it matches the python version quite well! 🤣

(quantized output)
image

python version

image

@fxmarty
Copy link
Contributor Author

fxmarty commented Aug 24, 2023

Awesome, thank you for giving it a try! So it means this PR is probably good to merge.

I re-exported, but it looks like the config.json has changed quite a bit (see diff):

Good catch! I'll fix that.

@fxmarty
Copy link
Contributor Author

fxmarty commented Aug 24, 2023

@xenova This PR huggingface/transformers#25237 broke the nested config loading it seems

Edit: it is not really a bug in transformers, but a behavior that I find unintuitive. Basically saved configs are not greedy due to a hard-coded use_diff=True when saving a config. The PR linked above treats subconfigs as PretrainedConfig, thus the change of behavior. I don't think there is anything I can do.

@xenova
Copy link
Contributor

xenova commented Aug 24, 2023

Ah, I see 😅. Well, I'll be able to work around this in transformers.js; but I thought I should bring it up.

We can also just wait for the transformers team to respond to that issue before merging this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[SAM] Split encoder and mask decoder into separate .onnx files [ONNX] torch.repeat_interleave export failed

3 participants