人工知能に関する断創録

人工知能、認知科学、心理学、ロボティクス、生物学などに興味を持っています。このブログでは人工知能のさまざまな分野について調査したことをまとめています。最近は、機械学習、Deep Learning、Keras、PyTorchに関する記事が多いです。



PyTorch (12) Generative Adversarial Networks (MNIST)

前回(2018/2/28)の最後で次はConditional VAEだと言っていたけど思いっきり無視して (^^;) 今回はGenerative Adversarial Networks (GAN) やろう。いくつかのデータセットで実験しようと思っているけど今回は最初ということでMNISTから。

今回の実装は正確に言うとGeneratorとDiscriminatorに畳み込みニューラルネットを使っているので DCGAN(Deep Convolutional Generative Adversarial Networks) と呼ばれるGANにあたる。論文の設定とは微妙に違うところあるけど。

180303-gan-mnist.ipynb - Google ドライブ

まずはいつものimportから。

import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image

import matplotlib.pyplot as plt
%matplotlib inline

cuda = torch.cuda.is_available()
if cuda:
    print('cuda available!')

GANのアーキテクチャ

手書きで申し訳ないが・・・パワポで清書するの面倒なので (^^;)

f:id:aidiary:20180305091835p:plain

  • Nはミニバッチサイズを意味する
  • Generatorは乱数のベクトル z から偽物画像 G(z) を生成するニューラルネットワーク
  • x はMNISTの本物画像
  • Discriminatorは入力された画像(本物画像 x または偽物画像 G(z))が本物(1)か偽物(0)かを識別するニューラルネットワーク

今回、MNISTの実験に使ったDiscriminatorのアーキテクチャは

f:id:aidiary:20180304212541p:plain

  • 普通の畳み込みニューラルネットワーク
  • ただし、MaxPoolingを使わずにstrideで画像サイズを半分にしていく
  • stride=2としているので画像サイズは半分になる

Generatorのアーキテクチャは

f:id:aidiary:20180304212702p:plain

  • 入力 z は62次元の乱数ベクトル、これをシードとして画像を生成する
  • 乱数から画像を生成する というのが面白い発想だと思った
  • 最初にLinear層を入れて4Dテンソルの形状までサイズを広げていく
  • 6272次元まで拡張したらview()で 7x7x128 の4Dテンソル(図ではバッチ次元は省略)に変換
  • ConvTranspose2D で画像サイズをMNISTの28x28ピクセルまで拡張していく、それとともにチャネル数を減らしていく
  • 出力は1チャンネル28x28ピクセルのMNISTの画像になる

ConvTranspose2d

  • 入力は4Dテンソル (N, C1, H1, W1) 出力は4Dテンソル (N, C2, H2, W2)
  • C2はパラメータで指定
  • H2 = (H1 - 1) * stride - 2 * padding + kernel_size
  • W2 = (W1 - 1) * stride - 2 * padding + kernel_size
  • 1つめのConvTranspose2d = (7-1) * 2 - 2 * 1 + 4 = 14
  • 2つめのConvTranspose2d = (14-1) * 2 - 2 * 1 + 4 = 28

TODO: 実は ConvTranspose2D の理解が怪しい・・・要調査

さっそく実装してみよう。

class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(62, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * 7 * 7),
            nn.BatchNorm1d(128 * 7 * 7),
            nn.ReLU(),
        )
        
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )
        
        initialize_weights(self)

    def forward(self, input):
        x = self.fc(input)
        x = x.view(-1, 128, 7, 7)
        x = self.deconv(x)
        return x


class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        
        self.fc = nn.Sequential(
            nn.Linear(128 * 7 * 7, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1),
            nn.Sigmoid(),
        )
        
        initialize_weights(self)
    
    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 128 * 7 * 7)
        x = self.fc(x)
        return x

initialize_weights() はネットワークの重みを初期化する関数。PyTorchはデフォルト以外の重みを指定するときはこう書く必要があるのかな?ちょっと面倒かも。

