人工知能に関する断創録

このブログでは人工知能のさまざまな分野について調査したことをまとめています(更新停止: 2019年12月31日)

KerasでVGG16を使う

今回は、Deep Learningの画像応用において代表的なモデルであるVGG16をKerasから使ってみた。この学習済みのVGG16モデルは画像に関するいろいろな面白い実験をする際の基礎になるためKerasで取り扱う方法をちゃんと理解しておきたい。

ソースコード: test_vgg16

VGG16の概要

VGG16*1は2014年のILSVRC(ImageNet Large Scale Visual Recognition Challenge)で提案された畳み込み13層とフル結合3層の計16層から成る畳み込みニューラルネットワーク。層の数が多いだけで一般的な畳み込みニューラルネットと大きな違いはなく、同時期に提案されたGoogLeNetに比べるとシンプルでわかりやすい。ImageNetと呼ばれる大規模な画像データセットを使って訓練したモデルが公開されている。

VGG16の出力層は1000ユニットあり、1000クラスを分類するニューラルネットである。1000クラスのリストは1000 synsets for Task 2にある。あとでこの1000クラスの画像をクローリングする方法もまとめたい。

KerasのVGG16モデル

KerasではVGG16モデルがkeras.applications.vgg16モジュールに実装されているため簡単に使える。これはImageNetの大規模画像セットで学習済みのモデルなので自分で画像を集めて学習する必要がない

(注)少し古いバージョンのKerasだと自分でモデル構造を書いて、.h5ファイル形式の重みをダウンロードする必要があった(参考: VGG16 model for Keras)が最新の1.2.0では不要になっている。バックエンドに合わせて変換された重みファイルを自動ダウンロードしてくれる。keras.applications.vgg16が実装される前に書かれた記事も多いので要注意。

from keras.applications.vgg16 import VGG16
model = VGG16(include_top=True, weights='imagenet', input_tensor=None, input_shape=None)

VGG16クラスは4つの引数を取る。

  • include_topはVGG16のトップにある1000クラス分類するフル結合層(FC)を含むか含まないかを指定する。今回は画像分類を行いたいためFCを含んだ状態で使う。FCを捨ててVGG16を特徴抽出器として使うことでいろいろ面白いことができるがまた今度取り上げたい。
  • weightsはVGG16の重みの種類を指定する。VGG16は単にモデル構造であるため必ずしもImageNetを使って学習しなければいけないわけではない。しかし、現状ではImageNetで学習した重みしか提供されていない。Noneにするとランダム重みになる。自分で集めた画像で学習する猛者はこちらか?
  • input_tensorは自分でモデルに画像を入力したいときに使うが今回は未使用。あとでVGG16のFine-tuningをする際に使う。
  • input_shapeは入力画像の形状を指定する。include_top=Trueにして画像分類器として使う場合は (224, 224, 3) で固定なのでNoneでOK。何か中途半端な解像度だけどこれがImageNetの標準サイズのようだ。

読み込んだモデルをちょっと調べてみよう。

% print(model)
<keras.engine.training.Model at 0x2b05220aa978>

どうやらVGG16はKerasで一般的なSequentialモデルではなく、別のクラスのようだ。dir(model)をするとわかるが、Sequentialモデルで層を積み重ねるのによく使っていたadd()がないので注意。たとえば、VGG16に新たに層を付け加えるときにちょっとした工夫がいる。これもあとで詳しく取り上げたい。

summary()するとモデル構造が見られる。

% model.summary()
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to
====================================================================================================
input_26 (InputLayer)            (None, 224, 224, 3)   0
____________________________________________________________________________________________________
block1_conv1 (Convolution2D)     (None, 224, 224, 64)  1792        input_26[0][0]
____________________________________________________________________________________________________
block1_conv2 (Convolution2D)     (None, 224, 224, 64)  36928       block1_conv1[0][0]
____________________________________________________________________________________________________
block1_pool (MaxPooling2D)       (None, 112, 112, 64)  0           block1_conv2[0][0]
____________________________________________________________________________________________________
block2_conv1 (Convolution2D)     (None, 112, 112, 128) 73856       block1_pool[0][0]
____________________________________________________________________________________________________
block2_conv2 (Convolution2D)     (None, 112, 112, 128) 147584      block2_conv1[0][0]
____________________________________________________________________________________________________
block2_pool (MaxPooling2D)       (None, 56, 56, 128)   0           block2_conv2[0][0]
____________________________________________________________________________________________________
block3_conv1 (Convolution2D)     (None, 56, 56, 256)   295168      block2_pool[0][0]
____________________________________________________________________________________________________
block3_conv2 (Convolution2D)     (None, 56, 56, 256)   590080      block3_conv1[0][0]
____________________________________________________________________________________________________
block3_conv3 (Convolution2D)     (None, 56, 56, 256)   590080      block3_conv2[0][0]
____________________________________________________________________________________________________
block3_pool (MaxPooling2D)       (None, 28, 28, 256)   0           block3_conv3[0][0]
____________________________________________________________________________________________________
block4_conv1 (Convolution2D)     (None, 28, 28, 512)   1180160     block3_pool[0][0]
____________________________________________________________________________________________________
block4_conv2 (Convolution2D)     (None, 28, 28, 512)   2359808     block4_conv1[0][0]
____________________________________________________________________________________________________
block4_conv3 (Convolution2D)     (None, 28, 28, 512)   2359808     block4_conv2[0][0]
____________________________________________________________________________________________________
block4_pool (MaxPooling2D)       (None, 14, 14, 512)   0           block4_conv3[0][0]
____________________________________________________________________________________________________
block5_conv1 (Convolution2D)     (None, 14, 14, 512)   2359808     block4_pool[0][0]
____________________________________________________________________________________________________
block5_conv2 (Convolution2D)     (None, 14, 14, 512)   2359808     block5_conv1[0][0]
____________________________________________________________________________________________________
block5_conv3 (Convolution2D)     (None, 14, 14, 512)   2359808     block5_conv2[0][0]
____________________________________________________________________________________________________
block5_pool (MaxPooling2D)       (None, 7, 7, 512)     0           block5_conv3[0][0]
____________________________________________________________________________________________________
flatten (Flatten)                (None, 25088)         0           block5_pool[0][0]
____________________________________________________________________________________________________
fc1 (Dense)                      (None, 4096)          102764544   flatten[0][0]
____________________________________________________________________________________________________
fc2 (Dense)                      (None, 4096)          16781312    fc1[0][0]
____________________________________________________________________________________________________
predictions (Dense)              (None, 1000)          4097000     fc2[0][0]
====================================================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
____________________________________________________________________________________________________

