-
Couldn't load subscription status.
- Fork 725
Closed
Labels
PyTorch (traced)bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)
Description
🐞Describing the bug
ValueError: Torch var mask.3 not found in context when I convert the Transforms ClipTextEncoder model to coreml.
Stack Trace
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[<ipython-input-52-dd4c56533b83>](https://localhost:8080/#) in <module>
11 traced_model = torch.jit.trace(model, [input_ids, attention_mask], strict=False)
12
---> 13 model = ct.convert(
14 traced_model,
15 convert_to="mlprogram",
12 frames
[/usr/local/lib/python3.8/dist-packages/coremltools/converters/mil/frontend/torch/converter.py](https://localhost:8080/#) in __getitem__(self, torch_name)
86 if torch_name in current_graph:
87 return self._current_graph[idx][torch_name]
---> 88 raise ValueError(
89 "Torch var {} not found in context {}".format(torch_name, self.name)
90 )
ValueError: Torch var mask.3 not found in context
To Reproduce
colab: https://colab.research.google.com/drive/1f_mvFftpCdiGDJBZnI18RzsjU0tz-YWT?usp=sharing
! pip install coremltools --pre -U
! pip install transformers
from transformers import AutoTokenizer, CLIPTextModel
import torch
import coremltools as ct
model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
traced_model = torch.jit.trace(model, [input_ids, attention_mask], strict=False)
model = ct.convert(
traced_model,
convert_to="mlprogram",
inputs=[
ct.TensorType(shape=input_ids.shape),
ct.TensorType(shape=attention_mask.shape)])
model.save("CLIPTextModel.mlpackage")System environment (please complete the following information):
- coremltools version: lastest
- OS (e.g. MacOS version or Linux type): ubuntu
- Any other relevant version information (e.g. PyTorch or TensorFlow version): PyTorch 1.13.1
Metadata
Metadata
Assignees
Labels
PyTorch (traced)bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)