【pytorchで深層生成モデル#5】pytorchでVAE

記事の目的

深層生成モデルのVAE(Variable Auto Encoder)をpytorchを使用して実装していきます。ここにある全てのコードは、コピペで再現することが可能です。

 

目次

  1. 今回のモデル
  2. 準備
  3. モデル
  4. モデルの学習
  5. 画像の生成

 

1 今回のモデル

 

2 準備

# [1]
!nvidia-smi

# [2]
# データ作成に使用するライブラリ
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# モデル作成に使用するライブラリ
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
# よく使用するライブラリ
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(1)

# [3]
# データの読み込み
transform = transforms.Compose([transforms.ToTensor(), lambda x: x.view(-1)])
root = './data'
mnist_dataset = datasets.MNIST(root=root,download=True,train=True,transform=transform)
dataloader = DataLoader(mnist_dataset, batch_size=100,shuffle=True)

# [4]
# gpuの指定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

 

3 モデル

# [5]
class VAE(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        self.encoder = Encoder(device=device)
        self.decoder = Decoder(device=device)

    def forward(self, x):
        mean, var = self.encoder(x)
        z = self.reparameterize(mean, var)
        y = self.decoder(z)
        return y, z

    def reparameterize(self, mean, var):
        z = mean + torch.sqrt(var) * torch.randn(mean.size()).to(self.device)
        return z

    def criterion(self, x):
        mean, var = self.encoder(x)
        z = self.reparameterize(mean, var)
        y = self.decoder(z)
        L1 =  - torch.mean(torch.sum(x * torch.log(y) + (1 - x) * torch.log(1 - y), dim=1))
        L2 = - 1/2 * torch.mean(torch.sum(1 + torch.log(var) - mean**2 - var, dim=1))
        L =  L1 + L2

        return L

# [6]
class Encoder(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        self.l1 = nn.Linear(784, 256)
        self.l2 = nn.Linear(256, 128)
        self.l_mean = nn.Linear(128, 2)
        self.l_var = nn.Linear(128, 2)

    def forward(self, x):
        h = self.l1(x)
        h = torch.relu(h)
        h = self.l2(h)
        h = torch.relu(h)
        mean = self.l_mean(h)
        var = self.l_var(h)
        var = F.softplus(var)

        return mean, var

# [7]
class Decoder(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        self.l1 = nn.Linear(2, 128)
        self.l2 = nn.Linear(128, 256)
        self.out = nn.Linear(256, 784)

    def forward(self, x):
        h = self.l1(x)
        h = torch.relu(h)
        h = self.l2(h)
        h = torch.relu(h)
        h = self.out(h)
        y = torch.sigmoid(h)

        return y

# [8]
model = VAE(device=device).to(device)
criterion = model.criterion
optimizer = optim.Adam(model.parameters())

 

4 モデルの学習

# [9]
n_epoch = 8

for epoch in range(n_epoch):
    loss_mean = 0.
    for (x, t) in dataloader:

      # 学習準備
      x = x.to(device)
      model.train()
      
      # モデルの学習
      loss = criterion(x)
      optimizer.zero_grad()
      loss.backward()     
      optimizer.step()

      # 損失関数の計算
      loss_mean += loss.item()
    loss_mean /= len(dataloader)

    print('Epoch: {}, Loss: {:.3f}'.format(epoch+1, loss_mean))

 

5 画像の生成

# [10]
# データの生成
model.eval()
z = torch.randn(10, 2, device = device)
images = model.decoder(z)
images = images.view(-1, 28, 28)
images = images.squeeze().detach().cpu().numpy()

# データの可視化
for i, image in enumerate(images):
    plt.subplot(2, 5, i+1)
    plt.imshow(image, cmap='binary_r')
    plt.axis('off')
plt.tight_layout()
plt.show()

# [11]
# データ可視化の前準備
img_size=28
n_image = 10
img_size_spaced = img_size + 2
matrix_image = np.zeros((img_size_spaced*n_image, img_size_spaced*n_image))  # 全体の画像

# 潜在変数の作成
z_1 = torch.linspace(-3, 3, n_image)  # 行
z_2 = torch.linspace(-3, 3, n_image)  # 列

#  潜在変数を変化させて画像を生成
for i, z1 in enumerate(z_1):
    for j, z2 in enumerate(z_2):
        x = torch.tensor([float(z1), float(z2)], device=device)
        images = model.decoder(x)
        images = images.view(-1, 28, 28)
        images = images.squeeze().detach().cpu().numpy()
        top = i*img_size_spaced
        left = j*img_size_spaced
        matrix_image[top : top+img_size, left : left+img_size] = images

# データの可視化
plt.figure(figsize=(8, 8))
plt.imshow(matrix_image.tolist(), cmap="Greys_r")
plt.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)
plt.show()