Practical session 5: Out-of-distribution detection

You can open this practical session in colab here : https://colab.research.google.com/drive/1CFM5sCZDkzK7BKBCH8-U2pbNORkfvBH1?usp=sharing

or download the notebook directly here.

This last lab session will focus on 2 examples where good uncertainty estimation is crucial : failure prediction and out-of-distribution detection.

Goal: Take hand on applying uncertainty estimation for out-of-distribution detection

All Imports and Useful Functions

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score
from IPython import display
from tqdm.notebook import tqdm
from matplotlib.pyplot import imread
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets, transforms

use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {}
#kwargs = {'num_workers': 10, 'pin_memory': True} if use_cuda else {}
def plot_predicted_images(selected_idx, images, pred, labels, uncertainties, hists, mc_samples):
    """ Plot predicted images along with mean-pred probabilities histogram, maxprob frequencies and some class histograms
    across sampliong

    Args:
    selected_ix: (array) chosen index in the uncertainties tensor
    images: (tensor) images from the test set
    pred: (tensor) class predictions by the model
    labels: (tensor) true labels of the given dataset
    uncertainties: (tensor) uncertainty estimates of the given dataset
    errors: (tensor) 0/1 vector whether the model wrongly predicted a sample
    hists : (array) number of occurences by class in each sample fo the given dataset, only with MCDropout
    mc_samples: (tensor) prediction matrix for s=100 samples, only with MCDropout

    Returns:
        None
    """
    plt.figure(figsize=(15,10))

    for i, idx in enumerate(selected_idx):
        # Plot original image
        plt.subplot(5,6,1+6*i)
        plt.axis('off')
        plt.title(f'var-ratio={uncertainties[idx]:.3f}, \n gt={labels[idx]}, pred={pred[idx]}')
        plt.imshow(images[idx], cmap='gray')

        # Plot mean probabilities
        plt.subplot(5,6,1+6*i+1)
        plt.title('Mean probs')
        plt.bar(range(10), mc_samples[idx].mean(0))
        plt.xticks(range(10))

        # Plot frequencies
        plt.subplot(5,6,1+6*i+2)
        plt.title('Maxprob frequencies')
        plt.bar(range(10), hists[idx])
        plt.xticks(range(10))

        # Plot probs frequency for specific class
        list_plotprobs = [hists[idx].argsort()[-1], hists[idx].argsort()[-2], hists[idx].argsort()[-4]]
        ymax = max([max(np.histogram(mc_samples[idx][:,c])[0]) for c in list_plotprobs])
        for j, c in enumerate(list_plotprobs):
            plt.subplot(5,6,1+6*i+(3+j))
            plt.title(f'Samples probs of class {c}')
            plt.hist(mc_samples[idx][:,c], bins=np.arange(0,1.1,0.1))
            plt.ylim(0,np.ceil(ymax/10)*10)
            plt.xticks(np.arange(0,1,0.1), rotation=60)

    plt.tight_layout()
    plt.show()

Out-of-distribution detection

Modern neural networks are known to generalize well when the training and testing data are sampled from the same distribution. However, when deploying neural networks in real-world applications, there is often very little control over the testing data distribution. It is important for classifiers to be aware of uncertainty when shown new kinds of inputs, i.e., out-of-distribution examples. Therefore, being able to accurately detect out-of-distribution examples can be practically important for visual recognition tasks.

dogs and Cats

In this section, we will use Kuzushiji-MNIST, a drop-in replacement for the MNIST dataset (28x28 grayscale, 70,000 images) containing 3832 Kanji (japanese) characters, as out-of-distribution sample to our model trained on MNIST. We will compare the methods for uncertainty estimates used previously and ODIN.

# Load KMNIST dataset
kmnist_test_dataset = datasets.KMNIST('data', train=False, download=True, transform=transform)
kmnist_test_loader = DataLoader(kmnist_test_dataset, batch_size=128)

