Skip to content

Commit eba8f10

Browse files
authored
[PT FE] Support F8 constants (#29313)
### Details: - *Support F8 constants* - *Allow to set patch condition externally* ### Tickets: - *ticket-id* Signed-off-by: Maxim Vafin <[email protected]>
1 parent 6aec418 commit eba8f10

File tree

3 files changed

+102
-42
lines changed

3 files changed

+102
-42
lines changed

src/bindings/python/src/openvino/frontend/pytorch/patch_model.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,25 +76,30 @@ def unpatch_model(model, orig_forward_name):
7676
"Original exception details:\n%s", error)
7777

7878

79-
def __make_16bit_traceable(model: torch.nn.Module):
79+
def __make_16bit_traceable(model: torch.nn.Module,
80+
orig_forward_name: str = "_openvino_module_extension_patch_orig_forward",
81+
patch_condition=None):
8082
"""
8183
Prepare a 16-bit PyTorch model for tracing with OpenVINO.
8284
- Replace known list of modules with ModuleExtension.
8385
- Convert other modules with weights to FP32.
8486
"""
85-
def patch_condition(module):
86-
supported = [torch.float32, torch.float16, torch.bfloat16]
87-
return (hasattr(module, "weight")
88-
and getattr(module.weight, "dtype", None) in supported)
89-
87+
if patch_condition is None:
88+
def patch_condition(module):
89+
supported = {torch.float32, torch.float16, torch.bfloat16}
90+
weight = getattr(module, "weight", None)
91+
return weight is not None and weight.dtype in supported
92+
93+
def fp32_tensor(*shape):
94+
return torch.full(shape, 0.5, dtype=torch.float32)
95+
9096
extensions = {
9197
torch.nn.Linear: ModuleExtension(
9298
torch.nn.Linear, "ov_ext::linear",
9399
convert=lambda module, target_op, *args, **kwargs: target_op(args[0],
94100
module.weight,
95101
module.bias),
96-
evaluate=lambda module, *args, **kwargs: torch.full(
97-
list(args[0].shape[:-1]) + [module.out_features], 0.5, dtype=torch.float32),
102+
evaluate=lambda module, *args, **kwargs: fp32_tensor(*args[0].shape[:-1], module.out_features),
98103
condition=patch_condition),
99104
torch.nn.Embedding: ModuleExtension(
100105
torch.nn.Embedding, "ov_ext::embedding",
@@ -103,8 +108,7 @@ def patch_condition(module):
103108
module.padding_idx,
104109
module.scale_grad_by_freq,
105110
module.sparse),
106-
evaluate=lambda module, *args, **kwargs: torch.full(
107-
list(args[1].shape) + [module.embedding_dim], 0.5, dtype=torch.float32),
111+
evaluate=lambda module, *args, **kwargs: fp32_tensor(*args[1].shape, module.embedding_dim),
108112
condition=patch_condition),
109113
}
110114
try:
@@ -114,14 +118,12 @@ def patch_condition(module):
114118
convert=lambda module, target_op, *args, **kwargs: target_op(args[0],
115119
module.weight,
116120
module.bias),
117-
evaluate=lambda module, *args, **kwargs: torch.full(
118-
list(args[0].shape[:-1]) + [module.nf], 0.5, dtype=torch.float32),
121+
evaluate=lambda module, *args, **kwargs: fp32_tensor(*args[0].shape[:-1], module.nf),
119122
condition=patch_condition)
120123
except ImportError:
121124
pass
122-
patch_model(model, extensions,
123-
"_openvino_module_extension_patch_orig_forward")
124-
dtype_to_patch = [torch.float16, torch.bfloat16]
125+
patch_model(model, extensions, orig_forward_name)
126+
dtype_to_patch = {torch.float16, torch.bfloat16}
125127
for _, module in model.named_modules():
126128
if (module.__class__ not in extensions and
127129
(any(p.dtype in dtype_to_patch for p in module.parameters(False))

src/bindings/python/src/openvino/frontend/pytorch/utils.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,35 @@ def get_type_from_py_type(value):
4949
return OVType.dynamic
5050

5151

52+
F8_DTYPE_MAP = {
53+
torch.float8_e4m3fn: OVType.f8e4m3,
54+
torch.float8_e5m2: OVType.f8e5m2,
55+
}
56+
57+
5258
def torch_tensor_to_ov_const(torch_t: torch.Tensor, shared_memory=True):
53-
is_fake_tensor = False
5459
try:
5560
from torch._prims import FakeTensor
56-
is_fake_tensor = isinstance(torch_t, FakeTensor)
57-
except:
61+
if isinstance(torch_t, FakeTensor):
62+
raise AssertionError("`FakeTensor` detected. Infer the "
63+
"model before exporting to avoid this.")
64+
except ImportError:
5865
pass
59-
assert not is_fake_tensor, '`FakeTensor` is found in the graph during conversion. ' \
60-
'In order to avoid `FakeTensor` in the traced model, ' \
61-
'try to infer the model before exporting.'
66+
67+
dtype = torch_t.dtype
6268
torch_t = torch_t.contiguous()
63-
if torch_t.dtype == torch.bfloat16:
69+
if dtype == torch.bfloat16:
6470
# reinterpret bfloat16 data as float16 to allow conversion to numpy
6571
torch_t = torch_t.view(torch.float16)
6672
narr = torch_t.numpy(force=True)
6773
tensor = Tensor(narr, torch_t.shape, OVType.bf16)
6874
ov_const = op.Constant(tensor, shared_memory=shared_memory)
75+
elif dtype in F8_DTYPE_MAP:
76+
# reinterpret f8 data as u8 to allow conversion to numpy
77+
torch_t = torch_t.view(torch.uint8)
78+
narr = torch_t.numpy(force=True)
79+
tensor = Tensor(narr, torch_t.shape, F8_DTYPE_MAP[dtype])
80+
ov_const = op.Constant(tensor, shared_memory=shared_memory)
6981
else:
7082
narr = torch_t.numpy(force=True)
7183
ov_const = op.Constant(narr, shared_memory=shared_memory)
@@ -126,6 +138,8 @@ def graph_has_ops(graph, op_types: list) -> bool:
126138
"float": OVType.f32,
127139
"int": OVType.i64,
128140
"bool": OVType.boolean,
141+
"torch.float8_e4m3fn": OVType.f8e4m3,
142+
"torch.float8_e5m2": OVType.f8e5m2,
129143
"torch.bfloat16": OVType.bf16,
130144
"torch.float16": OVType.f16,
131145
"torch.float32": OVType.f32,

tests/layer_tests/py_frontend_tests/test_torch_decoder.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,54 @@ def test_pytorch_decoder_get_input_type_none():
9696
assert isinstance(nc_decoder.get_input_type(2).value, DecoderType.PyNone)
9797

9898

99+
@pytest.mark.precommit
100+
def test_pytorch_decoder_can_convert_f8_e4m3_tensor():
101+
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
102+
from openvino import PartialShape, Type
103+
104+
class SomeTensor(torch.nn.Module):
105+
def forward(self):
106+
return torch.tensor([1, 2], dtype=torch.float8_e4m3fn)
107+
108+
model = get_scripted_model(SomeTensor())
109+
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
110+
"prim::Constant"]
111+
assert len(consts) > 0
112+
some_const = consts[0]
113+
nc_decoder = TorchScriptPythonDecoder(model, some_const)
114+
ov_const = nc_decoder.as_constant()
115+
assert ov_const is not None
116+
assert len(ov_const) == 1
117+
assert ov_const[0].get_element_type() == Type.f8e4m3
118+
assert ov_const[0].get_partial_shape() == PartialShape([2])
119+
120+
121+
@pytest.mark.precommit
122+
def test_pytorch_decoder_can_convert_f8_e5m2_tensor():
123+
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
124+
from openvino import PartialShape, Type
125+
126+
class SomeTensor(torch.nn.Module):
127+
def forward(self):
128+
return torch.tensor([1, 2], dtype=torch.float8_e5m2)
129+
130+
model = get_scripted_model(SomeTensor())
131+
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
132+
"prim::Constant"]
133+
assert len(consts) > 0
134+
some_const = consts[0]
135+
nc_decoder = TorchScriptPythonDecoder(model, some_const)
136+
ov_const = nc_decoder.as_constant()
137+
assert ov_const is not None
138+
assert len(ov_const) == 1
139+
assert ov_const[0].get_element_type() == Type.f8e5m2
140+
assert ov_const[0].get_partial_shape() == PartialShape([2])
141+
142+
99143
@pytest.mark.precommit
100144
def test_pytorch_decoder_can_convert_fp16_tensor():
101145
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
102-
from openvino.runtime import PartialShape, Type
146+
from openvino import PartialShape, Type
103147

