1616 UnknownModelException ,
1717)
1818from invokeai .app .services .model_records .model_records_base import ModelRecordChanges
19- from invokeai .backend .model_manager import BaseModelType , ModelFormat , ModelType
20- from invokeai .backend .model_manager .config import (
21- ControlAdapterDefaultSettings ,
22- MainDiffusersConfig ,
19+ from invokeai .backend .model_manager .configs .controlnet import ControlAdapterDefaultSettings
20+ from invokeai .backend .model_manager .configs .main import (
21+ Main_Diffusers_SD1_Config ,
22+ Main_Diffusers_SD2_Config ,
23+ Main_Diffusers_SDXL_Config ,
2324 MainModelDefaultSettings ,
24- TI_File_Config ,
25- VAEDiffusersConfig ,
2625)
27- from invokeai .backend .model_manager .taxonomy import ModelSourceType
26+ from invokeai .backend .model_manager .configs .textual_inversion import TI_File_SD1_Config
27+ from invokeai .backend .model_manager .configs .vae import VAE_Diffusers_SD1_Config
28+ from invokeai .backend .model_manager .taxonomy import (
29+ BaseModelType ,
30+ ModelFormat ,
31+ ModelRepoVariant ,
32+ ModelSourceType ,
33+ ModelType ,
34+ ModelVariantType ,
35+ SchedulerPredictionType ,
36+ )
2837from invokeai .backend .util .logging import InvokeAILogger
2938from tests .fixtures .sqlite_database import create_mock_sqlite_database
3039
@@ -40,8 +49,8 @@ def store(
4049 return ModelRecordServiceSQL (db , logger )
4150
4251
43- def example_ti_config (key : Optional [str ] = None ) -> TI_File_Config :
44- config = TI_File_Config (
52+ def example_ti_config (key : Optional [str ] = None ) -> TI_File_SD1_Config :
53+ config = TI_File_SD1_Config (
4554 source = "test/source/" ,
4655 source_type = ModelSourceType .Path ,
4756 path = "/tmp/pokemon.bin" ,
@@ -61,7 +70,7 @@ def test_type(store: ModelRecordServiceBase):
6170 config = example_ti_config ("key1" )
6271 store .add_model (config )
6372 config1 = store .get_model ("key1" )
64- assert isinstance (config1 , TI_File_Config )
73+ assert isinstance (config1 , TI_File_SD1_Config )
6574
6675
6776def test_raises_on_violating_uniqueness (store : ModelRecordServiceBase ):
@@ -122,7 +131,7 @@ def test_exists(store: ModelRecordServiceBase):
122131
123132
124133def test_filter (store : ModelRecordServiceBase ):
125- config1 = MainDiffusersConfig (
134+ config1 = Main_Diffusers_SD1_Config (
126135 key = "config1" ,
127136 path = "/tmp/config1" ,
128137 name = "config1" ,
@@ -132,8 +141,11 @@ def test_filter(store: ModelRecordServiceBase):
132141 file_size = 1001 ,
133142 source = "test/source" ,
134143 source_type = ModelSourceType .Path ,
144+ variant = ModelVariantType .Normal ,
145+ prediction_type = SchedulerPredictionType .Epsilon ,
146+ repo_variant = ModelRepoVariant .Default ,
135147 )
136- config2 = MainDiffusersConfig (
148+ config2 = Main_Diffusers_SD1_Config (
137149 key = "config2" ,
138150 path = "/tmp/config2" ,
139151 name = "config2" ,
@@ -143,17 +155,21 @@ def test_filter(store: ModelRecordServiceBase):
143155 file_size = 1002 ,
144156 source = "test/source" ,
145157 source_type = ModelSourceType .Path ,
158+ variant = ModelVariantType .Normal ,
159+ prediction_type = SchedulerPredictionType .Epsilon ,
160+ repo_variant = ModelRepoVariant .Default ,
146161 )
147- config3 = VAEDiffusersConfig (
162+ config3 = VAE_Diffusers_SD1_Config (
148163 key = "config3" ,
149164 path = "/tmp/config3" ,
150165 name = "config3" ,
151- base = BaseModelType ( "sd-2" ) ,
166+ base = BaseModelType . StableDiffusion1 ,
152167 type = ModelType .VAE ,
153168 hash = "CONFIG3HASH" ,
154169 file_size = 1003 ,
155170 source = "test/source" ,
156171 source_type = ModelSourceType .Path ,
172+ repo_variant = ModelRepoVariant .Default ,
157173 )
158174 for c in config1 , config2 , config3 :
159175 store .add_model (c )
@@ -176,7 +192,7 @@ def test_filter(store: ModelRecordServiceBase):
176192
177193
178194def test_unique (store : ModelRecordServiceBase ):
179- config1 = MainDiffusersConfig (
195+ config1 = Main_Diffusers_SD1_Config (
180196 path = "/tmp/config1" ,
181197 base = BaseModelType .StableDiffusion1 ,
182198 type = ModelType .Main ,
@@ -185,28 +201,35 @@ def test_unique(store: ModelRecordServiceBase):
185201 file_size = 1004 ,
186202 source = "test/source/" ,
187203 source_type = ModelSourceType .Path ,
204+ variant = ModelVariantType .Normal ,
205+ prediction_type = SchedulerPredictionType .Epsilon ,
206+ repo_variant = ModelRepoVariant .Default ,
188207 )
189- config2 = MainDiffusersConfig (
208+ config2 = Main_Diffusers_SD2_Config (
190209 path = "/tmp/config2" ,
191- base = BaseModelType ( "sd-2" ) ,
210+ base = BaseModelType . StableDiffusion2 ,
192211 type = ModelType .Main ,
193212 name = "nonuniquename" ,
194213 hash = "CONFIG1HASH" ,
195214 file_size = 1005 ,
196215 source = "test/source/" ,
197216 source_type = ModelSourceType .Path ,
217+ variant = ModelVariantType .Normal ,
218+ prediction_type = SchedulerPredictionType .Epsilon ,
219+ repo_variant = ModelRepoVariant .Default ,
198220 )
199- config3 = VAEDiffusersConfig (
221+ config3 = VAE_Diffusers_SD1_Config (
200222 path = "/tmp/config3" ,
201- base = BaseModelType ( "sd-2" ) ,
223+ base = BaseModelType . StableDiffusion1 ,
202224 type = ModelType .VAE ,
203225 name = "nonuniquename" ,
204226 hash = "CONFIG1HASH" ,
205227 file_size = 1006 ,
206228 source = "test/source/" ,
207229 source_type = ModelSourceType .Path ,
230+ repo_variant = ModelRepoVariant .Default ,
208231 )
209- config4 = MainDiffusersConfig (
232+ config4 = Main_Diffusers_SD1_Config (
210233 path = "/tmp/config4" ,
211234 base = BaseModelType .StableDiffusion1 ,
212235 type = ModelType .Main ,
@@ -215,6 +238,9 @@ def test_unique(store: ModelRecordServiceBase):
215238 file_size = 1007 ,
216239 source = "test/source/" ,
217240 source_type = ModelSourceType .Path ,
241+ variant = ModelVariantType .Normal ,
242+ prediction_type = SchedulerPredictionType .Epsilon ,
243+ repo_variant = ModelRepoVariant .Default ,
218244 )
219245 # config1, config2 and config3 are compatible because they have unique combos
220246 # of name, type and base
@@ -229,7 +255,7 @@ def test_unique(store: ModelRecordServiceBase):
229255
230256
231257def test_filter_2 (store : ModelRecordServiceBase ):
232- config1 = MainDiffusersConfig (
258+ config1 = Main_Diffusers_SD1_Config (
233259 path = "/tmp/config1" ,
234260 name = "config1" ,
235261 base = BaseModelType .StableDiffusion1 ,
@@ -238,8 +264,11 @@ def test_filter_2(store: ModelRecordServiceBase):
238264 file_size = 1008 ,
239265 source = "test/source/" ,
240266 source_type = ModelSourceType .Path ,
267+ variant = ModelVariantType .Normal ,
268+ prediction_type = SchedulerPredictionType .Epsilon ,
269+ repo_variant = ModelRepoVariant .Default ,
241270 )
242- config2 = MainDiffusersConfig (
271+ config2 = Main_Diffusers_SD1_Config (
243272 path = "/tmp/config2" ,
244273 name = "config2" ,
245274 base = BaseModelType .StableDiffusion1 ,
@@ -248,28 +277,37 @@ def test_filter_2(store: ModelRecordServiceBase):
248277 file_size = 1009 ,
249278 source = "test/source/" ,
250279 source_type = ModelSourceType .Path ,
280+ variant = ModelVariantType .Normal ,
281+ prediction_type = SchedulerPredictionType .Epsilon ,
282+ repo_variant = ModelRepoVariant .Default ,
251283 )
252- config3 = MainDiffusersConfig (
284+ config3 = Main_Diffusers_SD2_Config (
253285 path = "/tmp/config3" ,
254286 name = "dup_name1" ,
255- base = BaseModelType ( "sd-2" ) ,
287+ base = BaseModelType . StableDiffusion2 ,
256288 type = ModelType .Main ,
257289 hash = "CONFIG3HASH" ,
258290 file_size = 1010 ,
259291 source = "test/source/" ,
260292 source_type = ModelSourceType .Path ,
293+ variant = ModelVariantType .Normal ,
294+ prediction_type = SchedulerPredictionType .Epsilon ,
295+ repo_variant = ModelRepoVariant .Default ,
261296 )
262- config4 = MainDiffusersConfig (
297+ config4 = Main_Diffusers_SDXL_Config (
263298 path = "/tmp/config4" ,
264299 name = "dup_name1" ,
265- base = BaseModelType ( "sdxl" ) ,
300+ base = BaseModelType . StableDiffusionXL ,
266301 type = ModelType .Main ,
267302 hash = "CONFIG3HASH" ,
268303 file_size = 1011 ,
269304 source = "test/source/" ,
270305 source_type = ModelSourceType .Path ,
306+ variant = ModelVariantType .Normal ,
307+ prediction_type = SchedulerPredictionType .Epsilon ,
308+ repo_variant = ModelRepoVariant .Default ,
271309 )
272- config5 = VAEDiffusersConfig (
310+ config5 = VAE_Diffusers_SD1_Config (
273311 path = "/tmp/config5" ,
274312 name = "dup_name1" ,
275313 base = BaseModelType .StableDiffusion1 ,
@@ -278,6 +316,7 @@ def test_filter_2(store: ModelRecordServiceBase):
278316 file_size = 1012 ,
279317 source = "test/source/" ,
280318 source_type = ModelSourceType .Path ,
319+ repo_variant = ModelRepoVariant .Default ,
281320 )
282321 for c in config1 , config2 , config3 , config4 , config5 :
283322 store .add_model (c )
0 commit comments