読者です 読者をやめる 読者になる 読者になる

人工知能に関する断創録

人工知能、認知科学、心理学、ロボティクス、生物学などに興味を持っています。このブログでは人工知能のさまざまな分野について調査したことをまとめています。最近は、機械学習・Deep Learningに関する記事が多いです。



Machine Learning with Scikit Learn (Part IV)

機械学習 scikit-learn

Machine Learning with Scikit Learn (Part III)(2015/9/8)のつづき。5.2節ではSupport Vector Machine (SVM)が詳しく取り上げられている。

5.2 In Depth - Support Vector Machines

scikit-learnでは分類用のSupport Vector Classifierssklearn.svm.SVCと回帰用のSupport Vector Regressorsklearn.svm.SVRの2種類が用意されている。この節では主にSVCが取り上げられている。以後、SVMと言ったらSVCを指す。

SVMの識別関数

SVMの識別関数は下式のようになる。

 \displaystyle \hat{y} = sign(\alpha_0 + \sum_{j} \alpha_j y^{(j)} k(x^{(j)}, x))

ここで、(x^{(j)}, y^{(j)})はj番目の訓練データ、xはテストサンプル、kはカーネル関数を意味する。\alpha_jが訓練データから学習されるSVMのパラメータ。SVMのパラメータは訓練データと同じ数だけある。

SVMのパラメータを学習するといくつかのサンプルx^{(j)}のみ0ではない\alpha_jが得られて、その他は0になる。この0でないa_jを持つサンプルx^{(j)}サポートベクトルと呼ばれ、分類境界を決定する上で大きな役割を果たす。

SVMの識別関数の導出は以前PRMLを勉強したときにまとめた。下の記事の式(7.13)が上と同じ式。

カーネル関数

SVMはカーネル関数を変えることで挙動が変わる。scikit-learnには、

  • 線形カーネル(linear
  • 多項式カーネル(poly
  • ガウスカーネル(rbf
  • シグモイドカーネル(sigmoid

が用意されている。独自のカーネル関数を定義することもできる。デフォルトではガウスカーネルが使われる。

k(x, x') = \exp (-\gamma ||x - x'||^2)

ガウスカーネルはパラメータgammaを取り、分類境界の滑らかさを表す。gammaが小さいほど滑らかな境界が得られる。

SVMの正則化パラメータC

前回(2015/9/8)も取り上げたけどSVMではパラメータCを用いて正則化の強さを調整できる。このCはペナルティ項の重みを意味する。つまり、Cが大きいほど誤分類のペナルティが大きくなるため訓練データをなるべく分類しようとして過学習気味になる。一方、Cを小さくすると正則化が強まり、汎化性能が高まる。ただし、Cを小さくしすぎるとモデルが単純になりすぎて逆に性能が出なくなる。

ここら辺も前にPRMLでソフトマージンSVMを勉強したときにまとめた。

SVMのパラメータ

ここまでをまとめるとガウスカーネルを用いたSVMには、gammaCという2つの調整すべきパラメータがある。このチュートリアルのノートブックには、これらのパラメータによって分類曲線がどう変わるかをインタラクティブに体験できるデモがついている。

f:id:aidiary:20150909215651p:plain

上のつまみをいじるとリアルタイムにサポートベクトル(黒い枠付きの丸)と分類曲線が変わる(この記事のはスクリーンショットなのでできない)。iPython Notebookはこんなこともできるのか~面白いな。

練習問題

最後に手書き数字認識のデータを使ってSVMのパラメータをグリッドサーチで求めろという練習問題があったので解いてみた。下のような感じかな?SVCのカーネルはデフォルトがrbfなので書かなくてもOK。

実行結果は、

{'C': 10, 'gamma': 0.001} 0.991833704529
0.993333333333

となる。パラメータはC = 10 & gamma = 0.001のときが最適で、そのときのValidation Scoreは99.18%でTest Scoreは99.33%となる。SVMはニューラルネットが伸びてきたせいであまり注目されなくなった気がするけど分類精度はかなりのものだ。

Part Vにつづく。次はRandom Forest。