104148
class SomeTensor(torch.nn.Module):
105149
def forward(self):
@@ -121,7 +165,7 @@ def forward(self):
121165
@pytest.mark.precommit
122166
def test_pytorch_decoder_can_convert_bf16_tensor():
123167
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
124-
from openvino.runtime import PartialShape, Type
168+
from openvino import PartialShape, Type
125169

126170
class SomeTensor(torch.nn.Module):
127171
def forward(self):
@@ -143,7 +187,7 @@ def forward(self):
143187
@pytest.mark.precommit
144188
def test_pytorch_decoder_can_convert_fp32_tensor():
145189
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
146-
from openvino.runtime import PartialShape, Type
190+
from openvino import PartialShape, Type
147191

148192
class SomeTensor(torch.nn.Module):
149193
def forward(self):
@@ -165,7 +209,7 @@ def forward(self):
165209
@pytest.mark.precommit
166210
def test_pytorch_decoder_can_convert_fp64_tensor():
167211
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
168-
from openvino.runtime import PartialShape, Type
212+
from openvino import PartialShape, Type
169213

170214
class SomeTensor(torch.nn.Module):
171215
def forward(self):
@@ -187,7 +231,7 @@ def forward(self):
187231
@pytest.mark.precommit
188232
def test_pytorch_decoder_can_convert_bool_tensor():
189233
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
190-
from openvino.runtime import PartialShape, Type
234+
from openvino import PartialShape, Type
191235

