人工知能に関する断創録

人工知能、認知科学、心理学、ロボティクス、生物学などに興味を持っています。このブログでは人工知能のさまざまな分野について調査したことをまとめています。最近は、機械学習、Deep Learning、Kerasに関する記事が多いです。



VGG16のFine-tuningによる17種類の花の分類

前回(2017/1/10)は、VGG16をFine-tuningして犬か猫を分類できる2クラス分類のニューラルネットワークを学習した。今回は、同様のアプローチで17種類の花を分類するニューラルネットワークを学習してみたい。前回の応用編みたいな感じ。この実験はオリジナルなので結果がどう出るかわからなかったけどうまくいったのでまとめてみた。

リポジトリ:17flowers

使用したデータは、VGG16を提案したOxford大学のグループが公開している 17 Category Flower Dataset である。下のような17種類の花の画像データ。とっても美しい。

f:id:aidiary:20170131203715j:plain:h700

前に実験した(2017/1/4)ようにデフォルトのVGG16ではひまわりさえ分類できなかったが、VGG16をFine-tuningすることで果たしてこれら17種類の花(ひまわりもある)を分類できるようになるのだろうか?さっそく試してみよう!

セットアップ

17flowers.tgzというファイルをダウンロードする。これを解凍するとjpgというディレクトリの中にimage_0001.jpgからimage_1360.jpgまで各クラス80枚、計1360枚の画像が含まれている。1360枚は畳み込みニューラルネットワークをスクラッチから学習するには心もとないデータ数であるが、今回はVGG16を少量データでチューニングする転移学習を使うので十分だろう。

各クラス80枚で17クラスなので1360枚なのだが、各画像がどのラベルなのかがわからない。とりあえずサンプル画像と見比べて下のようにラベルを割り振ったlabels.txtというファイルを作成した。たとえば、1行目はimage_0001.jpgからimage_0080.jpgまでがTulipであることを意味する。

1       80      Tulip
81      160     Snowdrop
161     240     LilyValley
241     320     Bluebell
321     400     Crocus
401     480     Iris
481     560     Tigerlily
561     640     Daffodil
641     720     Fritillary
721     800     Sunflower
801     880     Daisy
881     960     ColtsFoot
961     1040    Dandelion
1041    1120    Cowslip
1121    1200    Buttercup
1201    1280    Windflower
1281    1360    Pansy

まずは、Kerasからロードしやすいように以下のsetup.pyで画像ファイルを分割する。各クラスで訓練データが70枚、テストデータが10枚になるように分割した。犬猫分類でやったようにサブディレクトリにクラス名を付けておくと自動的に認識してくれる。このクラス名を付けるために先ほどのlabels.txtを使った。

import os
import shutil
import random

IN_DIR = 'jpg'
TRAIN_DIR = 'train_images'
TEST_DIR = 'test_images'

if not os.path.exists(TRAIN_DIR):
    os.mkdir(TRAIN_DIR)

if not os.path.exists(TEST_DIR):
    os.mkdir(TEST_DIR)

# name => (start idx, end idx)
flower_dics = {}

with open('labels.txt') as fp:
    for line in fp:
        line = line.rstrip()
        cols = line.split()

        assert len(cols) == 3

        start = int(cols[0])
        end = int(cols[1])
        name = cols[2]

        flower_dics[name] = (start, end)

# 花ごとのディレクトリを作成
for name in flower_dics:
    os.mkdir(os.path.join(TRAIN_DIR, name))
    os.mkdir(os.path.join(TEST_DIR, name))

# jpgをスキャン
for f in sorted(os.listdir(IN_DIR)):
    # image_0001.jpg => 1
    prefix = f.replace('.jpg', '')
    idx = int(prefix.split('_')[1])

    for name in flower_dics:
        start, end = flower_dics[name]
        if idx in range(start, end + 1):
            source = os.path.join(IN_DIR, f)
            dest = os.path.join(TRAIN_DIR, name)
            shutil.copy(source, dest)
            continue

# 訓練データの各ディレクトリからランダムに10枚をテストとする
for d in os.listdir(TRAIN_DIR):
    files = os.listdir(os.path.join(TRAIN_DIR, d))
    random.shuffle(files)
    for f in files[:10]:
        source = os.path.join(TRAIN_DIR, d, f)
        dest = os.path.join(TEST_DIR, d)
        shutil.move(source, dest)

1. 小さな畳み込みニューラルネットをスクラッチから学習する

前回と同様に今回もベースラインとしてLeNet相当の小さな畳み込みニューラルネットワークを学習してみよう*1。まずはモデルを構築する。

