静かなる名辞

pythonとプログラミングのこと

scipy.optimize.curve_fitを使っていろいろな関数にフィットさせてみる

はじめに

 scipy.optimize.curve_fitを使うと曲線あてはめができます。いろいろな関数にフィッティングさせてみて、うまくいくかどうか試してみます。

scipy.optimize.curve_fit — SciPy v1.2.1 Reference Guide

f(x) = x + a

 ただの足し算。

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def func(x, a):
    return x + a

def main():
    n = 50
    x = np.linspace(-10, 10, n)
    plt.figure()
    for i, a in enumerate([2, 5, 14]):
        y = func(x, a) + np.random.randn(n)
        params, cov = curve_fit(func, x, y)
        plt.scatter(x, y, c="rgb"[i])
        plt.plot(x, func(x, *params), c="rgb"[i],
                 label="a={:.4f}".format(*params))
    plt.legend()
    plt.savefig("result1.png")

if __name__ == "__main__":
    main()

result1.png
result1.png

 問題なさそうです。

f(x) = a * x

 掛け算。

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def func(x, a):
    return x * a

def main():
    n = 50
    x = np.linspace(-10, 10, n)
    plt.figure()
    for i, a in enumerate([1, 1.5, 2]):
        y = func(x, a) + np.random.randn(n)
        params, cov = curve_fit(func, x, y)
        plt.scatter(x, y, c="rgb"[i])
        plt.plot(x, func(x, *params), c="rgb"[i],
                 label="a={:.4f}".format(*params))
    plt.legend()
    plt.savefig("result2.png")

if __name__ == "__main__":
    main()

result2.png
result2.png

 これも問題なし。

f(x) = a*x + b

 一次式です。これくらいはできないと困ります。

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def func(x, a, b):
    return x * a + b

def main():
    n = 50
    x = np.linspace(-10, 10, n)
    plt.figure()
    for i, pt in enumerate([[1, 2], [3, 4], [5,6]]):
        y = func(x, *pt) + np.random.randn(n)
        params, cov = curve_fit(func, x, y)
        plt.scatter(x, y, c="rgb"[i])
        plt.plot(x, func(x, *params), c="rgb"[i],
                 label="a={:.4f}, b={:.4f}".format(*params))
    plt.legend()
    plt.savefig("result3.png")

if __name__ == "__main__":
    main()

result3.png
result3.png

 難なくこなしました。

4次多項式

 次数の少し大きめの多項式でやってみます。

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def func(x, a, b, c, d, e,):
    return a*x**4 + b*x**3 + c*x**2 + d*x + e

def main():
    n = 50
    x = np.linspace(-10, 10, n)
    plt.figure()
    for i, pt in enumerate([[1,2,3,4,5],
                            [2,3,4,5,6],
                            [3,4,5,6,7]]):
        # 値のスケールが大きいのでノイズの大きさを調整
        y = func(x, *pt) + 1000*np.random.randn(n)
        params, cov = curve_fit(func, x, y)
        plt.scatter(x, y, c="rgb"[i])
        plt.plot(x, func(x, *params), c="rgb"[i],
                 label=("a={:.3f} b={:.3f}"
                        "c={:.3f} d={:.3f}"
                        "e={:.3f}").format(*params))
    plt.legend()
    plt.savefig("result4.png")

if __name__ == "__main__":
    main()

result4.png
result4.png

 当たり前ですが、全体の傾向は高次の係数に支配されるため、低次の係数ほど元の値と乖離します。それでもノイズが小さければ使えるんですが。

aのx乗

 なんとなく難しそうな気がしました。

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def func(x, a):
    return a**x

def main():
    n = 50
    x = np.linspace(-10, 10, n)
    plt.figure()
    for i, pt in enumerate([[1], [1.2], [1.4]]):
        y = func(x, *pt) + np.random.randn(n)
        params, cov = curve_fit(func, x, y)
        plt.scatter(x, y, c="rgb"[i])
        plt.plot(x, func(x, *params), c="rgb"[i],
                 label=("a={:.3f}").format(*params))
    plt.legend()
    plt.savefig("result5.png")

if __name__ == "__main__":
    main()

result5.png
result5.png

 意外となんとかなりました。けっこういい精度のように見えます。

xのa乗

 こちらもやってみます。

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def func(x, a):
    return x**a

def main():
    n = 50
    x = np.linspace(-10, 10, n)
    plt.figure()
    for i, pt in enumerate([[1], [1.2], [1.4]]):
        y = func(x, *pt) + np.random.randn(n)
        params, cov = curve_fit(func, x, y)
        plt.scatter(x, y, c="rgb"[i])
        plt.plot(x, func(x, *params), c="rgb"[i],
                 label=("a={:.3f}").format(*params))
    plt.legend()
    plt.savefig("result6.png")

