
【pytorchで深層生成モデル#8】pytorchでCGAN
記事の目的
深層生成モデルのCGAN(Convolutional GAN)をpytorchを使用して実装していきます。ここにある全てのコードは、コピペで再現することが可能です。
目次
1 今回のモデル

2 準備

# [1]
!nvidia-smi
# [2]
# データ作成に使用するライブラリ
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# モデル作成に使用するライブラリ
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
# よく使用するライブラリ
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(1)
# [3]
batch_size = 100
n_channel = 100
n_epoch = 10
n_class = 10
# [4]
# データ作成に使用するライブラリ
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
root = './data'
mnist_train = datasets.MNIST(root=root,download=True,train=True,transform=transform)
dataloader = DataLoader(mnist_train,batch_size=batch_size,shuffle=True)
# [5]
# データの読み込み
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
3 モデル

# [6]
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# ニューラルネットワークの構造を定義する
self.layers = nn.ModuleDict({
'layer0': nn.Sequential(
nn.ConvTranspose2d(n_channel+n_class, 512, 3, 1, 0),
nn.BatchNorm2d(512),
nn.ReLU()
),
'layer1': nn.Sequential(
nn.ConvTranspose2d(512, 256, 3, 2, 0),
nn.BatchNorm2d(256),
nn.ReLU()
),
'layer2': nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU()
),
'layer3': nn.Sequential(
nn.ConvTranspose2d(128, 1, 4, 2, 1),
nn.Tanh()
)
})
def forward(self, z):
for layer in self.layers.values():
z = layer(z)
return z
# [7]
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.layers = nn.ModuleDict({
'layer0': nn.Sequential(
nn.Conv2d(1+n_class, 128, 4, 2, 1),
nn.LeakyReLU(negative_slope=0.2)
),
'layer1': nn.Sequential(
nn.Conv2d(128, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.LeakyReLU(negative_slope=0.2)
),
'layer2': nn.Sequential(
nn.Conv2d(256, 512, 3, 2, 0),
nn.BatchNorm2d(512),
nn.LeakyReLU(negative_slope=0.2)
),
'layer3': nn.Sequential(
nn.Conv2d(512, 1, 3, 1, 0),
nn.Sigmoid()
)
})
def forward(self, x):
for layer in self.layers.values():
x = layer(x)
return x.squeeze()
# [8]
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# [9]
criterion = nn.BCELoss()
optimizerG = optim.Adam(generator.parameters(), lr = 0.0002, betas=(0.5, 0.999), weight_decay=1e-5)
optimizerD = optim.Adam(discriminator.parameters(), lr = 0.0002, betas=(0.5, 0.999), weight_decay=1e-5)
4 条件指定用関数

# [10]
def onehot_encode(label, device):
eye = torch.eye(n_class, device=device)
return eye[label].view(-1, n_class, 1, 1)
# [11]
def concat_image_label(image, label, device):
B, C, H, W = image.shape
oh_label = onehot_encode(label, device)
oh_label = oh_label.expand(B, n_class, H, W)
return torch.cat((image, oh_label), dim=1)
# [12]
def concat_noise_label(noise, label, device):
oh_label = onehot_encode(label, device)
return torch.cat((noise, oh_label), dim=1)
5 モデルの学習

# [13]
G_losses = []
D_losses = []
D_x_list = []
D_G_z1_list = []
D_G_z2_list = []
# 学習のループ
for epoch in range(n_epoch):
for x, t in dataloader:
# 前準備
real_image = x.to(device) # 本物の画像データ
noise = torch.randn(batch_size, n_channel, 1, 1, device=device)
real_target = torch.full((batch_size,), 1., device=device) # 本物ラベル
fake_target = torch.full((batch_size,), 0., device=device) # 偽物ラベル
# discriminatorの学習(本物画像の学習)
############################################################
real_label = t.to(device) # 本物の画像データ
real_image_label = concat_image_label(real_image, real_label, device)
############################################################
discriminator.zero_grad()
y = discriminator(real_image_label)
errD_real = criterion(y, real_target)
D_x = y.mean().item()
# discriminatorの学習(偽物画像の学習)
############################################################
fake_label = torch.randint(10, (batch_size,), dtype=torch.long, device=device)
fake_noise_label = concat_noise_label(noise, fake_label, device)
fake_image = generator(fake_noise_label)
fake_image_label = concat_image_label(fake_image, fake_label, device)
############################################################
y = discriminator(fake_image_label.detach())
errD_fake = criterion(y, fake_target)
D_G_z1 = y.mean().item()
# discriminatorの更新
errD = errD_real + errD_fake
errD.backward()
optimizerD.step()
# generatorの学習
generator.zero_grad()
y = discriminator(fake_image_label)
errG = criterion(y, real_target)
errG.backward()
D_G_z2 = y.mean().item()
optimizerG.step()
# 損失関数のリスト作成
D_losses.append(errD.item())
G_losses.append(errG.item())
D_x_list.append(D_x)
D_G_z1_list.append(D_G_z1)
D_G_z1_list.append(D_G_z2)
print('Epoch:{}/{}, Loss_D: {:.3f}, Loss_G: {:.3f}, D(x): {:.3f}, D(G(z)): {:.3f}/{:.3f}'
.format(epoch + 1, n_epoch, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
6 画像の生成

# [14] # サンプルデータの作成 sample_noise = torch.randn(batch_size, n_channel, 1, 1, device=device) sample_label = [i for i in range(10)] * (batch_size // 10) sample_label = torch.tensor(sample_label, dtype=torch.long, device=device) sample_noise_label = concat_noise_label(sample_noise, sample_label, device) generator.eval y = generator(sample_noise_label) # データの可視化 fig = plt.figure(figsize=(20,20)) plt.subplots_adjust(wspace=0.1, hspace=-0.8) for i in range(50): ax = fig.add_subplot(5, 10, i+1, xticks=[], yticks=[]) ax.imshow(y[i,].view(28,28).cpu().detach(), "gray")