GAN(Generative Adversarial Network)とは何か? 簡単にしくみ解説!【理論編】
記事の目的
GAN(Generative Adversarial Network)の仕組みを、数式と図を使用して解説していきます。実際にdigitデータを例としてGeneratorとDiscriminatorの構成を確認します。その後、誤差関数や勾配法、誤差逆伝播法を解説。最後に、GANの学習から画像生成の流れを解説します。
目次
1. 使用するデータ
使用するデータはdigitデータで、8×8=64次元のデータです。以下の画像のように、0-9の数字画像のデータです。
2. Generator
Generatorの構成を、「Parameters」、「Forward」、「Backward」の3つにわけて解説します。今回、digitデータを使用して構成するGeneratorは3層です。
2.1 Parameters
Parametersは、Generatorのパラメータです。各層に2つずつパラメータがあり、計6つのパラメータがあります。「Generatorを学習させる」=「パラメータを更新する」です。
2.2 Forward
Forwardは、関数の中身です。入力がノイズで、出力が画像データの関数です。Generatorの役目は、「このforwardの関数が本物画像に近い偽物画像を出力するようにParameterを決定すること」です。
2.3 Backward
Backwardは、損失関数を各パラメータや各層の出力で微分した値です。この値が必要な理由は、「4. 誤差関数と勾配法」で解説します。
3. Discriminator
Discriminatorの構成もGeneratorと同様、「Parameters」、「Forward」、「Backward」の3つにわけて解説します。Discriminatorも3層で構成されています。
3.1 Parameters
Parametersは、Discriminatorのパラメータです。Generatorと同様に各層に2つずつパラメータがあり、計6つのパラメータがあります。「Discriminatorを学習させる」=「パラメータを更新する」です。
3.2 Forward
Forwardは、関数の中身です。入力が偽物画像か本物画像のデータで、出力が0-1の値です。Discriminatorの役目は、「このforwardの関数が本物画像と偽物画像を識別できるようにParameterを決定すること」です。
3.3 Backward
Backwardは、損失関数を各パラメータや各層の出力で微分した値です。この値が必要な理由は、「4. 誤差関数と勾配法」で解説します。
4. 誤差関数と勾配法
4.1 誤差関数
誤差関数は交差エントロピーを使用しています。誤差関数は、「正解ラベル(t)が1でDiscriminatorの出力が1に近い」か「正解ラベル(t)が0でDiscriminatorの出力が0に近い」場合に小さい値をとります。誤差関数の値が小さい場合は、偽物画像に対してDiscriminatorは0に近い値を出力し、本物画像に対してDiscriminatorは1に近い値を出力します。ゆえに、誤差関数が小さければ小さいほど、Discriminatorは偽物と本物を識別できるようになります。よって、「パラメータを学習させる」=「誤差関数の値が小さくなるようにパラメータを更新する」です。
4.2 勾配法
勾配法とは、誤差関数が小さくなるようにパラメータを設定する方法です。細かい話は省略しますが、θを各パラメータとして以下の式でパラメータを更新することで誤差関数の値を小さくできます。
$$ \theta = \{ W_i, b_i \ | \ i \in \{1, 2, …, 6\}\}$$
$$ \theta \leftarrow \theta \ – \ \eta \frac{\partial E}{\partial \theta}$$
4.3 パラメータの学習の流れ
パラメータの学習の流れは以下の通りです。2はパラメータ学習の上で必須ではないですが、抜けてると不自然なので加えています。
- Forward関数を使用して値を出力
- 損失関数を計算
- Backwardで勾配を計算
- 勾配法を使用してパラメータを更新
5. Backward (誤差逆伝播法)
GeneratorとDiscriminatorのBackwardで勾配を計算していますが、このとき誤差逆伝播法を使用して効率的に計算しています。誤差逆伝播法とは、1つ後の層の勾配の値を使用して計算量を減らす手法です。例えば、以下の画像ではDiscriminatorの3層目の勾配の値を2層の勾配の計算に使用することができています。勾配が誤差関数から、forward関数の逆方向に伝搬しているので誤差逆伝播法(backpropagation)と呼ばれているんだと思います(多分)。
6. GANの学習から画像生成
6.1 Step1. Discriminatorの学習1
以下の画像の青枠の部分を使用します。Generatorから偽物画像データを生成し、そのデータを使用してDiscriminatorを学習させます。
6.2 Step2. Discriminatorの学習2
以下の画像の青枠の部分を使用します。本物画像を使用してDiscriminatorを学習させます。
6.3 Step3. Generatorの学習
以下の画像の青枠の部分を使用します。GAN全体を使用してGeneratorを学習させます。なお、このときDiscriminatorのBackwardでDiscriminatorの勾配は計算していますが、学習(パラメータの更新)はしていません。
Step1-3を繰り返すことで、GAN全体の学習は終了します。
6.4 Step4. 画像の生成
GANのGeneratorの部分を使用して画像を生成させたのが以下の画像です。