静かなる名辞

pythonとプログラミングのこと



【python】MeanShiftのbandwidthを変えるとどうなるか実験してみた

 前回の記事ではMeanShiftクラスタリングを試してみた。

www.haya-programming.com

 このMeanShiftにはbandwidthというパラメータがあり、クラスタ数を決定する上で重要な役割を果たしているはずである。

 いまいち結果に納得がいかないというとき、bandwidthをいじって改善が見込めるのかどうか確認してみます。

プログラム

 例によってirisとwineで比較。簡単に書きました。

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from sklearn.datasets import load_iris, load_wine
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.decomposition import PCA

def process(dataset, name):
    origin_bandwidth = estimate_bandwidth(dataset.data)
    rates = np.logspace(np.log10(0.2), np.log10(5), 11)
    fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(24,18))

    PCA_X = PCA().fit_transform(dataset.data)
    for target in range(3):
        axes[0,0].scatter(PCA_X[dataset.target==target, 0],
                        PCA_X[dataset.target==target, 1],
                        c=cm.Paired(target/3))
    axes[0,0].set_title("original label", fontsize=28)

    for r, ax in zip(rates, axes.ravel()[1:]):
        ms = MeanShift(bandwidth=r*origin_bandwidth, n_jobs=-1)
        y = ms.fit_predict(dataset.data)
        n_cluster = ms.cluster_centers_.shape[0]
        for target in range(n_cluster):
            ax.scatter(PCA_X[y==target, 0],
                       PCA_X[y==target, 1],
                       c=cm.Paired(target/n_cluster))
        ax.set_title("r:{0:.3f} b:{1:.3f}".format(
            r, origin_bandwidth), fontsize=28)
    fig.savefig(name+".png")

def main():
    iris = load_iris()
    wine = load_wine()

    process(iris, "iris")
    process(wine, "wine")

if __name__ == "__main__":
    main()

 bandwidthをsklearn.cluster.estimate_bandwidthの推定値(デフォルトで用いられる値)の1/5倍から5倍まで変化させ、結果をプロットします。

結果

 プロットされた結果を示します。

 結果の図の見方は、まずタイトルが

  • b

 sklearn.cluster.estimate_bandwidthによる推定値

  • r

 かけた比率

 という風に対応しており、あとは便宜的に2次元上に主成分分析で写像した散布図が、クラスタごとに色分けされて出ています。一枚目が本来のクラスに基づく色分け、r=1の図が推定値による色分けです。

 まずiris。

iris.png
iris.png
 きれいに元通りになるrは今回見た中にはありませんでした。クラスタ数的にはr=0.525とr=0.725の間くらいで3クラスタになりそうですが、この図を見るとそれでうまく元通りまとまるかは疑問です。

 次にwine。

wine.png
wine.png
 こちらもうまく元通りにはならないようです。そもそもデータが悪いという話はあると思います。

結論

 確かにクラスタ数は変わるが、クラスタリングの良し悪しが改善するかはなんともいえない。データをスケーリングしたり、もっと色々頑張ると改善は見込めるかもしれません。