Skip to content

Commit 495b5ad

Browse files
Merge pull request #7 from soham-chitnis10/atn
ATNs for ImageNet Dataset
2 parents 4733b3e + 38e52e5 commit 495b5ad

File tree

1 file changed

+170
-0
lines changed
  • code_soup/ch5/algorithms

1 file changed

+170
-0
lines changed

code_soup/ch5/algorithms/atn.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,3 +331,173 @@ def forward(self, x):
331331
logits = self.classifier_model(adv_out + x)
332332
softmax_logits = F.softmax(logits, dim=1)
333333
return adv_out + x, softmax_logits
334+
335+
336+
class BilinearUpsample(nn.Module):
337+
def __init__(self, scale_factor):
338+
super(BilinearUpsample, self).__init__()
339+
self.scale_factor = scale_factor
340+
341+
def forward(self, x):
342+
return F.interpolate(
343+
x, scale_factor=self.scale_factor, mode="bilinear", align_corners=True
344+
)
345+
346+
347+
class BaseDeconvAAE(SimpleAAE):
348+
def __init__(
349+
self,
350+
classifier_model: torch.nn.Module,
351+
pretrained_backbone: torch.nn.Module,
352+
target_idx: int,
353+
alpha: float = 1.5,
354+
beta: float = 0.010,
355+
backbone_output_shape: list = [192, 35, 35],
356+
):
357+
358+
if backbone_output_shape != [192, 35, 35]:
359+
raise ValueError("Backbone output shape must be [192, 35, 35].")
360+
361+
super(BaseDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta)
362+
363+
layers = [
364+
pretrained_backbone,
365+
nn.ZeroPad2d((1, 1, 1, 1)),
366+
nn.ConvTranspose2d(192, 512, kernel_size=4, stride=2, padding=1),
367+
nn.ReLU(),
368+
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
369+
nn.ReLU(),
370+
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
371+
nn.ReLU(),
372+
nn.ZeroPad2d((3, 2, 3, 2)),
373+
nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1),
374+
nn.Tanh(),
375+
]
376+
377+
self.atn = nn.ModuleList(layers)
378+
379+
380+
class ResizeConvAAE(SimpleAAE):
381+
def __init__(
382+
self,
383+
classifier_model: torch.nn.Module,
384+
target_idx: int,
385+
alpha: float = 1.5,
386+
beta: float = 0.010,
387+
):
388+
389+
super(ResizeConvAAE, self).__init__(classifier_model, target_idx, alpha, beta)
390+
391+
layers = [
392+
nn.Conv2d(3, 128, 5, padding=11),
393+
nn.ReLU(),
394+
BilinearUpsample(scale_factor=0.5),
395+
nn.Conv2d(128, 256, 4, padding=11),
396+
nn.ReLU(),
397+
BilinearUpsample(scale_factor=0.5),
398+
nn.Conv2d(256, 512, 3, padding=11),
399+
nn.ReLU(),
400+
BilinearUpsample(scale_factor=0.5),
401+
nn.Conv2d(512, 512, 1, padding=11),
402+
nn.ReLU(),
403+
BilinearUpsample(scale_factor=2),
404+
nn.Conv2d(512, 256, 3, padding=11),
405+
nn.ReLU(),
406+
BilinearUpsample(scale_factor=2),
407+
nn.Conv2d(256, 128, 4, padding=11),
408+
nn.ReLU(),
409+
nn.ZeroPad2d((8, 8, 8, 8)),
410+
nn.Conv2d(128, 3, 3, padding=11),
411+
nn.Tanh(),
412+
]
413+
414+
self.atn = nn.ModuleList(layers)
415+
416+
417+
class ConvDeconvAAE(SimpleAAE):
418+
def __init__(
419+
self,
420+
classifier_model: torch.nn.Module,
421+
target_idx: int,
422+
alpha: float = 1.5,
423+
beta: float = 0.010,
424+
):
425+
426+
super(ConvDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta)
427+
428+
layers = [
429+
nn.Conv2d(3, 256, 3, stride=2, padding=2),
430+
nn.ReLU(),
431+
nn.Conv2d(256, 512, 3, stride=2, padding=2),
432+
nn.ReLU(),
433+
nn.Conv2d(512, 768, 3, stride=2, padding=2),
434+
nn.ReLU(),
435+
nn.ConvTranspose2d(768, 512, kernel_size=4, stride=2, padding=2),
436+
nn.ReLU(),
437+
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=2),
438+
nn.ReLU(),
439+
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=2),
440+
nn.ReLU(),
441+
nn.ZeroPad2d((146, 145, 146, 145)),
442+
nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1),
443+
nn.Tanh(),
444+
]
445+
446+
self.atn = nn.ModuleList(layers)
447+
448+
449+
class BaseDeconvPATN(SimplePATN):
450+
def __init__(
451+
self,
452+
classifier_model: torch.nn.Module,
453+
pretrained_backbone: torch.nn.Module,
454+
target_idx: int,
455+
alpha: float = 1.5,
456+
beta: float = 0.010,
457+
backbone_output_shape: list = [192, 35, 35],
458+
):
459+
460+
if backbone_output_shape != [192, 35, 35]:
461+
raise ValueError("Backbone output shape must be [192, 35, 35].")
462+
463+
super(BaseDeconvPATN, self).__init__(classifier_model, target_idx, alpha, beta)
464+
465+
layers = [
466+
pretrained_backbone,
467+
nn.ZeroPad2d((1, 1, 1, 1)),
468+
nn.ConvTranspose2d(192, 512, kernel_size=4, stride=2, padding=1),
469+
nn.ReLU(),
470+
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
471+
nn.ReLU(),
472+
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
473+
nn.ReLU(),
474+
nn.ZeroPad2d((3, 2, 3, 2)),
475+
nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1),
476+
nn.Tanh(), # TODO: CHeck if right activation
477+
]
478+
479+
self.atn = nn.ModuleList(layers)
480+
481+
482+
class ConvFCPATN(SimplePATN):
483+
def __init__(
484+
self,
485+
classifier_model: torch.nn.Module,
486+
target_idx: int,
487+
alpha: float = 1.5,
488+
beta: float = 0.010,
489+
):
490+
491+
super(BaseDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta)
492+
493+
layers = [
494+
nn.Conv2d(3, 512, 3, stride=2, padding=22),
495+
nn.Conv2d(512, 256, 3, stride=2, padding=22),
496+
nn.Conv2d(256, 128, 3, stride=2, padding=22),
497+
nn.Flatten(),
498+
nn.Linear(184832, 512),
499+
nn.Linear(512, 268203),
500+
nn.Tanh(),
501+
]
502+
503+
self.atn = nn.ModuleList(layers)

0 commit comments

Comments
 (0)