model = Sequential()
model.add(Convolution2D(32, 3, 3, input_shape=(img_rows, img_cols, channels)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Convolution2D(32, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Convolution2D(64, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

今まで何回もやってきたので特に難しいところはない。今回は17クラスの多クラス分類なので損失関数にcategorical_crossentropyを使う。

画像ファイルしか提供されていないときはデータの読み込みにImageDataGeneratorを使うと便利。今回もデータ拡張(2016/12/12)には前回と同じくshear_rangezoom_rangehorizontal_flipを使ったがデータの特徴を見て慎重に決めるとより精度が向上するかも。

# ディレクトリの画像を使ったジェネレータ
train_datagen = ImageDataGenerator(
    rescale=1.0 / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1.0 / 255)

上の設定を使って実際にジェネレータを作成。画像ファイルを含むディレクトリを指定するだけでよい。ここで、classesを自分で指定すると順番にクラスラベルを割り振ってくれる。color_modeclass_modeは何に使われるかいまいち把握できていないがrgbcategoricalでよさそう。

classes = ['Tulip', 'Snowdrop', 'LilyValley', 'Bluebell', 'Crocus',
           'Iris', 'Tigerlily', 'Daffodil', 'Fritillary', 'Sunflower',
           'Daisy', 'ColtsFoot', 'Dandelion', 'Cowslip', 'Buttercup',
           'Windflower', 'Pansy']

train_generator = train_datagen.flow_from_directory(
    directory='train_images',
    target_size=(img_rows, img_cols),
    color_mode='rgb',
    classes=classes,
    class_mode='categorical',
    batch_size=batch_size,
    shuffle=True)

test_generator = test_datagen.flow_from_directory(
    directory='test_images',
    target_size=(img_rows, img_cols),
    color_mode='rgb',
    classes=classes,
    class_mode='categorical',
    batch_size=batch_size,
    shuffle=True)

ジェネレータができたのでジェネレータが生成する画像の4Dテンソルを使ってモデルを訓練する。ここは前回と同じ。

history = model.fit_generator(
    train_generator,
    samples_per_epoch=samples_per_epoch,
    nb_epoch=nb_epoch,
    validation_data=test_generator,
    nb_val_samples=nb_val_samples)
save_history(history, os.path.join(result_dir, 'history_smallcnn.txt'))

学習途中の損失と精度はあとで参照できるようにファイルに保存した。

2. VGG16をFine-tuningする

次に前回と似た感じでVGG16をFine-tuningしてみよう。最後の畳込み層ブロックとフル結合層のみ重みを再調整する。

# VGG16モデルと学習済み重みをロード
# Fully-connected層(FC)はいらないのでinclude_top=False)
input_tensor = Input(shape=(img_rows, img_cols, 3))
vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)

# FC層を構築
top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(nb_classes, activation='softmax'))

# VGG16とFCを接続
model = Model(input=vgg16.input, output=top_model(vgg16.output))

# 最後のconv層の直前までの層をfreeze
for layer in model.layers[:15]:
    layer.trainable = False

# Fine-tuningのときはSGDの方がよい
model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
              metrics=['accuracy'])

前回とほとんど同じ。違いは、FC層の出力が17クラス分類になるので活性化関数にsoftmaxを使っているところと損失関数にcategorical_crossentropyを使っているところくらい。あとはほとんど同じなので省略。

実験結果

損失と精度の履歴は下のようになった。小さな畳み込みニューラルネットでは60%程度の分類精度しか出ないが、VGG16をFine-tuningすると85%くらいまで分類精度が跳ね上がるのが確認できた。

f:id:aidiary:20170131204317p:plain f:id:aidiary:20170131204327p:plain

1000クラスにひまわりがなかったVGG16ではあるが、花の分類に役立つ特徴抽出器は学習できていたようだ。

花の分類例

f:id:aidiary:20170131205422j:plain:w200

% python predict.py test_images/Sunflower/image_0724.jpg
input: test_images/Sunflower/image_0724.jpg
('Sunflower', 0.9999969)
('ColtsFoot', 1.3097971e-06)
('Iris', 8.8473638e-07)
('Tigerlily', 6.5053348e-07)
('Buttercup', 1.2474243e-07)

念願のひまわりもちゃんと分類できている!

f:id:aidiary:20170131205545j:plain:w200

% python predict.py test_images/Iris/image_0410.jpg
input: test_images/Iris/image_0410.jpg
('Iris', 0.99993575)
('Crocus', 1.9791674e-05)
('Sunflower', 1.734318e-05)
('Buttercup', 9.1189122e-06)
('Fritillary', 4.9292394e-06)

統計屋にはおなじみのアイリスもちゃんと分類できる。

f:id:aidiary:20170131205647j:plain:w200

% python predict.py test_images/Dandelion/image_0966.jpg
input: test_images/Dandelion/image_0966.jpg
('Dandelion', 0.99502832)
('ColtsFoot', 0.0034611411)
('Sunflower', 0.0014925624)
('Tigerlily', 5.2598648e-06)
('Buttercup', 4.0573868e-06)

蜂がいてもなんのその。何となくひまわりに似ているけどちゃんと区別できているのがすごい。3番目にひまわりが出ていた。

f:id:aidiary:20170131205854j:plain:w200

% python predict.py test_images/Windflower/image_1219.jpg
input: test_images/Windflower/image_1219.jpg
('Windflower', 0.99989629)
('Pansy', 4.6885198e-05)
('Daisy', 3.5976835e-05)
('LilyValley', 8.592202e-06)
('Snowdrop', 5.5862884e-06)

日本語だとアネモネ。これも正解。

f:id:aidiary:20170131210015j:plain:w200

% python predict.py test_images/Daisy/image_0807.jpg
input: test_images/Daisy/image_0807.jpg
('Daisy', 0.99801069)
('Sunflower', 0.0019772816)
('ColtsFoot', 1.0563146e-05)
('Pansy', 8.2547245e-07)
('Windflower', 4.5572747e-07)

デイジーも正解。

こういうきれいな画像を見ているとテンション上がる。とっても楽しい実験だった。

次は畳込みニューラルネットのフィルタの可視化方法を深堀りしてDeep Dreamへとつなげていきたい。何か今さら感が強いんだけどね(^^;

参考

*1:ちなみにここでVGG16をスクラッチから学習しようとすると全然精度が出ないのは確認済み