Skip to content

Commit 4e2bfaf

Browse files
wip
1 parent df84881 commit 4e2bfaf

File tree

14 files changed

+240
-127
lines changed

14 files changed

+240
-127
lines changed

invokeai/app/api/dependencies.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import asyncio
44
from logging import Logger
5+
from pathlib import Path
56

67
import torch
8+
from torchvision.datasets.clevr import json
79

810
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
911
from invokeai.app.services.board_images.board_images_default import BoardImagesService
@@ -187,6 +189,17 @@ def initialize(
187189
)
188190

189191
ApiDependencies.invoker = Invoker(services)
192+
all_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
193+
for m in all_models:
194+
path = Path(m.path)
195+
if path.is_absolute():
196+
continue
197+
198+
metadata_path = config.models_path / m.key / "__metadata__.json"
199+
print(f"Writing metadata for model {m.name} to {metadata_path}")
200+
content = {"source": m.source, "expected_config_attrs": m.model_dump(), "notes": ""}
201+
content_json = json.dumps(content, indent=2)
202+
metadata_path.write_text(content_json)
190203
db.clean()
191204

192205
@staticmethod

invokeai/backend/model_manager/configs/factory.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,8 @@ def from_model_on_disk(
408408
results[class_name] = e
409409
logger.debug(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}")
410410

411+
# Extract just the successful matches
412+
# NOTE: This will include Unknown_Config matches, which we will handle later.
411413
matches = [r for r in results.values() if isinstance(r, Config_Base)]
412414

413415
if not matches:
@@ -443,7 +445,10 @@ def sort_key(m: AnyModelConfig) -> int:
443445

444446
matches.sort(key=sort_key)
445447

446-
if len(matches) > 1:
448+
# Warn if we have multiple non-unknown matches
449+
has_unknown = any(isinstance(m, Unknown_Config) for m in matches)
450+
real_match_count = len(matches) - (1 if has_unknown else 0)
451+
if real_match_count > 1:
447452
logger.warning(
448453
f"Multiple model config classes matched for model {mod.path}: {[type(m).__name__ for m in matches]}."
449454
)
@@ -457,15 +462,7 @@ def sort_key(m: AnyModelConfig) -> int:
457462
# Now do any post-processing needed for specific model types/bases/etc.
458463
match instance.type:
459464
case ModelType.Main:
460-
match instance.base:
461-
case BaseModelType.StableDiffusion1:
462-
instance.default_settings = MainModelDefaultSettings(width=512, height=512)
463-
case BaseModelType.StableDiffusion2:
464-
instance.default_settings = MainModelDefaultSettings(width=768, height=768)
465-
case BaseModelType.StableDiffusionXL:
466-
instance.default_settings = MainModelDefaultSettings(width=1024, height=1024)
467-
case _:
468-
pass
465+
instance.default_settings = MainModelDefaultSettings.from_base(instance.base)
469466
case ModelType.ControlNet | ModelType.T2IAdapter | ModelType.ControlLoRa:
470467
instance.default_settings = ControlAdapterDefaultSettings.from_model_name(instance.name)
471468
case ModelType.LoRA:

invokeai/backend/model_manager/configs/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def _validate_base(cls, mod: ModelOnDisk) -> None:
152152
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
153153
# First rule out ControlLoRA and Diffusers LoRA
154154
flux_format = _get_flux_lora_format(mod)
155-
if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
156-
raise NotAMatchError("model looks like ControlLoRA or Diffusers LoRA")
155+
if flux_format in [FluxLoRAFormat.Control]:
156+
raise NotAMatchError("model looks like Control LoRA")
157157

158158
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
159159
# Some main models have these keys, likely due to the creator merging in a LoRA.

invokeai/backend/model_manager/configs/main.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,19 @@ class MainModelDefaultSettings(BaseModel):
5151

5252
model_config = ConfigDict(extra="forbid")
5353

54+
@classmethod
55+
def from_base(cls, base: BaseModelType) -> Self | None:
56+
match base:
57+
case BaseModelType.StableDiffusion1:
58+
return cls(width=512, height=512)
59+
case BaseModelType.StableDiffusion2:
60+
return cls(width=768, height=768)
61+
case BaseModelType.StableDiffusionXL:
62+
return cls(width=1024, height=1024)
63+
case _:
64+
# TODO(psyche): Do we want defaults for other base types?
65+
return None
66+
5467

5568
class Main_Config_Base(ABC, BaseModel):
5669
type: Literal[ModelType.Main] = Field(default=ModelType.Main)

tests/app/services/model_install/test_model_install.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@
3232
URLModelSource,
3333
)
3434
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
35-
from invokeai.backend.model_manager.config import (
35+
from invokeai.backend.model_manager.taxonomy import (
3636
BaseModelType,
37-
InvalidModelConfigException,
3837
ModelFormat,
3938
ModelRepoVariant,
4039
ModelType,
@@ -71,7 +70,7 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
7170

7271
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
7372
key = None
74-
with pytest.raises((ValidationError, InvalidModelConfigException)):
73+
with pytest.raises(ValidationError):
7574
key = mm2_installer.register_path(
7675
embedding_file, ModelRecordChanges(name="banana_sushi", type=ModelType("lora"))
7776
)

tests/app/services/model_records/test_model_records_sql.py

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,24 @@
1616
UnknownModelException,
1717
)
1818
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
19-
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
20-
from invokeai.backend.model_manager.config import (
21-
ControlAdapterDefaultSettings,
22-
MainDiffusersConfig,
19+
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
20+
from invokeai.backend.model_manager.configs.main import (
21+
Main_Diffusers_SD1_Config,
22+
Main_Diffusers_SD2_Config,
23+
Main_Diffusers_SDXL_Config,
2324
MainModelDefaultSettings,
24-
TI_File_Config,
25-
VAEDiffusersConfig,
2625
)
27-
from invokeai.backend.model_manager.taxonomy import ModelSourceType
26+
from invokeai.backend.model_manager.configs.textual_inversion import TI_File_SD1_Config
27+
from invokeai.backend.model_manager.configs.vae import VAE_Diffusers_SD1_Config
28+
from invokeai.backend.model_manager.taxonomy import (
29+
BaseModelType,
30+
ModelFormat,
31+
ModelRepoVariant,
32+
ModelSourceType,
33+
ModelType,
34+
ModelVariantType,
35+
SchedulerPredictionType,
36+
)
2837
from invokeai.backend.util.logging import InvokeAILogger
2938
from tests.fixtures.sqlite_database import create_mock_sqlite_database
3039

@@ -40,8 +49,8 @@ def store(
4049
return ModelRecordServiceSQL(db, logger)
4150

4251

43-
def example_ti_config(key: Optional[str] = None) -> TI_File_Config:
44-
config = TI_File_Config(
52+
def example_ti_config(key: Optional[str] = None) -> TI_File_SD1_Config:
53+
config = TI_File_SD1_Config(
4554
source="test/source/",
4655
source_type=ModelSourceType.Path,
4756
path="/tmp/pokemon.bin",
@@ -61,7 +70,7 @@ def test_type(store: ModelRecordServiceBase):
6170
config = example_ti_config("key1")
6271
store.add_model(config)
6372
config1 = store.get_model("key1")
64-
assert isinstance(config1, TI_File_Config)
73+
assert isinstance(config1, TI_File_SD1_Config)
6574

6675

6776
def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase):
@@ -122,7 +131,7 @@ def test_exists(store: ModelRecordServiceBase):
122131

123132

124133
def test_filter(store: ModelRecordServiceBase):
125-
config1 = MainDiffusersConfig(
134+
config1 = Main_Diffusers_SD1_Config(
126135
key="config1",
127136
path="/tmp/config1",
128137
name="config1",
@@ -132,8 +141,11 @@ def test_filter(store: ModelRecordServiceBase):
132141
file_size=1001,
133142
source="test/source",
134143
source_type=ModelSourceType.Path,
144+
variant=ModelVariantType.Normal,
145+
prediction_type=SchedulerPredictionType.Epsilon,
146+
repo_variant=ModelRepoVariant.Default,
135147
)
136-
config2 = MainDiffusersConfig(
148+
config2 = Main_Diffusers_SD1_Config(
137149
key="config2",
138150
path="/tmp/config2",
139151
name="config2",
@@ -143,17 +155,21 @@ def test_filter(store: ModelRecordServiceBase):
143155
file_size=1002,
144156
source="test/source",
145157
source_type=ModelSourceType.Path,
158+
variant=ModelVariantType.Normal,
159+
prediction_type=SchedulerPredictionType.Epsilon,
160+
repo_variant=ModelRepoVariant.Default,
146161
)
147-
config3 = VAEDiffusersConfig(
162+
config3 = VAE_Diffusers_SD1_Config(
148163
key="config3",
149164
path="/tmp/config3",
150165
name="config3",
151-
base=BaseModelType("sd-2"),
166+
base=BaseModelType.StableDiffusion1,
152167
type=ModelType.VAE,
153168
hash="CONFIG3HASH",
154169
file_size=1003,
155170
source="test/source",
156171
source_type=ModelSourceType.Path,
172+
repo_variant=ModelRepoVariant.Default,
157173
)
158174
for c in config1, config2, config3:
159175
store.add_model(c)
@@ -176,7 +192,7 @@ def test_filter(store: ModelRecordServiceBase):
176192

177193

178194
def test_unique(store: ModelRecordServiceBase):
179-
config1 = MainDiffusersConfig(
195+
config1 = Main_Diffusers_SD1_Config(
180196
path="/tmp/config1",
181197
base=BaseModelType.StableDiffusion1,
182198
type=ModelType.Main,
@@ -185,28 +201,35 @@ def test_unique(store: ModelRecordServiceBase):
185201
file_size=1004,
186202
source="test/source/",
187203
source_type=ModelSourceType.Path,
204+
variant=ModelVariantType.Normal,
205+
prediction_type=SchedulerPredictionType.Epsilon,
206+
repo_variant=ModelRepoVariant.Default,
188207
)
189-
config2 = MainDiffusersConfig(
208+
config2 = Main_Diffusers_SD2_Config(
190209
path="/tmp/config2",
191-
base=BaseModelType("sd-2"),
210+
base=BaseModelType.StableDiffusion2,
192211
type=ModelType.Main,
193212
name="nonuniquename",
194213
hash="CONFIG1HASH",
195214
file_size=1005,
196215
source="test/source/",
197216
source_type=ModelSourceType.Path,
217+
variant=ModelVariantType.Normal,
218+
prediction_type=SchedulerPredictionType.Epsilon,
219+
repo_variant=ModelRepoVariant.Default,
198220
)
199-
config3 = VAEDiffusersConfig(
221+
config3 = VAE_Diffusers_SD1_Config(
200222
path="/tmp/config3",
201-
base=BaseModelType("sd-2"),
223+
base=BaseModelType.StableDiffusion1,
202224
type=ModelType.VAE,
203225
name="nonuniquename",
204226
hash="CONFIG1HASH",
205227
file_size=1006,
206228
source="test/source/",
207229
source_type=ModelSourceType.Path,
230+
repo_variant=ModelRepoVariant.Default,
208231
)
209-
config4 = MainDiffusersConfig(
232+
config4 = Main_Diffusers_SD1_Config(
210233
path="/tmp/config4",
211234
base=BaseModelType.StableDiffusion1,
212235
type=ModelType.Main,
@@ -215,6 +238,9 @@ def test_unique(store: ModelRecordServiceBase):
215238
file_size=1007,
216239
source="test/source/",
217240
source_type=ModelSourceType.Path,
241+
variant=ModelVariantType.Normal,
242+
prediction_type=SchedulerPredictionType.Epsilon,
243+
repo_variant=ModelRepoVariant.Default,
218244
)
219245
# config1, config2 and config3 are compatible because they have unique combos
220246
# of name, type and base
@@ -229,7 +255,7 @@ def test_unique(store: ModelRecordServiceBase):
229255

230256

231257
def test_filter_2(store: ModelRecordServiceBase):
232-
config1 = MainDiffusersConfig(
258+
config1 = Main_Diffusers_SD1_Config(
233259
path="/tmp/config1",
234260
name="config1",
235261
base=BaseModelType.StableDiffusion1,
@@ -238,8 +264,11 @@ def test_filter_2(store: ModelRecordServiceBase):
238264
file_size=1008,
239265
source="test/source/",
240266
source_type=ModelSourceType.Path,
267+
variant=ModelVariantType.Normal,
268+
prediction_type=SchedulerPredictionType.Epsilon,
269+
repo_variant=ModelRepoVariant.Default,
241270
)
242-
config2 = MainDiffusersConfig(
271+
config2 = Main_Diffusers_SD1_Config(
243272
path="/tmp/config2",
244273
name="config2",
245274
base=BaseModelType.StableDiffusion1,
@@ -248,28 +277,37 @@ def test_filter_2(store: ModelRecordServiceBase):
248277
file_size=1009,
249278
source="test/source/",
250279
source_type=ModelSourceType.Path,
280+
variant=ModelVariantType.Normal,
281+
prediction_type=SchedulerPredictionType.Epsilon,
282+
repo_variant=ModelRepoVariant.Default,
251283
)
252-
config3 = MainDiffusersConfig(
284+
config3 = Main_Diffusers_SD2_Config(
253285
path="/tmp/config3",
254286
name="dup_name1",
255-
base=BaseModelType("sd-2"),
287+
base=BaseModelType.StableDiffusion2,
256288
type=ModelType.Main,
257289
hash="CONFIG3HASH",
258290
file_size=1010,
259291
source="test/source/",
260292
source_type=ModelSourceType.Path,
293+
variant=ModelVariantType.Normal,
294+
prediction_type=SchedulerPredictionType.Epsilon,
295+
repo_variant=ModelRepoVariant.Default,
261296
)
262-
config4 = MainDiffusersConfig(
297+
config4 = Main_Diffusers_SDXL_Config(
263298
path="/tmp/config4",
264299
name="dup_name1",
265-
base=BaseModelType("sdxl"),
300+
base=BaseModelType.StableDiffusionXL,
266301
type=ModelType.Main,
267302
hash="CONFIG3HASH",
268303
file_size=1011,
269304
source="test/source/",
270305
source_type=ModelSourceType.Path,
306+
variant=ModelVariantType.Normal,
307+
prediction_type=SchedulerPredictionType.Epsilon,
308+
repo_variant=ModelRepoVariant.Default,
271309
)
272-
config5 = VAEDiffusersConfig(
310+
config5 = VAE_Diffusers_SD1_Config(
273311
path="/tmp/config5",
274312
name="dup_name1",
275313
base=BaseModelType.StableDiffusion1,
@@ -278,6 +316,7 @@ def test_filter_2(store: ModelRecordServiceBase):
278316
file_size=1012,
279317
source="test/source/",
280318
source_type=ModelSourceType.Path,
319+
repo_variant=ModelRepoVariant.Default,
281320
)
282321
for c in config1, config2, config3, config4, config5:
283322
store.add_model(c)

tests/backend/ip_adapter/test_ip_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33

4-
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
4+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
55
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
66
from invokeai.backend.util.test_utils import install_and_load_model
77

0 commit comments

Comments
 (0)