PyTorch (13) GAN (Fashion MNIST)
今回はDCGANをFashion MNISTのデータで試してみた。このデータは使うの始めてだな〜
画像サイズがMNISTとまったく同じで 1x28x28 なのでネットワーク構造は何も変えなくてよい (^^;) 今回は手抜きして変えたところだけ掲載します。
180303-gan-mnist.ipynb - Google ドライブ
PyTorchにはFashion MNISTをロードする関数があるのでそれを使うだけ。
from torchvision import datasets # load dataset transform = transforms.Compose([ transforms.ToTensor() ]) dataset = datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transform) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
描画してみよう。
# データの可視化 def imshow(img): npimg = img.numpy() # [c, h, w] => [h, w, c] plt.imshow(np.transpose(npimg, (1, 2, 0))) images, labels = iter(data_loader).next() images, labels = images[:25], labels[:25] imshow(make_grid(images, nrow=5, padding=1)) plt.axis('off')
洋服とかバッグとか靴とかの画像みたいですね。
その他の訓練コードなどは前回とまったく同じなので省略します。全部見たい方は↑のJupyter Notebookを参照してください。
学習曲線はMNISTと同じようにGeneratorがDiscriminatorに負ける感じに推移しました。
from IPython.display import Image Image('logs/epoch_001.png')
エポック1だとあまり再現できないけど・・・
Image('logs/epoch_010.png')
Image('logs/epoch_025.png')
エポック25まで学習させるとオリジナルに似た画像が生成できました!
おしまい (^^;)