はじめに
scipy.optimize.curve_fitを使うと曲線あてはめができます。いろいろな関数にフィッティングさせてみて、うまくいくかどうか試してみます。
f(x) = x + a
ただの足し算。
import numpy as np import matplotlib.pyplot as plt from scipy.optimize import curve_fit def func(x, a): return x + a def main(): n = 50 x = np.linspace(-10, 10, n) plt.figure() for i, a in enumerate([2, 5, 14]): y = func(x, a) + np.random.randn(n) params, cov = curve_fit(func, x, y) plt.scatter(x, y, c="rgb"[i]) plt.plot(x, func(x, *params), c="rgb"[i], label="a={:.4f}".format(*params)) plt.legend() plt.savefig("result1.png") if __name__ == "__main__": main()
問題なさそうです。
スポンサーリンク
f(x) = a * x
掛け算。
import numpy as np import matplotlib.pyplot as plt from scipy.optimize import curve_fit def func(x, a): return x * a def main(): n = 50 x = np.linspace(-10, 10, n) plt.figure() for i, a in enumerate([1, 1.5, 2]): y = func(x, a) + np.random.randn(n) params, cov = curve_fit(func, x, y) plt.scatter(x, y, c="rgb"[i]) plt.plot(x, func(x, *params), c="rgb"[i], label="a={:.4f}".format(*params)) plt.legend() plt.savefig("result2.png") if __name__ == "__main__": main()
これも問題なし。
f(x) = a*x + b
一次式です。これくらいはできないと困ります。
import numpy as np import matplotlib.pyplot as plt from scipy.optimize import curve_fit def func(x, a, b): return x * a + b def main(): n = 50 x = np.linspace(-10, 10, n) plt.figure() for i, pt in enumerate([[1, 2], [3, 4], [5,6]]): y = func(x, *pt) + np.random.randn(n) params, cov = curve_fit(func, x, y) plt.scatter(x, y, c="rgb"[i]) plt.plot(x, func(x, *params), c="rgb"[i], label="a={:.4f}, b={:.4f}".format(*params)) plt.legend() plt.savefig("result3.png") if __name__ == "__main__": main()
難なくこなしました。
4次多項式
次数の少し大きめの多項式でやってみます。
import numpy as np import matplotlib.pyplot as plt from scipy.optimize import curve_fit def func(x, a, b, c, d, e,): return a*x**4 + b*x**3 + c*x**2 + d*x + e def main(): n = 50 x = np.linspace(-10, 10, n) plt.figure() for i, pt in enumerate([[1,2,3,4,5], [2,3,4,5,6], [3,4,5,6,7]]): # 値のスケールが大きいのでノイズの大きさを調整 y = func(x, *pt) + 1000*np.random.randn(n) params, cov = curve_fit(func, x, y) plt.scatter(x, y, c="rgb"[i]) plt.plot(x, func(x, *params), c="rgb"[i], label=("a={:.3f} b={:.3f}" "c={:.3f} d={:.3f}" "e={:.3f}").format(*params)) plt.legend() plt.savefig("result4.png") if __name__ == "__main__": main()
当たり前ですが、全体の傾向は高次の係数に支配されるため、低次の係数ほど元の値と乖離します。それでもノイズが小さければ使えるんですが。
aのx乗
なんとなく難しそうな気がしました。
import numpy as np import matplotlib.pyplot as plt from scipy.optimize import curve_fit def func(x, a): return a**x def main(): n = 50 x = np.linspace(-10, 10, n) plt.figure() for i, pt in enumerate([[1], [1.2], [1.4]]): y = func(x, *pt) + np.random.randn(n) params, cov = curve_fit(func, x, y) plt.scatter(x, y, c="rgb"[i]) plt.plot(x, func(x, *params), c="rgb"[i], label=("a={:.3f}").format(*params)) plt.legend() plt.savefig("result5.png") if __name__ == "__main__": main()
意外となんとかなりました。けっこういい精度のように見えます。
xのa乗
こちらもやってみます。
import numpy as np import matplotlib.pyplot as plt from scipy.optimize import curve_fit def func(x, a): return x**a def main(): n = 50 x = np.linspace(-10, 10, n) plt.figure() for i, pt in enumerate([[1], [1.2], [1.4]]): y = func(x, *pt) + np.random.randn(n) params, cov = curve_fit(func, x, y) plt.scatter(x, y, c="rgb"[i]) plt.plot(x, func(x, *params), c="rgb"[i], label=("a={:.3f}").format(*params)) plt.legend() plt.savefig("result6.png") if __name__ == "__main__": main()
上のコードを実行しようとすると、エラーを吐きます。
***.py:6: RuntimeWarning: invalid value encountered in power return x**a /***/site-packages/scipy/optimize/minpack.py:715: OptimizeWarning: Covariance of the parameters could not be estimated category=OptimizeWarning) Traceback (most recent call last): File "***.py", line 22, in <module> main() File "***.py", line 14, in main params, cov = curve_fit(func, x, y) File "/***/site-packages/scipy/optimize/minpack.py", line 654, in curve_fit ydata = np.asarray_chkfinite(ydata) File "/***/site-packages/numpy/lib/function_base.py", line 1033, in asarray_chkfinite "array must not contain infs or NaNs") ValueError: array must not contain infs or NaNs
区間が駄目なのだろうと思って、0を含まない区間に変えます。
import numpy as np import matplotlib.pyplot as plt from scipy.optimize import curve_fit def func(x, a): return x**a def main(): n = 50 x = np.linspace(1, 10, n) plt.figure() for i, pt in enumerate([[1], [1.2], [1.4]]): y = func(x, *pt) + np.random.randn(n) params, cov = curve_fit(func, x, y) plt.scatter(x, y, c="rgb"[i]) plt.plot(x, func(x, *params), c="rgb"[i], label=("a={:.3f}").format(*params)) plt.legend() plt.savefig("result6.png") if __name__ == "__main__": main()
今度はできました。0を何乗しても0なので別に問題ない気がしますが、内部の計算に失敗するのかもしれません。もしかして微分して0になる系は厳しい?
対数の底
対数の底をパラメータで探ってみます。
import numpy as np import matplotlib.pyplot as plt from scipy.optimize import curve_fit def func(x, a): return np.log(x)/np.log(a) def main(): n = 50 x = np.linspace(1, 10, n) plt.figure() for i, pt in enumerate([[2], [3], [4]]): y = func(x, *pt) + 0.1*np.random.randn(n) params, cov = curve_fit(func, x, y) plt.scatter(x, y, c="rgb"[i]) plt.plot(x, func(x, *params), c="rgb"[i], label=("a={:.3f}").format(*params)) plt.legend() plt.savefig("result7.png") if __name__ == "__main__": main()
***.py:6: RuntimeWarning: divide by zero encountered in true_divide return np.log(x)/np.log(a) ***.py:6: RuntimeWarning: invalid value encountered in true_divide return np.log(x)/np.log(a) /***/site-packages/scipy/optimize/minpack.py:787: OptimizeWarning: Covariance of the parameters could not be estimated category=OptimizeWarning)
たぶん初期値が駄目なパターンなので、初期値として1.1を与えてみます。それ以外は変更なし。
params, cov = curve_fit(func, x, y, p0=[1.1])
問題なくフィッティングできました。
p0を指定しなかった場合、パラメータの初期値は0になります。関数によっては0から始まると困ったことになるので、適当な初期値を与えることが重要です。
シグモイド関数
割と複雑な関数。シグモイドで曲線あてはめしたいというのはそれなりに実用的なユースケースでしょう。
yにノイズを足すと0より小さかったり1より大きい値が出てきてしまうので、xをずらしました。
import numpy as np import matplotlib.pyplot as plt from scipy.optimize import curve_fit def func(x, a): return (np.tanh(a*x/2)+1)/2 def main(): n = 50 x = np.linspace(-10, 10, n) plt.figure() for i, pt in enumerate([[1], [10], [100]]): y = func(x + 0.2*np.random.randn(n), *pt) params, cov = curve_fit(func, x, y, p0=[1.1]) plt.scatter(x, y, c="rgb"[i]) plt.plot(x, func(x, *params), c="rgb"[i], label=("a={:.3f}").format(*params)) plt.legend() plt.savefig("result8.png") if __name__ == "__main__": main()
このまま実行すると、数回に一回くらいうまくフィットしないことがあります。
/*/lib/python3.5/site-packages/scipy/optimize/minpack.py:787: OptimizeWarning: Covariance of the parameters could not be estimated category=OptimizeWarning)
そのときの結果がこんな感じです。
データが悪い。初期値を真値に近づけると多少改善するかもしれません。
まだやっていないもの
- 各種確率密度関数など(正規分布、指数分布、多項分布あたり)
- シグモイド関数に似ているもの(ゴンペルツなど)
このあたりは現時点ではやっていませんが、基本的には同様にできると考えられるので、気が向いたら追記します。