def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

print すると下のように表示される。

Generator(
  (fc): Sequential(
    (0): Linear(in_features=62, out_features=1024)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU()
    (3): Linear(in_features=1024, out_features=6272)
    (4): BatchNorm1d(6272, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU()
  )
  (deconv): Sequential(
    (0): ConvTranspose2d (128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU()
    (3): ConvTranspose2d (64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): Sigmoid()
  )
)

Discriminator(
  (conv): Sequential(
    (0): Conv2d (1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(0.2)
    (2): Conv2d (64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (4): LeakyReLU(0.2)
  )
  (fc): Sequential(
    (0): Linear(in_features=6272, out_features=1024)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True)
    (2): LeakyReLU(0.2)
    (3): Linear(in_features=1024, out_features=1)
    (4): Sigmoid()
  )
)

DCGANの論文にはGANの訓練が安定する工夫として以下があげられている。

  • DiscriminatorでPoolingの代わりにStrided Convolutionを使う
  • GeneratorはFractional-strided Convolutionを使う
  • Generator、DiscriminatorともにBatchNormを使う
  • 層が深いときはFC層を除去してGlobal Average Poolingを使う
  • GeneratorにはReLUを使う
  • ただし出力層のみTanhを使う(今回は画像を0-1標準化したのでSigmoid使用)
  • DiscriminatorにはLeakyReLUを使う

FC層除去以外は適応済みかな。Fractional-strided Convolutionsってのが ConvTranspose2d と同じものを指しているのかがわからない。あとで調査する。

Optimizerの作成

ハイパーパラメータの定義。こういうのはコマンドライン引数から指定できたほうが汎用性があってよいのだけど今回はJupyter Notebook上での簡単な実験なので直書きする。

# hyperparameters
batch_size = 128
lr = 0.0002
z_dim = 62
num_epochs = 25
sample_num = 16
log_dir = './logs'

モデルオブジェクトを作成して、Optimizerを作成。

# initialize network
G = Generator()
D = Discriminator()
if cuda:
    G.cuda()
    D.cuda()

# optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# loss
criterion = nn.BCELoss()
  • GeneratorとDiscriminatorはそれぞれ別のOptimizerがある ことに注意
  • PyTorchはOptimizerの更新対象となるパラメータを第1引数で指定することになっている(Kerasにはなかった)
  • この機能のおかげで D_optimizer.step() でパラメータ更新を走らせたときにDiscriminatorのパラメータしか更新されない。Generatorのパラメータは固定される。
  • 参考先では loss.cuda() しているがlossはパラメータがないのでしなくてOK
  • What exactly does Loss.cuda() do? - PyTorch Forums

データのロード

MNISTをロード。前処理で ToTensor() しかしてないので画像は [0, 1] で標準化される。つまり、Generatorの出力層の活性化関数はSigmoid にする必要がある。[-1, 1] に標準化したときは Tanh を使う。

# load dataset
transform = transforms.Compose([
    transforms.ToTensor()
])
dataset = datasets.MNIST('data/mnist', train=True, download=True, transform=transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

GANのloss function

Discriminatorの目的関数(数式のLはLossではなく、Likelihoodだと思われる)を数式で書くと下のようになる*1

 L_D = E[log(D(x))] + E[log(1 - D(G(z)))] \to max

この式を解読すると下のように考えられる。

  • Discriminatorの目的関数は正解画像を入れたときと偽物画像を入れたときの2項から構成される
  • D(・)はDiscriminatorの出力で0〜1を返す。1に近いほど本物と判定されたことを意味する
  • G(z)はGeneratorが乱数zから生成した偽物画像を表す
  • 本物画像xをDに入力したとき(1項目)はD(x)が1(本物)に近いほどlogの値が大きくなる(=Dは本物画像を本物と正しく判定できた)
  • 偽物画像G(z)をDに入力したとき(2項目)はD(G(z))が0(偽物)に近いほどlogの値が大きくなる(=Dは偽物画像を偽物と正しく判定できた)

次はGeneratorの目的関数。

 L_G = E[log(D(G(z)))] \to max

この式を解読すると下のように考えられる。

  • 偽物画像G(z)をDに入力したときにD(G(z))が1(本物)に近いほどlogの値が大きくなる(=Gが生成した偽物画像をDに本物画像と誤って判定させることができた)

上の式をコードに落とすと下のようになる。

def train(D, G, criterion, D_optimizer, G_optimizer, data_loader):
    # 訓練モードへ
    D.train()
    G.train()

    # 本物のラベルは1
    y_real = Variable(torch.ones(batch_size, 1))
    # 偽物のラベルは0
    y_fake = Variable(torch.zeros(batch_size, 1))
    if cuda:
        y_real = y_real.cuda()
        y_fake = y_fake.cuda()

    D_running_loss = 0
    G_running_loss = 0
    for batch_idx, (real_images, _) in enumerate(data_loader):
        # 一番最後、バッチサイズに満たない場合は無視する
        if real_images.size()[0] != batch_size:
            break

        z = torch.rand((batch_size, z_dim))
        if cuda:
            real_images, z = real_images.cuda(), z.cuda()
        real_images, z = Variable(real_images), Variable(z)

        # Discriminatorの更新
        D_optimizer.zero_grad()

        # Discriminatorにとって本物画像の認識結果は1(本物)に近いほどよい
        # E[log(D(x))]
        D_real = D(real_images)
        D_real_loss = criterion(D_real, y_real)

        # DiscriminatorにとってGeneratorが生成した偽物画像の認識結果は0(偽物)に近いほどよい
        # E[log(1 - D(G(z)))]
        # fake_imagesを通じて勾配がGに伝わらないようにdetach()して止める
        fake_images = G(z)
        D_fake = D(fake_images.detach())
        D_fake_loss = criterion(D_fake, y_fake)

        # 2つのlossの和を最小化する
        D_loss = D_real_loss + D_fake_loss
        D_loss.backward()
        D_optimizer.step()  # これでGのパラメータは更新されない!
        D_running_loss += D_loss.data[0]

        # Generatorの更新
        G_optimizer.zero_grad()

        # GeneratorにとってGeneratorが生成した画像の認識結果は1(本物)に近いほどよい
        # E[log(D(G(z)))
        fake_images = G(z)
        D_fake = D(fake_images)
        G_loss = criterion(D_fake, y_real)
        G_loss.backward()
        G_optimizer.step()
        G_running_loss += G_loss.data[0]
    
    D_running_loss /= len(data_loader)
    G_running_loss /= len(data_loader)
    
    return D_running_loss, G_running_loss
  • 先の数式解釈で 0に近い方がよい、1に近い方がよいと言っていたのを正解ラベルとのBCELoss(Binary Cross Entropy Loss)で置き換えているのがポイント
  • GANはDiscriminatorのパラメータ更新とGeneratorのパラメータ更新を順番に繰り返す
  • Discriminatorのパラメータ更新をするときはGeneratorのパラメータは固定する必要がある(Kerasはこの実装が面倒だった)
  • PyTorchはOptimizerのパラメータ指定と detach() で実装する
  • Variableをdetach()するとそれ以上の勾配伝搬が止まる
  • 上の実装ではGeneratorの出力である fake_imagesdetach() しているのでそれより前のGeneratorに勾配は伝搬されない
  • D_optimizer にはDのパラメータしか渡しておらず、D_optimizer.step() してもGeneratorのパラメータは更新されないためdetach()しなくても結果は変わらない
  • ただし、明示的に detach() することで fake_images を通じてGeneratorに勾配が伝搬することを防ぎ、計算が高速化されるとのこと
  • why is detach necessary · Issue #116 · pytorch/examples · GitHub
  • ちなみに参考元のコードはほとんどの実装でdetach()が入ってない。公式のは入ってるので入れておいた

画像を生成する関数。学習途中のエポックでGeneratorを使ってサンプル画像を生成するのに使う。

def generate(epoch, G, log_dir='logs'):
    G.eval()
    
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # 生成のもとになる乱数を生成
    sample_z = torch.rand((64, z_dim))
    if cuda:
        sample_z = sample_z.cuda()
    sample_z = Variable(sample_z, volatile=True)
    
    # Generatorでサンプル生成
    samples = G(sample_z).data.cpu()
    save_image(samples, os.path.join(log_dir, 'epoch_%03d.png' % (epoch)))

訓練ループ!

history = {}
history['D_loss'] = []
history['G_loss'] = []
for epoch in range(num_epochs):
    D_loss, G_loss = train(D, G, criterion, D_optimizer, G_optimizer, data_loader)
    
    print('epoch %d, D_loss: %.4f G_loss: %.4f' % (epoch + 1, D_loss, G_loss))
    history['D_loss'].append(D_loss)
    history['G_loss'].append(G_loss)
    
    # 特定のエポックでGeneratorから画像を生成してモデルも保存
    if epoch == 0 or epoch == 9 or epoch == 24:
        generate(epoch + 1, G, log_dir)
        torch.save(G.state_dict(), os.path.join(log_dir, 'G_%03d.pth' % (epoch + 1)))
        torch.save(D.state_dict(), os.path.join(log_dir, 'D_%03d.pth' % (epoch + 1)))

# 学習履歴を保存
with open(os.path.join(log_dir, 'history.pkl'), 'wb') as f:
    pickle.dump(history, f)

実験結果

まずは、DiscriminatorとGeneratorのlossをプロットしてみる。

with open(os.path.join(log_dir, 'history.pkl'), 'rb') as f:
    history = pickle.load(f)

D_loss, G_loss = history['D_loss'], history['G_loss']
plt.plot(D_loss, label='D_loss')
plt.plot(G_loss, label='G_loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()
plt.grid()

f:id:aidiary:20180304223039p:plain

  • 学習が進むとDiscriminatorが強くなってGeneratorが競り負けていることがわかる
  • GANの発明者のGoodfellowも言及しているが、Generatorは贋作者(counterfeiter)でDiscriminatorは警察(police)に例えられる
  • つまり、Generatorを応援することは犯罪者を応援することにつながる(笑)この例えあまりよくないよね〜
  • Generatorもっと頑張れ!

途中のエポック(1、10、25エポック)で生成した画像を描画してみよう。

from IPython.display import Image
Image('logs/epoch_001.png')

f:id:aidiary:20180304223146p:plain

Image('logs/epoch_010.png')

f:id:aidiary:20180304223206p:plain

Image('logs/epoch_025.png')

f:id:aidiary:20180304223225p:plain

学習が進むにつれてDiscriminatorに見破られないような精巧な偽物画像をGeneratorが生成できるようになったことがわかる。

追記(2018/3/9)

コメントでご指摘いただいたようにGeneratorの更新の前にノイズを新しく生成しなおす必要があります。元の論文でもDiscriminatorの更新とGeneratorの更新では別のノイズを使っていました。

        # Generatorの更新
        z = torch.rand((batch_size, z_dim))  # <= ★これを追加
        if cuda:
            z = z.cuda()
        z = Variable(z)

        G_optimizer.zero_grad()

        # GeneratorにとってGeneratorが生成した画像の認識結果は1(本物)に近いほどよい
        # E[log(D(G(z)))
        fake_images = G(z)
        D_fake = D(fake_images)
        G_loss = criterion(D_fake, y_real)
        G_loss.backward()
        G_optimizer.step()
        G_running_loss += G_loss.data[0]

Notebookも更新済みです。こちらでも実験してみたのですが、大きく結果が変わるということはないようでした。

参考

GANの参考資料は他にもたくさんあるんだけど今回の記事に関連のあるものだけ

*1:論文ではDiscriminatorとGeneratorの目的関数が1つの式にまとめているが2つに分けた方がわかりやすい!