KerasでVGG16を使う
今回は、Deep Learningの画像応用において代表的なモデルであるVGG16をKerasから使ってみた。この学習済みのVGG16モデルは画像に関するいろいろな面白い実験をする際の基礎になるためKerasで取り扱う方法をちゃんと理解しておきたい。
ソースコード: test_vgg16
VGG16の概要
VGG16*1は2014年のILSVRC(ImageNet Large Scale Visual Recognition Challenge)で提案された畳み込み13層とフル結合3層の計16層から成る畳み込みニューラルネットワーク。層の数が多いだけで一般的な畳み込みニューラルネットと大きな違いはなく、同時期に提案されたGoogLeNetに比べるとシンプルでわかりやすい。ImageNetと呼ばれる大規模な画像データセットを使って訓練したモデルが公開されている。
VGG16の出力層は1000ユニットあり、1000クラスを分類するニューラルネットである。1000クラスのリストは1000 synsets for Task 2にある。あとでこの1000クラスの画像をクローリングする方法もまとめたい。
KerasのVGG16モデル
KerasではVGG16モデルがkeras.applications.vgg16
モジュールに実装されているため簡単に使える。これはImageNetの大規模画像セットで学習済みのモデルなので自分で画像を集めて学習する必要がない。
(注)少し古いバージョンのKerasだと自分でモデル構造を書いて、.h5ファイル形式の重みをダウンロードする必要があった(参考: VGG16 model for Keras)が最新の1.2.0では不要になっている。バックエンドに合わせて変換された重みファイルを自動ダウンロードしてくれる。keras.applications.vgg16
が実装される前に書かれた記事も多いので要注意。
from keras.applications.vgg16 import VGG16 model = VGG16(include_top=True, weights='imagenet', input_tensor=None, input_shape=None)
VGG16
クラスは4つの引数を取る。
include_top
はVGG16のトップにある1000クラス分類するフル結合層(FC)を含むか含まないかを指定する。今回は画像分類を行いたいためFCを含んだ状態で使う。FCを捨ててVGG16を特徴抽出器として使うことでいろいろ面白いことができるがまた今度取り上げたい。weights
はVGG16の重みの種類を指定する。VGG16は単にモデル構造であるため必ずしもImageNetを使って学習しなければいけないわけではない。しかし、現状ではImageNetで学習した重みしか提供されていない。None
にするとランダム重みになる。自分で集めた画像で学習する猛者はこちらか?input_tensor
は自分でモデルに画像を入力したいときに使うが今回は未使用。あとでVGG16のFine-tuningをする際に使う。input_shape
は入力画像の形状を指定する。include_top=True
にして画像分類器として使う場合は(224, 224, 3)
で固定なのでNone
でOK。何か中途半端な解像度だけどこれがImageNetの標準サイズのようだ。
読み込んだモデルをちょっと調べてみよう。
% print(model) <keras.engine.training.Model at 0x2b05220aa978>
どうやらVGG16
はKerasで一般的なSequential
モデルではなく、別のクラスのようだ。dir(model)
をするとわかるが、Sequential
モデルで層を積み重ねるのによく使っていたadd()
がないので注意。たとえば、VGG16に新たに層を付け加えるときにちょっとした工夫がいる。これもあとで詳しく取り上げたい。
summary()
するとモデル構造が見られる。
% model.summary() ____________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ==================================================================================================== input_26 (InputLayer) (None, 224, 224, 3) 0 ____________________________________________________________________________________________________ block1_conv1 (Convolution2D) (None, 224, 224, 64) 1792 input_26[0][0] ____________________________________________________________________________________________________ block1_conv2 (Convolution2D) (None, 224, 224, 64) 36928 block1_conv1[0][0] ____________________________________________________________________________________________________ block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 block1_conv2[0][0] ____________________________________________________________________________________________________ block2_conv1 (Convolution2D) (None, 112, 112, 128) 73856 block1_pool[0][0] ____________________________________________________________________________________________________ block2_conv2 (Convolution2D) (None, 112, 112, 128) 147584 block2_conv1[0][0] ____________________________________________________________________________________________________ block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 block2_conv2[0][0] ____________________________________________________________________________________________________ block3_conv1 (Convolution2D) (None, 56, 56, 256) 295168 block2_pool[0][0] ____________________________________________________________________________________________________ block3_conv2 (Convolution2D) (None, 56, 56, 256) 590080 block3_conv1[0][0] ____________________________________________________________________________________________________ block3_conv3 (Convolution2D) (None, 56, 56, 256) 590080 block3_conv2[0][0] ____________________________________________________________________________________________________ block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 block3_conv3[0][0] ____________________________________________________________________________________________________ block4_conv1 (Convolution2D) (None, 28, 28, 512) 1180160 block3_pool[0][0] ____________________________________________________________________________________________________ block4_conv2 (Convolution2D) (None, 28, 28, 512) 2359808 block4_conv1[0][0] ____________________________________________________________________________________________________ block4_conv3 (Convolution2D) (None, 28, 28, 512) 2359808 block4_conv2[0][0] ____________________________________________________________________________________________________ block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 block4_conv3[0][0] ____________________________________________________________________________________________________ block5_conv1 (Convolution2D) (None, 14, 14, 512) 2359808 block4_pool[0][0] ____________________________________________________________________________________________________ block5_conv2 (Convolution2D) (None, 14, 14, 512) 2359808 block5_conv1[0][0] ____________________________________________________________________________________________________ block5_conv3 (Convolution2D) (None, 14, 14, 512) 2359808 block5_conv2[0][0] ____________________________________________________________________________________________________ block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 block5_conv3[0][0] ____________________________________________________________________________________________________ flatten (Flatten) (None, 25088) 0 block5_pool[0][0] ____________________________________________________________________________________________________ fc1 (Dense) (None, 4096) 102764544 flatten[0][0] ____________________________________________________________________________________________________ fc2 (Dense) (None, 4096) 16781312 fc1[0][0] ____________________________________________________________________________________________________ predictions (Dense) (None, 1000) 4097000 fc2[0][0] ==================================================================================================== Total params: 138,357,544 Trainable params: 138,357,544 Non-trainable params: 0 ____________________________________________________________________________________________________
重み(#Param)がある層を数えていくと全部で16個あることがわかる。今回は、include_top=True
なのでfc1
、fc2
、predictions
という層が追加されているのが確認できる。また、最後のpredictions
層の形状が (None, 1000)
で1000クラスの分類であることもわかる。None
はサイズが決まっていないことを意味し、ここでは入力サンプル数(入力バッチ数)を意味する。
VGG16で一般物体認識
VGG16モデルが読み込めたのでさっそく画像を入力して分類するプログラムを書いてみよう。今回はコマンドラインから分類したい画像ファイル名を引数として入力するようにした。実際は、VGG16のロードに時間がかかるので起動後にプロンプトでファイル名を入力できるようにした方がよさそう。
from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions from keras.preprocessing import image import numpy as np import sys """ ImageNetで学習済みのVGG16モデルを使って入力画像のクラスを予測する """ if len(sys.argv) != 2: print("usage: python test_vgg16.py [image file]") sys.exit(1) filename = sys.argv[1] # 学習済みのVGG16をロード # 構造とともに学習済みの重みも読み込まれる model = VGG16(weights='imagenet') # model.summary() # 引数で指定した画像ファイルを読み込む # サイズはVGG16のデフォルトである224x224にリサイズされる img = image.load_img(filename, target_size=(224, 224)) # 読み込んだPIL形式の画像をarrayに変換 x = image.img_to_array(img) # 3次元テンソル(rows, cols, channels) を # 4次元テンソル (samples, rows, cols, channels) に変換 # 入力画像は1枚なのでsamples=1でよい x = np.expand_dims(x, axis=0) # Top-5のクラスを予測する # VGG16の1000クラスはdecode_predictions()で文字列に変換される preds = model.predict(preprocess_input(x)) results = decode_predictions(preds, top=5)[0] for result in results: print(result)
今回は、複数の画像をまとめて入力せずに1枚だけ入力するようにした(実際はバッチ単位で1000枚入力してまとめて予測も可)。
画像の入力はkeras.preprocessing.image
モジュールを使うといろいろ便利。load_img()
で指定したサイズにリサイズして画像がロードできる。また、img_to_array()
でPIL形式の画像をNumPy array形式に変換できる。
load_img()
でロードした画像は (rows, cols, channels)
の3Dテンソルなのでこれにサンプル数 samples
を追加した4Dテンソルに変換する必要がある。
クラスの予測はpredict()
で行う。VGG16用の平均を引く前処理 preprocess_input()
を通した4Dテンソルを入力とする。predict()
の戻り値はNNの出力であり1000クラスの確率値である。このままではどのクラスが何なのか非常にわかりづらい。VGG16用のdecode_predictions()
を使うと確率値が高い順にクラス名を出力してくれる。
いくつかImageNetからクローリングした適当な画像を入力して認識結果を見てみよう。実際のところVGG16の訓練データがどれかわからない。なのでImageNetから適当に拾ってきた下の画像がたまたま訓練内データに含まれていた可能性もあるので注意。
('n02328150', 'Angora', 0.98844689)★ ('n02326432', 'hare', 0.0081565334) ('n02325366', 'wood_rabbit', 0.0029539457) ('n02342885', 'hamster', 0.00032567442) ('n02364673', 'guinea_pig', 7.7807999e-05)
結果は認識結果のTop-5を出力している。各結果は (WordNet ID, クラス名, 確率) の3つ組からなる。確率が高い順にTop-5が出力されているのがわかると思う。この結果は、写真がAngora(アンゴラ)である確率が98.8%という意味。実際、入力画像はアンゴラ(ウサギの一種)なので正解!これはすごい。正解のクラスに★を付けておく。
('n02009912', 'American_egret', 0.9786104)★ ('n02012849', 'crane', 0.020076046) ('n02009229', 'little_blue_heron', 0.0011492056) ('n02007558', 'flamingo', 6.9171525e-05) ('n02006656', 'spoonbill', 4.3949272e-05)
あっている。
('n04147183', 'schooner', 0.99256611)★ ('n04612504', 'yawl', 0.006880702) ('n03947888', 'pirate', 0.00047176969) ('n02981792', 'catamaran', 2.9590283e-05) ('n04483307', 'trimaran', 1.5395544e-05)
正解!他の4個も船に関連したクラスが出てくるのがすごい。
('n04208210', 'shovel', 0.98540133)★ ('n04367480', 'swab', 0.0067446162) ('n02906734', 'broom', 0.0027940609) ('n03498962', 'hatchet', 0.0025928672) ('n03481172', 'hammer', 0.0012616612)
正解。
('n03028079', 'church', 0.76123959) ('n02699494', 'altar', 0.21932022)★ ('n04523525', 'vault', 0.010518107) ('n03854065', 'organ', 0.004585206) ('n03781244', 'monastery', 0.00092995056)
これはおしい。正解はaltarなのだがchurchでもいいだろう。
('n02109961', 'Eskimo_dog', 0.58899087) ('n02110185', 'Siberian_husky', 0.37142584)★ ('n02110063', 'malamute', 0.032857068) ('n03218198', 'dogsled', 0.0035322146) ('n02105412', 'kelpie', 0.00089457975)
これもおしい。犬の種類を間違えているけどシベリアンハスキーは2つ目に出てくる。
次にImageNetにないデータを使ってみよう。つまり、正解のクラスがわからない。
('n02504013', 'Indian_elephant', 0.63950682) ('n02504458', 'African_elephant', 0.31736749) ('n01871265', 'tusker', 0.034719132) ('n02437312', 'Arabian_camel', 0.004969846) ('n02410509', 'bison', 0.00089000992)
実際の種類はよくわからないけど1位、2位ともに象の種類なのであっていそう。
('n03085013', 'computer_keyboard', 0.78958303) ('n04264628', 'space_bar', 0.13960978) ('n04505470', 'typewriter_keyboard', 0.050729375) ('n03793489', 'mouse', 0.0087937126) ('n04074963', 'remote_control', 0.0026325041)
この前買ったHHKの写真だけどキーボードなので正解。
('n02090622', 'borzoi', 0.21255279) ('n02111889', 'Samoyed', 0.16779339) ('n02109961', 'Eskimo_dog', 0.04057616) ('n02104029', 'kuvasz', 0.035596009) ('n03026506', 'Christmas_stocking', 0.031400401)
うちのコロちゃんの写真。ボルゾイ、サモエド、エスキモー犬、クーバースと犬の名前を出してきたので「犬」ということは認識できているようだ。雑種なんだけどさ(笑)
('n02134084', 'ice_bear', 0.35386214) ('n02114548', 'white_wolf', 0.20906797) ('n02104029', 'kuvasz', 0.13921976) ('n02437616', 'llama', 0.08272367) ('n02109961', 'Eskimo_dog', 0.048978001)
うちのくうちゃんの写真。1位はシロクマ・・・一体なぜと思ったけどよくみたらまあわからなくもない(笑)。2位は白狼で3位はクーバースなのでなかなかよい線。
まあこんな感じでなかなかうまくいっているようにみえるが実は限界も多い。VGG16を学習した際に選ばれた1000クラス以外はどうやっても認識できないのだ。たとえば、認識がすごく簡単そうなひまわりの画像を入れてみよう。
('n11939491', 'daisy', 0.97069114) ('n02206856', 'bee', 0.011910294) ('n01944390', 'snail', 0.0033144613) ('n02219486', 'ant', 0.0021784289) ('n02281406', 'sulphur_butterfly', 0.0020090577)
VGG16の1000クラスにはsunflowerがないのでどうやっても認識できない。まあdaisyと一番近そうな花の名前を出してきたのはすごいけどね。
実はVGG16を利用してImageNetの1000クラスに含まれていない画像もちゃんと認識するFine-tuningという技術がある。次回取り上げたい。
参考
- Very Deep Convolutional Networks for Large-Scale Image Recognition
- ImageNet
- ImageNet classification with Python and Keras
- Building powerful image classification models using very little data
- Kerasで学ぶ転移学習
- 3日で作る高速特定物体認識システム
*1:VGGが何の略か結局わからない・・・Visual Geometry Groupという研究グループ名だとTwitterで教えていただきました。ありがとうございます!