192236
class SomeTensor(torch.nn.Module):
193237
def forward(self):
@@ -209,7 +253,7 @@ def forward(self):
209253
@pytest.mark.precommit
210254
def test_pytorch_decoder_can_convert_u8_tensor():
211255
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
212-
from openvino.runtime import PartialShape, Type
256+
from openvino import PartialShape, Type
213257

214258
class SomeTensor(torch.nn.Module):
215259
def forward(self):
@@ -231,7 +275,7 @@ def forward(self):
231275
@pytest.mark.precommit
232276
def test_pytorch_decoder_can_convert_i8_tensor():
233277
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
234-
from openvino.runtime import PartialShape, Type
278+
from openvino import PartialShape, Type
235279

236280
class SomeTensor(torch.nn.Module):
237281
def forward(self):
@@ -253,7 +297,7 @@ def forward(self):
253297
@pytest.mark.precommit
254298
def test_pytorch_decoder_can_convert_i16_tensor():
255299
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
256-
from openvino.runtime import PartialShape, Type
300+
from openvino import PartialShape, Type
257301

258302
class SomeTensor(torch.nn.Module):
259303
def forward(self):
@@ -275,7 +319,7 @@ def forward(self):
275319
@pytest.mark.precommit
276320
def test_pytorch_decoder_can_convert_i32_tensor():
277321
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
278-
from openvino.runtime import PartialShape, Type
322+
from openvino import PartialShape, Type
279323

