Skip to content

Commit 625e69f

Browse files
committed
Fix exception caused by TensorFlow 2.5 KerasTensor API incompatible change (#149)
1 parent 2f431e4 commit 625e69f

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

lpot/model/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,11 +345,14 @@ def keras_session(model, input_tensor_names, output_tensor_names, **kwargs):
345345
if not isinstance(model, tf.keras.Model):
346346
model = tf.keras.models.load_model(model)
347347
kwargs = dict(zip(model.input_names, model.inputs))
348-
if tf.version.VERSION > '2.2.0':
348+
if tf.version.VERSION > '2.2.0' and tf.version.VERSION < '2.5.0':
349349
from tensorflow.python.keras.engine import keras_tensor
350350
if keras_tensor.keras_tensors_enabled():
351351
for name, tensor in kwargs.items():
352352
kwargs[name] = tensor.type_spec
353+
elif tf.version.VERSION >= '2.5.0':
354+
for name, tensor in kwargs.items():
355+
kwargs[name] = tensor.type_spec
353356
full_model = tf.function(lambda **kwargs: model(kwargs.values()))
354357
concrete_function = full_model.get_concrete_function(**kwargs)
355358
frozen_model = convert_variables_to_constants_v2(concrete_function)

0 commit comments

Comments
 (0)