In this demo we will demonstrate how a set of existing samples and their predicted labels can be used in a black-box attack against a model which we no longer have access to. This will be demonstrated on the Nursery dataset (original dataset can be found here: https://archive.ics.uci.edu/ml/datasets/nursery).
We have already preprocessed the dataset such that all categorical features are one-hot encoded, and the data was scaled using sklearn's StandardScaler.
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
from art.utils import load_nursery
(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5)
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from art.estimators.classification.pytorch import PyTorchClassifier
class ModelToAttack(nn.Module):
def __init__(self, num_classes, num_features):
super(ModelToAttack, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(num_features, 1024),
nn.Tanh(), )
self.fc2 = nn.Sequential(
nn.Linear(1024, 512),
nn.Tanh(), )
self.classifier = nn.Linear(512, num_classes)
# self.softmax = nn.Softmax(dim=1)
def forward(self, x):
out = self.fc1(x)
out = self.fc2(out)
return self.classifier(out)
mlp_model = ModelToAttack(4, 24)
mlp_model = torch.nn.DataParallel(mlp_model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mlp_model.parameters(), lr=0.0001)
class NurseryDataset(Dataset):
def __init__(self, x, y=None):
self.x = torch.from_numpy(x.astype(np.float64)).type(torch.FloatTensor)
if y is not None:
self.y = torch.from_numpy(y.astype(np.int8)).type(torch.LongTensor)
else:
self.y = torch.zeros(x.shape[0])
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
if idx >= len(self.x):
raise IndexError("Invalid Index")
return self.x[idx], self.y[idx]
train_set = NurseryDataset(x_train, y_train)
train_loader = DataLoader(train_set, batch_size=100, shuffle=True, num_workers=0)
for epoch in range(20):
for (input, targets) in train_loader:
input, targets = torch.autograd.Variable(input), torch.autograd.Variable(targets)
optimizer.zero_grad()
outputs = mlp_model(input)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
mlp_art_model = PyTorchClassifier(model=mlp_model, loss=criterion, optimizer=optimizer, input_shape=(24,), nb_classes=4)
pred = np.array([np.argmax(arr) for arr in mlp_art_model.predict(x_test.astype(np.float32))])
print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test))
Base model accuracy: 0.9720592775548008
attack_train_ratio = 0.5
attack_member_size = int(len(x_train) * attack_train_ratio)
attack_nonmember_size = int(len(x_test) * attack_train_ratio)
# For training the attack model
attack_x_member = x_train[:attack_member_size].astype(np.float32)
attack_x_nonmember = x_test[:attack_nonmember_size].astype(np.float32)
predicted_y_member = mlp_art_model.predict(attack_x_member)
predicted_y_nonmember = mlp_art_model.predict(attack_x_nonmember)
# For testing the attack model
attack_x_member_test = x_train[attack_member_size:].astype(np.float32)
attack_x_nonmember_test = x_train[attack_nonmember_size:].astype(np.float32)
predicted_y_member_test = mlp_art_model.predict(attack_x_member_test)
predicted_y_nonmember_test = mlp_art_model.predict(attack_x_nonmember_test)
from art.estimators.classification import BlackBoxClassifier
existing_samples = np.vstack((attack_x_member, attack_x_nonmember, attack_x_member_test, attack_x_nonmember_test))
existing_predictions = np.vstack((predicted_y_member, predicted_y_nonmember, predicted_y_member_test, predicted_y_nonmember_test))
classifier = BlackBoxClassifier((existing_samples, existing_predictions), x_train[0].shape, 4)
We no longer need access to the model, and the attack is running using the set of predictions we have made earlier.
from art.attacks.inference.membership_inference import MembershipInferenceBlackBox
bb_attack = MembershipInferenceBlackBox(classifier, attack_model_type='rf')
# train attack model
bb_attack.fit(attack_x_member, y_train[:attack_member_size], attack_x_nonmember, y_test[:attack_nonmember_size])
# infer
inferred_member_bb = bb_attack.infer(attack_x_member_test, y_train[attack_member_size:])
inferred_nonmember_bb = bb_attack.infer(attack_x_nonmember_test, y_test[attack_nonmember_size:])
# check accuracy
member_acc = np.sum(inferred_member_bb) / len(inferred_member_bb)
nonmember_acc = 1 - (np.sum(inferred_nonmember_bb) / len(inferred_nonmember_bb))
acc = (member_acc * len(inferred_member_bb) + nonmember_acc * len(inferred_nonmember_bb)) / (len(inferred_member_bb) + len(inferred_nonmember_bb))
print("Member Accuracy", member_acc)
print("Non-Member Accuracy", nonmember_acc)
print("Accuracy", acc)
Member Accuracy 0.4949058351343007 Non-Member Accuracy 0.7063908613769683 Accuracy 0.6006483482556345