人工知能に関する断創録

人工知能、認知科学、心理学、ロボティクス、生物学などに興味を持っています。このブログでは人工知能のさまざまな分野について調査したことをまとめています。最近は、機械学習、Deep Learning、Keras、PyTorchに関する記事が多いです。



PyTorch (15) CycleGAN (horse2zebra)

今回はCycleGANの実験をした。CycleGANはあるドメインの画像を別のドメインの画像に変換できる。アプリケーションを見たほうがイメージしやすいので論文の図1の画像を引用。

f:id:aidiary:20180324132034p:plain

  • モネの絵を写真に変換する(またはその逆)
  • 馬の画像をシマウマに変換する(またはその逆)
  • 夏の景色を冬の景色に変換する(またはその逆)

こんな魔法のようなことが実現できる。

似たような技術にpix2pixという技術がある(両方ともUC Berkeley)が、これは変換元画像と変換先画像の1対1のペアの訓練データが必要になる。その一方で、CycleGANはこのようなペアとなる訓練画像が必要ないという利点がある。ドメインAの画像セット(馬の画像セット)とドメインBの画像セット(シマウマの画像セット)だけがあればよい。

f:id:aidiary:20180324132433p:plain

いろいろなアプリケーションがあるが、今回はウマをシマウマに変換するというアプリケーションを実験してみた。論文の著者がPyTorch版のコードを公開してくれているのでそれを参考にした。ただ、いろいろなパラメータやアルゴリズムで実験できるように大量のオプションがあってわかりにくかったため必要最小限のコードだけ抜き出してJupyter Notebook形式で書き直した。

Jupyter Notebook

全コードを説明するととても長くなるので今回から省略しよう (^^;) 今後も自分のメモ書き程度でまとめていくことになるかも・・・

horse2zebraデータセット

CycleGANで使われるデータセットは著者らのレポジトリのスクリプトを使えば簡単に取得できるが、ここから直接ダウンロードできる。

Model

CycleGANはその名の通りGANの一種であるため画像を生成するGeneratorとその画像が本物か偽物かを判定するDiscriminatorから構成される。今回の実験では、Generatorは9ブロックのResNet、Discriminatorは一般的なCNNとした。

CycleGANはGeneratorが2つ(G_AとG_B)とDiscriminatorが2つ(D_AとD_B)の4つのサブモデルから構成される。

self.netG_A = Generator()
self.netG_B = Generator()
self.netD_A = Discriminator()
self.netD_B = Discriminator()

これらの関係がわかりにくいので論文の図をベースに図示してみた。論文と実装では表記が異なるので混乱しやすいのだが、紫字がコードに相当する。

f:id:aidiary:20180324142105p:plain

  • CycleGANはドメインAの画像集合(馬)とドメインBの画像集合(シマウマ)を相互に変換する
  • netG_A は馬(real_A)から偽のシマウマ(fake_B)を生成するGenerator
  • netG_B はシマウマ(real_B)から偽の馬(fake_A)を生成するGenerator
  • netD_A は本物のシマウマ(real_B)と生成した偽のシマウマ(fake_B)を見分けるDiscriminator
  • netD_B は本物の馬(real_A)と生成した偽の馬(fake_A)を見分けるDiscriminator
  • これまでのGANではGANの入力が1次元のノイズベクトルだったが、CycleGANではソースドメインの画像になる

Loss

CycleGANのキモとなるLossの定義。次の3つのLossを組み合わせて最適化している。

1. Adversarial Loss

  • 一般的なGANで使われるLoss
  • Generatorは、生成した偽物の画像(馬またはシマウマ)をDiscriminatorに本物と判定させたい
# GAN loss D_A(G_A(A))
# G_Aとしては生成した偽物画像が本物(True)とみなしてほしい
fake_B = self.netG_A(real_A)
pred_fake = self.netD_A(fake_B)
loss_G_A = self.criterionGAN(pred_fake, True)

# GAN loss D_B(G_B(B))
# G_Bとしては生成した偽物画像が本物(True)とみなしてほしい
fake_A = self.netG_B(real_B)
pred_fake = self.netD_B(fake_A)
loss_G_B = self.criterionGAN(pred_fake, True)
  • Discriminatorは、本物画像を入れたときは本物と判定し、偽物画像を入れたときは偽物と判定したい
  • これまで実装したGANのDiscriminatorは0または1のスカラーを出力していたが、Discriminatorの出力は30x30ピクセルのfeature mapになっている。その場合は、30x30を0または1で埋め尽くした行列との間でMSELossをとればOK
# 本物画像を入れたときは本物と認識するほうがよい
pred_real = self.netD_A(real_B)
loss_D_real = self.criterionGAN(pred_real, True)

