Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion keras_hub/src/models/llama/llama_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
self.preprocessor = preprocessor

# === Functional Model ===
inputs = backbone.inputs
# This must be "backbone.input" i.e. the full input structure,
# rather than "backbone.inputs" which is the flattened list of inputs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add this comment? I think the PR description is clear and enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

People were getting it wrong roughly 50% of the time.
It would be worth checking if a less error-prone syntax exists. "input_spec" would have been ideal but for some reason, this also has the flattened inputs.

inputs = backbone.input
hidden_states = backbone(inputs)
outputs = backbone.token_embedding(hidden_states, reverse=True)
super().__init__(
Expand Down
4 changes: 3 additions & 1 deletion keras_hub/src/models/mistral/mistral_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
self.preprocessor = preprocessor

# === Functional Model ===
inputs = backbone.inputs
# This must be "backbone.input" i.e. the full input structure,
# rather than "backbone.inputs" which is the flattened list of inputs.
inputs = backbone.input
hidden_states = backbone(inputs)
outputs = backbone.token_embedding(hidden_states, reverse=True)
super().__init__(
Expand Down
4 changes: 3 additions & 1 deletion keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def __init__(
self.backbone = backbone

# === Functional Model ===
inputs = backbone.inputs
# This must be "backbone.input" i.e. the full input structure,
# rather than "backbone.inputs" which is the flattened list of inputs.
inputs = backbone.input
hidden_state = backbone(inputs=inputs)
outputs = backbone.token_embedding(hidden_state, reverse=True)
outputs = outputs[:, backbone.image_sequence_length :, :]
Expand Down
4 changes: 3 additions & 1 deletion keras_hub/src/models/phi3/phi3_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
self.preprocessor = preprocessor

# === Functional Model ===
inputs = backbone.inputs
# This must be "backbone.input" i.e. the full input structure,
# rather than "backbone.inputs" which is the flattened list of inputs.
inputs = backbone.input
hidden_states = backbone(inputs)
outputs = backbone.token_embedding(hidden_states, reverse=True)
super().__init__(
Expand Down
9 changes: 9 additions & 0 deletions keras_hub/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,15 @@ def run_task_test(
ds = tf.data.Dataset.from_tensor_slices(train_data).batch(batch_size)
x, y, sw = keras.utils.unpack_x_y_sample_weight(train_data)

# Test: the tree struct output by the
# preprocessor must match what model expects.
preprocessed_data = preprocessor(*train_data)[0]
tree.assert_same_structure(
preprocessed_data,
task._inputs_struct,
check_types=False,
)

# Test predict.
output = task.predict(x)
if expected_output_shape is not None:
Expand Down
20 changes: 18 additions & 2 deletions keras_hub/src/tokenizers/byte_pair_tokenizer_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import keras
import pytest
import tensorflow as tf

from keras_hub.src.tests.test_case import TestCase
Expand All @@ -15,7 +14,6 @@
)


@pytest.mark.large
class BytePairTokenizerTest(TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -111,6 +109,24 @@ def test_whitespace_split(self):
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [1437, 1437, 50140, 50118, 29])

# This is important for Llama3 which uses the \n\n sequence in chat
# templates: \n\n must be tokenized as a single token
input_data = "Hello\n\nHello"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140, 31414])

input_data = "Hello\n\n\n\nHello"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140, 50140, 31414])

input_data = "Hello\n\n"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140])

input_data = "Hello\n\n\n\n"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140, 50140])

def test_special_whitespace(self):
input_data = "\xa0 \xa0 \x3000 s"
encoded = self.tokenizer(input_data)
Expand Down
Loading