Skip to content

Commit c91971d

Browse files
authored
Added Example for DynamicBackdoorGAN
1 parent 66515a1 commit c91971d

File tree

1 file changed

+212
-0
lines changed

1 file changed

+212
-0
lines changed
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# -*- coding: utf-8 -*-
2+
"""DynamicBackdoorGAN_Demo.ipynb
3+
4+
Automatically generated by Colab.
5+
6+
Original file is located at
7+
https://colab.research.google.com/drive/1Uxw5hHxnvtDh2-dC5cHgSfBMNMl05lpD
8+
"""
9+
10+
# ✅ Imports
11+
import torch
12+
import torch.nn as nn
13+
import numpy as np
14+
from torch.utils.data import Subset
15+
from torchvision import datasets, transforms, models
16+
from art.estimators.classification import PyTorchClassifier
17+
from art.utils import to_categorical
18+
from art.attacks.poisoning import PoisoningAttackBackdoor
19+
20+
21+
# ✅ User Config
22+
config = {
23+
"dataset": "MNIST", # CIFAR10, CIFAR100, MNIST
24+
"model_name": "densenet121", # resnet18, resnet50, mobilenetv2, densenet121
25+
"poison_ratio": 0.1,
26+
"target_label": 0, # Target label to which poisoned samples are mapped
27+
"epochs": 30,
28+
"batch_size": 128,
29+
"epsilon": 0.5,
30+
"train_subset": None,
31+
"test_subset": None
32+
}
33+
34+
35+
# ✅ Trigger Generator
36+
class TriggerGenerator(nn.Module):
37+
def __init__(self, input_channels=3):
38+
super().__init__()
39+
self.net = nn.Sequential(
40+
nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
41+
nn.ReLU(),
42+
nn.Conv2d(32, 32, kernel_size=3, padding=1),
43+
nn.ReLU(),
44+
nn.Conv2d(32, input_channels, kernel_size=3, padding=1),
45+
nn.Tanh()
46+
)
47+
48+
def forward(self, x):
49+
return self.net(x)
50+
51+
52+
# ✅ ART-Compatible Poisoning Attack
53+
class DynamicBackdoorGAN(PoisoningAttackBackdoor):
54+
def __init__(self, generator, target_label, backdoor_rate, classifier, epsilon=0.5):
55+
super().__init__(perturbation=lambda x: x)
56+
self.classifier = classifier
57+
self.generator = generator.to(classifier.device)
58+
self.target_label = target_label
59+
self.backdoor_rate = backdoor_rate
60+
self.epsilon = epsilon
61+
62+
def apply_trigger(self, images):
63+
self.generator.eval()
64+
with torch.no_grad():
65+
images = nn.functional.interpolate(images, size=(32, 32), mode='bilinear')
66+
triggers = self.generator(images.to(self.classifier.device))
67+
poisoned = (images.to(self.classifier.device) + self.epsilon * triggers).clamp(0, 1)
68+
return poisoned
69+
70+
def poison(self, x, y):
71+
x_tensor = torch.tensor(x).float()
72+
y_tensor = torch.tensor(np.argmax(y, axis=1))
73+
74+
batch_size = x_tensor.shape[0]
75+
n_poison = int(self.backdoor_rate * batch_size)
76+
77+
poisoned = self.apply_trigger(x_tensor[:n_poison])
78+
clean = x_tensor[n_poison:].to(self.classifier.device)
79+
80+
poisoned_images = torch.cat([poisoned, clean], dim=0).cpu().numpy()
81+
82+
new_labels = y_tensor.clone()
83+
new_labels[:n_poison] = self.target_label
84+
85+
new_labels = to_categorical(new_labels.numpy(), nb_classes=self.classifier.nb_classes)
86+
return poisoned_images.astype(np.float32), new_labels.astype(np.float32)
87+
88+
def evaluate(self, x_clean, y_clean):
89+
x_tensor = torch.tensor(x_clean).float()
90+
poisoned_test = self.apply_trigger(x_tensor).cpu().numpy().astype(np.float32)
91+
92+
preds = self.classifier.predict(poisoned_test)
93+
true_target = np.full((len(preds),), self.target_label)
94+
pred_labels = np.argmax(preds, axis=1)
95+
96+
success = np.sum(pred_labels == true_target)
97+
asr = 100.0 * success / len(pred_labels)
98+
return asr
99+
100+
101+
# ✅ Utility: Load Data
102+
103+
def get_data(dataset="CIFAR10", train_subset=None, test_subset=None):
104+
if dataset in ["CIFAR10", "CIFAR100"]:
105+
transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])
106+
elif dataset == "MNIST":
107+
transform = transforms.Compose([
108+
transforms.Grayscale(num_output_channels=3),
109+
transforms.Resize((32, 32)),
110+
transforms.ToTensor()
111+
])
112+
else:
113+
raise ValueError("Unsupported dataset")
114+
115+
if dataset == "CIFAR10":
116+
dataset_cls = datasets.CIFAR10
117+
num_classes = 10
118+
elif dataset == "CIFAR100":
119+
dataset_cls = datasets.CIFAR100
120+
num_classes = 100
121+
elif dataset == "MNIST":
122+
dataset_cls = datasets.MNIST
123+
num_classes = 10
124+
125+
train_set = dataset_cls(root="./data", train=True, download=True, transform=transform)
126+
test_set = dataset_cls(root="./data", train=False, download=True, transform=transform)
127+
128+
if train_subset is not None:
129+
train_set = Subset(train_set, range(train_subset))
130+
if test_subset is not None:
131+
test_set = Subset(test_set, range(test_subset))
132+
133+
x_train = torch.stack([x for x, _ in train_set]).numpy()
134+
y_train = to_categorical([y for _, y in train_set], nb_classes=num_classes)
135+
136+
x_test = torch.stack([x for x, _ in test_set]).numpy()
137+
y_test = to_categorical([y for _, y in test_set], nb_classes=num_classes)
138+
139+
return x_train, y_train, x_test, y_test, num_classes
140+
141+
142+
# ✅ Utility: Get ART Classifier
143+
def get_classifier(config):
144+
model_name = config["model_name"]
145+
nb_classes = config["nb_classes"]
146+
input_shape = config["input_shape"]
147+
lr = config.get("learning_rate", 0.001)
148+
149+
if model_name == "resnet18":
150+
model = models.resnet18(num_classes=nb_classes)
151+
elif model_name == "resnet50":
152+
model = models.resnet50(num_classes=nb_classes)
153+
elif model_name == "mobilenetv2":
154+
model = models.mobilenet_v2(num_classes=nb_classes)
155+
elif model_name == "densenet121":
156+
model = models.densenet121(num_classes=nb_classes)
157+
else:
158+
raise ValueError(f"Unsupported model: {model_name}")
159+
160+
loss = torch.nn.CrossEntropyLoss()
161+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
162+
163+
classifier = PyTorchClassifier(
164+
model=model,
165+
loss=loss,
166+
optimizer=optimizer,
167+
input_shape=input_shape,
168+
nb_classes=nb_classes,
169+
clip_values=(0.0, 1.0),
170+
device_type="gpu" if torch.cuda.is_available() else "cpu"
171+
)
172+
return classifier
173+
174+
175+
# ✅ Full Experiment
176+
def run_dynamic_backdoor_experiment(config):
177+
x_train, y_train, x_test, y_test, num_classes = get_data(
178+
dataset=config["dataset"],
179+
train_subset=config.get("train_subset"),
180+
test_subset=config.get("test_subset")
181+
)
182+
config["nb_classes"] = num_classes
183+
config["input_shape"] = x_train.shape[1:]
184+
185+
classifier = get_classifier(config)
186+
187+
# Clean training
188+
classifier.fit(x_train, y_train, nb_epochs=config["epochs"], batch_size=config["batch_size"])
189+
clean_acc = np.mean(np.argmax(classifier.predict(x_test), axis=1) == np.argmax(y_test, axis=1))
190+
print(f"✅ Clean Accuracy: {clean_acc * 100:.2f}%")
191+
192+
# Poison training
193+
generator = TriggerGenerator()
194+
attack = DynamicBackdoorGAN(
195+
generator,
196+
config["target_label"],
197+
config["poison_ratio"],
198+
classifier,
199+
epsilon=config["epsilon"]
200+
)
201+
x_poison, y_poison = attack.poison(x_train, y_train)
202+
203+
classifier.fit(x_poison, y_poison, nb_epochs=config["epochs"], batch_size=config["batch_size"])
204+
poisoned_acc = np.mean(np.argmax(classifier.predict(x_test), axis=1) == np.argmax(y_test, axis=1))
205+
print(f"🎯 Poisoned Accuracy: {poisoned_acc * 100:.2f}%")
206+
207+
asr = attack.evaluate(x_test, y_test)
208+
print(f"💥 Attack Success Rate (ASR): {asr:.2f}%")
209+
210+
211+
# ✅ Run
212+
run_dynamic_backdoor_experiment(config)

0 commit comments

Comments
 (0)