Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 32 additions & 32 deletions code/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def latex_table_certified_accuracy(outfile: str, radius_start: float, radius_sto
f.write("& $r = {:.3}$".format(radius))
f.write("\\\\\n")

f.write("\midrule\n")
f.write(r"\midrule\n")

for i, method in enumerate(methods):
f.write(method.legend)
Expand Down Expand Up @@ -153,54 +153,54 @@ def markdown_table_certified_accuracy(outfile: str, radius_start: float, radius_
if __name__ == "__main__":
latex_table_certified_accuracy(
"analysis/latex/vary_noise_cifar10", 0.25, 1.5, 0.25, [
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), "$\sigma = 0.12$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), r"$\sigma = 0.12$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), r"$\sigma = 0.25$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), r"$\sigma = 0.50$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), r"$\sigma = 1.00$"),
])
markdown_table_certified_accuracy(
"analysis/markdown/vary_noise_cifar10", 0.25, 1.5, 0.25, [
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), "σ = 0.12"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), "σ = 0.25"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "σ = 0.50"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), "σ = 1.00"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), r"σ = 0.12"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), r"σ = 0.25"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), r"σ = 0.50"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), r"σ = 1.00"),
])
latex_table_certified_accuracy(
"analysis/latex/vary_noise_imagenet", 0.5, 3.0, 0.5, [
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), r"$\sigma = 0.25$"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), r"$\sigma = 0.50$"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), r"$\sigma = 1.00$"),
])
markdown_table_certified_accuracy(
"analysis/markdown/vary_noise_imagenet", 0.5, 3.0, 0.5, [
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), "σ = 0.25"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "σ = 0.50"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), "σ = 1.00"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), r"σ = 0.25"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), r"σ = 0.50"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), r"σ = 1.00"),
])
plot_certified_accuracy(
"analysis/plots/vary_noise_cifar10", "CIFAR-10, vary $\sigma$", 1.5, [
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), "$\sigma = 0.12$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"),
"analysis/plots/vary_noise_cifar10", r"CIFAR-10, vary $\sigma$", 1.5, [
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), r"$\sigma = 0.12$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), r"$\sigma = 0.25$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), r"$\sigma = 0.50$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), r"$\sigma = 1.00$"),
])
plot_certified_accuracy(
"analysis/plots/vary_train_noise_cifar_050", "CIFAR-10, vary train noise, $\sigma=0.5$", 1.5, [
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.50"), "train $\sigma = 0.25$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "train $\sigma = 0.50$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_0.50"), "train $\sigma = 1.00$"),
"analysis/plots/vary_train_noise_cifar_050", r"CIFAR-10, vary train noise, $\sigma=0.5$", 1.5, [
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.50"), r"train $\sigma = 0.25$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), r"train $\sigma = 0.50$"),
Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_0.50"), r"train $\sigma = 1.00$"),
])
plot_certified_accuracy(
"analysis/plots/vary_train_noise_imagenet_050", "ImageNet, vary train noise, $\sigma=0.5$", 1.5, [
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.50"), "train $\sigma = 0.25$"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "train $\sigma = 0.50$"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_0.50"), "train $\sigma = 1.00$"),
"analysis/plots/vary_train_noise_imagenet_050", r"ImageNet, vary train noise, $\sigma=0.5$", 1.5, [
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.50"), r"train $\sigma = 0.25$"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), r"train $\sigma = 0.50$"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_0.50"), r"train $\sigma = 1.00$"),
])
plot_certified_accuracy(
"analysis/plots/vary_noise_imagenet", "ImageNet, vary $\sigma$", 4, [
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"),
"analysis/plots/vary_noise_imagenet", r"ImageNet, vary $\sigma$", 4, [
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), r"$\sigma = 0.25$"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), r"$\sigma = 0.50$"),
Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), r"$\sigma = 1.00$"),
])
plot_certified_accuracy(
"analysis/plots/high_prob", "Approximate vs. High-Probability", 2.0, [
Expand Down
2 changes: 1 addition & 1 deletion code/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _imagenet(split: str) -> Dataset:
elif split == "test":
subdir = os.path.join(dir, "val")
transform = transforms.Compose([
transforms.Scale(256),
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
])
Expand Down
2 changes: 1 addition & 1 deletion code/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def accuracy(output, target, topk=(1,)):

res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

Expand Down
67 changes: 67 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
ansicon==1.89.0
asttokens==3.0.0
blessed==1.21.0
colorama==0.4.6
comm==0.2.2
contourpy==1.3.3
cycler==0.12.1
debugpy==1.8.14
decorator==5.2.1
executing==2.2.0
filelock==3.13.1
fonttools==4.59.1
fsspec==2024.6.1
gpustat==1.1.1
ipykernel==6.29.5
ipython==9.2.0
ipython_pygments_lexers==1.1.1
jedi==0.19.2
Jinja2==3.1.4
jinxed==1.3.0
joblib==1.5.1
jupyter_client==8.6.3
jupyter_core==5.7.2
kiwisolver==1.4.9
MarkupSafe==2.1.5
matplotlib==3.10.5
matplotlib-inline==0.1.7
mpmath==1.3.0
nest-asyncio==1.6.0
networkx==3.3
numpy==1.26.4
nvidia-ml-py==13.580.65
opencv-python==4.11.0.86
packaging==25.0
pandas==2.3.2
parso==0.8.4
patsy==1.0.1
pillow==11.0.0
platformdirs==4.3.8
prompt_toolkit==3.0.51
psutil==7.0.0
pure_eval==0.2.3
Pygments==2.19.1
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytz==2025.2
pywin32==310
pyzmq==26.4.0
scikit-learn==1.7.0
scipy==1.16.0
seaborn==0.13.2
setGPU==0.0.7
six==1.17.0
skorch==1.1.0
stack-data==0.6.3
statsmodels==0.14.5
sympy==1.13.3
tabulate==0.9.0
threadpoolctl==3.6.0
torch==2.8.0+cu126
torchvision==0.23.0+cu126
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
typing_extensions==4.13.2
tzdata==2025.2
wcwidth==0.2.13