はじめに
RandomForestではOOBエラー(Out-of-bag error、OOB estimate、OOB誤り率)を見ることができます。交差検証と同様に汎化性能を見れます。
原理の説明とかは他に譲るのですが、これはちゃんと交差検証のように使えるのでしょうか? もちろん原理的には使えるのでしょうが、実際どうなるのかはやってみないとわかりません。
もしかしたらもう他の人がやっているかもしれませんが*1、自分でやった方が納得感があります*2。
ということで、やってみました。
みたいこと
とりあえずトイデータでやってみて、交差検証の場合とスコアを比べる。交差検証は分割のkを変えて様子を見る必要があるでしょう。
また、モデルの性能がよくなったり悪くなったりしたとき、交差検証と同様のスコアの変化が見れるかも確認してみる必要がありそうです。
プログラム
こんなプログラムを書きました。
import time import numpy as np from sklearn.datasets import load_iris, load_digits, load_wine from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import StratifiedKFold from sklearn.metrics import precision_recall_fscore_support as prf def test_func(dataset): rfc = RandomForestClassifier(n_estimators=300, oob_score=True, n_jobs=-1) t1 = time.time() rfc.fit(dataset.data, dataset.target) t2 = time.time() oob_pred = rfc.oob_decision_function_.argmax(axis=1) print("{0:6} p:{2:.4f} r:{3:.4f} f1:{4:.4f} time:{1:.4f}".format( "oob", t2-t1, *prf(dataset.target, oob_pred, average="macro"))) rfc = RandomForestClassifier(n_estimators=300, oob_score=False, n_jobs=-1) for k in [2,4,6,8]: skf = StratifiedKFold(n_splits=k) trues = [] preds = [] t1 = time.time() for train_idx, test_idx in skf.split(dataset.data, dataset.target): rfc.fit(dataset.data[train_idx], dataset.target[train_idx]) trues.append(dataset.target[test_idx]) preds.append(rfc.predict(dataset.data[test_idx])) t2 = time.time() print("{0:6} p:{2:.4f} r:{3:.4f} f1:{4:.4f} time:{1:.4f}".format( "CV k={}".format(k), t2-t1, *prf(np.hstack(trues), np.hstack(preds), average="macro"))) def main(): iris = load_iris() digits = load_digits() wine = load_wine() print("iris") test_func(iris) print("\ndigits") test_func(digits) print("\nwine") test_func(wine) print("\niris + noise") iris.data += np.random.randn(*iris.data.shape)*iris.data.std() test_func(iris) print("\ndigits + noise") digits.data += np.random.randn(*digits.data.shape)*digits.data.std() test_func(digits) print("\nwine + noise") wine.data += np.random.randn(*wine.data.shape)*wine.data.std() test_func(wine) if __name__ == "__main__": main()
注目ポイント。
- iris, digits, wineでためしました
- rfc.oob_decision_function_.argmax(axis=1)でOOBで推定されたラベルが得られるので、それを使って精度、再現率、F1値を計算しています(マクロ平均)。交差検証でも同様に計算することで、同じ指標で比較を可能にしています(accuracyだけだと寂しいので・・・)
- 処理の所要時間も測った
- 特徴量にノイズを付与して分類させることで、条件が悪いときのスコアも確認。ノイズは特徴量全体の標準偏差くらいの正規分布を付与しました。理論的な根拠は特にないです(だいたい軸によってスケールが違うのを無視しているのだし・・・)
だいたいこんな感じで、あとは普通にやってます*3。
結果
テキスト出力をそのまんま。
iris oob p:0.9534 r:0.9533 f1:0.9533 time:0.5774 CV k=2 p:0.9534 r:0.9533 f1:0.9533 time:1.3196 CV k=4 p:0.9600 r:0.9600 f1:0.9600 time:2.7399 CV k=6 p:0.9600 r:0.9600 f1:0.9600 time:4.1367 CV k=8 p:0.9600 r:0.9600 f1:0.9600 time:5.9125 digits oob p:0.9795 r:0.9794 f1:0.9794 time:0.9511 CV k=2 p:0.9282 r:0.9271 f1:0.9272 time:1.9101 CV k=4 p:0.9429 r:0.9422 f1:0.9420 time:4.0144 CV k=6 p:0.9519 r:0.9515 f1:0.9515 time:5.9703 CV k=8 p:0.9500 r:0.9494 f1:0.9494 time:7.8814 wine oob p:0.9762 r:0.9803 f1:0.9780 time:0.6950 CV k=2 p:0.9748 r:0.9812 f1:0.9774 time:1.2872 CV k=4 p:0.9714 r:0.9746 f1:0.9728 time:3.2011 CV k=6 p:0.9603 r:0.9643 f1:0.9619 time:4.4640 CV k=8 p:0.9714 r:0.9746 f1:0.9728 time:6.1454 iris + noise oob p:0.4723 r:0.4800 f1:0.4759 time:0.6677 CV k=2 p:0.5298 r:0.5267 f1:0.5279 time:1.3729 CV k=4 p:0.5456 r:0.5533 f1:0.5489 time:3.1605 CV k=6 p:0.4776 r:0.4800 f1:0.4788 time:4.1999 CV k=8 p:0.5043 r:0.5000 f1:0.5009 time:6.0875 digits + noise oob p:0.7850 r:0.7853 f1:0.7834 time:1.4698 CV k=2 p:0.7365 r:0.7377 f1:0.7342 time:2.8328 CV k=4 p:0.7567 r:0.7568 f1:0.7543 time:5.4918 CV k=6 p:0.7682 r:0.7697 f1:0.7671 time:9.1825 CV k=8 p:0.7717 r:0.7714 f1:0.7692 time:12.1008 wine + noise oob p:0.5377 r:0.5483 f1:0.5348 time:0.7344 CV k=2 p:0.4911 r:0.5034 f1:0.4873 time:1.4287 CV k=4 p:0.4828 r:0.5058 f1:0.4855 time:3.1024 CV k=6 p:0.5375 r:0.5439 f1:0.5320 time:4.3943 CV k=8 p:0.4778 r:0.5129 f1:0.4840 time:5.9890
これからわかることとしては、
- 全体的にOOBエラーとCVで求めたスコアはそこそこ近いので、OOBエラーはそこそこ信頼できると思います
- OOBは短い時間で済むので、お得です
- CVの場合はkを大きくすると性能が上がりますが、これは学習に使うデータ量がk=2なら全体の1/2、k=4なら3/4、k=8なら7/8という風に増加していくからです
- OOBエラーがCVのスコアを上回る場合、下回る場合ともにあるようです。OOBエラーは、学習しているデータ量はほぼleave one outに近いものの、木の本数が設定値の約1/3くらいになるという性質があります。学習データ量の有効性が高いデータセットではCVの場合より高いスコアに、木の本数の有効性が高いデータセットではCVの場合に対して低いスコアになるということでしょう
まあ、とりあえず妥当に評価できるんじゃねえの? という感じがします。
もちろん同じ条件下で計測したスコアではないので、OOBエラーと交差検証の結果を直接比較することはできませんが、OOBエラー同士の優劣で性能を見積もる分にはたぶん問題ないでしょう*4。
注意点としては、OOBエラーは全体の木の約1/3(厳密には36%くらい)を使って予測するので、実際の結果よりは悪めに出る可能性があります。木の本数を多めにおごってやると良いでしょう。
結論
OOBでもいい。