Example of robust training on CIFAR10.ΒΆ

Clean data accuracies, Adversarial data accuracies

Out:

Files already downloaded and verified
Files already downloaded and verified
Training on Linf ball(0.03137254901960784).
Train Accuracy: 32.6%
Train Adv Accuracy: 22.9%
Test Accuracy: 30.7%
Test Adv Accuracy: 21.2%
Train Accuracy: 46.3%
Train Adv Accuracy: 27.6%
Test Accuracy: 44.9%
Test Adv Accuracy: 27.0%
Train Accuracy: 54.6%
Train Adv Accuracy: 29.2%
Test Accuracy: 38.1%
Test Adv Accuracy: 27.2%
Train Accuracy: 59.5%
Train Adv Accuracy: 30.7%
Test Accuracy: 53.8%
Test Adv Accuracy: 27.7%
Train Accuracy: 63.4%
Train Adv Accuracy: 31.9%
Test Accuracy: 52.7%
Test Adv Accuracy: 26.3%
Train Accuracy: 65.8%
Train Adv Accuracy: 31.9%
Test Accuracy: 44.2%
Test Adv Accuracy: 28.0%
Train Accuracy: 67.4%
Train Adv Accuracy: 32.4%
Test Accuracy: 57.8%
Test Adv Accuracy: 32.5%
Train Accuracy: 69.1%
Train Adv Accuracy: 32.1%
Test Accuracy: 59.3%
Test Adv Accuracy: 26.5%
Train Accuracy: 70.2%
Train Adv Accuracy: 33.0%
Test Accuracy: 58.7%
Test Adv Accuracy: 29.1%
Train Accuracy: 71.2%
Train Adv Accuracy: 33.3%
Test Accuracy: 62.5%
Test Adv Accuracy: 31.9%
Train Accuracy: 71.6%
Train Adv Accuracy: 33.1%
Test Accuracy: 60.6%
Test Adv Accuracy: 28.7%
Train Accuracy: 72.2%
Train Adv Accuracy: 33.3%
Test Accuracy: 60.7%
Test Adv Accuracy: 30.0%
Train Accuracy: 72.6%
Train Adv Accuracy: 33.1%
Test Accuracy: 66.1%
Test Adv Accuracy: 23.6%
Train Accuracy: 73.2%
Train Adv Accuracy: 33.6%
Test Accuracy: 60.9%
Test Adv Accuracy: 28.9%
Train Accuracy: 73.6%
Train Adv Accuracy: 34.0%
Test Accuracy: 60.7%
Test Adv Accuracy: 31.1%
Train Accuracy: 74.0%
Train Adv Accuracy: 34.0%
Test Accuracy: 63.8%
Test Adv Accuracy: 27.6%
Train Accuracy: 74.4%
Train Adv Accuracy: 34.1%
Test Accuracy: 63.6%
Test Adv Accuracy: 28.6%
Train Accuracy: 74.7%
Train Adv Accuracy: 33.8%
Test Accuracy: 62.9%
Test Adv Accuracy: 26.4%
Train Accuracy: 74.9%
Train Adv Accuracy: 33.9%
Test Accuracy: 61.2%
Test Adv Accuracy: 28.0%
Train Accuracy: 75.1%
Train Adv Accuracy: 34.2%
Test Accuracy: 61.8%
Test Adv Accuracy: 28.1%
Train Accuracy: 75.0%
Train Adv Accuracy: 34.1%
Test Accuracy: 63.5%
Test Adv Accuracy: 32.7%
Train Accuracy: 75.4%
Train Adv Accuracy: 34.2%
Test Accuracy: 62.5%
Test Adv Accuracy: 28.4%
Train Accuracy: 75.3%
Train Adv Accuracy: 34.1%
Test Accuracy: 62.2%
Test Adv Accuracy: 30.5%
Train Accuracy: 75.8%
Train Adv Accuracy: 34.2%
Test Accuracy: 56.4%
Test Adv Accuracy: 29.4%
Train Accuracy: 76.0%
Train Adv Accuracy: 33.9%
Test Accuracy: 62.2%
Test Adv Accuracy: 27.7%
Train Accuracy: 76.0%
Train Adv Accuracy: 34.6%
Test Accuracy: 63.8%
Test Adv Accuracy: 28.6%
Train Accuracy: 76.1%
Train Adv Accuracy: 34.3%
Test Accuracy: 58.4%
Test Adv Accuracy: 30.3%
Train Accuracy: 76.1%
Train Adv Accuracy: 34.6%
Test Accuracy: 66.1%
Test Adv Accuracy: 30.0%
Train Accuracy: 75.9%
Train Adv Accuracy: 34.6%
Test Accuracy: 64.3%
Test Adv Accuracy: 31.4%
Train Accuracy: 76.3%
Train Adv Accuracy: 34.3%
Test Accuracy: 60.1%
Test Adv Accuracy: 31.6%
Train Accuracy: 76.4%
Train Adv Accuracy: 34.4%
Test Accuracy: 51.9%
Test Adv Accuracy: 31.9%
Train Accuracy: 76.6%
Train Adv Accuracy: 34.8%
Test Accuracy: 65.5%
Test Adv Accuracy: 33.0%
Train Accuracy: 76.5%
Train Adv Accuracy: 34.5%
Test Accuracy: 56.4%
Test Adv Accuracy: 31.8%
Train Accuracy: 76.5%
Train Adv Accuracy: 34.2%
Test Accuracy: 65.2%
Test Adv Accuracy: 29.7%
Train Accuracy: 76.6%
Train Adv Accuracy: 34.6%
Test Accuracy: 62.9%
Test Adv Accuracy: 31.8%
Train Accuracy: 76.5%
Train Adv Accuracy: 34.5%
Test Accuracy: 58.2%
Test Adv Accuracy: 25.9%
Train Accuracy: 76.4%
Train Adv Accuracy: 34.7%
Test Accuracy: 63.8%
Test Adv Accuracy: 29.7%
Train Accuracy: 77.0%
Train Adv Accuracy: 34.7%
Test Accuracy: 57.9%
Test Adv Accuracy: 30.2%
Train Accuracy: 77.0%
Train Adv Accuracy: 34.0%
Test Accuracy: 64.6%
Test Adv Accuracy: 29.4%
Train Accuracy: 76.7%
Train Adv Accuracy: 34.5%
Test Accuracy: 62.4%
Test Adv Accuracy: 28.2%
Train Accuracy: 76.8%
Train Adv Accuracy: 34.7%
Test Accuracy: 59.6%
Test Adv Accuracy: 26.4%
Train Accuracy: 77.3%
Train Adv Accuracy: 34.7%
Test Accuracy: 64.6%
Test Adv Accuracy: 27.7%
Train Accuracy: 77.0%
Train Adv Accuracy: 35.0%
Test Accuracy: 57.5%
Test Adv Accuracy: 30.7%
Train Accuracy: 76.8%
Train Adv Accuracy: 34.8%
Test Accuracy: 65.8%
Test Adv Accuracy: 32.9%
Train Accuracy: 77.3%
Train Adv Accuracy: 35.1%
Test Accuracy: 61.3%
Test Adv Accuracy: 27.5%
Train Accuracy: 77.4%
Train Adv Accuracy: 34.6%
Test Accuracy: 62.4%
Test Adv Accuracy: 31.8%
Train Accuracy: 77.0%
Train Adv Accuracy: 34.6%
Test Accuracy: 57.0%
Test Adv Accuracy: 31.7%
Train Accuracy: 77.2%
Train Adv Accuracy: 35.2%
Test Accuracy: 57.7%
Test Adv Accuracy: 31.7%
Train Accuracy: 77.2%
Train Adv Accuracy: 35.0%
Test Accuracy: 65.6%
Test Adv Accuracy: 24.2%
Train Accuracy: 77.6%
Train Adv Accuracy: 34.8%
Test Accuracy: 63.5%
Test Adv Accuracy: 30.3%
Train Accuracy: 77.7%
Train Adv Accuracy: 35.2%
Test Accuracy: 63.7%
Test Adv Accuracy: 30.8%
Train Accuracy: 77.3%
Train Adv Accuracy: 35.0%
Test Accuracy: 63.5%
Test Adv Accuracy: 30.8%
Train Accuracy: 77.4%
Train Adv Accuracy: 35.1%
Test Accuracy: 59.3%
Test Adv Accuracy: 32.5%
Train Accuracy: 77.4%
Train Adv Accuracy: 34.8%
Test Accuracy: 56.7%
Test Adv Accuracy: 30.5%
Train Accuracy: 77.3%
Train Adv Accuracy: 35.3%
Test Accuracy: 67.7%
Test Adv Accuracy: 32.0%
Train Accuracy: 77.7%
Train Adv Accuracy: 35.2%
Test Accuracy: 61.4%
Test Adv Accuracy: 32.5%
Train Accuracy: 77.6%
Train Adv Accuracy: 34.9%
Test Accuracy: 63.1%
Test Adv Accuracy: 31.6%
Train Accuracy: 77.8%
Train Adv Accuracy: 35.0%
Test Accuracy: 63.5%
Test Adv Accuracy: 29.2%
Train Accuracy: 77.2%
Train Adv Accuracy: 35.5%
Test Accuracy: 61.5%
Test Adv Accuracy: 33.4%
Train Accuracy: 77.6%
Train Adv Accuracy: 35.2%
Test Accuracy: 66.1%
Test Adv Accuracy: 27.1%
Train Accuracy: 77.6%
Train Adv Accuracy: 34.9%
Test Accuracy: 63.4%
Test Adv Accuracy: 31.2%
Train Accuracy: 77.7%
Train Adv Accuracy: 35.2%
Test Accuracy: 62.4%
Test Adv Accuracy: 31.3%
Train Accuracy: 77.3%
Train Adv Accuracy: 35.0%
Test Accuracy: 65.2%
Test Adv Accuracy: 28.5%
Train Accuracy: 77.9%
Train Adv Accuracy: 34.8%
Test Accuracy: 63.9%
Test Adv Accuracy: 27.9%
Train Accuracy: 77.8%
Train Adv Accuracy: 35.1%
Test Accuracy: 63.7%
Test Adv Accuracy: 28.0%
Train Accuracy: 77.7%
Train Adv Accuracy: 35.5%
Test Accuracy: 58.6%
Test Adv Accuracy: 31.7%
Train Accuracy: 77.8%
Train Adv Accuracy: 35.1%
Test Accuracy: 63.4%
Test Adv Accuracy: 31.5%
Train Accuracy: 77.9%
Train Adv Accuracy: 34.8%
Test Accuracy: 57.1%
Test Adv Accuracy: 31.2%
Train Accuracy: 77.7%
Train Adv Accuracy: 35.2%
Test Accuracy: 62.8%
Test Adv Accuracy: 31.0%
Train Accuracy: 77.1%
Train Adv Accuracy: 34.7%
Test Accuracy: 57.0%
Test Adv Accuracy: 32.3%
Train Accuracy: 77.9%
Train Adv Accuracy: 35.1%
Test Accuracy: 64.8%
Test Adv Accuracy: 29.7%
Train Accuracy: 77.9%
Train Adv Accuracy: 35.3%
Test Accuracy: 65.1%
Test Adv Accuracy: 29.5%
Train Accuracy: 78.0%
Train Adv Accuracy: 34.9%
Test Accuracy: 64.1%
Test Adv Accuracy: 31.1%
Train Accuracy: 77.8%
Train Adv Accuracy: 35.7%
Test Accuracy: 58.1%
Test Adv Accuracy: 33.8%
Train Accuracy: 77.8%
Train Adv Accuracy: 35.2%
Test Accuracy: 66.6%
Test Adv Accuracy: 30.8%
Train Accuracy: 78.1%
Train Adv Accuracy: 35.4%
Test Accuracy: 57.5%
Test Adv Accuracy: 27.9%
Train Accuracy: 78.1%
Train Adv Accuracy: 35.1%
Test Accuracy: 60.9%
Test Adv Accuracy: 29.9%
Train Accuracy: 77.9%
Train Adv Accuracy: 35.1%
Test Accuracy: 64.3%
Test Adv Accuracy: 26.8%
Train Accuracy: 77.8%
Train Adv Accuracy: 34.9%
Test Accuracy: 54.2%
Test Adv Accuracy: 30.5%
Train Accuracy: 78.2%
Train Adv Accuracy: 35.0%
Test Accuracy: 54.6%
Test Adv Accuracy: 31.0%
Train Accuracy: 78.2%
Train Adv Accuracy: 35.2%
Test Accuracy: 62.4%
Test Adv Accuracy: 31.7%
Train Accuracy: 77.9%
Train Adv Accuracy: 35.4%
Test Accuracy: 61.1%
Test Adv Accuracy: 34.1%
Train Accuracy: 78.1%
Train Adv Accuracy: 34.8%
Test Accuracy: 51.0%
Test Adv Accuracy: 31.5%
Train Accuracy: 77.9%
Train Adv Accuracy: 35.0%
Test Accuracy: 64.0%
Test Adv Accuracy: 26.5%
Train Accuracy: 78.2%
Train Adv Accuracy: 35.1%
Test Accuracy: 54.2%
Test Adv Accuracy: 33.4%
Train Accuracy: 77.6%
Train Adv Accuracy: 35.3%
Test Accuracy: 63.6%
Test Adv Accuracy: 33.1%
Train Accuracy: 77.9%
Train Adv Accuracy: 35.4%
Test Accuracy: 65.6%
Test Adv Accuracy: 29.6%
Train Accuracy: 78.0%
Train Adv Accuracy: 35.1%
Test Accuracy: 65.7%
Test Adv Accuracy: 27.6%
Train Accuracy: 77.8%
Train Adv Accuracy: 35.5%
Test Accuracy: 57.6%
Test Adv Accuracy: 32.7%
Train Accuracy: 78.0%
Train Adv Accuracy: 35.0%
Test Accuracy: 50.4%
Test Adv Accuracy: 30.7%
Train Accuracy: 78.3%
Train Adv Accuracy: 35.0%
Test Accuracy: 59.2%
Test Adv Accuracy: 27.7%
Train Accuracy: 78.0%
Train Adv Accuracy: 34.5%
Test Accuracy: 59.7%
Test Adv Accuracy: 32.5%
Train Accuracy: 78.2%
Train Adv Accuracy: 35.4%
Test Accuracy: 57.2%
Test Adv Accuracy: 29.3%
Train Accuracy: 78.1%
Train Adv Accuracy: 35.3%
Test Accuracy: 53.1%
Test Adv Accuracy: 30.8%
Train Accuracy: 77.7%
Train Adv Accuracy: 34.9%
Test Accuracy: 59.7%
Test Adv Accuracy: 29.9%
Train Accuracy: 78.2%
Train Adv Accuracy: 35.3%
Test Accuracy: 56.8%
Test Adv Accuracy: 28.0%
Train Accuracy: 77.9%
Train Adv Accuracy: 35.0%
Test Accuracy: 57.8%
Test Adv Accuracy: 31.8%
Train Accuracy: 78.1%
Train Adv Accuracy: 35.0%
Test Accuracy: 55.6%
Test Adv Accuracy: 30.0%
Train Accuracy: 78.2%
Train Adv Accuracy: 35.1%
Test Accuracy: 63.7%
Test Adv Accuracy: 30.3%
Train Accuracy: 78.2%
Train Adv Accuracy: 35.6%
Test Accuracy: 63.1%
Test Adv Accuracy: 35.0%

