はじめに
scikit-learnのv0.22で、混同行列をプロットするための便利関数であるsklearn.metrics.plot_confusion_matrixが追加されました。
使いやすそうなので試してみます。
使い方
リファレンスはこちらです。
sklearn.metrics.plot_confusion_matrix — scikit-learn 0.22 documentation
引数のフォーマットを見ると、
sklearn.metrics.plot_confusion_matrix(estimator, X, y_true, labels=None, sample_weight=None, normalize=None, display_labels=None, include_values=True, xticks_rotation='horizontal', values_format=None, cmap='viridis', ax=None)
あ、予測器とXとyを入れるタイプの関数だ。なんか微妙に使いづらいですね。この時点でなんか困惑気味ですが、やってみます。
import matplotlib.pyplot as plt from sklearn.datasets import load_wine from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression from sklearn.metrics import plot_confusion_matrix wine = load_wine() X_train, X_test, y_train, y_test = train_test_split( wine.data, wine.target, stratify=wine.target) clf = LogisticRegression() clf.fit(X_train, y_train) plot_confusion_matrix(clf, X_test, y_test, display_labels=wine.target_names, cmap=plt.cm.Blues,) plt.savefig("result.png")
ソースを見る限り、内部で交差検証などしてくれる訳ではないようなので、学習済みモデルとテストデータを渡してプロットさせます。また、labelsという引数がありますが微妙に罠っぽくて、表示に使われるラベルはdisplay_labelsの方です。
一応各引数の説明など。
- estimator, X, y_true
は、説明要らないよね。上で示したのと同じ使い方をします。
- labels
ラベルの順序を並び替えたり、一部のラベルのみ取り出してプロットしたいとき使うそうです。y_trueの中身が[0,0,0,1,1,1]なら[0,1]や[1,0]などが指定できます。別に要らないでしょう。
- sample_weightarray-like of shape (n_samples,), default=None
サンプルの重み。
- normalize{"true", "pred", "all"}, default=None
全体を正規化するかどうか。するならその方法を文字列で指定します。
- display_labels
表示されるラベルの名前はこちらで指定します。使用頻度は高いはずです。
- include_valuesbool, default=True
Falseに設定すると数字が出てこなくなります。普通は数字があったほうが好ましいでしょう。
- xticks_rotation{"vertical", "horizontal"}
x軸のラベルが回転するかどうか。デフォルトでは回転しません。
- values_format
"d"や".2f"などが指定できる。表示の書式で、format関数などに準じると思われる。
- cmap, ax
matplotlib関連です。デフォルトのcmapの"viridis"がexampleで「ダサいよねこれ」とBluesにされているあたり泣けます。
使いづらいのでConfusionMatrixDisplayを使うことにする
なんで混同行列を描くためだけにpredictメソッド走らせなきゃいかんのだと思ったので、仕様を確認します。すると、ConfusionMatrixDisplayなるクラスがあることがわかります。
It is recommend to use plot_confusion_matrix to create a ConfusionMatrixDisplay.
sklearn.metrics.ConfusionMatrixDisplay — scikit-learn 0.22 documentation
うるせえ、あるもんは使うんじゃ。
インスタンスを作ってplotメソッドを呼ぶと動きます。引数はだいたい上と共通ですが、インスタンスを作るときに
class sklearn.metrics.ConfusionMatrixDisplay(confusion_matrix, display_labels)
なので柔軟性が多少上がります。
import matplotlib.pyplot as plt from sklearn.datasets import load_wine from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay wine = load_wine() X_train, X_test, y_train, y_test = train_test_split( wine.data, wine.target, stratify=wine.target) clf = LogisticRegression() clf.fit(X_train, y_train) y_pred = clf.predict(X_test) cmx = confusion_matrix(y_test, y_pred) cmd = ConfusionMatrixDisplay(cmx, wine.target_names) cmd.plot() plt.savefig("result.png")
結果の図は同じなので省略。任意の予測ラベルで描画しようと思えばできます。ちょっと微妙な感じもしますが、許容範囲でしょう。
まとめ
このようなものができましたので、今後は混同行列のプロットではそんなに困らないと思います。
あとどうでもいい話、そもそもこのブログは混同行列の描き方がわからなくて調べてまとめたのが始まりですが、
なんとなく原点回帰した感があって感動しています。