人工知能に関する断創録

このブログでは人工知能のさまざまな分野について調査したことをまとめています(更新停止: 2019年12月31日)

混合ガウスモデルとEM

今回は、9.2の混合ガウス分布のところです。混合ガウス分布はK個のガウス分布の線形重ね合わせで表されます。

(9.7) \hspace{50pt} \displaystyle p(x) = \sum_{k=1}^K \pi_k N(x|\mu_k, \Sigma_k)

ここで、π_kを混合係数と言い、k番目のガウス分布を選択する確率を表します。π_kは確率の条件を満たし、すべてのkについて足し合わせると1になります。ここら辺は、2.3.9の混合ガウス分布でも出てきました。この章では、混合ガウス分布を潜在変数zを導入して再定式化しています。zはK次元ベクトルで、K個の要素のうちどれか1つだけ1を取り、他は0、つまり1-of-K表現です。zはデータxがどのガウス分布から生成されたかを表し、下のような分布になります。

\displaystyle p(z_k = 1) = \pi_k
\displaystyle p(x|z_k = 1) = N(x|\mu_k, \Sigma_k)

そして、式(9.12)のようにこのzを陽に用いた形でp(x)を求めてもやっぱり混合ガウス分布の式 (9.7) になります(演習9.3)。

(9.12) \hspace{50pt} \displaystyle p(x) = \sum_{z} p(x, z) = \sum_{z} p(z) p(x|z) = \sum_{k=1}^K \pi_k N(x|\mu_k, \Sigma_k)

つまり、混合ガウス分布を「潜在変数zを含む別の式」で表現できたってことですね。何でこんなことするのか不明だったのですが、潜在変数を導入するとEMアルゴリズムが導入しやすいからなんでしょうねぇ。データxが与えられてもそのデータxがどのガウス分布から出ているかというzが分からないという状況です。

もう1つ重要なのが、潜在変数zの事後確率です。これを負担率と呼ぶとのこと。

(9.13) \hspace{50pt} \displaystyle \gamma(z_{nk}) = p(z_k = 1|x_n) = \frac{\pi_k N(x_n|\mu_k, \Sigma_k)}{\sum_{j=1}^K \pi_j N(x_n|\mu_j, \Sigma_j)}

で、今回の問題は、データ行列Xから混合ガウス分布のパラメータ(μ,Σ,π)を最尤推定することです。ただ、今までと違って潜在変数zがからんでくるために解析的に解けず、EMアルゴリズムという繰り返し最適化の手法を使いましょうって流れにつながります。

最尤推定なのでまず対数尤度関数を立てます。

(9.14) \hspace{50pt} \displaystyle \ln p(X|\pi,\mu,\Sigma) = \sum_{n=1}^N \ln \{ \sum_{k=1}^K \pi_k N(x_n|\mu_k, \Sigma_k) \}

EMアルゴリズムを使う状況では、logの中に潜在変数の和が出てくるのが特徴のようです。で、これをパラメータ(μ,Σ,π)で偏微分して0と置いた式を立てるとそれぞれ

(9.17) \hspace{50pt} \displaystyle \mu_k = \frac{1}{N_k} \sum_{n=1}^N \gamma(z_{nk}) x_n
(9.19) \hspace{50pt} \displaystyle \Sigma_k = \frac{1}{N_k} \sum_{n=1}^N \gamma(z_{nk})(x_n - \mu_k)(x_n - \mu_k)^T
(9.22) \hspace{50pt} \displaystyle \pi_k = \frac{N_k}{N}

ここで、N_kは、

(9.18) \hspace{50pt} \displaystyle N_k = \sum_{n=1}^N \gamma(z_{nk})

となります。今までだとここで終わってめでたし、めでたしなわけですが、今回はそうはいかず潜在変数zの事後確率(負担率γ)が出てきてしまいます。負担率の式(9.13)を見ると、負担率それ自体が今求めようとしたパラメータ(μ、Σ、π)で定義されるのでこれでは定義が循環してます。

そこで、EMアルゴリズムが登場ってわけですねー。

混合ガウス分布のためのEMアルゴリズム

  1. 平均μ_k、分散Σ_k、混合係数π_kを初期化し、対数尤度(9.14)の初期値を計算する。
  2. Eステップ: 現在のパラメータ値を使って、(9.13)の負担率(潜在変数の事後確率)を計算する。
  3. Mステップ: 現在の負担率を使って、(9.17)(9.19)(9.22)でパラメータ値を再計算する。
  4. 対数尤度(9.14)の変化を見て収束性を確認し、収束基準を満たしてなければEステップに戻る

では、Pythonでプログラムしてみます。今回もPRMLの原著サポートページのfaithful.txtというデータを使うので同じフォルダにおいてください。

#coding:utf-8

# 混合ガウス分布のEMアルゴリズム

import numpy as np
from pylab import *

K = 2  # 混合ガウス分布の数(固定)

def scale(X):
    """データ行列Xを属性ごとに標準化したデータを返す"""
    # 属性の数(=列の数)
    col = X.shape[1]
    
    # 属性ごとに平均値と標準偏差を計算
    mu = np.mean(X, axis=0)
    sigma = np.std(X, axis=0)
    
    # 属性ごとデータを標準化
    for i in range(col):
        X[:,i] = (X[:,i] - mu[i]) / sigma[i]
    
    return X

