はじめに
train_test_splitはsklearnをはじめて学んだ頃からよくお世話になっています。しかし、stratifyを指定しないとまずいことが起こり得ると最近気づきました。
stratifyって何?
層化という言葉を聞いたことがある方が一定数いると思いますが、それです。あるいは、交差検証でStratifiedKFoldを使ったことのある人もだいたい理解しているでしょう。
要するに、クラスラベル(など)ごとにサンプルを取ってくるということを意味します。2クラス分類、100サンプルで元の各クラスの比率が50,50であれば、10件取り出しても5,5ずつになることが保証される、というのがここでいうstratify(層化)の意味です。
これを指定しないと、けっこういい加減なことになります。
指定しないで試してみる
まずdigitsでやります。こんな感じ。
from sklearn.datasets import load_digits from sklearn.model_selection import train_test_split from sklearn.svm import SVC from sklearn.metrics import classification_report def main(): dataset = load_digits() X_train, X_test, y_train, y_test\ = train_test_split(dataset.data, dataset.target) svm = SVC(gamma="scale") svm.fit(X_train, y_train) prediction = svm.predict(X_test) print(classification_report(y_test, prediction)) if __name__ == "__main__": main()
結果
precision recall f1-score support 0 1.00 1.00 1.00 58 1 0.94 1.00 0.97 49 2 1.00 1.00 1.00 46 3 1.00 1.00 1.00 51 4 1.00 0.98 0.99 42 5 1.00 0.98 0.99 42 6 0.98 1.00 0.99 43 7 1.00 1.00 1.00 46 8 0.98 0.93 0.96 46 9 1.00 1.00 1.00 27 accuracy 0.99 450 macro avg 0.99 0.99 0.99 450 weighted avg 0.99 0.99 0.99 450
supportに注目してください。0は58, 9は27で、倍以上ばらついています。あまり良い評価とは言えないということです。
stratifyを指定します。
X_train, X_test, y_train, y_test\
= train_test_split(dataset.data, dataset.target,
stratify=dataset.target)
結果
precision recall f1-score support 0 1.00 0.98 0.99 45 1 0.96 1.00 0.98 46 2 1.00 1.00 1.00 44 3 1.00 1.00 1.00 46 4 0.98 0.96 0.97 45 5 1.00 1.00 1.00 46 6 1.00 1.00 1.00 45 7 1.00 0.98 0.99 45 8 0.98 0.95 0.96 43 9 0.96 1.00 0.98 45 accuracy 0.99 450 macro avg 0.99 0.99 0.99 450 weighted avg 0.99 0.99 0.99 450
若干のぶれがありますが、概ね数がそろっています。これによってより妥当な評価が可能になると思われます。
次に不均衡データの場合について考えてみます。同じデータセットで、5かどうかを判別するとしましょう。
from sklearn.datasets import load_digits from sklearn.model_selection import train_test_split from sklearn.svm import SVC from sklearn.metrics import classification_report def main(): dataset = load_digits() y = dataset.target == 5 X_train, X_test, y_train, y_test\ = train_test_split(dataset.data, y) svm = SVC(gamma="scale") svm.fit(X_train, y_train) prediction = svm.predict(X_test) print(classification_report( y_test, prediction)) if __name__ == "__main__": main()
結果(何回かやった中で偏りがひどかった奴です)
precision recall f1-score support False 1.00 1.00 1.00 415 True 1.00 0.97 0.99 35 accuracy 1.00 450 macro avg 1.00 0.99 0.99 450 weighted avg 1.00 1.00 1.00 450
stratifyを指定します。
X_train, X_test, y_train, y_test\
= train_test_split(dataset.data, y, stratify=y)
結果
precision recall f1-score support False 1.00 1.00 1.00 404 True 0.98 0.96 0.97 46 accuracy 0.99 450 macro avg 0.99 0.98 0.98 450 weighted avg 0.99 0.99 0.99 450
この場合は、何回やってもデータ個数の比率は保証されます。
まとめ
ということで、stratifyを指定した方が良いでしょう。
testに回すサンプル数が少ないときは特に相対的にクラス間の比率のばらつきが大きくなるので、やっておくべきだと思います。testに回すサンプル数が多い場合はまだやらなくてもなんとかなるかもしれませんが、それでもやった方が結果が安定するはずです。