PyTorch (14) GAN (CelebA)
今回はDCGANをCelebAのデータで試してみた。このデータもよく見るけど使うの始めてだな。これまでのMNIST(2018/3/4)やFashion MNISTのGANと違ってカラー画像でサイズも大きいので少し修正が必要。
180303-gan.ipynb - Google ドライブ
CelebA dataset
CelebAのサイトではGoogle Driveを使って画像ファイルを提供している。ブラウザ上から直接ダウンロードしてきてもよいが、AWSなどクラウド環境を使っているときはいちいちローカルにダウンロードしてそれをAWSにアップするのが面倒だ。コマンドラインツールでAWS上に直接ダウンロードしたくて少し悩んだ。
Google Drive上のファイルをCLIで操作する gdrive という便利なツールがあるが、これは自分のGoogle Driveのファイルしか操作できない。他人のファイルを共有して自分のところに持ってきてもダウンロードできなかった。
いろいろ調べたところこのスクリプトを使うと他人のGoogle Driveのファイルをコマンドライン上からダウンロードできることがわかった。ファイルのIDを変えればCelebA以外のデータもダウンロードできるのでとても便利だ。
python get_drive_file.py 0B7EVK8r0v71pZjFTYXZWM3FlRnM img_align_celeba.zip
モデル構造
入力画像が 3x64x64
とちょっと大きくなったのでネットワークの出力サイズが少し違うが、基本的にはこれまでと同じ。
class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.fc = nn.Sequential( nn.Linear(62, 1024), nn.BatchNorm1d(1024), nn.ReLU(), nn.Linear(1024, 128 * 16 * 16), nn.BatchNorm1d(128 * 16 * 16), nn.ReLU(), ) self.deconv = nn.Sequential( nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1), nn.Sigmoid(), ) initialize_weights(self) def forward(self, input): x = self.fc(input) x = x.view(-1, 128, 16, 16) x = self.deconv(x) return x class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.conv = nn.Sequential( nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), ) self.fc = nn.Sequential( nn.Linear(128 * 16 * 16, 1024), nn.BatchNorm1d(1024), nn.LeakyReLU(0.2), nn.Linear(1024, 1), nn.Sigmoid(), ) initialize_weights(self) def forward(self, input): x = self.conv(input) x = x.view(-1, 128 * 16 * 16) x = self.fc(x) return x
データロード
CelebAは画像サイズがそろっていないのでデータ変換が必要。画像の中心の160x160ピクセルをクロップして、それを64x64にリサイズする処理が入る。画像が入ったディレクトリ(data/celebA
)から直接データを読み込むにはImageFolder
が使える。
# load dataset transform = transforms.Compose([ transforms.CenterCrop(160), transforms.Resize((64, 64)), transforms.ToTensor() ]) dataset = datasets.ImageFolder('data/celebA', transform) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
変換後の画像を描画するとこんな感じ。
# データの可視化 def imshow(img): npimg = img.numpy() # [c, h, w] => [h, w, c] plt.imshow(np.transpose(npimg, (1, 2, 0))) images, labels = iter(data_loader).next() images, labels = images[:25], labels[:25] imshow(make_grid(images, nrow=5, padding=1)) plt.axis('off')
訓練スクリプトは、GAN(2018/3/4)と同じなので省略。
実験結果
まずは学習曲線。これまでと同じでGeneratorがDiscriminatorに負ける感じで徐々に推移する。
1エポック目のGeneratorの生成画像
from IPython.display import Image Image('logs/epoch_001.png')
10エポック目
25エポック目
危ない人(笑)もいるけどうまく生成できている。
これらの生成画像に似た顔の人が訓練画像内にいないかどうかを確かめたいのだけどどうすればよいのだろう?潜在空間上でk-NNすればよいのかな?