from __future__ import absolute_import, division, print_function, unicode_literals
import sys
!{sys.executable} -m pip install git+https://github.com/nottombrown/imagenet_stubs
sys.path.append("..")
%matplotlib inline
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
import imagenet_stubs
import numpy as np
import tensorflow.keras
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
import tensorflow.keras.backend as k
from matplotlib import pyplot as plt
from IPython.display import clear_output
from art.estimators.classification import KerasClassifier
from art.attacks.evasion import HopSkipJump
from art.utils import to_categorical
Collecting git+https://github.com/nottombrown/imagenet_stubs
Cloning https://github.com/nottombrown/imagenet_stubs to /private/var/folders/_5/f0k1rpfj41x01f5_19w2vrkc0000gn/T/pip-b1brxr92-build
Requirement already satisfied (use --upgrade to upgrade): imagenet-stubs==0.0.7 from git+https://github.com/nottombrown/imagenet_stubs in /Users/minhtn/ibm/installation/miniconda3/lib/python3.6/site-packages
You are using pip version 9.0.1, however version 20.2.3 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.
Using TensorFlow backend.
mean_imagenet = np.zeros([224, 224, 3])
mean_imagenet[...,0].fill(103.939)
mean_imagenet[...,1].fill(116.779)
mean_imagenet[...,2].fill(123.68)
model = ResNet50(weights='imagenet')
classifier = KerasClassifier(clip_values=(0, 255), model=model, preprocessing=(mean_imagenet, np.ones([224, 224, 3])))
WARNING:tensorflow:From /Users/minhtn/ibm/installation/miniconda3/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers. WARNING:tensorflow:From /Users/minhtn/ibm/installation/miniconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4070: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead. WARNING:tensorflow:From /Users/minhtn/ibm/installation/miniconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.
target_image_name = 'koala.jpg'
init_image_name = 'tractor.jpg'
for image_path in imagenet_stubs.get_image_paths():
if image_path.endswith(target_image_name):
target_image = image.load_img(image_path, target_size=(224, 224))
target_image = image.img_to_array(target_image)
if image_path.endswith(init_image_name):
init_image = image.load_img(image_path, target_size=(224, 224))
init_image = image.img_to_array(init_image)
print("Target image is: ", np.argmax(classifier.predict(np.array([target_image]))[0]))
plt.imshow(target_image.astype(np.uint))
plt.show()
print("Init image is: ", np.argmax(classifier.predict(np.array([init_image]))[0]))
plt.imshow(init_image.astype(np.uint))
plt.show()
Target image is: 105
Init image is: 866
attack = HopSkipJump(classifier=classifier, targeted=False, max_iter=0, max_eval=1000, init_eval=10)
iter_step = 10
x_adv = None
for i in range(20):
x_adv = attack.generate(x=np.array([target_image]), x_adv_init=x_adv, resume=True)
#clear_output()
print("Adversarial image at step %d." % (i * iter_step), "L2 error",
np.linalg.norm(np.reshape(x_adv[0] - target_image, [-1])),
"and class label %d." % np.argmax(classifier.predict(x_adv)[0]))
plt.imshow(x_adv[0].astype(np.uint))
plt.show(block=False)
attack.max_iter = iter_step
HopSkipJump: 100%|██████████| 1/1 [00:00<00:00, 1.22it/s]
Adversarial image at step 0. L2 error 15147.211 and class label 112.
HopSkipJump: 100%|██████████| 1/1 [00:30<00:00, 30.74s/it]
Adversarial image at step 10. L2 error 8259.469 and class label 359.
HopSkipJump: 100%|██████████| 1/1 [00:48<00:00, 48.43s/it]
Adversarial image at step 20. L2 error 6047.2925 and class label 359.
HopSkipJump: 100%|██████████| 1/1 [00:58<00:00, 58.10s/it]
Adversarial image at step 30. L2 error 4446.662 and class label 359.
HopSkipJump: 100%|██████████| 1/1 [01:09<00:00, 69.60s/it]
Adversarial image at step 40. L2 error 3667.595 and class label 359.
HopSkipJump: 100%|██████████| 1/1 [01:18<00:00, 78.53s/it]
Adversarial image at step 50. L2 error 3021.8582 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:20<00:00, 80.93s/it]
Adversarial image at step 60. L2 error 2527.9243 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:33<00:00, 93.12s/it]
Adversarial image at step 70. L2 error 2139.856 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:35<00:00, 95.85s/it]
Adversarial image at step 80. L2 error 1847.0266 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:39<00:00, 99.96s/it]
Adversarial image at step 90. L2 error 1587.3302 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:43<00:00, 103.27s/it]
Adversarial image at step 100. L2 error 1390.9678 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:48<00:00, 108.16s/it]
Adversarial image at step 110. L2 error 1228.7748 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:52<00:00, 112.21s/it]
Adversarial image at step 120. L2 error 1098.0759 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:53<00:00, 113.12s/it]
Adversarial image at step 130. L2 error 1000.0699 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:57<00:00, 117.46s/it]
Adversarial image at step 140. L2 error 912.61176 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:08<00:00, 128.22s/it]
Adversarial image at step 150. L2 error 842.8174 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:18<00:00, 138.04s/it]
Adversarial image at step 160. L2 error 778.2984 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:23<00:00, 143.82s/it]
Adversarial image at step 170. L2 error 725.36945 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:28<00:00, 148.45s/it]
Adversarial image at step 180. L2 error 675.5015 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:31<00:00, 151.85s/it]
Adversarial image at step 190. L2 error 636.62787 and class label 356.
attack = HopSkipJump(classifier=classifier, targeted=False, max_iter=0, max_eval=1000, init_eval=10)
iter_step = 10
x_adv = None
mask = np.random.binomial(n=1, p=0.1, size=np.prod(target_image.shape))
mask = mask.reshape(target_image.shape)
for i in range(20):
x_adv = attack.generate(x=np.array([target_image]), x_adv_init=x_adv, resume=True, mask=mask)
#clear_output()
print("Adversarial image at step %d." % (i * iter_step), "L2 error",
np.linalg.norm(np.reshape(x_adv[0] - target_image, [-1])),
"and class label %d." % np.argmax(classifier.predict(x_adv)[0]))
plt.imshow(x_adv[0].astype(np.uint))
plt.show(block=False)
attack.max_iter = iter_step
HopSkipJump: 100%|██████████| 1/1 [00:01<00:00, 1.06s/it]
Adversarial image at step 0. L2 error 12889.198 and class label 354.
HopSkipJump: 100%|██████████| 1/1 [00:38<00:00, 38.48s/it]
Adversarial image at step 10. L2 error 7398.0615 and class label 359.
HopSkipJump: 100%|██████████| 1/1 [00:54<00:00, 54.56s/it]
Adversarial image at step 20. L2 error 4836.4653 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:05<00:00, 65.75s/it]
Adversarial image at step 30. L2 error 3370.0251 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:19<00:00, 79.80s/it]
Adversarial image at step 40. L2 error 2583.0674 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:34<00:00, 94.31s/it]
Adversarial image at step 50. L2 error 2119.813 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:40<00:00, 100.32s/it]
Adversarial image at step 60. L2 error 1762.1847 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:42<00:00, 102.18s/it]
Adversarial image at step 70. L2 error 1546.3962 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:50<00:00, 110.12s/it]
Adversarial image at step 80. L2 error 1386.7443 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:04<00:00, 124.02s/it]
Adversarial image at step 90. L2 error 1237.5015 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:13<00:00, 133.32s/it]
Adversarial image at step 100. L2 error 1133.0155 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [01:59<00:00, 119.38s/it]
Adversarial image at step 110. L2 error 1043.9689 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:06<00:00, 126.47s/it]
Adversarial image at step 120. L2 error 964.1652 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:12<00:00, 132.32s/it]
Adversarial image at step 130. L2 error 896.62427 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:13<00:00, 133.93s/it]
Adversarial image at step 140. L2 error 843.6781 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:20<00:00, 140.56s/it]
Adversarial image at step 150. L2 error 797.6606 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:24<00:00, 144.22s/it]
Adversarial image at step 160. L2 error 753.5747 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:28<00:00, 148.94s/it]
Adversarial image at step 170. L2 error 716.2186 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:36<00:00, 156.52s/it]
Adversarial image at step 180. L2 error 680.78467 and class label 356.
HopSkipJump: 100%|██████████| 1/1 [02:34<00:00, 154.01s/it]
Adversarial image at step 190. L2 error 653.2177 and class label 356.
attack = HopSkipJump(classifier=classifier, targeted=True, max_iter=0, max_eval=1000, init_eval=10)
iter_step = 10
x_adv = np.array([init_image])
for i in range(20):
x_adv = attack.generate(x=np.array([target_image]), y=to_categorical([866], 1000), x_adv_init=x_adv, resume=True)
#clear_output()
print("Adversarial image at step %d." % (i * iter_step), "L2 error",
np.linalg.norm(np.reshape(x_adv[0] - target_image, [-1])),
"and class label %d." % np.argmax(classifier.predict(x_adv)[0]))
plt.imshow(x_adv[0].astype(np.uint))
plt.show(block=False)
attack.max_iter = iter_step
HopSkipJump: 100%|██████████| 1/1 [00:00<00:00, 1834.78it/s]
Adversarial image at step 0. L2 error 44399.297 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [00:38<00:00, 38.73s/it]
Adversarial image at step 10. L2 error 15533.123 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [00:52<00:00, 52.57s/it]
Adversarial image at step 20. L2 error 13701.819 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:00<00:00, 60.65s/it]
Adversarial image at step 30. L2 error 11656.512 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:10<00:00, 70.40s/it]
Adversarial image at step 40. L2 error 10472.013 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:17<00:00, 77.89s/it]
Adversarial image at step 50. L2 error 8748.627 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:25<00:00, 85.27s/it]
Adversarial image at step 60. L2 error 7440.469 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:30<00:00, 90.55s/it]
Adversarial image at step 70. L2 error 6323.5347 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:36<00:00, 96.99s/it]
Adversarial image at step 80. L2 error 5699.7773 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:41<00:00, 101.47s/it]
Adversarial image at step 90. L2 error 4760.8467 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:47<00:00, 107.23s/it]
Adversarial image at step 100. L2 error 4279.1885 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:52<00:00, 112.58s/it]
Adversarial image at step 110. L2 error 3736.5771 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:51<00:00, 111.61s/it]
Adversarial image at step 120. L2 error 3415.3271 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:58<00:00, 118.21s/it]
Adversarial image at step 130. L2 error 3092.7314 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [02:04<00:00, 124.78s/it]
Adversarial image at step 140. L2 error 2844.5347 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [02:10<00:00, 130.12s/it]
Adversarial image at step 150. L2 error 2662.4702 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [02:15<00:00, 135.94s/it]
Adversarial image at step 160. L2 error 2455.6245 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [02:16<00:00, 136.99s/it]
Adversarial image at step 170. L2 error 2294.5833 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [02:24<00:00, 144.74s/it]
Adversarial image at step 180. L2 error 2154.4878 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [02:31<00:00, 151.74s/it]
Adversarial image at step 190. L2 error 2001.887 and class label 866.
attack = HopSkipJump(classifier=classifier, targeted=True, max_iter=0, max_eval=1000, init_eval=10)
iter_step = 10
x_adv = np.array([init_image])
mask = np.random.binomial(n=1, p=0.9, size=np.prod(target_image.shape))
mask = mask.reshape(target_image.shape)
for i in range(20):
x_adv = attack.generate(
x=np.array([target_image]), y=to_categorical([866], 1000), x_adv_init=x_adv, resume=True, mask=mask
)
#clear_output()
print("Adversarial image at step %d." % (i * iter_step), "L2 error",
np.linalg.norm(np.reshape(x_adv[0] - target_image, [-1])),
"and class label %d." % np.argmax(classifier.predict(x_adv)[0]))
plt.imshow(x_adv[0].astype(np.uint))
plt.show(block=False)
attack.max_iter = iter_step
HopSkipJump: 100%|██████████| 1/1 [00:00<00:00, 1114.91it/s]
Adversarial image at step 0. L2 error 42160.312 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [00:40<00:00, 40.37s/it]
Adversarial image at step 10. L2 error 20118.164 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [00:56<00:00, 56.42s/it]
Adversarial image at step 20. L2 error 17655.967 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:06<00:00, 66.46s/it]
Adversarial image at step 30. L2 error 15877.944 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:15<00:00, 75.47s/it]
Adversarial image at step 40. L2 error 13526.768 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:20<00:00, 80.74s/it]
Adversarial image at step 50. L2 error 11542.905 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:27<00:00, 87.36s/it]
Adversarial image at step 60. L2 error 8880.384 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:33<00:00, 93.60s/it]
Adversarial image at step 70. L2 error 7306.2417 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:40<00:00, 100.22s/it]
Adversarial image at step 80. L2 error 6212.25 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:46<00:00, 106.80s/it]
Adversarial image at step 90. L2 error 5457.6064 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:52<00:00, 112.10s/it]
Adversarial image at step 100. L2 error 4831.744 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:56<00:00, 116.91s/it]
Adversarial image at step 110. L2 error 4265.8936 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:58<00:00, 118.64s/it]
Adversarial image at step 120. L2 error 3713.8616 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [01:59<00:00, 119.58s/it]
Adversarial image at step 130. L2 error 3174.1445 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [02:01<00:00, 121.41s/it]
Adversarial image at step 140. L2 error 2791.7537 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [02:09<00:00, 129.83s/it]
Adversarial image at step 150. L2 error 2502.5105 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [02:12<00:00, 132.56s/it]
Adversarial image at step 160. L2 error 2268.6338 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [02:16<00:00, 136.67s/it]
Adversarial image at step 170. L2 error 2045.2914 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [02:23<00:00, 143.29s/it]
Adversarial image at step 180. L2 error 1866.911 and class label 866.
HopSkipJump: 100%|██████████| 1/1 [02:16<00:00, 136.50s/it]
Adversarial image at step 190. L2 error 1717.2174 and class label 866.
HopSkipJump attack supports inputs of unsquared images. The code in the following cell describes an example of creating a Resnet50-based classifier to attack unsquared images.
# Adjust image shape here
image_shape = (224, 150)
mean_imagenet = np.zeros(tuple(list(image_shape) + [3]))
mean_imagenet[...,0].fill(103.939)
mean_imagenet[...,1].fill(116.779)
mean_imagenet[...,2].fill(123.68)
model = ResNet50(weights='imagenet', input_shape=tuple(list(image_shape) + [3]), include_top=False)
def _kr_initialize(_, dtype=None):
return k.variable(value=np.random.randn(np.prod(list(model.output.shape)[1:]).value, 1000))
head = model.output
head = Flatten()(head)
head = Dense(1000, kernel_initializer=_kr_initialize, bias_initializer=keras.initializers.Zeros())(head)
new_model = Model(inputs=model.input, outputs=head)
classifier = KerasClassifier(clip_values=(0, 255), model=new_model, preprocessing=(mean_imagenet, 1))
# Then call classifier.fit() to train the new weights