GAN(Generative Adversarial Network)とは何か? 簡単にしくみ解説!【概要編】
記事の目的
GAN(Generative Adversarial Network)の仕組みを、数式なしで図を使用して解説していきます。GANとは何かの概要を解説した後、仕組みを4つのStepに分けて解説します。なお、「GANが深層生成モデルである」や、「DeepLearningの応用である」といった話はしません。
目次
- 1. GANの概要
- 2. Step1 Discriminatorの学習1
- 3. Step2 Discriminatorの学習2
- 4. Step3 Generatorの学習
- 5. Step4 画像の生成 (digitデータ使用)
1. GANの概要
GANとは、すごく簡単に言えば「たくさんの画像を学習して、その画像と似たオリジナルの画像を生成するアルゴリズム」です。GANは、「Generator」と「Discrimanator」から構成されており、この2つが競い合いながら学習します。この2つは簡単に言えば、「複雑な関数」です。以下の画像はGANの全体の構造です。
Generatorは、ノイズ(ベクトル)から画像を生成する複雑な関数です。Discriminatorは、Generatorが生成した偽物の画像と、本物の画像を識別できるようにする複雑な関数です。これらの関数(DiccriminatorとGenerator)が上手く画像を生成できるように関数のパラメータを決定(学習)させます。
Generatorは本物の画像に近い画像を出力するように頑張り、DiscriminatorはGeneratorの偽物画像と本物の画像を識別できるように頑張ります。GeneratorはDiscriminatorを騙せるように、DiscriminatorはGeneratorに騙されないように競い合って学習するわけです。
2. Step1 Discriminatorの学習1
Generatorから偽物画像を出力し、その画像を用いてDiscriminatorが学習します。正解ラベルを0として、Discriminatorの出力が0に近づくようにDiscriminatorのパラメータを学習します。
3. Step2 Discriminatorの学習2
本物の画像を用いてGeneratorが学習します。正解ラベルを1として、Discriminatorの出力が1に近づくようにDiscriminatorのパラメータを学習します。
step1, 2を終えて、Discriminatorは偽物と本物の画像を識別できるようになるわけです。
4. Step3 Generatorの学習
GAN全体を用いてGeneratorを学習します。正解ラベルを1として、Discriminatorの出力が1に近づくようにGeneratorのパラメータを学習します。正解ラベルを1としてGeneratorを学習させるため、Generatorが本物に近い画像を出力するようにパラメータが学習されます。
Step1〜3を繰り返し行うことで、パラメータの学習をします。
5. Step4 画像生成 (digitデータ使用)
- 画像を生成するのに必要なのは、Generatorの部分のみです。
- 実際に、pythonに元々入っている以下のdigitデータを使用して学習させ、画像を生成します。
- ノイズ(16次元のベクトル)から、本物にそっくりな画像のデータを生成することができました。