This notebook implements the evaluation of Tramer, Carlini, Brendel and Madry (2020) using ART and focuses on section 12 evaluating "EMPIR: Ensembles of Mixed Precision Deep Networks for Increased Robustness against Adversarial Attacks".
This notebook uses code from Sen et al. (2020) at : https://github.com/sancharisen/EMPIR
Before running this notebook you need to download the CIFAR-10 EMPIR models from https://github.com/sancharisen/EMPIR. into the local directory containing this notebook and save the 3 models into directories named ./CIFARconv/Model1
, ./CIFARconv/Model2
, and ./CIFARconv/Model3
.
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import sys
import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import keras
from keras.datasets import cifar10
from keras.utils import np_utils
import numpy as np
from art.estimators.classification import TensorFlowClassifier
from art.attacks.evasion import ProjectedGradientDescent
Using TensorFlow backend.
%%bash
if ! [[ -d "./EMPIR" ]]
then
git clone git@github.com:sancharisen/EMPIR.git
fi
touch ./EMPIR/__init__.py
sys.path.append("./")
sys.path.append("./EMPIR")
from EMPIR.cleverhans.utils_tf import model_eval_ensemble
sess = tf.Session()
keras.backend.set_session(sess)
tf.set_random_seed(1234)
# CIFAR10-specific dimensions
img_rows = 32
img_cols = 32
channels = 3
nb_classes = 10
# Model specifications
nb_filters = 32
batch_size = 128
nb_samples = 10000
abits=2
wbits=4
abits2=2
wbits2=2
model_path1 = './CIFARconv/Model1'
model_path2 = './CIFARconv/Model2'
model_path3 = './CIFARconv/Model3'
# Scaling input to softmax
INIT_T = 1.0
def data_cifar10():
"""
Preprocess CIFAR10 dataset
:return:
"""
# These values are specific to CIFAR10
img_rows = 32
img_cols = 32
nb_classes = 10
# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 3)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 3)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)
return X_train, Y_train, X_test, Y_test
# Get CIFAR10 test data
X_train, Y_train, X_test, Y_test = data_cifar10()
assert Y_train.shape[1] == 10.
label_smooth = .1
Y_train = Y_train.clip(label_smooth / 9., 1. - label_smooth)
X_train shape: (50000, 32, 32, 3) 50000 train samples 10000 test samples
# Create placeholders
x = tf.placeholder(tf.float32, shape=(None, img_rows, img_cols, channels))
y = tf.placeholder(tf.float32, shape=(None, 10))
phase = tf.placeholder(tf.bool, name="phase")
logits_scalar = tf.placeholder_with_default(INIT_T, shape=(), name="logits_temperature")
%%capture
from EMPIR.cleverhans_tutorials.tutorial_models import make_ensemble_three_cifar_cnn
model = make_ensemble_three_cifar_cnn(phase, logits_scalar, 'lp1_', 'lp2_', 'fp_', wbits, abits, wbits2,
abits2, input_shape=(None, img_rows, img_cols, channels),
nb_filters=nb_filters)
%%capture
preds_index = model.ensemble_call(x, reuse=False)
preds_one_hot = tf.one_hot(preds_index, depth=nb_classes, on_value=None, off_value=None, axis=None,
dtype=None, name=None)
preds_prob = model.get_probs(x)
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
stored_variables = ['lp_conv1_init/k', 'lp_conv2_init/k', 'lp_conv3_init/k', 'lp_ip1init/W',
'lp_logits_init/W']
variable_dict = dict(zip(stored_variables, variables[:5]))
# Restore the first set of variables from model_path1
saver = tf.train.Saver(variable_dict)
saver.restore(sess, tf.train.latest_checkpoint(model_path1))
# Restore the second set of variables from model_path2
variable_dict = dict(zip(stored_variables, variables[5:10]))
saver2 = tf.train.Saver(variable_dict)
saver2.restore(sess, tf.train.latest_checkpoint(model_path2))
stored_variables = ['fp_conv1_init/k', 'fp_conv2_init/k', 'fp_conv3_init/k', 'fp_ip1init/W',
'fp_logits_init/W']
variable_dict = dict(zip(stored_variables, variables[10:]))
saver3 = tf.train.Saver(variable_dict)
saver3.restore(sess, tf.train.latest_checkpoint(model_path3))
# Evaluate the accuracy of the CIFAR10 model on legitimate test examples
eval_params = {'batch_size': batch_size}
accuracy = model_eval_ensemble(sess, x, y, preds_index, X_test, Y_test, phase=phase, args=eval_params)
print('Test accuracy on legitimate test examples: {0}'.format(accuracy))
Test accuracy on legitimate test examples: 0.7256
def get_accuracy(X, Y, batch_size, predictions):
sum_correct = 0
sum_samples = 0
with sess.as_default():
nb_batches = int(X.shape[0] / batch_size)
for i_batch in range(nb_batches):
i_start = i_batch * batch_size
i_end = i_start + batch_size
if i_end <= X.shape[0]:
feed_dict = {x: X[i_start:i_end],
phase: False}
y_pred = sess.run(predictions, feed_dict=feed_dict)
sum_correct += np.sum(np.argmax(Y[i_start:i_end], axis=1) == np.argmax(y_pred, axis=1))
sum_samples += batch_size
accuracy = sum_correct / sum_samples
return accuracy
accuracy_test_benign = get_accuracy(X=X_test, Y=Y_test, batch_size=batch_size, predictions=preds_one_hot)
print('The accuracy on benign test samples: {0:.2f}%'.format(accuracy_test_benign *100))
The accuracy on benign test samples: 72.57%
This is in agreement with the Unperturbed Accuracy of 72.56% reported by Sen et al. (2020).
loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_true=y, y_pred=preds_prob, from_logits=False,
label_smoothing=0))
feed_dict = {phase: False}
classifier_empir = TensorFlowClassifier(input_ph=x,
output=preds_prob,
labels_ph=y,
train=None,
loss=loss,
learning=phase,
sess=sess,
channels_first=False,
clip_values=(0, 1),
preprocessing=(0, 1),
feed_dict=feed_dict)
attack_empir = ProjectedGradientDescent(classifier=classifier_empir,
norm=np.inf,
eps=0.1,
eps_step=0.01,
max_iter=40,
targeted=False,
num_random_init=1,
batch_size=batch_size)
X_test_adv = attack_empir.generate(X_test[:nb_samples], Y_test[:nb_samples])
accuracy_test_adv = get_accuracy(X=X_test_adv, Y=Y_test, batch_size=batch_size, predictions=preds_one_hot)
print('Accuracy on adversarial test examples created by ART using EMPIR\'s loss: '
'{0:.2f}%'.format(accuracy_test_adv * 100))
Accuracy on adversarial test examples created by ART using EMPIR's loss: 11.57%
Accuracy on adversarial test examples created by ART using EMPIR's loss is in agreement with the accuracy of 13.55% reported by Sen et al. (2020).
%%capture
feed_dict = {x: X_test[0:12], phase: False}
x_1 = x
for layer in model.layers1:
x_1 = layer.fprop(x_1, reuse=False)
assert x_1 is not None
preds_prob_1 = x_1
x_2 = x
for layer in model.layers2:
x_2 = layer.fprop(x_2, reuse=False)
assert x_2 is not None
preds_prob_2 = x_2
x_3 = x
for layer in model.layers3:
x_3 = layer.fprop(x_3, reuse=False)
assert x_3 is not None
preds_prob_3 = x_3
%%capture
preds_prob_new = (preds_prob_1 + preds_prob_2 + preds_prob_3) / 3
loss_new = tf.keras.losses.categorical_crossentropy(y_true=y, y_pred=preds_prob_new, from_logits=False,
label_smoothing=0)
feed_dict = {phase: False}
classifier_eval = TensorFlowClassifier(input_ph=x,
output=preds_prob_new,
labels_ph=y,
train=None,
loss=loss_new,
learning=phase,
sess=sess,
channels_first=False,
clip_values=(0, 1),
preprocessing=(0, 1),
feed_dict=feed_dict)
attack_eval = ProjectedGradientDescent(classifier=classifier_eval,
norm=np.inf,
eps=0.031,
eps_step=0.0078,
max_iter=100,
targeted=False,
num_random_init=1,
batch_size=batch_size)
X_test_adv_final = attack_eval.generate(X_test[:nb_samples], Y_test[:nb_samples])
accuracy_test_adv_final = get_accuracy(X=X_test_adv_final, Y=Y_test, batch_size=batch_size,
predictions=preds_one_hot)
print('Accuracy on adversarial test examples created by ART using the loss by Tramer et al. (2020): '
'{0:.2f}%.'.format(accuracy_test_adv_final * 100))
Accuracy on adversarial test examples created by ART using the loss by Tramer et al. (2020): 1.41%.
This is in agreement with the accuracy of 1.5% reported by Tramer et al. (2020).
# Get accuracy on benign test samples for each model separately
accuracy_test_benign_1 = get_accuracy(X=X_test, Y=Y_test, batch_size=batch_size, predictions=preds_prob_1)
print('Model 1 - Accuracy on benign test samples: {0:.2f}%.'.format(accuracy_test_benign_1 * 100))
accuracy_test_benign_2 = get_accuracy(X=X_test, Y=Y_test, batch_size=batch_size, predictions=preds_prob_2)
print('Model 2 - Accuracy on benign test samples: {0:.2f}%.'.format(accuracy_test_benign_2 * 100))
accuracy_test_benign_3 = get_accuracy(X=X_test, Y=Y_test, batch_size=batch_size, predictions=preds_prob_3)
print('Model 3 - Accuracy on benign test samples: {0:.2f}%.'.format(accuracy_test_benign_3 * 100))
Model 1 - Accuracy on benign test samples: 64.55%. Model 2 - Accuracy on benign test samples: 61.80%. Model 3 - Accuracy on benign test samples: 74.54%.
# Get accuracy on adversarial test examples for each model separately
for i_pred, preds_prob_i in enumerate([preds_prob_1, preds_prob_2, preds_prob_3]):
loss_i = tf.keras.losses.categorical_crossentropy(y_true=y, y_pred=preds_prob_i, from_logits=False,
label_smoothing=0)
classifier_eval_i = TensorFlowClassifier(input_ph=x,
output=preds_prob_i,
labels_ph=y,
train=None,
loss=loss_i,
learning=phase,
sess=sess,
channels_first=False,
clip_values=(0, 1),
preprocessing=(0, 1),
feed_dict=feed_dict)
attack_eval_i = ProjectedGradientDescent(classifier=classifier_eval_i,
norm=np.inf,
eps=0.031,
eps_step=0.0078,
max_iter=100,
targeted=False,
num_random_init=1,
batch_size=batch_size)
X_test_adv_i = attack_eval_i.generate(X_test[:nb_samples], Y_test[:nb_samples])
accuracy_test_adv_i = get_accuracy(X=X_test_adv_i, Y=Y_test, batch_size=batch_size,
predictions=preds_prob_i)
print('Model {0} - Accuracy on adversarial test examples: {1:.2f}%.'.format(i_pred + 1,
accuracy_test_adv_i * 100))
Model 1 - Accuracy on adversarial test examples: 0.98%. Model 2 - Accuracy on adversarial test examples: 1.06%. Model 3 - Accuracy on adversarial test examples: 0.02%.