【pytorchで深層生成モデル#4】pytorchでAE

記事の目的

深層生成モデルのAE(AutoEncoder)をpytorchを使用して実装していきます。ここにある全てのコードは、コピペで再現することが可能です。

 

目次

  1. 今回のモデル
  2. 準備
  3. モデル
  4. モデルの学習
  5. 画像の生成

 

1 今回のモデル

 

2 準備

# [1]
!nvidia-smi

# [2]
# データ作成に使用するライブラリ
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# モデル作成に使用するライブラリ
import torch
import torch.nn as nn
import torch.optim as optim
# モデルの可視化に使用するライブラリ
import matplotlib.pyplot as plt
torch.manual_seed(10)

# [3]
# データセットの作成
transform = transforms.Compose([transforms.ToTensor(), lambda x: x.view(-1)])
root = './data'
mnist_dataset = datasets.MNIST(root=root,download=True,train=True,transform=transform)
dataloader = DataLoader(mnist_dataset, batch_size=100,shuffle=True)

# [4]
# gpuの指定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

 

3 モデル

# [5]
class AE(nn.Module):

    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(784, 256)
        self.l2 = nn.Linear(256, 784)

    def forward(self, x):
        x = self.l1(x)
        z = torch.relu(x)

        z = self.l2(z)
        y = torch.sigmoid(z)

        return y

# [6]
model = AE().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters())

 

4 モデルの学習

# [7]
n_epoch = 10
for epoch in range(n_epoch):

    loss_mean = 0.
    for (x, t) in dataloader:

      # 学習準備
      x = x.to(device)
      model.train()
      
      # モデルの学習
      x_fake = model(x)
      loss = criterion(x_fake, x)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      
      # 損失関数の計算
      loss_mean += loss.item()
    loss_mean /= len(dataloader)
    
    print('Epoch: {}, Loss: {:.3f}'.format(epoch+1, loss_mean))

 

5 画像の生成

# [8]
# 元データの可視化
x, t = next(iter(dataloader))
real_image = x[0,].view(28,28).detach().numpy()
plt.imshow(real_image, cmap='binary_r')
plt.axis('off')

# [9]
# 生成データの可視化
model.eval()
x = x.to(device)
fake_image = model(x)[0,].view(28,28).detach().cpu().numpy()
plt.imshow(fake_image, cmap='binary_r')
plt.axis('off')