def gaussian(x, mean, cov):
    """多変量ガウス関数"""
    temp1 = 1 / ((2 * np.pi) ** (x.size/2.0))
    temp2 = 1 / (np.linalg.det(cov) ** 0.5)
    temp3 = - 0.5 * np.dot(np.dot(x - mean, np.linalg.inv(cov)), x - mean)
    return temp1 * temp2 * np.exp(temp3)

def likelihood(X, mean, cov, pi):
    """対数尤度関数"""
    sum = 0.0
    for n in range(len(X)):
        temp = 0.0
        for k in range(K):
            temp += pi[k] * gaussian(X[n], mean[k], cov[k])
        sum += np.log(temp)
    return sum

if __name__ == "__main__":
    # 訓練データをロード
    data = np.genfromtxt("faithful.txt")
    X = data[:, 0:2]
    X = scale(X)  # データを標準化(各次元が平均0、分散1になるように)
    N = len(X)    # データ数
    
    # 訓練データから混合ガウス分布のパラメータをEMアルゴリズムで推定する
    
    # 平均、分散、混合係数を初期化
    mean = np.random.rand(K, 2)
    cov = zeros( (K, 2, 2) ) 
    for k in range(K):
        cov[k] = [[1.0, 0.0], [0.0, 1.0]]
    pi = np.random.rand(K)
    
    # 負担率の空配列を用意
    gamma = zeros( (N, K) )
    
    # 対数尤度の初期値を計算
    like = likelihood(X, mean, cov, pi)

    turn = 0
    while True:
        print turn, like
        
        # E-step : 現在のパラメータを使って、負担率を計算
        for n in range(N):
            # 分母はkによらないので最初に1回だけ計算
            denominator = 0.0
            for j in range(K):
                denominator += pi[j] * gaussian(X[n], mean[j], cov[j])
            # 各kについて負担率を計算
            for k in range(K):
                gamma[n][k] = pi[k] * gaussian(X[n], mean[k], cov[k]) / denominator
        
        # M-step : 現在の負担率を使って、パラメータを再計算
        for k in range(K):
            # Nkを計算する
            Nk = 0.0
            for n in range(N):
                Nk += gamma[n][k]
            
            # 平均を再計算
            mean[k] = array([0.0, 0.0])
            for n in range(N):
                mean[k] += gamma[n][k] * X[n]
            mean[k] /= Nk
            
            # 共分散を再計算
            cov[k] = array([[0.0,0.0], [0.0,0.0]])
            for n in range(N):
                temp = X[n] - mean[k]
                cov[k] += gamma[n][k] * matrix(temp).reshape(2, 1) * matrix(temp).reshape(1, 2)  # 縦ベクトルx横ベクトル
            cov[k] /= Nk
            
            # 混合係数を再計算
            pi[k] = Nk / N
            
        # 収束判定
        new_like = likelihood(X, mean, cov, pi)
        diff = new_like - like
        if diff < 0.01:
            break
        like = new_like
        turn += 1

    # ガウス分布の平均を描画
    for k in range(K):
        scatter(mean[k, 0], mean[k, 1], c='r', marker='o')
    
    # 等高線を描画
    xlist = np.linspace(-2.5, 2.5, 50)
    ylist = np.linspace(-2.5, 2.5, 50)
    x, y = np.meshgrid(xlist, ylist)
    for k in range(K):
        z = bivariate_normal(x, y, np.sqrt(cov[k,0,0]), np.sqrt(cov[k,1,1]), mean[k,0], mean[k,1], cov[k,0,1])
        cs = contour(x, y, z, 3, colors='k', linewidths=1)
    
    # 訓練データを描画
    plot(X[:,0], X[:,1], 'gx')
    
    xlim(-2.5, 2.5)
    ylim(-2.5, 2.5)
    show()

結果は、

f:id:aidiary:20100521173130p:plain

K=2としているのでデータを生成したであろう2つのガウス分布(平均と分散で表される)が求められています。ただ、この図からじゃπがわからないですね。πを出力すると

[ 0.64405793  0.35594207]

となります。0.64が右上の分布、0.36が左下の分布のπです。データを見ると右上のガウス分布の方が生成された点が多いので、右上のガウス分布が多く選択される(=πの値が大きい)ってことで合っていそうです。対数尤度関数の変化を出力してみると、

0 -1126.70035851
1 -544.30111944
2 -544.01831258
3 -543.705876101
4 -543.320368192
5 -542.818355622
6 -542.133201031
7 -541.14318814
8 -539.597134229
9 -536.909831316
10 -531.510655647
11 -518.509080616
12 -484.009551238
13 -435.076263391
14 -412.31655532
15 -398.788737436
16 -388.532170712
17 -385.597427177
18 -385.466168682

のようにどんどん大きくなり、最後には収束するのがわかります。EMアルゴリズムはEステップとMステップの繰り返しで対数尤度が増大し、いずれ収束するとのこと。これが、「9.4 一般のEMアルゴリズム」で示されているようです。ただ、この節は難しい。式の導出はできたけどその意味がよくわからない・・・PRML合宿まとめサイトに載っている資料も当たってみようかと思ってます。

EMアルゴリズムの説明というとQが出てくる一般式だけでどういう風に使うのか理解に苦しんでいたのですが、PRMLでは具体例が先で理解しやすかったです。Qを使った一般化は「9.3 EMアルゴリズムのもう一つの解釈」で成されています。EMアルゴリズムはフレームワークみたいなもので応用によってEステップとMステップの更新式は違うのですね。他にもいろいろな具体例を見て、使いこなせるようになりたいなぁ。