From cb4fb8e2a41fc5a1f302f06058310c6f03f096f7 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 16 Oct 2024 23:50:49 +0000 Subject: [PATCH 1/6] add back default image resizing --- keras_hub/src/utils/timm/preset_loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/utils/timm/preset_loader.py b/keras_hub/src/utils/timm/preset_loader.py index e662444409..1524db8530 100644 --- a/keras_hub/src/utils/timm/preset_loader.py +++ b/keras_hub/src/utils/timm/preset_loader.py @@ -53,10 +53,11 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): def load_image_converter(self, cls, **kwargs): pretrained_cfg = self.config.get("pretrained_cfg", None) - if not pretrained_cfg: + if not pretrained_cfg or "input_size" not in pretrained_cfg: return None # This assumes the same basic setup for all timm preprocessing, We may # need to extend this as we cover more model types. + input_size = pretrained_cfg["input_size"] mean = pretrained_cfg["mean"] std = pretrained_cfg["std"] scale = [1.0 / 255.0 / s for s in std] @@ -65,6 +66,7 @@ def load_image_converter(self, cls, **kwargs): if interpolation not in ("bilinear", "nearest", "bicubic"): interpolation = "bilinear" # Unsupported interpolation type. return cls( + image_size=input_size[1:], scale=scale, offset=offset, interpolation=interpolation, From 67ba7dd930202d33fa00f9674f4ee15bee789206 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 17 Oct 2024 19:13:16 +0000 Subject: [PATCH 2/6] fix bug in image converter --- keras_hub/src/layers/preprocessing/image_converter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/layers/preprocessing/image_converter.py b/keras_hub/src/layers/preprocessing/image_converter.py index 89142c469b..2dc48bf1c9 100644 --- a/keras_hub/src/layers/preprocessing/image_converter.py +++ b/keras_hub/src/layers/preprocessing/image_converter.py @@ -149,7 +149,9 @@ def call(self, inputs): if self.image_size is not None: x = self.resizing(x) if self.scale is not None: - x = x * self._expand_non_channel_dims(self.scale, x) + x = ops.cast(x, self.dtype) * self._expand_non_channel_dims( + self.scale, x + ) if self.offset is not None: x = x + self._expand_non_channel_dims(self.offset, x) return x From b5c2565dc46520be22cc9052bd7e23db076c0d7c Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 17 Oct 2024 21:47:46 +0000 Subject: [PATCH 3/6] fix paligemma checkpoint conversion file --- .../convert_pali_gemma_checkpoints.py | 41 +++++++++++++++++-- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py b/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py index 15db5d10b4..5ca62e82ec 100644 --- a/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py +++ b/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py @@ -1,3 +1,21 @@ +""" +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-mix-224.npz \ + --image_size=224 --checkpoint_name=pali_gemma_3b_mix_224 +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-mix-448.npz \ + --image_size=448 --checkpoint_name=pali_gemma_3b_mix_428 +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-pt-224.npz \ + --image_size=224 --checkpoint_name=pali_gemma_3b_224 +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-pt-448.npz \ + --image_size=448 --checkpoint_name=pali_gemma_3b_448 +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-pt-896.npz \ + --image_size=896 --checkpoint_name=pali_gemma_3b_896 +""" + import argparse import os @@ -15,6 +33,9 @@ from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( PaliGemmaImageConverter, ) +from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( + PaliGemmaTokenizer, +) os.environ["KERAS_BACKEND"] = "jax" @@ -308,15 +329,27 @@ def main(args): pali_gemma_backbone_config = { "vit_num_layers": 27, "vit_hidden_dim": 1152, + "vocabulary_size": 257152, "image_size": args.image_size, + "num_layers": 18, + "num_query_heads": 8, + "num_key_value_heads": 1, + "hidden_dim": 2048, + "intermediate_dim": 32768, + "head_dim": 256, + "vit_patch_size": 14, + "vit_num_heads": 16, } pg_image_converter = PaliGemmaImageConverter( image_size=(args.image_size, args.image_size), scale=1.0 / 127.5, offset=-1, ) + tokenizer = PaliGemmaTokenizer( + proto="vocabulary.spm", + ) pg_presprocessor = PaliGemmaCausalLMPreprocessor( - image_converter=pg_image_converter + tokenizer=tokenizer, image_converter=pg_image_converter ) pg_backbone = PaliGemmaBackbone(**pali_gemma_backbone_config) keras_model = PaliGemmaCausalLM( @@ -325,8 +358,10 @@ def main(args): # This could be from kaggle or provide local dir path weights = np.load(args.weights_path) jax_weights = get_weights_as_numpy(weights, **pali_gemma_backbone_config) - keras_model = convert_pali_gemma_weights( - keras_model, jax_weights["params"], **pali_gemma_backbone_config + keras_model.backbone = convert_pali_gemma_weights( + keras_model.backbone, + jax_weights["params"], + **pali_gemma_backbone_config, ) # Specify preset name keras_model.save_to_preset(args.checkpoint_name) From 2cdbfc2c78bba3a0e01675055d85754ee772e594 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 17 Oct 2024 21:58:25 +0000 Subject: [PATCH 4/6] fix preset name --- tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py b/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py index 5ca62e82ec..befb6093cf 100644 --- a/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py +++ b/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py @@ -4,7 +4,7 @@ --image_size=224 --checkpoint_name=pali_gemma_3b_mix_224 python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ --weights_path=paligemma-3b-mix-448.npz \ - --image_size=448 --checkpoint_name=pali_gemma_3b_mix_428 + --image_size=448 --checkpoint_name=pali_gemma_3b_mix_448 python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ --weights_path=paligemma-3b-pt-224.npz \ --image_size=224 --checkpoint_name=pali_gemma_3b_224 From 54578298c511e8fa80123b8821692a31763b1e06 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 17 Oct 2024 22:02:42 +0000 Subject: [PATCH 5/6] remove debug code --- keras_hub/src/layers/preprocessing/image_converter.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/layers/preprocessing/image_converter.py b/keras_hub/src/layers/preprocessing/image_converter.py index 2dc48bf1c9..37b221c1fe 100644 --- a/keras_hub/src/layers/preprocessing/image_converter.py +++ b/keras_hub/src/layers/preprocessing/image_converter.py @@ -147,11 +147,9 @@ def image_size(self, value): def call(self, inputs): x = inputs if self.image_size is not None: - x = self.resizing(x) + x = self.resizing(inputs) if self.scale is not None: - x = ops.cast(x, self.dtype) * self._expand_non_channel_dims( - self.scale, x - ) + x = x * self._expand_non_channel_dims(self.scale, x) if self.offset is not None: x = x + self._expand_non_channel_dims(self.offset, x) return x From cec3a0572f6f362c5c62cb3253bf573291f40766 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 17 Oct 2024 22:04:08 +0000 Subject: [PATCH 6/6] revert unintended changes --- keras_hub/src/layers/preprocessing/image_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/layers/preprocessing/image_converter.py b/keras_hub/src/layers/preprocessing/image_converter.py index 37b221c1fe..89142c469b 100644 --- a/keras_hub/src/layers/preprocessing/image_converter.py +++ b/keras_hub/src/layers/preprocessing/image_converter.py @@ -147,7 +147,7 @@ def image_size(self, value): def call(self, inputs): x = inputs if self.image_size is not None: - x = self.resizing(inputs) + x = self.resizing(x) if self.scale is not None: x = x * self._expand_non_channel_dims(self.scale, x) if self.offset is not None: