前回の記事は複数のクラスタリング手法を説明しました。 GMM k-means++ Mini Batch K-Means Spectral Clustering

今回は、GMMのパラメーター設定を解説します。

GMMはGaussian mixture modelsの略称です。GMM は”ソフト クラスタリング” 方式と見なすことができます。ソフトクラスタリング方式では、1つの点に対して複数のクラスに所属する確率を出す事ができます。

sklearn.mixture のパッケージでGMMのパラメータは下記になります。

class sklearn.mixture.GaussianMixture(n_components=1, covariance_type=’full’, tol=0.001, reg_covar=1e-06, max_iter=100, n_init=1, init_params=’kmeans’, weights_init=None, means_init=None, precisions_init=None, random_state=None, warm_start=False, verbose=0, verbose_interval=10)

では、サンプルデータセットを作成します。

In [1]:
# Generate some data
from sklearn.datasets.samples_generator import make_blobs
import warnings
warnings.filterwarnings('ignore')

X, y_true = make_blobs(n_samples=400, 
                       centers=4,
                       cluster_std=0.60, 
                       random_state=0)
X = X[:, ::-1] # flip axes for better plotting

GMMモデルのパラメータ設定

In [2]:
from sklearn.mixture import GMM
gmm_cls = GMM(n_components=4, 
              covariance_type='full', 
              tol=0.001, 
              n_init=1, 
              init_params='kmeans', 
              random_state=111, 
              verbose=0)

モデル学習して、図を作成

In [3]:
import matplotlib.pyplot as plt
gmm = gmm_cls.fit(X)
labels = gmm.predict(X)
plt.title('n_components=4')
plt.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis');

n_components : int, defaults to 1

混合成分の数

In [4]:
gmm_cls = GMM(n_components=4, 
              covariance_type='full', 
              tol=0.001, 
              n_init=1, 
              init_params='kmeans', 
              random_state=111, 
              verbose=0)
gmm = gmm_cls.fit(X)
labels = gmm.predict(X)

gmm_cls2 = GMM(n_components=3, 
              covariance_type='full', 
              tol=0.001, 
              n_init=1, 
              init_params='kmeans', 
              random_state=111, 
              verbose=0)
gmm2 = gmm_cls2.fit(X)
labels2 = gmm2.predict(X)

fig = plt.figure(figsize=(10,10),dpi=70)

plt.subplot(221)
plt.title('n_components=4')
plt.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis');

plt.subplot(222)
plt.title('n_components=3')
plt.scatter(X[:, 0], X[:, 1], c=labels2, s=40, cmap='viridis');

plt.show()
In [5]:
gmm_cls = GMM(n_components=4, 
              covariance_type='full', 
              tol=0.001, 
              n_init=1, 
              init_params='kmeans', 
              random_state=111, 
              verbose=0)
gmm = gmm_cls.fit(X)
labels = gmm.predict(X)

gmm_cls2 = GMM(n_components=3, 
              covariance_type='full', 
              tol=0.001, 
              n_init=1, 
              init_params='kmeans', 
              random_state=111, 
              verbose=0)
gmm2 = gmm_cls2.fit(X)
labels2 = gmm2.predict(X)

fig = plt.figure(figsize=(10,10),dpi=70)

plt.subplot(221)
plt.title('n_components=4')
plt.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis');

plt.subplot(222)
plt.title('n_components=3')
plt.scatter(X[:, 0], X[:, 1], c=labels2, s=40, cmap='viridis');

plt.show()

covariance_type : {‘full’ (default), ‘tied’, ‘diag’, ‘spherical’}

  • ‘full’:各構成要素には独自の一般共分散行列があります。
  • ‘spherical’:各構成要素には独自の分散があります。
  • ‘tied’:すべての構成要素が同じ一般的な共分散行列を共有します。
  • ‘diag’:各構成要素には独自の対角共分散行列があります。
In [6]:
gmm_cls = GMM(n_components=4, 
              covariance_type='full', 
              tol=0.001, 
              n_init=1, 
              init_params='kmeans', 
              random_state=111, 
              verbose=0)
gmm = gmm_cls.fit(X)
labels = gmm.predict(X)

gmm_cls2 = GMM(n_components=4, 
              covariance_type='tied', 
              tol=0.001, 
              n_init=1, 
              init_params='kmeans', 
              random_state=111, 
              verbose=0)
gmm2 = gmm_cls2.fit(X)
labels2 = gmm2.predict(X)

gmm_cls3 = GMM(n_components=4, 
              covariance_type='diag', 
              tol=0.001, 
              n_init=1, 
              init_params='kmeans', 
              random_state=111, 
              verbose=0)
gmm3 = gmm_cls3.fit(X)
labels3 = gmm3.predict(X)

gmm_cls4 = GMM(n_components=4, 
              covariance_type='spherical', 
              tol=0.001, 
              n_init=1, 
              init_params='kmeans', 
              random_state=111, 
              verbose=0)
