layout: post title: "Unbalanced Validation Data Losses" description: "How does having an unbalanced dataset affect your loss function?" feature-img: "assets/img/rainbow.jpg" thumbnail: "assets/img/anzac_hill_lights.jpg" tags: [Computer Vision, Data Augmentation, Deep Learning, Python]
Most people would say you should use a validation set (and test set) that is representative of the real world data. In many cases, I think this is a good idea. However, we have to be careful when doing this.
Let's look at this example, image you're doing image classification for a disease that 1 in 100 members of the population has.
So you create a train, val, and test set to train your model and evaluate your results. There are tradeoffs for how to makeup the train set, but this post is going to focus on the validation set.
Let's start with it being representative, as that is the commonly provider guidance. Also, this will let you know how well your model will do in the real world.
from tensorflow.keras.losses import BinaryCrossentropy
import tensorflow_addons as tfa
import numpy as np
import tensorflow as tf
from mlflow import log_metric, log_param, log_artifacts
num_0s = 1000000
num_1s = 10000
Let's use binary cross entropy, as that's the most common loss function in these cases.
bce = BinaryCrossentropy()
def calc_precision(tp, fp):
return (tp + np.finfo(np.float64).eps) / (tp + fp + np.finfo(np.float64).eps)
def calc_recall(tp, fn):
return (tp + np.finfo(np.float64).eps) / (tp + fn + np.finfo(np.float64).eps)
def calc_f1(pre, rec):
return 2 * (pre * rec) / (pre + rec)
def get_pred_0s(num_0s, acc_0s):
y_true_0s = np.zeros(num_0s)
max_rand_0s = (1 - 0.5) / acc_0s
y_pred_0s = np.random.uniform(low=0, high=max_rand_0s, size=num_0s)
y_label_pred_0s = np.round(y_pred_0s)
# num_correct = len(y_label_preds_0s) - sum(y_label_preds_0s)
return y_true_0s, y_pred_0s, y_label_pred_0s
def get_pred_1s(num_1s, acc_1s):
y_true_1s = np.ones(num_1s)
min_rand_1s = 1 - (1 - 0.5) / acc_1s
y_pred_1s = np.random.uniform(low=min_rand_1s, high=1, size=num_1s)
y_label_pred_1s = np.round(y_pred_1s)
# num_correct = sum(y_label_preds_1s)
return y_true_1s, y_pred_1s, y_label_pred_1s
def calc_values(num_0s, acc_0s, num_1s, acc_1s, loss_func):
y_true_0s, y_pred_0s, y_label_pred_0s = get_pred_0s(num_0s, acc_0s)
y_true_1s, y_pred_1s, y_label_pred_1s = get_pred_1s(num_1s, acc_1s)
y_true = np.concatenate((y_true_0s, y_true_1s))
y_pred = np.concatenate((y_pred_0s, y_pred_1s))
y_label_preds = np.concatenate((y_label_pred_0s, y_label_pred_1s))
loss = loss_func(y_true.astype(np.float32), y_pred.astype(np.float32)).numpy()
tp = sum(y_label_pred_1s == y_true_1s)
fp = sum(y_label_pred_0s - y_true_0s == 1)
fn = sum(y_true_1s - y_label_pred_1s)
pre = calc_precision(tp, fp)
recall = calc_recall(tp, fn)
f1 = calc_f1(pre, recall)
return loss, pre, recall, f1
def print_metrics(loss, pre, recall, f1):
print('Loss: ', loss)
print('Precision: ', pre)
print("Recall: ", recall)
print("F1 Score: ", f1)
acc_0s = 0.999
acc_1s = 0.5
l, p, r, f = calc_values(num_0s, acc_0s, num_1s, acc_1s, bce)
print_metrics(l,p,r,f)
Loss: 0.313947 Precision: 0.8396806918343589 Recall: 0.5049 F1 Score: 0.6306126272403673
acc_0s = 0.98
acc_1s = 0.9
l, p, r, f = calc_values(num_0s, acc_0s, num_1s, acc_1s, bce)
print_metrics(l,p,r,f)
Loss: 0.3151517 Precision: 0.30972900010383136 Recall: 0.8949 F1 Score: 0.4601856375183195
Which model is better? Well, for an initial screen, you would probably be better off with a high-recall model. However, if you're just following the loss, this would push you away from that model and towards the model with much lower recall. This isn't what you want at all.
Now let's use a balanced split in the data.
num_0s = 10000
num_1s = 10000
acc_0s = 0.999
acc_1s = 0.5
l, p, r, f = calc_values(num_0s, acc_0s, num_1s, acc_1s, bce)
print_metrics(l,p,r,f)
Loss: 0.65621394 Precision: 0.9976128903918838 Recall: 0.5015 F1 Score: 0.6674652292540094
acc_0s = 0.98
acc_1s = 0.9
l, p, r, f = calc_values(num_0s, acc_0s, num_1s, acc_1s, bce)
print_metrics(l,p,r,f)
Loss: 0.33528897 Precision: 0.9771986970684039 Recall: 0.9 F1 Score: 0.9370119729307653
Now, the loss is much more in-line with what you would want.
Focal loss, which was introduced in the RetinaNet paper, was specifically designed for situations with unbalanced data. Let's try substituting this in.
sfce = tfa.losses.SigmoidFocalCrossEntropy()
num_0s = 1000000
num_1s = 10000
acc_0s = 0.999
acc_1s = 0.5
l, p, r, f = calc_values(num_0s, acc_0s, num_1s, acc_1s, sfce)
print_metrics(l,p,r,f)
Loss: 31629.947 Precision: 0.835135584761271 Recall: 0.502 F1 Score: 0.6270688901380302
acc_0s = 0.98
acc_1s = 0.9
l, p, r, f = calc_values(num_0s, acc_0s, num_1s, acc_1s, sfce)
print_metrics(l,p,r,f)
Loss: 32311.969 Precision: 0.31202799140857757 Recall: 0.9007 F1 Score: 0.46348993979313535
sfce
<tensorflow_addons.losses.focal_loss.SigmoidFocalCrossEntropy at 0x1eea6159250>
Result: Using a focal loss might help to mitigate some of the differences.