Skip to content

Commit a0f8537

Browse files
committed
add
1 parent 5ee89eb commit a0f8537

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,7 @@ def gt(context, node):
953953
context.add(greater)
954954

955955

956-
@register_torch_op(torch_alias=["t", "numpy_t"])
956+
@register_torch_op(torch_alias=["t", "numpy_t", "transpose_copy"])
957957
def transpose(context, node):
958958
assert len(node.outputs) == 1
959959
inputs = _get_inputs(context, node)

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7001,6 +7001,20 @@ def test(self, compute_unit, backend, frontend, shape, dims):
70017001
)
70027002

70037003

7004+
class TestTransposeCopy(TorchBaseTest):
7005+
@pytest.mark.parametrize(
7006+
"compute_unit, backend, frontend, shape, dims",
7007+
itertools.product(
7008+
compute_units, backends, frontends, COMMON_SHAPES, [(0, 1), (-2, -1), (1, 0), (-1, -2)]
7009+
),
7010+
)
7011+
def test(self, compute_unit, backend, frontend, shape, dims):
7012+
model = ModuleWrapper(function=torch.transpose_copy, kwargs={"dim0": dims[0], "dim1": dims[1]})
7013+
self.run_compare_torch(
7014+
shape, model, compute_unit=compute_unit, backend=backend, frontend=frontend
7015+
)
7016+
7017+
70047018
class TestTo(TorchBaseTest):
70057019
@pytest.mark.parametrize(
70067020
"compute_unit, backend, frontend",

0 commit comments

Comments
 (0)