Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions keras/src/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def call(self, inputs, training=None, mask=None):
if mask is None:
masks = [None] * len(inputs)
else:
masks = self._flatten_to_reference_inputs(mask)
masks = tree.flatten(mask)
for x, mask in zip(inputs, masks):
if mask is not None:
x._keras_mask = mask
Expand Down Expand Up @@ -205,8 +205,21 @@ def output_shape(self):
def _assert_input_compatibility(self, *args):
return super(Model, self)._assert_input_compatibility(*args)

def _flatten_to_reference_inputs(self, inputs):
return tree.flatten(inputs)
def _maybe_warn_inputs_struct_mismatch(self, inputs):
try:
tree.assert_same_structure(
inputs, self._inputs_struct, check_types=False
)
except:
model_inputs_struct = tree.map_structure(
lambda x: x.name, self._inputs_struct
)
inputs_struct = tree.map_structure(lambda x: "*", inputs)
warnings.warn(
"The structure of `inputs` doesn't match the expected "
f"structure: {model_inputs_struct}. "
f"Received: the structure of inputs={inputs_struct}"
)

def _convert_inputs_to_tensors(self, flat_inputs):
converted = []
Expand Down Expand Up @@ -254,7 +267,8 @@ def _adjust_input_rank(self, flat_inputs):
return adjusted

def _standardize_inputs(self, inputs):
flat_inputs = self._flatten_to_reference_inputs(inputs)
self._maybe_warn_inputs_struct_mismatch(inputs)
flat_inputs = tree.flatten(inputs)
flat_inputs = self._convert_inputs_to_tensors(flat_inputs)
return self._adjust_input_rank(flat_inputs)

Expand Down
27 changes: 21 additions & 6 deletions keras/src/models/functional_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -162,24 +161,26 @@ def test_input_dict_with_extra_field(self):

model = Functional({"a": input_a}, outputs)

# Eager call
with warnings.catch_warnings():
warnings.simplefilter("error")
with pytest.warns() as record:
# Eager call
in_val = {
"a": np.random.random((2, 3)),
"b": np.random.random((2, 1)),
}
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 3))

with warnings.catch_warnings():
warnings.simplefilter("error")
# Symbolic call
input_a_2 = Input(shape=(3,), batch_size=2)
input_b_2 = Input(shape=(1,), batch_size=2)
in_val = {"a": input_a_2, "b": input_b_2}
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 3))
self.assertLen(record, 1)
self.assertStartsWith(
str(record[0].message),
r"The structure of `inputs` doesn't match the expected structure:",
)

@parameterized.named_parameters(
("list", list),
Expand Down Expand Up @@ -495,6 +496,20 @@ def compute_output_shape(self, x_shape):
self.assertAllClose(out, np.ones((2, 2)))
# Note: it's not intended to work in symbolic mode (yet).

def test_warning_for_mismatched_inputs_structure(self):
i1 = Input((2,))
i2 = Input((2,))
outputs = layers.Add()([i1, i2])
model = Model({"i1": i1, "i2": i2}, outputs)

with pytest.warns() as record:
model([np.ones((2, 2)), np.zeros((2, 2))])
self.assertLen(record, 1)
self.assertStartsWith(
str(record[0].message),
r"The structure of `inputs` doesn't match the expected structure:",
)

def test_for_functional_in_sequential(self):
# Test for a v3.4.1 regression.
if backend.image_data_format() == "channels_first":
Expand Down