The basic intuition behind the attack is that samples used in training will be farther away from the decision boundary than non-training data.
The attacker uses various means of perturbations to measure the amount of noise that is needed to "change the classifier's mind" about their prediction for a given sample. Since the ML model is more confident on training data, the attacker will need to perturb the input more to force the model to misclassify. Thus, the amount of perturbation needed will be analogue to the sample's distance from the decision boundary. Both of the below listed attacks use an adversarial perturbation technique called HopSkipJump.
Given some estimate of a sample's distance from the model's decision boundary, the attacker compares it to a threshold $\tau$. Any distance greater than $\tau$ will cause the sample to be classified as a training sample.
There are two ways to learn the distance threshold $\tau$:
In this scenario, the attacker needs to know about a subset of the data if it had been used in training or not. It uses this data to calculate their distances to the decision boundary, and sets $\tau$ such that it maximizes membership inference accuracy. Misclassified samples will be regarded as non-training samples.
Here the attacker generates random data, and uses the same perturbation techniques as before to measure their distance from the decision threshold. In the end, the attacker chooses a suitable top t percentile over these distances to calibrate $\tau$.
import torch
from torch import nn
import numpy as np
from art.utils import load_mnist
# data
(x_train, y_train), (x_test, y_test), _min, _max = load_mnist(raw=True)
x_train = np.expand_dims(x_train, axis=1).astype(np.float32)
x_test = np.expand_dims(x_test, axis=1).astype(np.float32)
# model
model = nn.Sequential(
nn.Conv2d(1, 16, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, 4, stride=2, padding=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(32*7*7,100),
nn.ReLU(),
nn.Linear(100, 10)
)
import torch.optim as optim
from art.estimators.classification.pytorch import PyTorchClassifier
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
art_model = PyTorchClassifier(model=model, loss=criterion, optimizer=optimizer, channels_first=True, input_shape=(1,28,28,), nb_classes=10, clip_values=(_min,_max))
art_model.fit(x_train, y_train, nb_epochs=10, batch_size=128)
pred = np.array([np.argmax(arr) for arr in art_model.predict(x_test)])
print('Base model accuracy: ', np.sum(pred == y_test) / len(y_test))
Base model accuracy: 0.9801
from art.attacks.inference.membership_inference import LabelOnlyDecisionBoundary
mia_label_only = LabelOnlyDecisionBoundary(art_model)
# number of samples used to calibrate distance threshold
attack_train_size = 1500
attack_test_size = 1500
x = np.concatenate([x_train, x_test])
y = np.concatenate([y_train, y_test])
training_sample = np.array([1] * len(x_train) + [0] * len(x_test))
mia_label_only.calibrate_distance_threshold(x_train[:attack_train_size], y_train[:attack_train_size],
x_test[:attack_test_size], y_test[:attack_test_size])
HopSkipJump: 0%| | 0/1500 [00:00<?, ?it/s]
HopSkipJump: 0%| | 0/1500 [00:00<?, ?it/s]
from numpy.random import choice
# evaluation data
n = 500
eval_data_idx = choice(len(x), n)
x_eval, y_eval = x[eval_data_idx], y[eval_data_idx]
eval_label = training_sample[eval_data_idx]
pred_label = mia_label_only.infer(x_eval, y_eval)
HopSkipJump: 0%| | 0/500 [00:00<?, ?it/s]
from sklearn.metrics import accuracy_score
print("Accuracy: %f" % accuracy_score(eval_label, pred_label))
Accuracy: 0.656000
mia_label_only_unsupervised = LabelOnlyDecisionBoundary(art_model)
# calibrate distance threshold in an UNSUPERVISED way, without data
mia_label_only_unsupervised.calibrate_distance_threshold_unsupervised(top_t=50, num_samples=500, max_queries=2, verbose=True, batch_size=256)
HopSkipJump: 0%| | 0/500 [00:00<?, ?it/s]
pred_label_unsupervised = mia_label_only_unsupervised.infer(x_eval, y_eval)
HopSkipJump: 0%| | 0/500 [00:00<?, ?it/s]
print("Accuracy: %f" % accuracy_score(eval_label, pred_label_unsupervised))
Accuracy: 0.868000
As we can see, one does not need any data, their correct label or knowledge about their membership to perform a successful membership inference attack.
The attacker needs only the observed output labels of the model.