diff --git a/keras_hub/src/models/preprocessor.py b/keras_hub/src/models/preprocessor.py index 01ed2b6bfc..49aaa98475 100644 --- a/keras_hub/src/models/preprocessor.py +++ b/keras_hub/src/models/preprocessor.py @@ -177,6 +177,39 @@ def from_preset( cls = find_subclass(preset, cls, backbone_cls) return loader.load_preprocessor(cls, **kwargs) + @classmethod + def _add_missing_kwargs(cls, loader, kwargs): + """Fill in required kwargs when loading from preset. + + This is a private method hit when loading a preprocessing layer that + was not directly saved in the preset. This method should fill in + all required kwargs required to call the class constructor. For almost, + all preprocessors, the only required args are `tokenizer`, + `image_converter`, and `audio_converter`, but this can be overridden, + e.g. for a preprocessor with multiple tokenizers for different + encoders.""" + if "tokenizer" not in kwargs and cls.tokenizer_cls: + kwargs["tokenizer"] = loader.load_tokenizer(cls.tokenizer_cls) + if "audio_converter" not in kwargs and cls.audio_converter_cls: + kwargs["audio_converter"] = loader.load_audio_converter( + cls.audio_converter_cls + ) + if "image_converter" not in kwargs and cls.image_converter_cls: + kwargs["image_converter"] = loader.load_image_converter( + cls.image_converter_cls + ) + return kwargs + + def load_preset_assets(self, preset): + """Load all static assets needed by the preprocessing layer. + + Args: + preset_dir: The path to the local model preset directory. + """ + for layer in self._flatten_layers(include_self=False): + if hasattr(layer, "load_preset_assets"): + layer.load_preset_assets(self.preset) + def save_to_preset(self, preset_dir): """Save preprocessor to a preset directory. @@ -188,9 +221,6 @@ def save_to_preset(self, preset_dir): preset_dir, config_file=PREPROCESSOR_CONFIG_FILE, ) - if self.tokenizer: - self.tokenizer.save_to_preset(preset_dir) - if self.audio_converter: - self.audio_converter.save_to_preset(preset_dir) - if self.image_converter: - self.image_converter.save_to_preset(preset_dir) + for layer in self._flatten_layers(include_self=False): + if hasattr(layer, "save_to_preset"): + layer.save_to_preset(preset_dir) diff --git a/keras_hub/src/tokenizers/tokenizer.py b/keras_hub/src/tokenizers/tokenizer.py index 7856b79ca6..8e9aff6d99 100644 --- a/keras_hub/src/tokenizers/tokenizer.py +++ b/keras_hub/src/tokenizers/tokenizer.py @@ -17,14 +17,12 @@ from keras_hub.src.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) -from keras_hub.src.utils.preset_utils import TOKENIZER_ASSET_DIR -from keras_hub.src.utils.preset_utils import TOKENIZER_CONFIG_FILE +from keras_hub.src.utils.preset_utils import ASSET_DIR from keras_hub.src.utils.preset_utils import builtin_presets from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_file from keras_hub.src.utils.preset_utils import get_preset_loader from keras_hub.src.utils.preset_utils import save_serialized_object -from keras_hub.src.utils.preset_utils import save_tokenizer_assets from keras_hub.src.utils.python_utils import classproperty from keras_hub.src.utils.tensor_utils import preprocessing_function @@ -80,6 +78,7 @@ def detokenize(self, inputs): backbone_cls = None def __init__(self, *args, **kwargs): + self.config_name = kwargs.pop("config_name", "tokenizer.json") super().__init__(*args, **kwargs) self.file_assets = None @@ -187,18 +186,26 @@ def _update_special_token_ids(self): token = getattr(self, attr) setattr(self, f"{attr}_id", self.token_to_id(token)) + def get_config(self): + config = super().get_config() + config.update( + { + "config_name": self.config_name, + } + ) + return config + def save_to_preset(self, preset_dir): """Save tokenizer to a preset directory. Args: preset_dir: The path to the local model preset directory. """ - save_serialized_object( - self, - preset_dir, - config_file=TOKENIZER_CONFIG_FILE, - ) - save_tokenizer_assets(self, preset_dir) + save_serialized_object(self, preset_dir, config_file=self.config_name) + subdir = self.config_name.split(".")[0] + asset_dir = os.path.join(preset_dir, ASSET_DIR, subdir) + os.makedirs(asset_dir, exist_ok=True) + self.save_assets(asset_dir) @preprocessing_function def call(self, inputs, *args, training=None, **kwargs): @@ -207,11 +214,11 @@ def call(self, inputs, *args, training=None, **kwargs): def load_preset_assets(self, preset): asset_path = None for asset in self.file_assets: - asset_path = get_file( - preset, os.path.join(TOKENIZER_ASSET_DIR, asset) - ) - tokenizer_asset_dir = os.path.dirname(asset_path) - self.load_assets(tokenizer_asset_dir) + subdir = self.config_name.split(".")[0] + preset_path = os.path.join(ASSET_DIR, subdir, asset) + asset_path = get_file(preset, preset_path) + tokenizer_config_name = os.path.dirname(asset_path) + self.load_assets(tokenizer_config_name) @classproperty def presets(cls): @@ -222,6 +229,7 @@ def presets(cls): def from_preset( cls, preset, + config_name="tokenizer.json", **kwargs, ): """Instantiate a `keras_hub.models.Tokenizer` from a model preset. @@ -267,4 +275,4 @@ class like `keras_hub.models.Tokenizer.from_preset()`, or from backbone_cls = loader.check_backbone_class() if cls.backbone_cls != backbone_cls: cls = find_subclass(preset, cls, backbone_cls) - return loader.load_tokenizer(cls, **kwargs) + return loader.load_tokenizer(cls, config_name, **kwargs) diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 6f368f2d10..58fa00c78b 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -55,7 +55,8 @@ GS_SCHEME = "gs" HF_SCHEME = "hf" -TOKENIZER_ASSET_DIR = "assets/tokenizer" +ASSET_DIR = "assets" +TOKENIZER_ASSET_DIR = f"{ASSET_DIR}/tokenizer" # Config file names. CONFIG_FILE = "config.json" @@ -307,13 +308,6 @@ def make_preset_dir(preset): os.makedirs(preset, exist_ok=True) -def save_tokenizer_assets(tokenizer, preset): - if tokenizer: - asset_dir = os.path.join(preset, TOKENIZER_ASSET_DIR) - os.makedirs(asset_dir, exist_ok=True) - tokenizer.save_assets(asset_dir) - - def save_serialized_object( layer, preset, @@ -345,37 +339,6 @@ def save_metadata(layer, preset): metadata_file.write(json.dumps(metadata, indent=4)) -def _validate_tokenizer(preset): - if not check_file_exists(preset, TOKENIZER_CONFIG_FILE): - return - config_path = get_file(preset, TOKENIZER_CONFIG_FILE) - try: - with open(config_path, encoding="utf-8") as config_file: - config = json.load(config_file) - except Exception as e: - raise ValueError( - f"Tokenizer config file `{config_path}` is an invalid json file. " - f"Error message: {e}" - ) - layer = keras.saving.deserialize_keras_object(config) - - for asset in layer.file_assets: - asset_path = get_file(preset, os.path.join(TOKENIZER_ASSET_DIR, asset)) - if not os.path.exists(asset_path): - tokenizer_asset_dir = os.path.dirname(asset_path) - raise FileNotFoundError( - f"Asset `{asset}` doesn't exist in the tokenizer asset direcotry" - f" `{tokenizer_asset_dir}`." - ) - config_dir = os.path.dirname(config_path) - asset_dir = os.path.join(config_dir, TOKENIZER_ASSET_DIR) - - tokenizer = get_tokenizer(layer) - if not tokenizer: - raise ValueError(f"Model or layer `{layer}` is missing tokenizer.") - tokenizer.load_assets(asset_dir) - - def _validate_backbone(preset): config_path = os.path.join(preset, CONFIG_FILE) if not os.path.exists(config_path): @@ -493,7 +456,6 @@ def upload_preset( raise FileNotFoundError(f"The preset directory {preset} doesn't exist.") _validate_backbone(preset) - _validate_tokenizer(preset) if uri.startswith(KAGGLE_PREFIX): if kagglehub is None: @@ -665,7 +627,7 @@ def load_backbone(self, cls, load_weights, **kwargs): """Load the backbone model from the preset.""" raise NotImplementedError - def load_tokenizer(self, cls, **kwargs): + def load_tokenizer(self, cls, config_name="tokenizer.json", **kwargs): """Load a tokenizer layer from the preset.""" raise NotImplementedError @@ -703,16 +665,7 @@ def load_preprocessor(self, cls, **kwargs): arguments. This allow us to support transformers checkpoints by only converting the backbone and tokenizer. """ - if "tokenizer" not in kwargs and cls.tokenizer_cls: - kwargs["tokenizer"] = self.load_tokenizer(cls.tokenizer_cls) - if "audio_converter" not in kwargs and cls.audio_converter_cls: - kwargs["audio_converter"] = self.load_audio_converter( - cls.audio_converter_cls - ) - if "image_converter" not in kwargs and cls.image_converter_cls: - kwargs["image_converter"] = self.load_image_converter( - cls.image_converter_cls - ) + kwargs = cls._add_missing_kwargs(self, kwargs) return cls(**kwargs) @@ -727,8 +680,8 @@ def load_backbone(self, cls, load_weights, **kwargs): backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE)) return backbone - def load_tokenizer(self, cls, **kwargs): - tokenizer_config = load_json(self.preset, TOKENIZER_CONFIG_FILE) + def load_tokenizer(self, cls, config_name="tokenizer.json", **kwargs): + tokenizer_config = load_json(self.preset, config_name) tokenizer = load_serialized_object(tokenizer_config, **kwargs) tokenizer.load_preset_assets(self.preset) return tokenizer @@ -755,8 +708,8 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): ) # We found a `task.json` with a complete config for our class. task = load_serialized_object(task_config, **kwargs) - if task.preprocessor and task.preprocessor.tokenizer: - task.preprocessor.tokenizer.load_preset_assets(self.preset) + if task.preprocessor: + task.preprocessor.load_preset_assets(self.preset) if load_weights: has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE) if has_task_weights and load_task_weights: @@ -779,5 +732,5 @@ def load_preprocessor(self, cls, **kwargs): return super().load_preprocessor(cls, **kwargs) # We found a `preprocessing.json` with a complete config for our class. preprocessor = load_serialized_object(preprocessor_json, **kwargs) - preprocessor.tokenizer.load_preset_assets(self.preset) + preprocessor.load_preset_assets(self.preset) return preprocessor diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 593792d08e..a88cec246b 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -69,7 +69,7 @@ def load_backbone(self, cls, load_weights, **kwargs): self.converter.convert_weights(backbone, loader, self.config) return backbone - def load_tokenizer(self, cls, **kwargs): + def load_tokenizer(self, cls, config_name="tokenizer.json", **kwargs): return self.converter.convert_tokenizer(cls, self.preset, **kwargs) def load_image_converter(self, cls, **kwargs):