diff --git a/code_soup/ch5/models/gan.py b/code_soup/ch5/models/gan.py index 71e62c0..d8236fb 100644 --- a/code_soup/ch5/models/gan.py +++ b/code_soup/ch5/models/gan.py @@ -169,6 +169,8 @@ def step(self, data: torch.Tensor) -> Tuple: Discriminator loss D_G_z2: Average discriminator outputs for the all fake batch after updating discriminator + errG: + Generator loss """ real_image, _ = data real_image = real_image.to(self.device) @@ -176,6 +178,7 @@ def step(self, data: torch.Tensor) -> Tuple: label = torch.full( (batch_size,), self.real_label, dtype=torch.float, device=self.device ) + self.discriminator.zero_grad() # Forward pass real batch through D output = self.discriminator(real_image).view(-1) # Calculate loss on all-real batch @@ -211,4 +214,4 @@ def step(self, data: torch.Tensor) -> Tuple: D_G_z2 = output.mean().item() # Update G self.generator.optimizer.step() - return D_x, D_G_z1, errD, D_G_z2 + return D_x, D_G_z1, errD, D_G_z2, errG diff --git a/code_soup/common/vision/perturbations.py b/code_soup/common/vision/perturbations.py index c68bc3c..bdc85cf 100644 --- a/code_soup/common/vision/perturbations.py +++ b/code_soup/common/vision/perturbations.py @@ -3,13 +3,19 @@ import numpy as np import torch +import torch.nn as nn + +from math import log10 from code_soup.common.perturbation import Perturbation class VisualPerturbation(Perturbation): """ - Docstring for VisualPerturbations + An abstract method for various Visual Perturbation Metrics + Methods + __init__(self, original : Union[np.ndarray, torch.Tensor], perturbed: Union[np.ndarray, torch.Tensor]) + - init method """ def __init__( @@ -21,16 +27,41 @@ def __init__( Docstring #Automatically cast to Tensor using the torch.from_numpy() in the __init__ using if """ - raise NotImplementedError - def calculate_LPNorm(self, p: Union[int, str]): - raise NotImplementedError + if type(original) == torch.Tensor: + self.original = original + else: + self.original = torch.from_numpy(original) + print(self.original.shape) - def calculate_PSNR(self): - raise NotImplementedError + if type(perturbed) == torch.Tensor: + self.perturbed = perturbed + else: + self.perturbed = torch.from_numpy(perturbed) - def calculate_RMSE(self): - raise NotImplementedError + def flatten(self, array : torch.tensor) -> torch.Tensor: + return array.flatten() + + def totensor(self, array : np.ndarray) -> torch.Tensor: + return torch.from_numpy(array) + + def subtract(self,original : torch.Tensor, perturbed : torch.Tensor) -> torch.Tensor: + return torch.sub(original, perturbed) + + def calculate_LPNorm(self, p: Union[int, str]) -> float: + if p == 'inf': + return torch.linalg.vector_norm(self.flatten(self.subtract(self.original,self.perturbed)), ord = float('inf')).item() + elif p == 'fro': + return self.calculate_LPNorm(2) + else: + return torch.linalg.norm(self.flatten(self.subtract(self.original,self.perturbed)), ord = p).item() + + def calculate_PSNR(self) -> float: + return 20 * log10(1.0/self.calculate_RMSE()) + + def calculate_RMSE(self) -> float: + loss = nn.MSELoss() + return (loss(self.original, self.perturbed)**0.5).item() def calculate_SAM(self): raise NotImplementedError