人工知能に関する断創録

このブログでは人工知能のさまざまな分野について調査したことをまとめています(更新停止: 2019年12月31日)

PyTorch (13) GAN (Fashion MNIST)

今回はDCGANをFashion MNISTのデータで試してみた。このデータは使うの始めてだな〜

画像サイズがMNISTとまったく同じで 1x28x28 なのでネットワーク構造は何も変えなくてよい (^^;) 今回は手抜きして変えたところだけ掲載します。

180303-gan-mnist.ipynb - Google ドライブ

PyTorchにはFashion MNISTをロードする関数があるのでそれを使うだけ。

from torchvision import datasets

# load dataset
transform = transforms.Compose([
    transforms.ToTensor()
])
dataset = datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)  

描画してみよう。

# データの可視化
def imshow(img):
    npimg = img.numpy()
    # [c, h, w] => [h, w, c]
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

images, labels = iter(data_loader).next()
images, labels = images[:25], labels[:25]
imshow(make_grid(images, nrow=5, padding=1))
plt.axis('off')

f:id:aidiary:20180306225310p:plain

洋服とかバッグとか靴とかの画像みたいですね。

その他の訓練コードなどは前回とまったく同じなので省略します。全部見たい方は↑のJupyter Notebookを参照してください。

学習曲線はMNISTと同じようにGeneratorがDiscriminatorに負ける感じに推移しました。

f:id:aidiary:20180306225451p:plain

from IPython.display import Image
Image('logs/epoch_001.png')

エポック1だとあまり再現できないけど・・・

f:id:aidiary:20180306225527p:plain

Image('logs/epoch_010.png')

f:id:aidiary:20180306225620p:plain

Image('logs/epoch_025.png')

f:id:aidiary:20180306225648p:plain

エポック25まで学習させるとオリジナルに似た画像が生成できました!

おしまい (^^;)