if __name__ == "__main__":
    main()

 上のコードを実行しようとすると、エラーを吐きます。

***.py:6: RuntimeWarning: invalid value encountered in power
  return x**a
/***/site-packages/scipy/optimize/minpack.py:715: OptimizeWarning: Covariance of the parameters could not be estimated
  category=OptimizeWarning)
Traceback (most recent call last):
  File "***.py", line 22, in <module>
    main()
  File "***.py", line 14, in main
    params, cov = curve_fit(func, x, y)
  File "/***/site-packages/scipy/optimize/minpack.py", line 654, in curve_fit
    ydata = np.asarray_chkfinite(ydata)
  File "/***/site-packages/numpy/lib/function_base.py", line 1033, in asarray_chkfinite
    "array must not contain infs or NaNs")
ValueError: array must not contain infs or NaNs

 区間が駄目なのだろうと思って、0を含まない区間に変えます。

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def func(x, a):
    return x**a

def main():
    n = 50
    x = np.linspace(1, 10, n)
    plt.figure()
    for i, pt in enumerate([[1], [1.2], [1.4]]):
        y = func(x, *pt) + np.random.randn(n)
        params, cov = curve_fit(func, x, y)
        plt.scatter(x, y, c="rgb"[i])
        plt.plot(x, func(x, *params), c="rgb"[i],
                 label=("a={:.3f}").format(*params))
    plt.legend()
    plt.savefig("result6.png")

if __name__ == "__main__":
    main()

result6.png
result6.png

 今度はできました。0を何乗しても0なので別に問題ない気がしますが、内部の計算に失敗するのかもしれません。もしかして微分して0になる系は厳しい?

対数の底

 対数の底をパラメータで探ってみます。

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def func(x, a):
    return np.log(x)/np.log(a)

def main():
    n = 50
    x = np.linspace(1, 10, n)
    plt.figure()
    for i, pt in enumerate([[2], [3], [4]]):
        y = func(x, *pt) + 0.1*np.random.randn(n)
        params, cov = curve_fit(func, x, y)
        plt.scatter(x, y, c="rgb"[i])
        plt.plot(x, func(x, *params), c="rgb"[i],
                 label=("a={:.3f}").format(*params))
    plt.legend()
    plt.savefig("result7.png")

if __name__ == "__main__":
    main()
***.py:6: RuntimeWarning: divide by zero encountered in true_divide
  return np.log(x)/np.log(a)
***.py:6: RuntimeWarning: invalid value encountered in true_divide
  return np.log(x)/np.log(a)
/***/site-packages/scipy/optimize/minpack.py:787: OptimizeWarning: Covariance of the parameters could not be estimated
  category=OptimizeWarning)

 たぶん初期値が駄目なパターンなので、初期値として1.1を与えてみます。それ以外は変更なし。

        params, cov = curve_fit(func, x, y, p0=[1.1])

result7.png
result7.png

 問題なくフィッティングできました。

 p0を指定しなかった場合、パラメータの初期値は0になります。関数によっては0から始まると困ったことになるので、適当な初期値を与えることが重要です。

シグモイド関数

 割と複雑な関数。シグモイドで曲線あてはめしたいというのはそれなりに実用的なユースケースでしょう。

 yにノイズを足すと0より小さかったり1より大きい値が出てきてしまうので、xをずらしました。

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def func(x, a):
    return (np.tanh(a*x/2)+1)/2

def main():
    n = 50
    x = np.linspace(-10, 10, n)
    plt.figure()
    for i, pt in enumerate([[1], [10], [100]]):
        y = func(x + 0.2*np.random.randn(n), *pt)
        params, cov = curve_fit(func, x, y, p0=[1.1])
        plt.scatter(x, y, c="rgb"[i])
        plt.plot(x, func(x, *params), c="rgb"[i],
                 label=("a={:.3f}").format(*params))
    plt.legend()
    plt.savefig("result8.png")

if __name__ == "__main__":
    main()

 このまま実行すると、数回に一回くらいうまくフィットしないことがあります。

/*/lib/python3.5/site-packages/scipy/optimize/minpack.py:787: OptimizeWarning: Covariance of the parameters could not be estimated
  category=OptimizeWarning)

 そのときの結果がこんな感じです。

result8.png
result8.png

 データが悪い。初期値を真値に近づけると多少改善するかもしれません。

まだやっていないもの

  • 各種確率密度関数など(正規分布、指数分布、多項分布あたり)
  • シグモイド関数に似ているもの(ゴンペルツなど)

 このあたりは現時点ではやっていませんが、基本的には同様にできると考えられるので、気が向いたら追記します。