diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index 613d22083ad4..991b2ec0f8b9 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -168,7 +168,7 @@ def call(self, inputs, training=None, mask=None): if mask is None: masks = [None] * len(inputs) else: - masks = self._flatten_to_reference_inputs(mask) + masks = tree.flatten(mask) for x, mask in zip(inputs, masks): if mask is not None: x._keras_mask = mask @@ -205,8 +205,21 @@ def output_shape(self): def _assert_input_compatibility(self, *args): return super(Model, self)._assert_input_compatibility(*args) - def _flatten_to_reference_inputs(self, inputs): - return tree.flatten(inputs) + def _maybe_warn_inputs_struct_mismatch(self, inputs): + try: + tree.assert_same_structure( + inputs, self._inputs_struct, check_types=False + ) + except: + model_inputs_struct = tree.map_structure( + lambda x: x.name, self._inputs_struct + ) + inputs_struct = tree.map_structure(lambda x: "*", inputs) + warnings.warn( + "The structure of `inputs` doesn't match the expected " + f"structure: {model_inputs_struct}. " + f"Received: the structure of inputs={inputs_struct}" + ) def _convert_inputs_to_tensors(self, flat_inputs): converted = [] @@ -254,7 +267,8 @@ def _adjust_input_rank(self, flat_inputs): return adjusted def _standardize_inputs(self, inputs): - flat_inputs = self._flatten_to_reference_inputs(inputs) + self._maybe_warn_inputs_struct_mismatch(inputs) + flat_inputs = tree.flatten(inputs) flat_inputs = self._convert_inputs_to_tensors(flat_inputs) return self._adjust_input_rank(flat_inputs) diff --git a/keras/src/models/functional_test.py b/keras/src/models/functional_test.py index 553d352ac555..f5cbb229e087 100644 --- a/keras/src/models/functional_test.py +++ b/keras/src/models/functional_test.py @@ -1,5 +1,4 @@ import os -import warnings import numpy as np import pytest @@ -162,9 +161,8 @@ def test_input_dict_with_extra_field(self): model = Functional({"a": input_a}, outputs) - # Eager call - with warnings.catch_warnings(): - warnings.simplefilter("error") + with pytest.warns() as record: + # Eager call in_val = { "a": np.random.random((2, 3)), "b": np.random.random((2, 1)), @@ -172,14 +170,17 @@ def test_input_dict_with_extra_field(self): out_val = model(in_val) self.assertEqual(out_val.shape, (2, 3)) - with warnings.catch_warnings(): - warnings.simplefilter("error") # Symbolic call input_a_2 = Input(shape=(3,), batch_size=2) input_b_2 = Input(shape=(1,), batch_size=2) in_val = {"a": input_a_2, "b": input_b_2} out_val = model(in_val) self.assertEqual(out_val.shape, (2, 3)) + self.assertLen(record, 1) + self.assertStartsWith( + str(record[0].message), + r"The structure of `inputs` doesn't match the expected structure:", + ) @parameterized.named_parameters( ("list", list), @@ -495,6 +496,20 @@ def compute_output_shape(self, x_shape): self.assertAllClose(out, np.ones((2, 2))) # Note: it's not intended to work in symbolic mode (yet). + def test_warning_for_mismatched_inputs_structure(self): + i1 = Input((2,)) + i2 = Input((2,)) + outputs = layers.Add()([i1, i2]) + model = Model({"i1": i1, "i2": i2}, outputs) + + with pytest.warns() as record: + model([np.ones((2, 2)), np.zeros((2, 2))]) + self.assertLen(record, 1) + self.assertStartsWith( + str(record[0].message), + r"The structure of `inputs` doesn't match the expected structure:", + ) + def test_for_functional_in_sequential(self): # Test for a v3.4.1 regression. if backend.image_data_format() == "channels_first":