gmm4 = gmm_cls4.fit(X)
labels4 = gmm4.predict(X)

fig, ax = plt.subplots(2, 2, figsize=(10,10),dpi=70)

plt.subplot(221)
plt.title('covariance_type=full')
plt.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis');

plt.subplot(222)
plt.title('covariance_type=tied')
plt.scatter(X[:, 0], X[:, 1], c=labels2, s=40, cmap='viridis');

plt.subplot(223)
plt.title('covariance_type=diag')
plt.scatter(X[:, 0], X[:, 1], c=labels3, s=40, cmap='viridis');

plt.subplot(224)
plt.title('covariance_type=spherical')
plt.scatter(X[:, 0], X[:, 1], c=labels4, s=40, cmap='viridis');

plt.show()

tol : float, defaults to 1e-3.

収束しきい値

下限平均利得がこのしきい値を下回ると、EMインターネットは停止します。

In [7]:
gmm_cls = GMM(n_components=4, 
              covariance_type='full', 
              tol=0.001, 
              n_init=1, 
              init_params='kmeans', 
              random_state=111, 
              verbose=0)
gmm = gmm_cls.fit(X)
labels = gmm.predict(X)

gmm_cls2 = GMM(n_components=4, 
              covariance_type='full', 
              tol=0.1, 
              n_init=1, 
              init_params='kmeans', 
              random_state=111, 
              verbose=0)
gmm2 = gmm_cls2.fit(X)
labels2 = gmm2.predict(X)

fig = plt.figure(figsize=(10,10),dpi=70)

plt.subplot(221)
plt.title('tol=0.001')
plt.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis');

plt.subplot(222)
plt.title('tol=0.1')
plt.scatter(X[:, 0], X[:, 1], c=labels2, s=40, cmap='viridis');

plt.show()

n_init : int, defaults to 1.

実行する初期化の数 最良の結果が保たれます。

In [8]:
gmm_cls = GMM(n_components=4, 
              covariance_type='full', 
              tol=0.001, 
              n_init=1, 
              init_params='kmeans', 
              random_state=111, 
              verbose=0)
gmm = gmm_cls.fit(X)
labels = gmm.predict(X)

gmm_cls2 = GMM(n_components=4, 
              covariance_type='full', 
              tol=0.001, 
              n_init=10, 
              init_params='kmeans', 
              random_state=111, 
              verbose=0)
gmm2 = gmm_cls2.fit(X)
labels2 = gmm2.predict(X)

fig = plt.figure(figsize=(10,10),dpi=70)

plt.subplot(221)
plt.title('n_init=1')
plt.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis');

plt.subplot(222)
plt.title('n_init=10')
plt.scatter(X[:, 0], X[:, 1], c=labels2, s=40, cmap='viridis');

plt.show()

init_params : {‘kmeans’, ‘random’}, defaults to ‘kmeans’.

重み、平均値および精度を初期化するために使用される方法

  • 'kmeans':責任はkmeansを使って初期化されます。

  • 'random':責任はランダムに初期化されます。

In [9]:
gmm_cls = GMM(n_components=4, 
              covariance_type='full', 
              tol=0.001, 
              n_init=1, 
              init_params='kmeans', 
              random_state=111, 
              verbose=0)
gmm = gmm_cls.fit(X)
labels = gmm.predict(X)

gmm_cls2 = GMM(n_components=4, 
              covariance_type='full', 
              tol=0.001, 
              n_init=1, 
              init_params='random', 
              random_state=111, 
              verbose=0)
gmm2 = gmm_cls2.fit(X)
labels2 = gmm2.predict(X)

fig = plt.figure(figsize=(10,10),dpi=70)

plt.subplot(221)
plt.title('init_params=kmeans')
plt.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis');

plt.subplot(222)
plt.title('init_params=random')
plt.scatter(X[:, 0], X[:, 1], c=labels2, s=40, cmap='viridis');

plt.show()

verbose : int, default to 0

verbose 出力を有効にします。 1の場合、現在の初期化と各イテレーションステップを表示します。 1より大きい場合は、対数確率と各ステップに必要な時間も出力されます。

In [11]:
gmm_cls = GMM(n_components=4, 
              covariance_type='full', 
              tol=0.001, 
              n_init=1, 
              init_params='kmeans', 
              random_state=111, 
              verbose=1)
gmm = gmm_cls.fit(X)
labels = gmm.predict(X)

gmm = gmm_cls.fit(X)
labels = gmm.predict(X)
plt.title('verbose=1')
plt.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis');
Expectation-maximization algorithm started.
Initialization 1
	EM iteration 1
	EM iteration 2
	EM iteration 3
	EM iteration 4
	EM iteration 5
	EM iteration 6
	EM iteration 7
	EM iteration 8
	EM iteration 9
	EM iteration 10
	EM iteration 11
	EM iteration 12
		EM algorithm converged.
Expectation-maximization algorithm started.
Initialization 1
	EM iteration 1
	EM iteration 2
		EM algorithm converged.