
【pytorchで深層生成モデル#5】pytorchでVAE
記事の目的
深層生成モデルのVAE(Variable Auto Encoder)をpytorchを使用して実装していきます。ここにある全てのコードは、コピペで再現することが可能です。
目次
1 今回のモデル
2 準備
x
29
29
1
# [1]
2
!nvidia-smi
3
4
# [2]
5
# データ作成に使用するライブラリ
6
from torchvision import datasets
7
import torchvision.transforms as transforms
8
from torch.utils.data import DataLoader
9
# モデル作成に使用するライブラリ
10
import torch
11
import torch.nn as nn
12
import torch.optim as optim
13
import torch.nn.functional as F
14
# よく使用するライブラリ
15
import matplotlib.pyplot as plt
16
import numpy as np
17
torch.manual_seed(1)
18
19
# [3]
20
# データの読み込み
21
transform = transforms.Compose([transforms.ToTensor(), lambda x: x.view(-1)])
22
root = './data'
23
mnist_dataset = datasets.MNIST(root=root,download=True,train=True,transform=transform)
24
dataloader = DataLoader(mnist_dataset, batch_size=100,shuffle=True)
25
26
# [4]
27
# gpuの指定
28
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
device
3 モデル
1
72
72
1
# [5]
2
class VAE(nn.Module):
3
def __init__(self, device='cpu'):
4
super().__init__()
5
self.device = device
6
self.encoder = Encoder(device=device)
7
self.decoder = Decoder(device=device)
8
9
def forward(self, x):
10
mean, var = self.encoder(x)
11
z = self.reparameterize(mean, var)
12
y = self.decoder(z)
13
return y, z
14
15
def reparameterize(self, mean, var):
16
z = mean + torch.sqrt(var) * torch.randn(mean.size()).to(self.device)
17
return z
18
19
def criterion(self, x):
20
mean, var = self.encoder(x)
21
z = self.reparameterize(mean, var)
22
y = self.decoder(z)
23
L1 = - torch.mean(torch.sum(x * torch.log(y) + (1 - x) * torch.log(1 - y), dim=1))
24
L2 = - 1/2 * torch.mean(torch.sum(1 + torch.log(var) - mean**2 - var, dim=1))
25
L = L1 + L2
26
27
return L
28
29
# [6]
30
class Encoder(nn.Module):
31
def __init__(self, device='cpu'):
32
super().__init__()
33
self.device = device
34
self.l1 = nn.Linear(784, 256)
35
self.l2 = nn.Linear(256, 128)
36
self.l_mean = nn.Linear(128, 2)
37
self.l_var = nn.Linear(128, 2)
38
39
def forward(self, x):
40
h = self.l1(x)
41
h = torch.relu(h)
42
h = self.l2(h)
43
h = torch.relu(h)
44
mean = self.l_mean(h)
45
var = self.l_var(h)
46
var = F.softplus(var)
47
48
return mean, var
49
50
# [7]
51
class Decoder(nn.Module):
52
def __init__(self, device='cpu'):
53
super().__init__()
54
self.device = device
55
self.l1 = nn.Linear(2, 128)
56
self.l2 = nn.Linear(128, 256)
57
self.out = nn.Linear(256, 784)
58
59
def forward(self, x):
60
h = self.l1(x)
61
h = torch.relu(h)
62
h = self.l2(h)
63
h = torch.relu(h)
64
h = self.out(h)
65
y = torch.sigmoid(h)
66
67
return y
68
69
# [8]
70
model = VAE(device=device).to(device)
71
criterion = model.criterion
72
optimizer = optim.Adam(model.parameters())
4 モデルの学習
1
22
22
1
# [9]
2
n_epoch = 8
3
4
for epoch in range(n_epoch):
5
loss_mean = 0.
6
for (x, t) in dataloader:
7
8
# 学習準備
9
x = x.to(device)
10
model.train()
11
12
# モデルの学習
13
loss = criterion(x)
14
optimizer.zero_grad()
15
loss.backward()
16
optimizer.step()
17
18
# 損失関数の計算
19
loss_mean += loss.item()
20
loss_mean /= len(dataloader)
21
22
print('Epoch: {}, Loss: {:.3f}'.format(epoch+1, loss_mean))
5 画像の生成
1
43
43
1
# [10]
2
# データの生成
3
model.eval()
4
z = torch.randn(10, 2, device = device)
5
images = model.decoder(z)
6
images = images.view(-1, 28, 28)
7
images = images.squeeze().detach().cpu().numpy()
8
9
# データの可視化
10
for i, image in enumerate(images):
11
plt.subplot(2, 5, i+1)
12
plt.imshow(image, cmap='binary_r')
13
plt.axis('off')
14
plt.tight_layout()
15
plt.show()
16
17
# [11]
18
# データ可視化の前準備
19
img_size=28
20
n_image = 10
21
img_size_spaced = img_size + 2
22
matrix_image = np.zeros((img_size_spaced*n_image, img_size_spaced*n_image)) # 全体の画像
23
24
# 潜在変数の作成
25
z_1 = torch.linspace(-3, 3, n_image) # 行
26
z_2 = torch.linspace(-3, 3, n_image) # 列
27
28
# 潜在変数を変化させて画像を生成
29
for i, z1 in enumerate(z_1):
30
for j, z2 in enumerate(z_2):
31
x = torch.tensor([float(z1), float(z2)], device=device)
32
images = model.decoder(x)
33
images = images.view(-1, 28, 28)
34
images = images.squeeze().detach().cpu().numpy()
35
top = i*img_size_spaced
36
left = j*img_size_spaced
37
matrix_image[top : top+img_size, left : left+img_size] = images
38
39
# データの可視化
40
plt.figure(figsize=(8, 8))
41
plt.imshow(matrix_image.tolist(), cmap="Greys_r")
42
plt.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)
43
plt.show()