Practical session 5: Out-of-distribution detection¶
You can open this practical session in colab here : https://colab.research.google.com/drive/1CFM5sCZDkzK7BKBCH8-U2pbNORkfvBH1?usp=sharing
or download the notebook directly here.
This last lab session will focus on 2 examples where good uncertainty estimation is crucial : failure prediction and out-of-distribution detection.
Goal: Take hand on applying uncertainty estimation for out-of-distribution detection
All Imports and Useful Functions¶
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score
from IPython import display
from tqdm.notebook import tqdm
from matplotlib.pyplot import imread
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets, transforms
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {}
#kwargs = {'num_workers': 10, 'pin_memory': True} if use_cuda else {}
def plot_predicted_images(selected_idx, images, pred, labels, uncertainties, hists, mc_samples):
""" Plot predicted images along with mean-pred probabilities histogram, maxprob frequencies and some class histograms
across sampliong
Args:
selected_ix: (array) chosen index in the uncertainties tensor
images: (tensor) images from the test set
pred: (tensor) class predictions by the model
labels: (tensor) true labels of the given dataset
uncertainties: (tensor) uncertainty estimates of the given dataset
errors: (tensor) 0/1 vector whether the model wrongly predicted a sample
hists : (array) number of occurences by class in each sample fo the given dataset, only with MCDropout
mc_samples: (tensor) prediction matrix for s=100 samples, only with MCDropout
Returns:
None
"""
plt.figure(figsize=(15,10))
for i, idx in enumerate(selected_idx):
# Plot original image
plt.subplot(5,6,1+6*i)
plt.axis('off')
plt.title(f'var-ratio={uncertainties[idx]:.3f}, \n gt={labels[idx]}, pred={pred[idx]}')
plt.imshow(images[idx], cmap='gray')
# Plot mean probabilities
plt.subplot(5,6,1+6*i+1)
plt.title('Mean probs')
plt.bar(range(10), mc_samples[idx].mean(0))
plt.xticks(range(10))
# Plot frequencies
plt.subplot(5,6,1+6*i+2)
plt.title('Maxprob frequencies')
plt.bar(range(10), hists[idx])
plt.xticks(range(10))
# Plot probs frequency for specific class
list_plotprobs = [hists[idx].argsort()[-1], hists[idx].argsort()[-2], hists[idx].argsort()[-4]]
ymax = max([max(np.histogram(mc_samples[idx][:,c])[0]) for c in list_plotprobs])
for j, c in enumerate(list_plotprobs):
plt.subplot(5,6,1+6*i+(3+j))
plt.title(f'Samples probs of class {c}')
plt.hist(mc_samples[idx][:,c], bins=np.arange(0,1.1,0.1))
plt.ylim(0,np.ceil(ymax/10)*10)
plt.xticks(np.arange(0,1,0.1), rotation=60)
plt.tight_layout()
plt.show()
Out-of-distribution detection¶
Modern neural networks are known to generalize well when the training and testing data are sampled from the same distribution. However, when deploying neural networks in real-world applications, there is often very little control over the testing data distribution. It is important for classifiers to be aware of uncertainty when shown new kinds of inputs, i.e., out-of-distribution examples. Therefore, being able to accurately detect out-of-distribution examples can be practically important for visual recognition tasks.
In this section, we will use Kuzushiji-MNIST, a drop-in replacement for the MNIST dataset (28x28 grayscale, 70,000 images) containing 3832 Kanji (japanese) characters, as out-of-distribution sample to our model trained on MNIST. We will compare the methods for uncertainty estimates used previously and ODIN.
# Load KMNIST dataset
kmnist_test_dataset = datasets.KMNIST('data', train=False, download=True, transform=transform)
kmnist_test_loader = DataLoader(kmnist_test_dataset, batch_size=128)
# Visualize some images
images, labels = next(iter(kmnist_test_loader))
fig, axes = plt.subplots(nrows=4, ncols=4)
for i, (image, label) in enumerate(zip(images, labels)):
if i >= 16:
break
axes[i // 4][i % 4].imshow(images[i][0], cmap='gray')
axes[i // 4][i % 4].set_title(f"{kmnist_test_dataset.classes[label]}")
axes[i // 4][i % 4].set_xticks([])
axes[i // 4][i % 4].set_yticks([])
fig.set_size_inches(4, 4)
fig.tight_layout()
We compute the precision, recall and AUPR metric for OOD detection with MCP and MCDrpoout with mutual information
# Compute predictions for MCP method on MNIST
_, _, uncertainty_mcp, errors_mcp, _, _ = predict_test_set(lenet, test_loader, mode='mcp')
# Same on KMNIST
_, _, uncertainty_kmnist, errors_kmnist, _, _ = predict_test_set(lenet, kmnist_test_loader, mode='mcp')
# Concatenating predictions with MNIST, considering KMNIST samples as out-of-distributions
tot_uncertainty = np.concatenate((uncertainty_mcp, uncertainty_kmnist))
in_distribution = np.concatenate((np.zeros_like(uncertainty_mcp), np.ones_like(uncertainty_kmnist)))
# Obtaining precision and recall plot vector + AUPR
precision_ood_mcp, recall_ood_mcp, _ = precision_recall_curve(in_distribution, -tot_uncertainty)
aupr_ood_mcp = average_precision_score(in_distribution, -tot_uncertainty)
# Computing for MCDropout with entropy
_, _, uncertainty_mutinf, _, _, _ = predict_test_set(lenet, test_loader, mode='mut_inf')
_, _, uncertainty_mutinf_kmnist, _, _, _ = predict_test_set(lenet, kmnist_test_loader, mode='mut_inf')
tot_uncertainty = np.concatenate((uncertainty_mutinf, uncertainty_mutinf_kmnist))
in_distribution = np.concatenate((np.zeros_like(uncertainty_mutinf), np.ones_like(uncertainty_mutinf_kmnist)))
precision_ood_ent, recall_ood_ent, _ = precision_recall_curve(in_distribution, tot_uncertainty)
aupr_ood_ent = average_precision_score(in_distribution, tot_uncertainty)
We will now implement the ODIN method.
ODIN [Liang et al., ICLR 2018], is a threshold-based detector enhancing maximum softmax probabilities with two extensions:
temperature scaling: \(\textit{p}(y= c \vert \mathbf{x}, \mathbf{w}, T) = \frac{\exp(f_c( \mathbf{x}, \mathbf{w}) / T)}{\sum_{k=1}^K \exp(f_k( \mathbf{x}, \mathbf{w}) / T)}\)
where \(T \in \mathbb{R}^{+}\)
inverse adversarial perturbation: \(\tilde{\mathbf{x}} = \mathbf{x} - \epsilon \mathrm{sign} \big ( - \nabla_x \log (\textit{p}(y = y^* \vert \mathbf{x}, \mathbf{w}, T) \big )\)
Both technics aimed to increase in-distribution MCP higher than out-distribution MCP. Here, we set the hyperparameters \(T=5\) and \(\epsilon=0.0014\).
[CODING TASK] Implement ODIN preprocessing
def odin_preprocessing(model, input, epsilon):
# We perform the invese adversarial perturbation
# You can find some help in the link below:
# https://pytorch.org/tutorials/beginner/fgsm_tutorial.html
# ============ YOUR CODE HERE ============
# 1. Set requires_grad attribute of tensor. Important for Attack
# 2. Forward pass the data through the model
# 3. Calculate the loss w.r.t to class predictions
# 4. Zero all existing gradients
# 5. Calculate gradients of model in backward pass
# 6. Collect sign of datagrad
# 7. Normalizing the gradient to the same space of image
sign_input_grad = sign_input_grad / 0.3081
# 8. Apply FGSM Attack
return perturbed_input
# Compute predictions for ODIN on MNIST
_, _, uncertainty_odin, errors_odin, _, _ = predict_test_set(lenet, test_loader, mode='odin', temp=5, epsilon=0.006)
# Compute predictions for ODIN on KMNIST
_, _, uncertainty_kmnist, errors_kmnist, _, _ = predict_test_set(lenet, kmnist_test_loader, mode='odin', temp=5, epsilon=0.0006)
# Concatenating predictions with MNIST, considering KMNIST samples as out-of-distributions
tot_uncertainty = np.concatenate((uncertainty_odin, uncertainty_kmnist))
in_distribution = np.concatenate((np.zeros_like(uncertainty_odin), np.ones_like(uncertainty_kmnist)))
# Obtaining precision and recall plot vector + AUPR
precision_ood_odin, recall_ood_odin, _ = precision_recall_curve(in_distribution, -tot_uncertainty)
aupr_ood_odin = average_precision_score(in_distribution, -tot_uncertainty)
Let’s look at the comparative results for failure prediction
plt.figure(figsize=(7,7))
plt.title('Precision-recall curve for OOD detection')
plt.plot(recall_ood_mcp, precision_ood_mcp, label = f'MCP, AUPR = {aupr_ood_mcp:.2%}')
plt.plot(recall_ood_ent, precision_ood_ent, label = f'MCDropout (MutInf), AUPR = {aupr_ood_ent:.2%}')
plt.plot(recall_ood_odin, precision_ood_odin, label = f'ODIN, AUPR = {aupr_ood_odin:.2%}')
plt.legend()
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.show()
[Question 3.1]: Compare the precision-recall curves of each OOD method along with their AUPR values. Which method perform best and why?