import matplotlib.pyplot as plt
from chop.adversary import Adversary
import torch
from tqdm import tqdm
from easydict import EasyDict

import chop

from torch.optim import SGD

from torchvision import models

device = torch.device('cuda' if torch.cuda.is_available()
                      else 'cpu')

n_epochs = 100
batch_size = 128
batch_size_test = 100

loaders = chop.data.load_cifar10(train_batch_size=batch_size,
                                 test_batch_size=batch_size_test,
                                 data_dir='~/datasets',
                                 augment_train=True)

trainloader, testloader = loaders.train, loaders.test
n_train = len(trainloader.dataset)
n_test = len(testloader.dataset)

model = models.resnet18(pretrained=False)
model.to(device)

criterion = torch.nn.CrossEntropyLoss()

optimizer = SGD(model.parameters(), lr=.1, momentum=.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Define the perturbation constraint set
max_iter_train = 7
max_iter_test = 20
alpha = 8. / 255
constraint = chop.constraints.LinfBall(alpha)
criterion_adv = torch.nn.CrossEntropyLoss(reduction='none')

print(f"Training on L{constraint.p} ball({alpha}).")


adversary = Adversary(chop.optim.minimize_pgd_madry)

results = EasyDict(train_acc=[], test_acc=[],
                   train_acc_adv=[], test_acc_adv=[],
                   train_adv_loss=[],
                   test_adv_loss=[])

for _ in range(n_epochs):

    # Train
    n_correct = 0
    n_correct_adv = 0

    model.train()

    for k, (data, target) in enumerate(trainloader):
        data = data.to(device)
        target = target.to(device)

        @torch.no_grad()
        def image_constraint_prox(delta, step_size=None):
            """Projects perturbation delta
            so that 0. <= data + delta <= 1."""

            adv_img = torch.clamp(data + delta, 0, 1)
            delta = adv_img - data
            return delta

        @torch.no_grad()
        def prox(delta, step_size=None):
            delta = constraint.prox(delta, step_size)
            delta = image_constraint_prox(delta, step_size)
            return delta

        _, delta = adversary.perturb(data, target, model,
                                     criterion_adv,
                                     prox=prox,
                                     lmo=constraint.lmo,
                                     step=2. / max_iter_train,
                                     max_iter=max_iter_train)

        optimizer.zero_grad()

        output = model(data)
        output_adv = model(data + delta)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()

        pred = torch.argmax(output, dim=-1)
        pred_adv = torch.argmax(output_adv, dim=-1)

        n_correct += (pred == target).sum().item()
        n_correct_adv += (pred_adv == target).sum().item()

    results.train_acc.append(100. * n_correct / n_train)
    results.train_acc_adv.append(100. * n_correct_adv / n_train)
    print(f"Train Accuracy: {results.train_acc[-1] :.1f}%")
    print(f"Train Adv Accuracy: {results.train_acc_adv[-1]:.1f}%")

    # Test
    n_correct = 0
    n_correct_adv = 0

    model.eval()

    for k, (data, target) in enumerate(testloader):
        data = data.to(device)
        target = target.to(device)

        @torch.no_grad()
        def image_constraint_prox(delta, step_size=None):
            """Projects perturbation delta
            so that 0. <= data + delta <= 1."""

            adv_img = torch.clamp(data + delta, 0, 1)
            delta = adv_img - data
            return delta

        @torch.no_grad()
        def prox(delta, step_size=None):
            delta = constraint.prox(delta, step_size)
            delta = image_constraint_prox(delta, step_size)
            return delta

        _, delta = adversary.perturb(data, target, model,
                                        criterion_adv,
                                        prox=prox,
                                        lmo=constraint.lmo,
                                        step=2. / max_iter_test,
                                        max_iter=max_iter_test)

        with torch.no_grad():
            output = model(data)
            output_adv = model(data + delta)

            pred = torch.argmax(output, dim=-1)
            pred_adv = torch.argmax(output_adv, dim=-1)

        n_correct += (pred == target).sum().item()
        n_correct_adv += (pred_adv == target).sum().item()

    results.test_acc.append(100. * n_correct / n_test)
    results.test_acc_adv.append(100. * n_correct_adv / n_test)

    print(f"Test Accuracy: {results.test_acc[-1]:.1f}%")
    print(f"Test Adv Accuracy: {results.test_acc_adv[-1]:.1f}%")


fig, ax = plt.subplots(nrows=2, sharex=True)

ax[0].set_title("Clean data accuracies")
ax[0].plot(results.train_acc, label='Train Acc')
ax[0].plot(results.test_acc, label='Test Acc')
ax[1].set_title("Adversarial data accuracies")
ax[1].plot(results.train_acc_adv, label='Train Acc Adv')
ax[1].plot(results.test_acc_adv, label='Test Acc Adv')
plt.legend()
plt.show()

Total running time of the script: ( 624 minutes 28.935 seconds)

Estimated memory usage: 2425 MB

Gallery generated by Sphinx-Gallery