Travaux pratiques : Generative Adversarial Networks¶
L’objectif de cette séance de TP est d’illustrer par la pratique le fonctionnement des réseaux de neurones génératifs antagonistes (ou Generative Adversarial Networks, GAN). Cette séance est un peu moins guidée que les précédentes, n’hésitez pas à solliciter l’équipe enseignante en cas de difficultés, notamment liées à la programmation en PyTorch.
Rappelons pour commencer le principe des réseaux génératifs antagonistes. Ce modèle génératif met en compétition deux réseaux de neurones \({D}\) et \({G}\) que l’on appellera par la suite le discriminateur et le générateur, respectivement.
Note : on trouve parfois dans la litérature une analogie avec la falsification d’œuvres d’art. \({D}\) est alors appelé le « critique » et \({G}\) est appelé le « faussaire ».
L’objectif de \({G}\) est de transformer un bruit aléatoire \(z\) en un échantillon \(\tilde{x}\) le plus similaire possible aux observations réelles \(x \in \mathbf{X}\).
À l’inverse, l’objectif de \({D}\) est d’apprendre à reconnaître les « faux » échantillons \(\tilde{x}\) des vrais observations \(x\).
Pour implémenter un tel modèle, commençons par importer quelques bibliothèques et sous-modules utiles de PyTorch.
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
Jeux de données jouet¶
Points sur un cercle¶
Dans un premier temps, considérons un jeu de données simple : des points répartis le long du cercle unité. Une façon d’obtenir ces points est d’échantilloner uniformément selon une loi normale et de diviser chaque vecteur par sa norme pour le rendre unitaire :
X = np.random.randn(100, 2)
X /= np.linalg.norm(X, axis=1)[:, None].repeat(2, axis=1)
fig = plt.figure(figsize=(6, 6))
plt.scatter(X[:,0], X[:,1])
plt.title("Nuage de points") and plt.show()
On définit ici les paramètres du problème (dimension \(n\) des données d’entrée, ici \(n=2\), et dimension de l’espace latent). Ils pourront être modifiés par la suite ci-besoin.
data_dim = X.shape[-1]
hidden_dim = 10
Question
Écrire en utilisant l’interface nn.Sequential de PyTorch:
un générateur G entièrement connecté à 3 couches entièrement connectées qui projette un vecteur \(z\) de l’espace latent de dimension
hidden_dimvers un échantillon de dimensiondata_dim(dimension des données d’entrée),un discriminateur D entièrement connecté qui, à partir d’un vecteur de données (réel ou faux) produit en sortie un valeur entre 0 et 1 (obtenue en passant le score dans une sigmoïde).
Pour créer le jeu d’apprentissage, torch dispose d’une fonction bien
pratique qui permet de transformer automatiquement une matrice
d’observation en Dataset. Nous pouvons créer le DataLoader dans
la foulée :
from torch.utils import data
dataset = data.TensorDataset(torch.Tensor(X))
dataloader = data.DataLoader(dataset, batch_size=100)
(modifiez le paramètre batch_size si jamais vous ne disposez pas de
suffisamment de mémoire ou si les calculs sont trop lents)
L’apprentissage des poids des réseaux \(\mathcal{D}\) et \(\mathcal{G}\) se fait par des mises à jour séparées. Nous aurons donc besoin de deux optimiseurs différents, un qui porte sur les paramètres du générateur et l’autre qui portent sur les paramètres du discriminateur :
G_optimizer = torch.optim.Adam(generator.parameters())
D_optimizer = torch.optim.Adam(discriminator.parameters())
Question
Compléter la boucle d’apprentissage ci-dessous, et notamment le calcul des fonctions de coût pour le générateur et pour le discriminateur. On rappelle que :
le discriminateur cherche à maximiser \(\mathcal{L}_D = \log D(x) + \log(1- D(\hat{x}))\) où \(x\) sont des données réelles et \(\hat{x}\) des données générées (\(\hat{x} = G(z)\)), c’est-à-dire que la sortie de la sigmoïde du discriminateur doit valoir 1 pour les données réelles et 0 pour les données fausses,
le générateur cherche à minimiser \(\mathcal{L}_G = \log (1 - D(\hat{x})) = \log (1 - D(G(z))\), c’est à dire tromper le discriminateur en le poussant à prédire que des données fausses sont réelles.
L’algorithme d’apprentissage du GAN est donc le suivant :
Tant que la convergence n’est pas atteinte
Tirer un batch \(x\) de données réelles
Tirer un bruit aléatoire \(z \in \mathcal{N}(0,1)\)
Générer des fausses données \(G(z)\)
Calculer la fonction de coût de \(D\) sur les données réelles + fausses
Faire un pas d’optimisation de \(D\)
Calculer la fonction de coût de \(G\) sur les données synthétiques
Faire un pas d’optimisation sur \(G\)
On rappelle que la méthode .backward() permet de rétropropager les
gradients d’un tenseur et que optimizer.step() permet ensuite de
réaliser un pas de descente de gradient (mise à jour des poids).
Attention: il ne faut pas rétropropager le gradient dans \(G\)
lorsque vous réalisez une itération sur \(D\). Cela peut se faire à
l’aide de la méthode .detach() qui permet de signifier à PyTorch que
vous n’aurez pas besoin du gradient pour le tenseur concerné.
from tqdm.notebook import tqdm, trange
for epoch in trange(10000):
for real_data, in dataloader:
bs = len(real_data)
# Classe "faux" = 0
fake_labels = torch.zeros(len(fake_data))
# Classe "vrai" = 1
true_labels = torch.ones(len(real_data))
# À compléter
# ...
#
# N'oubliez pas d'appeler optimizer.zero_grad() pour réinitialiser les gradients à 0 !
Question
Tirez aléatoirement des vecteurs \(z\) de bruit selon une loi normale. Générez les points associés et visualisez côte à côte le nuage de points réel et le nuage de points synthétique.
z = torch.randn(1000, hidden_dim)
with torch.no_grad():
samples = generator(z)
fig = plt.figure(figsize=(12, 6))
fig.add_subplot(121)
plt.scatter(real_data[:, 0], real_data[:, 1])
plt.title("Points réels")
fig.add_subplot(122)
plt.scatter(samples[:, 0], samples[:, 1])
plt.title("Points générés")
plt.show()
Question
Colorez les points générés en fonction de la classe (“real” ou “fake”) prédite par le discriminateur. On mettera le seuil à 0.5.
Demi-lunes¶
Question
Optionnel (passez cette question si vous êtes pressés) Remplacez le jeu de données \(\mathbf{X}\) défini plus haut par les deux demi-lunes ci-dessous et répondez aux mêmes questions.
X, y = datasets.make_moons(n_samples=1000, noise=0.05)
plt.scatter(X[:,0], X[:, 1], c=y)
plt.show()
Swiss roll¶
Le swiss roll (gâteau roulé) est un jeu de données tridimensionnel formant un plan replié sur lui-même.
X, y = datasets.make_swiss_roll(n_samples=2000, noise=0.05)
import mpl_toolkits.mplot3d.axes3d as p3
fig = plt.figure()
ax = p3.Axes3D(fig)
ax.scatter(X[:, 0], X[:,1], X[:,2], c=y)
plt.show()
Question
Optionnel (passez cette question si vous êtes pressés) Répétez l’expérience ci-dessous. Que constatez-vous lorsque vous diminuez la dimensionalité de l’espace latent ?
Génération de chiffres manuscrits avec MNIST¶
Question
Reprendre le code précédent et l’adapter à la dimensionalité de MNIST. On rappelle qu’une image de MNIST est de dimensions \(28\times28\), c’est-à-dire un vecteur de dimension 784 une fois aplati.
Implémenter un GAN conditionnel sur MNIST. On utilisera comme vecteur de conditionnement \(y\) la classe du chiffre sous forme de vecteur suivant l’encodage one-hot. Visualiser les chiffres obtenus.
Question
Cette question est optionnelle et sert d’approfondissement. Remplacez
\(G\) et \(D\) définis ci-dessus par des modèles convolutifs. On
pourra notamment utiliser à bon escient les modules Upsample ou
Conv2DTranspose pour gérer l’augmentation des dimensions dans le
générateur.