Skip to content

Commit c1cff55

Browse files
Register VGG presets. (#1935)
* register vgg preset * nit * nit * nit
1 parent 566f8ae commit c1cff55

File tree

7 files changed

+40
-21
lines changed

7 files changed

+40
-21
lines changed

keras_hub/api/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
5353
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
5454
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
55-
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageConverter
55+
from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
5656
from keras_hub.src.models.whisper.whisper_audio_converter import (
5757
WhisperAudioConverter,
5858
)

keras_hub/api/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@
299299
from keras_hub.src.models.text_to_image import TextToImage
300300
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
301301
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier
302-
from keras_hub.src.models.vgg.vgg_image_classifier import (
302+
from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (
303303
VGGImageClassifierPreprocessor,
304304
)
305305
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
2+
from keras_hub.src.models.vgg.vgg_presets import backbone_presets
3+
from keras_hub.src.utils.preset_utils import register_presets
4+
5+
register_presets(backbone_presets, VGGBackbone)

keras_hub/src/models/vgg/vgg_image_classifier.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,12 @@
11
import keras
22

33
from keras_hub.src.api_export import keras_hub_export
4-
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
54
from keras_hub.src.models.image_classifier import ImageClassifier
6-
from keras_hub.src.models.image_classifier_preprocessor import (
7-
ImageClassifierPreprocessor,
8-
)
95
from keras_hub.src.models.task import Task
106
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
11-
12-
13-
@keras_hub_export("keras_hub.layers.VGGImageConverter")
14-
class VGGImageConverter(ImageConverter):
15-
backbone_cls = VGGBackbone
16-
17-
18-
@keras_hub_export("keras_hub.models.VGGImageClassifierPreprocessor")
19-
class VGGImageClassifierPreprocessor(ImageClassifierPreprocessor):
20-
backbone_cls = VGGBackbone
21-
image_converter_cls = VGGImageConverter
7+
from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (
8+
VGGImageClassifierPreprocessor,
9+
)
2210

2311

2412
@keras_hub_export("keras_hub.models.VGGImageClassifier")
@@ -211,6 +199,7 @@ def __init__(
211199
self.pooling = pooling
212200
self.pooling_hidden_dim = pooling_hidden_dim
213201
self.dropout = dropout
202+
self.preprocessor = preprocessor
214203

215204
def get_config(self):
216205
# Backbone serialized in `super`
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from keras_hub.src.api_export import keras_hub_export
2+
from keras_hub.src.models.image_classifier_preprocessor import (
3+
ImageClassifierPreprocessor,
4+
)
5+
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
6+
from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
7+
8+
9+
@keras_hub_export("keras_hub.models.VGGImageClassifierPreprocessor")
10+
class VGGImageClassifierPreprocessor(ImageClassifierPreprocessor):
11+
backbone_cls = VGGBackbone
12+
image_converter_cls = VGGImageConverter

keras_hub/src/models/vgg/vgg_image_classifier_test.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,40 @@
33

44
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
55
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier
6+
from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (
7+
VGGImageClassifierPreprocessor,
8+
)
9+
from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
610
from keras_hub.src.tests.test_case import TestCase
711

812

913
class VGGImageClassifierTest(TestCase):
1014
def setUp(self):
1115
# Setup model.
1216
self.images = np.ones((2, 8, 8, 3), dtype="float32")
13-
self.labels = [0, 3]
17+
self.labels = [0, 1]
1418
self.backbone = VGGBackbone(
1519
stackwise_num_repeats=[2, 4, 4],
1620
stackwise_num_filters=[2, 16, 16],
1721
image_shape=(8, 8, 3),
1822
)
23+
image_converter = VGGImageConverter(image_size=(8, 8))
24+
self.preprocessor = VGGImageClassifierPreprocessor(
25+
image_converter=image_converter,
26+
)
1927
self.init_kwargs = {
2028
"backbone": self.backbone,
2129
"num_classes": 2,
2230
"activation": "softmax",
2331
"pooling": "flatten",
32+
"preprocessor": self.preprocessor,
2433
}
2534
self.train_data = (
2635
self.images,
2736
self.labels,
2837
)
2938

3039
def test_classifier_basics(self):
31-
pytest.skip(
32-
reason="TODO: enable after preprocessor flow is figured out"
33-
)
3440
self.run_task_test(
3541
cls=VGGImageClassifier,
3642
init_kwargs=self.init_kwargs,
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from keras_hub.src.api_export import keras_hub_export
2+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3+
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
4+
5+
6+
@keras_hub_export("keras_hub.layers.VGGImageConverter")
7+
class VGGImageConverter(ImageConverter):
8+
backbone_cls = VGGBackbone

0 commit comments

Comments
 (0)