重み(#Param)がある層を数えていくと全部で16個あることがわかる。今回は、include_top=Trueなのでfc1fc2predictionsという層が追加されているのが確認できる。また、最後のpredictions層の形状が (None, 1000) で1000クラスの分類であることもわかる。Noneはサイズが決まっていないことを意味し、ここでは入力サンプル数(入力バッチ数)を意味する。

VGG16で一般物体認識

VGG16モデルが読み込めたのでさっそく画像を入力して分類するプログラムを書いてみよう。今回はコマンドラインから分類したい画像ファイル名を引数として入力するようにした。実際は、VGG16のロードに時間がかかるので起動後にプロンプトでファイル名を入力できるようにした方がよさそう。

from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from keras.preprocessing import image
import numpy as np
import sys

"""
ImageNetで学習済みのVGG16モデルを使って入力画像のクラスを予測する
"""

if len(sys.argv) != 2:
    print("usage: python test_vgg16.py [image file]")
    sys.exit(1)

filename = sys.argv[1]

# 学習済みのVGG16をロード
# 構造とともに学習済みの重みも読み込まれる
model = VGG16(weights='imagenet')
# model.summary()

# 引数で指定した画像ファイルを読み込む
# サイズはVGG16のデフォルトである224x224にリサイズされる
img = image.load_img(filename, target_size=(224, 224))

# 読み込んだPIL形式の画像をarrayに変換
x = image.img_to_array(img)

# 3次元テンソル(rows, cols, channels) を
# 4次元テンソル (samples, rows, cols, channels) に変換
# 入力画像は1枚なのでsamples=1でよい
x = np.expand_dims(x, axis=0)

# Top-5のクラスを予測する
# VGG16の1000クラスはdecode_predictions()で文字列に変換される
preds = model.predict(preprocess_input(x))
results = decode_predictions(preds, top=5)[0]
for result in results:
    print(result)

今回は、複数の画像をまとめて入力せずに1枚だけ入力するようにした(実際はバッチ単位で1000枚入力してまとめて予測も可)。

画像の入力はkeras.preprocessing.imageモジュールを使うといろいろ便利。load_img()で指定したサイズにリサイズして画像がロードできる。また、img_to_array()でPIL形式の画像をNumPy array形式に変換できる。

load_img()でロードした画像は (rows, cols, channels) の3Dテンソルなのでこれにサンプル数 samples を追加した4Dテンソルに変換する必要がある。

クラスの予測はpredict()で行う。VGG16用の平均を引く前処理 preprocess_input() を通した4Dテンソルを入力とする。predict()の戻り値はNNの出力であり1000クラスの確率値である。このままではどのクラスが何なのか非常にわかりづらい。VGG16用のdecode_predictions()を使うと確率値が高い順にクラス名を出力してくれる。

いくつかImageNetからクローリングした適当な画像を入力して認識結果を見てみよう。実際のところVGG16の訓練データがどれかわからない。なのでImageNetから適当に拾ってきた下の画像がたまたま訓練内データに含まれていた可能性もあるので注意。

f:id:aidiary:20170104210653j:plain:w300

('n02328150', 'Angora', 0.98844689)★
('n02326432', 'hare', 0.0081565334)
('n02325366', 'wood_rabbit', 0.0029539457)
('n02342885', 'hamster', 0.00032567442)
('n02364673', 'guinea_pig', 7.7807999e-05)

結果は認識結果のTop-5を出力している。各結果は (WordNet ID, クラス名, 確率) の3つ組からなる。確率が高い順にTop-5が出力されているのがわかると思う。この結果は、写真がAngora(アンゴラ)である確率が98.8%という意味。実際、入力画像はアンゴラ(ウサギの一種)なので正解!これはすごい。正解のクラスに★を付けておく。

f:id:aidiary:20170104210228g:plain:w300

('n02009912', 'American_egret', 0.9786104)★
('n02012849', 'crane', 0.020076046)
('n02009229', 'little_blue_heron', 0.0011492056)
('n02007558', 'flamingo', 6.9171525e-05)
('n02006656', 'spoonbill', 4.3949272e-05)

あっている。

f:id:aidiary:20170104210658j:plain:w300

('n04147183', 'schooner', 0.99256611)★
('n04612504', 'yawl', 0.006880702)
('n03947888', 'pirate', 0.00047176969)
('n02981792', 'catamaran', 2.9590283e-05)
('n04483307', 'trimaran', 1.5395544e-05)

正解!他の4個も船に関連したクラスが出てくるのがすごい。

f:id:aidiary:20170104210702j:plain:w300

('n04208210', 'shovel', 0.98540133)★
('n04367480', 'swab', 0.0067446162)
('n02906734', 'broom', 0.0027940609)
('n03498962', 'hatchet', 0.0025928672)
('n03481172', 'hammer', 0.0012616612)

正解。

f:id:aidiary:20170104210705j:plain:w300

('n03028079', 'church', 0.76123959)
('n02699494', 'altar', 0.21932022)★
('n04523525', 'vault', 0.010518107)
('n03854065', 'organ', 0.004585206)
('n03781244', 'monastery', 0.00092995056)

これはおしい。正解はaltarなのだがchurchでもいいだろう。

f:id:aidiary:20170104211944j:plain:w300

('n02109961', 'Eskimo_dog', 0.58899087)
('n02110185', 'Siberian_husky', 0.37142584)★
('n02110063', 'malamute', 0.032857068)
('n03218198', 'dogsled', 0.0035322146)
('n02105412', 'kelpie', 0.00089457975)

これもおしい。犬の種類を間違えているけどシベリアンハスキーは2つ目に出てくる。

次にImageNetにないデータを使ってみよう。つまり、正解のクラスがわからない。

f:id:aidiary:20170104210223j:plain:w300

('n02504013', 'Indian_elephant', 0.63950682)
('n02504458', 'African_elephant', 0.31736749)
('n01871265', 'tusker', 0.034719132)
('n02437312', 'Arabian_camel', 0.004969846)
('n02410509', 'bison', 0.00089000992)

実際の種類はよくわからないけど1位、2位ともに象の種類なのであっていそう。

f:id:aidiary:20170104210714j:plain:w300

('n03085013', 'computer_keyboard', 0.78958303)
('n04264628', 'space_bar', 0.13960978)
('n04505470', 'typewriter_keyboard', 0.050729375)
('n03793489', 'mouse', 0.0087937126)
('n04074963', 'remote_control', 0.0026325041)

この前買ったHHKの写真だけどキーボードなので正解。

f:id:aidiary:20170104210733j:plain:w300

('n02090622', 'borzoi', 0.21255279)
('n02111889', 'Samoyed', 0.16779339)
('n02109961', 'Eskimo_dog', 0.04057616)
('n02104029', 'kuvasz', 0.035596009)
('n03026506', 'Christmas_stocking', 0.031400401)

うちのコロちゃんの写真。ボルゾイ、サモエド、エスキモー犬、クーバースと犬の名前を出してきたので「犬」ということは認識できているようだ。雑種なんだけどさ(笑)

f:id:aidiary:20170104210738j:plain:w300

('n02134084', 'ice_bear', 0.35386214)
('n02114548', 'white_wolf', 0.20906797)
('n02104029', 'kuvasz', 0.13921976)
('n02437616', 'llama', 0.08272367)
('n02109961', 'Eskimo_dog', 0.048978001)

うちのくうちゃんの写真。1位はシロクマ・・・一体なぜと思ったけどよくみたらまあわからなくもない(笑)。2位は白狼で3位はクーバースなのでなかなかよい線。

まあこんな感じでなかなかうまくいっているようにみえるが実は限界も多い。VGG16を学習した際に選ばれた1000クラス以外はどうやっても認識できないのだ。たとえば、認識がすごく簡単そうなひまわりの画像を入れてみよう。

f:id:aidiary:20170104212004j:plain:w300

('n11939491', 'daisy', 0.97069114)
('n02206856', 'bee', 0.011910294)
('n01944390', 'snail', 0.0033144613)
('n02219486', 'ant', 0.0021784289)
('n02281406', 'sulphur_butterfly', 0.0020090577)

VGG16の1000クラスにはsunflowerがないのでどうやっても認識できない。まあdaisyと一番近そうな花の名前を出してきたのはすごいけどね。

実はVGG16を利用してImageNetの1000クラスに含まれていない画像もちゃんと認識するFine-tuningという技術がある。次回取り上げたい。

参考

*1:VGGが何の略か結局わからない・・・Visual Geometry Groupという研究グループ名だとTwitterで教えていただきました。ありがとうございます!