静かなる名辞

pythonとプログラミングのこと



【python】sklearnのRidgeとLassoを使ってみる

はじめに

 Rdige、Lassoといえば割と定番の正則化アルゴリズムです。

 特にLassoはスパースな解を得てくれるという触れ込みです。なんだかカッコいいので、昔から触ってみたいと思っていました。

実験

 このような関数fを考えます。

def f(x):
    return -3*x + 5*(x**2) + 6*(x**3) - (x**4)  + 5

 (pythonの関数として記述していますが、数学的な関数です)

 これを回帰してみましょう。

  • ただの線形回帰
  • Ridge
  • Lasso

 で実験します。関数を外挿して、どこまでまともな結果が得られるかそれぞれで確認します。

 回帰に使うモデルは10次多項式回帰です。単純にやったらぐねぐねになってまともな結果は得られないんですが、さてどうなるか。

 なお、各モデルのクラスはsklearn.linear_model以下にあります。また、多項式回帰はsklearn.preprocessing.PolynomialFeaturesとsklearn.pipeline.Pipelineを組み合わせて行えます(他の方法もあるかもだけど)。

 参考:

www.haya-programming.com



 以上のことを踏まえて、以下のようなプログラムを書きました。長く見えますが、同じような処理の繰り返しが多いだけで(要するにちゃんとループ等にしていないだけで)、処理としては単純に「データを作る」「回帰する」「予測する」「プロットする」という流れをやっているだけです。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline

def main():
    def f(x):
        return -3*x + 5*(x**2) + 6*(x**3) - (x**4)  + 5

    x = np.linspace(-10, 10, 50)
    y_true = f(x)

    x_test = np.linspace(-13, 13, 100)
    y_test = f(x_test)
    
    seed = np.random.RandomState(seed=0)
    y_samples = y_true + seed.randn(*x.shape)*500

    # least squares model
    poly = PolynomialFeatures(degree=10)
    lin = LinearRegression()
    lin_pl = Pipeline([("poly_10", poly), ("lin", lin)])
    lin_pl.fit(x.reshape(-1, 1), y_samples)
    y_lin = lin_pl.predict(x_test.reshape(-1, 1))
    
    # ridge model
    poly = PolynomialFeatures(degree=10)
    ridge = Ridge(alpha=10, max_iter=2000)
    ridge_pl = Pipeline([("poly_10", poly), ("ridge", ridge)])
    ridge_pl.fit(x.reshape(-1, 1), y_samples)
    y_ridge = ridge_pl.predict(x_test.reshape(-1, 1))

    # lasso model
    poly = PolynomialFeatures(degree=10)
    lasso = Lasso(alpha=10, max_iter=5000)
    lasso_pl = Pipeline([("poly_10", poly), ("lasso", lasso)])
    lasso_pl.fit(x.reshape(-1, 1), y_samples)
    y_lasso = lasso_pl.predict(x_test.reshape(-1, 1))

    print("最小二乗法")
    print(lin_pl.named_steps.lin.intercept_)
    print(lin_pl.named_steps.lin.coef_)

    print("Ridge")
    print(ridge_pl.named_steps.ridge.intercept_)
    print(ridge_pl.named_steps.ridge.coef_)

    print("Lasso")
    print(lasso_pl.named_steps.lasso.intercept_)
    print(lasso_pl.named_steps.lasso.coef_)

    plt.plot(x_test, y_test, label="本当の値")
    plt.plot(x_test, y_lin, label="予測値(最小二乗法)", linestyle="--")
    plt.plot(x_test, y_ridge, label="予測値(Ridge)", linestyle="--")
    plt.plot(x_test, y_lasso, label="予測値(Lasso)", linestyle="--")

    plt.scatter(x, y_samples, label="サンプル")
    plt.legend()
    plt.savefig("result.png")

if __name__ == "__main__":
    main()

結果

 とりあえず、切片と係数はこんな感じになります。

最小二乗法
44.889422533642346
[ 0.00000000e+00  1.04897038e+02  7.73241212e+00 -8.56765602e+00
 -9.18049844e-01  4.80313010e-01 -1.00932865e-02 -6.16744609e-03
  1.90299005e-04  2.67071826e-05 -9.88817760e-07]
Ridge
47.06589279466334
[ 0.00000000e+00  7.07372729e+01  6.90235252e+00 -5.10402716e+00
 -8.60138166e-01  3.71899115e-01 -1.15878477e-02 -4.83914018e-03
  2.06482860e-04  2.11304610e-05 -1.05109554e-06]
Lasso
45.1551327659522
[ 0.00000000e+00  7.31967685e+00  1.08654718e+01  3.66382021e+00
 -1.30143108e+00  4.95092775e-02  2.99149652e-03 -4.74590934e-04
  2.19344863e-05  1.51480557e-06 -2.59682534e-07]

 どれも微妙かも・・・ノイズをまぶしすぎたのと、データ量が少ないので(要するに条件を人為的に悪くしたので*1)こんな感じです。

 本来必要ない高次の係数を見ると、比較的Lassoが健闘しているかな?

 肝心のグラフは、

回帰結果のグラフ
回帰結果のグラフ

 こんな感じで、相対的にLassoがマシに見えます。

 試しに、RidgeとLassoの正則化係数alphaを100倍の1000にしてみた結果が以下のものです(どの程度のスケールにすれば良いのかよくわからないので適当にやっている)。

最小二乗法
44.889422533642346
[ 0.00000000e+00  1.04897038e+02  7.73241212e+00 -8.56765602e+00
 -9.18049844e-01  4.80313010e-01 -1.00932865e-02 -6.16744609e-03
  1.90299005e-04  2.67071826e-05 -9.88817760e-07]
Ridge
63.546256794998726
[ 0.00000000e+00  2.30805586e+00  5.71618427e-01  1.66536218e+00
 -4.16707117e-01  1.62932068e-01 -2.30602486e-02 -2.30145211e-03
  3.30923961e-04  1.05415628e-05 -1.53055663e-06]
Lasso
89.0329535932658
[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  2.79840777e+00
 -8.39251562e-01  9.44222102e-02 -3.28942545e-03 -1.10602103e-03
  4.33327616e-05  4.27414852e-06 -2.01546485e-07]

結果のグラフ2
結果のグラフ2

 それなりに改善しているように見えます。

まとめ

 これですべてを解決してくれるかというと微妙な感じですが、それなりに有効ではあるみたいです。

*1:そうしないとどれも似通った結果になる