はじめに
random_stateを設定して「結果を固定したい」ことはよくありますが、「結果を変えたい」ってあんまりないですよね。いろいろな条件下で比較して検定するときくらいでしょうか。
それでも、変わるだろうなと思って変えたら変わらなくて困るというパターンがたまに発生します。
例を示します。
>>> import numpy as np >>> X = np.arange(10).reshape(-1, 1) >>> y = np.array([0]*5 + [1]*5) >>> X array([[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]) >>> y array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) >>> from sklearn.model_selection import StratifiedKFold >>> skf = StratifiedKFold(n_splits=5, random_state=0) >>> result_a = list(skf.split(X, y)) >>> from pprint import pprint >>> pprint(result_a) [(array([1, 2, 3, 4, 6, 7, 8, 9]), array([0, 5])), (array([0, 2, 3, 4, 5, 7, 8, 9]), array([1, 6])), (array([0, 1, 3, 4, 5, 6, 8, 9]), array([2, 7])), (array([0, 1, 2, 4, 5, 6, 7, 9]), array([3, 8])), (array([0, 1, 2, 3, 5, 6, 7, 8]), array([4, 9]))] >>> skf = StratifiedKFold(n_splits=5, random_state=1) # random_state変更 >>> result_b = list(skf.split(X, y)) >>> pprint(result_b) # 明らかに同じものである [(array([1, 2, 3, 4, 6, 7, 8, 9]), array([0, 5])), (array([0, 2, 3, 4, 5, 7, 8, 9]), array([1, 6])), (array([0, 1, 3, 4, 5, 6, 8, 9]), array([2, 7])), (array([0, 1, 2, 4, 5, 6, 7, 9]), array([3, 8])), (array([0, 1, 2, 3, 5, 6, 7, 8]), array([4, 9]))]
ね、なんででしょう。
原因と対処
実はKFoldやStratifiedKFoldには、shuffleというパラメータがあります。デフォルトはFalseです。そしてこのときにはrandom_stateはno effect(効果なし)です。
上の結果からお察しの通り、何もしなければ結果は元のyの順番通り出てきます。ドキュメントにも書いてあるんですが、random_state指定すればいけるんやな! と思うとハマりますよね。
sklearn.model_selection.StratifiedKFold — scikit-learn 0.21.2 documentation
このことを踏まえて、上のコードは以下のように直せば期待通りの結果が得られます。
>>> skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=0) >>> result_c = list(skf.split(X, y)) >>> skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=1) >>> result_d = list(skf.split(X, y)) >>> pprint(result_c) [(array([0, 1, 3, 4, 6, 7, 8, 9]), array([2, 5])), (array([1, 2, 3, 4, 5, 6, 8, 9]), array([0, 7])), (array([0, 2, 3, 4, 5, 7, 8, 9]), array([1, 6])), (array([0, 1, 2, 4, 5, 6, 7, 8]), array([3, 9])), (array([0, 1, 2, 3, 5, 6, 7, 9]), array([4, 8]))] >>> pprint(result_d) [(array([0, 1, 3, 4, 6, 7, 8, 9]), array([2, 5])), (array([0, 2, 3, 4, 5, 6, 8, 9]), array([1, 7])), (array([0, 1, 2, 3, 5, 6, 7, 8]), array([4, 9])), (array([1, 2, 3, 4, 5, 6, 7, 9]), array([0, 8])), (array([0, 1, 2, 4, 5, 7, 8, 9]), array([3, 6]))]
shuffle=Trueが必要ということですね。
まとめ
わかってしまえばしょうもない話なのですが、shuffle=Trueしないでrandom_stateを指定しているコードもたまにブログやqiitaの記事で見かけたりして、「意味あるんかいなそれ」と思ってしまったりします。一種の「おまじない」ですかね。
なお、機械学習の評価指標は交差検証の分割の仕方によってもぶれますので、0.003とかくらいの話をするときには色々な分割の仕方で試して検定するという手続きが本来は望ましいです。なかなかそこまでやるのは大変ですが。