@@ -331,3 +331,173 @@ def forward(self, x):
331331 logits = self .classifier_model (adv_out + x )
332332 softmax_logits = F .softmax (logits , dim = 1 )
333333 return adv_out + x , softmax_logits
334+
335+
336+ class BilinearUpsample (nn .Module ):
337+ def __init__ (self , scale_factor ):
338+ super (BilinearUpsample , self ).__init__ ()
339+ self .scale_factor = scale_factor
340+
341+ def forward (self , x ):
342+ return F .interpolate (
343+ x , scale_factor = self .scale_factor , mode = "bilinear" , align_corners = True
344+ )
345+
346+
347+ class BaseDeconvAAE (SimpleAAE ):
348+ def __init__ (
349+ self ,
350+ classifier_model : torch .nn .Module ,
351+ pretrained_backbone : torch .nn .Module ,
352+ target_idx : int ,
353+ alpha : float = 1.5 ,
354+ beta : float = 0.010 ,
355+ backbone_output_shape : list = [192 , 35 , 35 ],
356+ ):
357+
358+ if backbone_output_shape != [192 , 35 , 35 ]:
359+ raise ValueError ("Backbone output shape must be [192, 35, 35]." )
360+
361+ super (BaseDeconvAAE , self ).__init__ (classifier_model , target_idx , alpha , beta )
362+
363+ layers = [
364+ pretrained_backbone ,
365+ nn .ZeroPad2d ((1 , 1 , 1 , 1 )),
366+ nn .ConvTranspose2d (192 , 512 , kernel_size = 4 , stride = 2 , padding = 1 ),
367+ nn .ReLU (),
368+ nn .ConvTranspose2d (512 , 256 , kernel_size = 3 , stride = 2 , padding = 1 ),
369+ nn .ReLU (),
370+ nn .ConvTranspose2d (256 , 128 , kernel_size = 4 , stride = 2 , padding = 1 ),
371+ nn .ReLU (),
372+ nn .ZeroPad2d ((3 , 2 , 3 , 2 )),
373+ nn .ConvTranspose2d (128 , 3 , kernel_size = 3 , stride = 1 , padding = 1 ),
374+ nn .Tanh (),
375+ ]
376+
377+ self .atn = nn .ModuleList (layers )
378+
379+
380+ class ResizeConvAAE (SimpleAAE ):
381+ def __init__ (
382+ self ,
383+ classifier_model : torch .nn .Module ,
384+ target_idx : int ,
385+ alpha : float = 1.5 ,
386+ beta : float = 0.010 ,
387+ ):
388+
389+ super (ResizeConvAAE , self ).__init__ (classifier_model , target_idx , alpha , beta )
390+
391+ layers = [
392+ nn .Conv2d (3 , 128 , 5 , padding = 11 ),
393+ nn .ReLU (),
394+ BilinearUpsample (scale_factor = 0.5 ),
395+ nn .Conv2d (128 , 256 , 4 , padding = 11 ),
396+ nn .ReLU (),
397+ BilinearUpsample (scale_factor = 0.5 ),
398+ nn .Conv2d (256 , 512 , 3 , padding = 11 ),
399+ nn .ReLU (),
400+ BilinearUpsample (scale_factor = 0.5 ),
401+ nn .Conv2d (512 , 512 , 1 , padding = 11 ),
402+ nn .ReLU (),
403+ BilinearUpsample (scale_factor = 2 ),
404+ nn .Conv2d (512 , 256 , 3 , padding = 11 ),
405+ nn .ReLU (),
406+ BilinearUpsample (scale_factor = 2 ),
407+ nn .Conv2d (256 , 128 , 4 , padding = 11 ),
408+ nn .ReLU (),
409+ nn .ZeroPad2d ((8 , 8 , 8 , 8 )),
410+ nn .Conv2d (128 , 3 , 3 , padding = 11 ),
411+ nn .Tanh (),
412+ ]
413+
414+ self .atn = nn .ModuleList (layers )
415+
416+
417+ class ConvDeconvAAE (SimpleAAE ):
418+ def __init__ (
419+ self ,
420+ classifier_model : torch .nn .Module ,
421+ target_idx : int ,
422+ alpha : float = 1.5 ,
423+ beta : float = 0.010 ,
424+ ):
425+
426+ super (ConvDeconvAAE , self ).__init__ (classifier_model , target_idx , alpha , beta )
427+
428+ layers = [
429+ nn .Conv2d (3 , 256 , 3 , stride = 2 , padding = 2 ),
430+ nn .ReLU (),
431+ nn .Conv2d (256 , 512 , 3 , stride = 2 , padding = 2 ),
432+ nn .ReLU (),
433+ nn .Conv2d (512 , 768 , 3 , stride = 2 , padding = 2 ),
434+ nn .ReLU (),
435+ nn .ConvTranspose2d (768 , 512 , kernel_size = 4 , stride = 2 , padding = 2 ),
436+ nn .ReLU (),
437+ nn .ConvTranspose2d (512 , 256 , kernel_size = 3 , stride = 2 , padding = 2 ),
438+ nn .ReLU (),
439+ nn .ConvTranspose2d (256 , 128 , kernel_size = 4 , stride = 2 , padding = 2 ),
440+ nn .ReLU (),
441+ nn .ZeroPad2d ((146 , 145 , 146 , 145 )),
442+ nn .ConvTranspose2d (128 , 3 , kernel_size = 3 , stride = 1 , padding = 1 ),
443+ nn .Tanh (),
444+ ]
445+
446+ self .atn = nn .ModuleList (layers )
447+
448+
449+ class BaseDeconvPATN (SimplePATN ):
450+ def __init__ (
451+ self ,
452+ classifier_model : torch .nn .Module ,
453+ pretrained_backbone : torch .nn .Module ,
454+ target_idx : int ,
455+ alpha : float = 1.5 ,
456+ beta : float = 0.010 ,
457+ backbone_output_shape : list = [192 , 35 , 35 ],
458+ ):
459+
460+ if backbone_output_shape != [192 , 35 , 35 ]:
461+ raise ValueError ("Backbone output shape must be [192, 35, 35]." )
462+
463+ super (BaseDeconvPATN , self ).__init__ (classifier_model , target_idx , alpha , beta )
464+
465+ layers = [
466+ pretrained_backbone ,
467+ nn .ZeroPad2d ((1 , 1 , 1 , 1 )),
468+ nn .ConvTranspose2d (192 , 512 , kernel_size = 4 , stride = 2 , padding = 1 ),
469+ nn .ReLU (),
470+ nn .ConvTranspose2d (512 , 256 , kernel_size = 3 , stride = 2 , padding = 1 ),
471+ nn .ReLU (),
472+ nn .ConvTranspose2d (256 , 128 , kernel_size = 4 , stride = 2 , padding = 1 ),
473+ nn .ReLU (),
474+ nn .ZeroPad2d ((3 , 2 , 3 , 2 )),
475+ nn .ConvTranspose2d (128 , 3 , kernel_size = 3 , stride = 1 , padding = 1 ),
476+ nn .Tanh (), # TODO: CHeck if right activation
477+ ]
478+
479+ self .atn = nn .ModuleList (layers )
480+
481+
482+ class ConvFCPATN (SimplePATN ):
483+ def __init__ (
484+ self ,
485+ classifier_model : torch .nn .Module ,
486+ target_idx : int ,
487+ alpha : float = 1.5 ,
488+ beta : float = 0.010 ,
489+ ):
490+
491+ super (BaseDeconvAAE , self ).__init__ (classifier_model , target_idx , alpha , beta )
492+
493+ layers = [
494+ nn .Conv2d (3 , 512 , 3 , stride = 2 , padding = 22 ),
495+ nn .Conv2d (512 , 256 , 3 , stride = 2 , padding = 22 ),
496+ nn .Conv2d (256 , 128 , 3 , stride = 2 , padding = 22 ),
497+ nn .Flatten (),
498+ nn .Linear (184832 , 512 ),
499+ nn .Linear (512 , 268203 ),
500+ nn .Tanh (),
501+ ]
502+
503+ self .atn = nn .ModuleList (layers )
0 commit comments