-
Couldn't load subscription status.
- Fork 19.6k
Description
The code below runs in Tensorflow 2.11 (keras 2) but not in tf-nightly (Keras 3.4.1 ). I think Keras 3 doesn't map inputs by dict key
Epoch 1/10
Traceback (most recent call last):
File "/home/wangx286/rnn-base-caller/base_caller/scripts/example_metric.py", line 32, in
model.fit({'before': x_train, 'after': y_train}, epochs=10, batch_size=32)
File "/home/wangx286/miniconda3/envs/tf216/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/wangx286/miniconda3/envs/tf216/lib/python3.10/site-packages/keras/src/models/functional.py", line 244, in _adjust_input_rank
raise ValueError(
ValueError: Exception encountered when calling Functional.call().
Invalid input shape for input Tensor("data_1:0", shape=(None,), dtype=float32). Expected shape (None, 20), but input has incompatible shape (None,)
Arguments received by Functional.call():
• inputs={'before': 'tf.Tensor(shape=(None, 20), dtype=float32)', 'after': 'tf.Tensor(shape=(None,), dtype=float32)'}
• training=True
• mask={'before': 'None', 'after': 'None'}
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Model
# Define the model using the Functional API
x = tf.keras.Input(shape=(20,), name="before", dtype=tf.float32)
y = tf.keras.Input(shape=(), name="after", dtype=tf.float32)
tmp = Dense(64, activation='relu')(x)
outputs = Dense(1, activation='sigmoid')(tmp)
class DummyLossLayer(tf.keras.layers.Layer):
def call(self, *x):
self.add_loss(tf.keras.losses.BinaryCrossentropy(from_logits=True)(*x))
return x
outputs, _ = DummyLossLayer()(outputs, y)
model = Model(inputs=[x, y], outputs=outputs)
# Compile the model with the custom metric
model.compile(optimizer='adam')
# Dummy data for demonstration
x_train = np.random.random((1000, 20))
y_train = np.random.randint(2, size=(1000,)).astype(np.float32)
# Train the model
model.fit({'before': x_train, 'after': y_train}, epochs=10, batch_size=32)