静かなる名辞

pythonとプログラミングのこと



【python】sklearnのtolってなんだ?

 公式ドキュメントをよく読む方なら、色々なモジュールに"tol"というオプションがあることに気づいていると思います。たとえばSVCだと、こんな風に書いてあります。他のモジュールも似たり寄ったりですが。

tol : float, optional (default=1e-3)

Tolerance for stopping criterion.

 出典:
sklearn.svm.SVC — scikit-learn 0.19.1 documentation

 よくわからないし、なんとなく重要そうじゃないし、デフォルトから変える必要もないでしょ、ということで無視されがちなパラメタですが、冷静に考えたら何なのかまったくわからない。実は重要だったりすると大変ですね。ということで、調べました。

 まずtoleranceという単語とcriterionという単語に対応する日本語がぱっと出てきません(そこからかよ)。仕方がないのでググると、「耐性」や「公差」「許容誤差」のような意味が出てきます。なんとなく雰囲気は伝わりました。

toleranceの意味・使い方 - 英和辞典 Weblio辞書

 ちなみに後者のcriterionは基準という意味でした。これは雰囲気でもなんでもなく、そのまんま基準です。

 よって、「Tolerance for stopping criterion.」を訳すと「打ち切るための許容誤差の基準」となり、なんとなくわかるようでわかりません。

 仕方ないので色々キーワードを変えて検索していると、CrossValidatedの質問を見つけました。

machine learning - What exactly is tol (tolerance) used as stopping criteria in sklearn models? - Cross Validated

tol will change depending on the objective function being minimized and the algorithm they use to find the minimum, and thus will depend on the model you are fitting. There is no universal tolerance to scikit .
超訳:目的関数とそれを最小化するアルゴリズムに依存するから、モデルによってちげーよ。sklearnには普遍的なtolなどない

 あっ、そうですか・・・。

 これで終わってしまうのも寂しいので、SVMでtolを変えて結果が変わるか実験してみます。

# coding: UTF-8

import time

from sklearn.datasets import load_digits
from sklearn.svm import SVC
from sklearn.model_selection import cross_validate, StratifiedKFold as SKF

def main():
    digits = load_digits()
    
    svm1 = SVC(C=5, gamma=0.001, tol=1)
    svm2 = SVC(C=5, gamma=0.001, tol=0.5)
    svm3 = SVC(C=5, gamma=0.001, tol=0.001)
    svm4 = SVC(C=5, gamma=0.001, tol=0.00001)

    skf = SKF(random_state=0)

    scoring = {"p": "precision_macro",
               "r": "recall_macro",
               "f":"f1_macro"}

    skf = SKF(n_splits=5, shuffle=True, random_state=0)

    for svm, tol in zip([svm1, svm2, svm3, svm4], [1, 0.5, 0.001, 0.00001]):
        t1 = time.time()
        scores = cross_validate(svm, digits.data, digits.target,
                                cv=skf, scoring=scoring)
        t2 = time.time()
        print("tol:{0:5} time:{1:8.3f} p:{2:.3f} r:{3:.3f} f:{4:.3f}".format(
            tol,t2-t1,
            scores["test_p"].mean(),
            scores["test_r"].mean(),
            scores["test_f"].mean()))

if __name__ == "__main__":
    main()

 結果は、

tol:    1 time:   0.661 p:0.990 r:0.989 f:0.989
tol:  0.5 time:   0.833 p:0.991 r:0.990 f:0.991
tol:0.001 time:   1.193 p:0.992 r:0.992 f:0.992
tol:1e-05 time:   1.226 p:0.992 r:0.992 f:0.992

 性能への影響はごく僅かですが、トータルの処理時間は若干(数10%程度)削れるようです。少しでも軽くしたい、というときは検討する(性能への悪影響が抑えられる範囲でtolを上げる)価値はありますね。