静かなる名辞

pythonとプログラミングのこと

2019/03/22:TechAcademyがteratailの質問・回答を盗用していた件
2019/03/26:TechAcademy盗用事件 公式発表と深まる疑念


【python】ロジスティック回帰で確率値で学習させる

はじめに

 ロジスティック回帰は回帰という名前なのにほとんど二項判別に使われますが、たまに本当に回帰に使うときもあります。0.1とか0.4とか0.6のような目的変数を使ってモデルを作る、というケースです。

 ちょっとした目的で必要になるかもしれないと思ってやろうとしたら、意外と手間取ったのでメモしておきます。

データ

 たとえば「普及率」のようなデータに対してあてはめを行うとき、こういうケースが出てきます。

 こちらで紹介されている、日本のカラーテレビ普及率のデータを使います。

データ解析・マイニングとR言語

 説明変数が年、目的変数が普及率です。

 とりあえずこんな配列にしておきます。

import numpy as np

x = np.array([1966, 1967, 1968, 1969, 1970,
              1971, 1972, 1973, 1974, 1975,
              1976, 1977, 1978, 1979, 1980,
              1981, 1982, 1983, 1984]).reshape(-1, 1)

y = np.array([0.003, 0.016, 0.054, 0.139, 0.263,
              0.423, 0.611, 0.758, 0.859, 0.903,
              0.937, 0.954, 0.978, 0.978, 0.982,
              0.985, 0.989, 0.988, 0.992])

scikit-learnでは(たぶん)できない

 誰でもまっさきに思いつく方法は、sklearnのLogisticRegressionを使うことです。しかし、これは

Logistic Regression (aka logit, MaxEnt) classifier.

sklearn.linear_model.LogisticRegression — scikit-learn 0.21.2 documentation

 と書いてあるとおり、判別用のモデルです。ユーザガイドもひたすら判別の話をしているだけです。

1.1. Generalized Linear Models — scikit-learn 0.21.2 documentation

 まあでも、もしかしたらできるかもしれないので、やってみましょう。

import numpy as np
from sklearn.linear_model import LogisticRegression

x = np.array([1966, 1967, 1968, 1969, 1970,
              1971, 1972, 1973, 1974, 1975,
              1976, 1977, 1978, 1979, 1980,
              1981, 1982, 1983, 1984]).reshape(-1, 1)

y = np.array([0.003, 0.016, 0.054, 0.139, 0.263,
              0.423, 0.611, 0.758, 0.859, 0.903,
              0.937, 0.954, 0.978, 0.978, 0.982,
              0.985, 0.989, 0.988, 0.992])

lr = LogisticRegression()
lr.fit(x, y)  # => ValueError: Unknown label type: 'continuous'
lr.predict(x)

 しってた。

statsmodelsでやる

 仕方がないので、statsmodelsを使います。statsmodelsはPythonでR風のことをするためのライブラリなので、参考ページと同じことができるはずです。

 APIをぜんぜん把握していないので、qiitaの解説記事を見ながらやります。

Statsmodels でロジスティック回帰を行う際の注意点 - Qiita

import numpy as np
import statsmodels.api as sm
import matplotlib.pyplot as plt

x = np.array([1966, 1967, 1968, 1969, 1970,
              1971, 1972, 1973, 1974, 1975,
              1976, 1977, 1978, 1979, 1980,
              1981, 1982, 1983, 1984]).reshape(-1, 1)

y = np.array([0.003, 0.016, 0.054, 0.139, 0.263,
              0.423, 0.611, 0.758, 0.859, 0.903,
              0.937, 0.954, 0.978, 0.978, 0.982,
              0.985, 0.989, 0.988, 0.992])

x_c = sm.add_constant(x)  
# ↑interceptのためにやらないといけないらしい(えぇ…)
lr = sm.Logit(y, x_c)
lr_result = lr.fit()
print(lr_result.params)
print(lr_result.summary())

y_pred = lr.predict(lr_result.params)
plt.scatter(x.ravel(), y_pred, c="b", alpha=0.2)
plt.plot(x.ravel(), y_pred, c="b")
plt.savefig("result.png")

 y, xって書くのがきもいとか、どうでもいいところが気になります。

 これはこれでRに慣れてる人にはいいと思うのですが、scikit-learnライクなAPIも用意してくれていたらと思わなくはありません。

 結果

Optimization terminated successfully.
         Current function value: 0.180377
         Iterations 9
[-1.23730786e+03  6.27547565e-01]
                           Logit Regression Results                           
==============================================================================
Dep. Variable:                      y   No. Observations:                   19
Model:                          Logit   Df Residuals:                       17
Method:                           MLE   Df Model:                            1
Date:                Sun, 30 Jun 2019   Pseudo R-squ.:                  0.7003
Time:                        22:00:31   Log-Likelihood:                -3.4272
converged:                       True   LL-Null:                       -11.435
Covariance Type:            nonrobust   LLR p-value:                 6.284e-05
==============================================================================
                 coef    std err          z      P>|z|      [0.025      0.975]
------------------------------------------------------------------------------
const      -1237.3079    591.413     -2.092      0.036   -2396.456     -78.159
x1             0.6275      0.300      2.092      0.036       0.040       1.215
==============================================================================

結果
結果

 なんかよくわからないけど、一応できているみたいです。

まとめ

 いろいろ分析していくとstatsmodelsもけっきょく必要になるときがあるので、慣れた方が良いのかなぁと思ったりしました。