人工知能に関する断創録

このブログでは人工知能のさまざまな分野について調査したことをまとめています(更新停止: 2019年12月31日)

多項式曲線フィッティング

PRMLをよく理解する&復習のために自分でもコーディングしていきます。Rを使っている方がいて(Rでベイズ線形回帰の予測分布)Rでやろうかなぁと思ったのですが慣れているPythonを使うことにしました。Pythonにも数値計算用のSciPyNumPy、グラフ描画のmatplotlibというRに匹敵するライブラリがそろっています。デフォルトでは入っていないので別途インストールしてください。

まずは、PRML1.1の多項式曲線フィッティングです。いわゆる最小二乗法ってやつですね。観測値xとtの訓練集合から多項式曲線のパラメータwを求めるという課題です。まず、訓練データ集合から作ります。PRMLでは、sin(2πx)の関数値を計算したあとに、ガウス分布に従う小さなランダムノイズを加えて対応するtを作っています。Pythonだと下のような感じかな。

実行すると下のようなグラフが表示されます。PRMLの図1.2と同じようなグラフです。

f:id:aidiary:20100327101228p:plain

次に式 (1.1) のM次多項式関数を作ります。wlistがパラメータベクトルwです。次数がMのときパラメータ数はM+1なので注意してください。

(1.1) \hspace{50pt} \displaystyle y(x,\bf{w}) = w_0 + w_1x + w_2x^2 + \cdots + w_Mx^M = \sum_{j=0}^M w_jx^j

次からが本番でいよいよ訓練データ集合xlist, tlistからパラメータwlistを求めます。式(1.2)の二乗誤差関数を最小化するようなwを見つけることを考えます。

(1.2) \hspace{50pt} \displaystyle E(\bf{w}) = \frac{1}{2} \sum_{n=1}^N \{y(x_n, \bf{w}) - t_n\}^2

E(w)を各w_iで偏微分して0とおくとM+1個の連立方程式(正規方程式)が立てられ、この連立方程式を解くとパラメータwが求められます。これをやれってのが、演習1.1ですが、紙と鉛筆でごりごり解くと式(1.122)と式(1.123)の連立方程式になります。

(1.122) \hspace{50pt} \displaystyle \sum_{j=0}^M A_{ij} w_j = T_i
(1.123) \hspace{50pt} \displaystyle A_{ij} = \sum_{n=1}^N (x_n)^{i+j} \hspace{50pt} T_i = \sum_{n=1}^N (x_n)^i t_n

式(1.122)は各w_iについて得られるのでM次多項式の場合はM+1個の連立方程式になります。連立方程式はnumpy.linalg.solve()という関数で解くことができます。この部分を実装すると下のようになります。連立方程式をAw=Tという行列表現にして、AとTを式(1.123)で計算し、最後に連立方程式の解wlistを求めます。

最後に、モデル曲線をプロットして完成です。今までのを全部まとめると下のようになります。

Mの値をいろいろ変えてグラフを書いてみました。図1.4と同じようにM=0, M=1, M=3, M=9の場合です。

f:id:aidiary:20100327110958p:plain
f:id:aidiary:20100327110959p:plain
f:id:aidiary:20100327111000p:plain
f:id:aidiary:20100327111001p:plain

M=0やM=1だと多項式モデルが貧弱すぎてデータをちゃんとフィッティングできていません。M=3はちょうどよいです。M=9はモデルが複雑すぎて訓練データはちゃんとフィッティングできるけど他の部分ではまったくダメで明らかに過学習してます。上のプログラムはパラメータwlistの値を出力していますが、表1.1にあるようにM=9だと非常に大きな値になることがわかります。

過学習を防ぐためには(1)データ数を増やす(2)正則化するという解決策があるようです。まずは、簡単にデータ数を増やしてみます。

    xlist = np.linspace(0, 1, 100)  # 100点取る

結果は、

f:id:aidiary:20100327113918p:plain

確かにデータ数を増やすと過学習は深刻な問題ではないようです。次に、データ数は増やさないで正則化してみます。式(1.4)のように誤差関数に罰金項をつけるだけで本当に解決できるんですかね?

(1.4) \hspace{50pt} \displaystyle \tilde{E}(\bf{w}) = \frac{1}{2} \sum_{n=1}^N \{y(x_n, \bf{w}) - t_n\}^2 + \frac{\lambda}{2} ||\bf{w}||^2

同様にパラメータwを解析的に計算します(演習1.2です)。答えは教科書に載っていませんが、係数行列Aの対角成分にλを足せばよいという結果が得られました。

ためしに先ほど過学習していたM=9でlnλ=-18、λ=exp(-18)として正則化項を導入してみます。λの値すごくちっさいですね。こんなんで効果があるんでしょうか?

f:id:aidiary:20100327112704p:plain

おおおー、過学習が抑えられています。パラメータwlistの値を見ても小さくなっていることがわかります。誤差関数に罰金項をつけるだけで抑えられるってのは新鮮な驚きでした。sin以外もちゃんとフィッティングできるか試してみたいですね。