# ドメインAから生成した偽物画像を入れたときは偽物と認識するほうがよい
# fake_Bを生成したGeneratorまで勾配が伝搬しないようにdetach()する
pred_fake = self.netD_A(fake_B.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# 本物画像を入れたときは本物と認識するほうがよい
pred_real = self.netD_B(real_A)
loss_D_real = self.criterionGAN(pred_real, True)

# 偽物画像を入れたときは偽物と認識するほうがよい
pred_fake = self.netD_B(fake_A.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)

2. Cycle Consistency Loss

  • CycleGANの名前のもととなる重要なLoss
  • 馬の画像をシマウマのドメインに変換し、さらに戻したときに元の馬の画像が復元されるようにしたい
  • シマウマの画像を馬のドメインに変換し、さらに戻したときに元のシマウマの画像が復元されるようにしたい
  • 実際は画像間のL1Loss
  • このLossがあるおかげで馬とシマウマが対応すると学習できている?
  • 背景などはドメインと無関係なものはなるべく維持しようとする?

f:id:aidiary:20180324150739p:plain

# forward cycle loss
# real_A => fake_B => rec_Aが元のreal_Aに近いほどよい
rec_A = self.netG_B(fake_B)
loss_cycle_A = self.criterionCycle(rec_A, real_A) * lambda_A

# backward cycle loss
# real_B => fake_A => rec_Bが元のreal_Bに近いほどよい
rec_B = self.netG_A(fake_A)
loss_cycle_B = self.criterionCycle(rec_B, real_B) * lambda_B

3. Identity Mapping Loss

  • Generatorに対象ドメインでない画像を入れたときに何もしないようにするLoss
  • たとえば、netG_A は馬をシマウマに変換するが、シマウマをいれたときはそのままシマウマを出力する
  • netG_B はシマウマを馬に変換するが、馬をいれたときはそのまま馬を出力する
# G_A, G_Bは変換先ドメインの本物画像を入力したときはそのまま出力するべき
# netG_AはドメインAの画像からドメインBの画像を生成するGeneratorだが
# ドメインBの画像も入れることができる
# その場合は何も変換してほしくないという制約
idt_A = self.netG_A(real_B)
loss_idt_A = self.criterionIdt(idt_A, real_B) * lambda_B * lambda_idt

idt_B = self.netG_B(real_A)
loss_idt_B = self.criterionIdt(idt_B, real_A) * lambda_A * lambda_idt

CycleGANではこれら3つのLossを重み付けしてGeneratorとDiscriminatorのパラメータを更新していく。

実験結果

TensorboardXで出力したログファイルは下のコマンドで読み込める。

tensorboard --logdir logs_cyclegan_horse2zebra

各Lossの推移を描画してみると下のようになった。やはりこれまでと同様にGeneratorのLossは上がる傾向が見える。

f:id:aidiary:20180324152447p:plain:w300 f:id:aidiary:20180324152512p:plain:w300 f:id:aidiary:20180324152554p:plain:w300 f:id:aidiary:20180324152623p:plain:w300 f:id:aidiary:20180324152654p:plain:w300 f:id:aidiary:20180324152712p:plain:w300 f:id:aidiary:20180324152727p:plain:w300 f:id:aidiary:20180324152742p:plain:w300

テストデータを使って馬をシマウマにしてみた。

model = CycleGAN()
model.log_dir = 'logs_cyclegan_horse2zebra/'
model.load('epoch195')

batch = iter(test_loader).next()

# 馬をシマウマに変換
fake_B = model.netG_A(batch['A'])

左が本物の馬の画像で右が netG_A で変換したシマウマの画像。

epoch 5

f:id:aidiary:20180324174940p:plain

epoch 50

f:id:aidiary:20180324174557p:plain

epoch 195

f:id:aidiary:20180324174506p:plain

次は netG_B を使ってシマウマを馬にしてみた。左が本物のシマウマの画像で右が netG_B で変換した馬の画像。

# シマウマを馬に変換
fake_A = model.netG_B(batch['B'])

f:id:aidiary:20180324154851p:plain

  • 茶色い馬はシマウマになりやすい。白馬はデータが少ないのかうまくいかなかった
  • シマウマを馬にするのはあまりできなかった。シマウマの縞がどうしても残ってしまう。もしかしてAとBを逆にして再学習しないとダメ?

要調査事項

  • netG_A で馬からシマウマ、netG_B でシマウマから馬と相互変換が同時に学習できるのか?と思ったのだが、元の実装では、--which_direction というオプションがあり、デフォルトでは AtoB (馬からシマウマへ)になっていた。逆方向の変換をきちんと学習させるにはAをシマウマ、Bを馬にひっくり返して再学習する必要があるのかも。
  • netG_A に馬やシマウマ以外の画像を入れると何が起きるのか?
  • Discriminatorの出力を0 or 1 のスカラーにせずにFeature mapにしているのはなぜなのか?FC層を入れるとパラメータが増えるから?
  • Image2Imageのタスクではバッチサイズを1にして InstanceNorm2d を使う方がよいのか?

ImagePoolの意味

原論文に次のような文章があった。実際はこれなくても学習はできたが、合ったほうが安定するようだ。

Second, to reduce model oscillation [14], we follow Shrivastava et al’s strategy [45] and update the discriminators using a history of generated images rather than the ones produced by the latest generative networks. We keep an image buffer that stores the 50 previously generated images.

馬やシマウマ以外の画像を変換すると何が起きるのか?

def convert2zebra(filename):
    img = Image.open(filename).convert('RGB')

    img_tensor = test_dataset.transform(img)
    img_tensor.unsqueeze_(0)

    fake_B = model.netG_A(Variable(img_tensor))

    plt.figure(figsize=(10, 20))

    plt.subplot(1, 2, 1)
    imshow(make_grid(img_tensor, nrow=2))
    plt.axis('off')

    plt.subplot(1, 2, 2)
    imshow(make_grid(fake_B.data, nrow=2))
    plt.axis('off')

convert2zebra('data/cat.jpg')
  • おそらく訓練データに含まれる馬っぽい馬しか変換できないだろうと思ったがその通りだった
  • 茶色に反応するようで茶色い犬や猫や熊などは縞模様が浮き出る
  • 白馬はやはりダメみたい。やはり限界はあるか・・・
  • Cycle Consistency LossやIdentity Lossが機能しているためか入力画像に近い画像がちゃんと生成されるのはすごい

f:id:aidiary:20180324165123p:plain f:id:aidiary:20180324165218p:plain f:id:aidiary:20180324165726p:plain f:id:aidiary:20180324165226p:plain f:id:aidiary:20180324165235p:plain f:id:aidiary:20180324165241p:plain

参考