# Visualize some images
images, labels = next(iter(kmnist_test_loader))
fig, axes = plt.subplots(nrows=4, ncols=4)
for i, (image, label) in enumerate(zip(images, labels)):
    if i >= 16:
        break
    axes[i // 4][i % 4].imshow(images[i][0], cmap='gray')
    axes[i // 4][i % 4].set_title(f"{kmnist_test_dataset.classes[label]}")
    axes[i // 4][i % 4].set_xticks([])
    axes[i // 4][i % 4].set_yticks([])
fig.set_size_inches(4, 4)
fig.tight_layout()

We compute the precision, recall and AUPR metric for OOD detection with MCP and MCDrpoout with mutual information

# Compute predictions for MCP method on MNIST
_, _, uncertainty_mcp, errors_mcp, _, _ = predict_test_set(lenet, test_loader, mode='mcp')

# Same on KMNIST
_, _, uncertainty_kmnist, errors_kmnist, _, _ = predict_test_set(lenet, kmnist_test_loader, mode='mcp')

# Concatenating predictions with MNIST, considering KMNIST samples as out-of-distributions
tot_uncertainty = np.concatenate((uncertainty_mcp, uncertainty_kmnist))
in_distribution = np.concatenate((np.zeros_like(uncertainty_mcp), np.ones_like(uncertainty_kmnist)))

# Obtaining precision and recall plot vector + AUPR
precision_ood_mcp, recall_ood_mcp, _ = precision_recall_curve(in_distribution, -tot_uncertainty)
aupr_ood_mcp = average_precision_score(in_distribution, -tot_uncertainty)
# Computing for MCDropout with entropy
_, _, uncertainty_mutinf, _, _, _ = predict_test_set(lenet, test_loader, mode='mut_inf')
_, _, uncertainty_mutinf_kmnist, _, _, _ = predict_test_set(lenet, kmnist_test_loader, mode='mut_inf')
tot_uncertainty = np.concatenate((uncertainty_mutinf, uncertainty_mutinf_kmnist))
in_distribution = np.concatenate((np.zeros_like(uncertainty_mutinf), np.ones_like(uncertainty_mutinf_kmnist)))

precision_ood_ent, recall_ood_ent, _ = precision_recall_curve(in_distribution, tot_uncertainty)
aupr_ood_ent = average_precision_score(in_distribution, tot_uncertainty)

We will now implement the ODIN method.

ODIN [Liang et al., ICLR 2018], is a threshold-based detector enhancing maximum softmax probabilities with two extensions:

  • temperature scaling: \(\textit{p}(y= c \vert \mathbf{x}, \mathbf{w}, T) = \frac{\exp(f_c( \mathbf{x}, \mathbf{w}) / T)}{\sum_{k=1}^K \exp(f_k( \mathbf{x}, \mathbf{w}) / T)}\)

where \(T \in \mathbb{R}^{+}\)

  • inverse adversarial perturbation: \(\tilde{\mathbf{x}} = \mathbf{x} - \epsilon \mathrm{sign} \big ( - \nabla_x \log (\textit{p}(y = y^* \vert \mathbf{x}, \mathbf{w}, T) \big )\)

Both technics aimed to increase in-distribution MCP higher than out-distribution MCP. Here, we set the hyperparameters \(T=5\) and \(\epsilon=0.0014\).

[CODING TASK] Implement ODIN preprocessing

def odin_preprocessing(model, input, epsilon):
    # We perform the invese adversarial perturbation
    # You can find some help in the link below:
    # https://pytorch.org/tutorials/beginner/fgsm_tutorial.html

    # ============ YOUR CODE HERE ============
    # 1. Set requires_grad attribute of tensor. Important for Attack

    # 2. Forward pass the data through the model

    # 3. Calculate the loss w.r.t to class predictions

    # 4. Zero all existing gradients

    # 5. Calculate gradients of model in backward pass

    # 6. Collect sign of datagrad

    # 7. Normalizing the gradient to the same space of image
    sign_input_grad = sign_input_grad / 0.3081

    # 8. Apply FGSM Attack

    return perturbed_input
# Compute predictions for ODIN on MNIST
_, _, uncertainty_odin, errors_odin, _, _ = predict_test_set(lenet, test_loader, mode='odin', temp=5, epsilon=0.006)

# Compute predictions for ODIN on KMNIST
_, _, uncertainty_kmnist, errors_kmnist, _, _ = predict_test_set(lenet, kmnist_test_loader, mode='odin', temp=5, epsilon=0.0006)

# Concatenating predictions with MNIST, considering KMNIST samples as out-of-distributions
tot_uncertainty = np.concatenate((uncertainty_odin, uncertainty_kmnist))
in_distribution = np.concatenate((np.zeros_like(uncertainty_odin), np.ones_like(uncertainty_kmnist)))

# Obtaining precision and recall plot vector + AUPR
precision_ood_odin, recall_ood_odin, _ = precision_recall_curve(in_distribution, -tot_uncertainty)
aupr_ood_odin = average_precision_score(in_distribution, -tot_uncertainty)

Let’s look at the comparative results for failure prediction

plt.figure(figsize=(7,7))
plt.title('Precision-recall curve for OOD detection')
plt.plot(recall_ood_mcp, precision_ood_mcp, label = f'MCP, AUPR = {aupr_ood_mcp:.2%}')
plt.plot(recall_ood_ent, precision_ood_ent, label = f'MCDropout (MutInf), AUPR = {aupr_ood_ent:.2%}')
plt.plot(recall_ood_odin, precision_ood_odin, label = f'ODIN, AUPR = {aupr_ood_odin:.2%}')
plt.legend()
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.show()

[Question 3.1]: Compare the precision-recall curves of each OOD method along with their AUPR values. Which method perform best and why?