280324
class SomeTensor(torch.nn.Module):
281325
def forward(self):
@@ -297,7 +341,7 @@ def forward(self):
297341
@pytest.mark.precommit
298342
def test_pytorch_decoder_can_convert_i64_tensor():
299343
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
300-
from openvino.runtime import PartialShape, Type
344+
from openvino import PartialShape, Type
301345

302346
class SomeTensor(torch.nn.Module):
303347
def forward(self):
@@ -337,7 +381,7 @@ def forward(self):
337381
@pytest.mark.precommit
338382
def test_pytorch_decoder_can_convert_int_list():
339383
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
340-
from openvino.runtime import PartialShape, Type
384+
from openvino import PartialShape, Type
341385

342386
class ListConst(torch.nn.Module):
343387
def forward(self):
@@ -360,7 +404,7 @@ def forward(self):
360404
@pytest.mark.precommit
361405
def test_pytorch_decoder_can_convert_float_list():
362406
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
363-
from openvino.runtime import PartialShape, Type
407+
from openvino import PartialShape, Type
364408

365409
class ListConst(torch.nn.Module):
366410
def forward(self):
@@ -383,7 +427,7 @@ def forward(self):
383427
@pytest.mark.precommit
384428
def test_pytorch_decoder_can_convert_bool_list():
385429
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
386-
from openvino.runtime import PartialShape, Type
430+
from openvino import PartialShape, Type
387431

388432
class ListConst(torch.nn.Module):
389433
def forward(self):
@@ -406,7 +450,7 @@ def forward(self):
406450
@pytest.mark.precommit
407451
def test_pytorch_decoder_can_convert_int_tuple():
408452
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
409-
from openvino.runtime import PartialShape, Type
453+
from openvino import PartialShape, Type
410454

411455
class ListConst(torch.nn.Module):
412456
def forward(self):
@@ -429,7 +473,7 @@ def forward(self):
429473
@pytest.mark.precommit
430474
def test_pytorch_decoder_can_convert_float_tuple():
431475
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
432-
from openvino.runtime import PartialShape, Type
476+
from openvino import PartialShape, Type
433477

434478
class ListConst(torch.nn.Module):
435479
def forward(self):
@@ -452,7 +496,7 @@ def forward(self):
452496
@pytest.mark.precommit
453497
def test_pytorch_decoder_can_convert_bool_tuple():
454498
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
455-
from openvino.runtime import PartialShape, Type
499+
from openvino import PartialShape, Type
456500

457501
class ListConst(torch.nn.Module):
458502
def forward(self):
@@ -475,7 +519,7 @@ def forward(self):
475519
@pytest.mark.precommit
476520
def test_pytorch_decoder_can_convert_empty_list():
477521
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
478-
from openvino.runtime import PartialShape, Type
522+
from openvino import PartialShape, Type
479523

480524
class aten_roll(torch.nn.Module):
481525
def __init__(self, shifts):
@@ -503,7 +547,7 @@ def forward(self, x):
503547
@pytest.mark.precommit
504548
def test_pytorch_decoder_can_convert_int_scalar_tensor():
505549
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
506-
from openvino.runtime import PartialShape, Type
550+
from openvino import PartialShape, Type
507551

508552
class SomeTensor(torch.nn.Module):
509553
def __init__(self) -> None:
@@ -534,7 +578,7 @@ def forward(self):
534578
@pytest.mark.precommit
535579
def test_pytorch_decoder_can_convert_float_scalar_tensor():
536580
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
537-
from openvino.runtime import PartialShape, Type
581+
from openvino import PartialShape, Type
538582

539583
class SomeTensor(torch.nn.Module):
540584
def __init__(self) -> None:
@@ -565,7 +609,7 @@ def forward(self):
565609
@pytest.mark.precommit
566610
def test_pytorch_decoder_can_convert_tensor_list():
567611
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
568-
from openvino.runtime import PartialShape, Type
612+
from openvino import PartialShape, Type
569613
from typing import List, Optional
570614

571615
class SomeTensor(torch.nn.Module):

0 commit comments

Comments
 (0)