least_squaresでフィッティング
はじめに
理解してないから間違ってるかもしれん。
least_squaresの使い方
least_squares(損失関数, x0=係数初期値, args=それ以外の引数)
2次関数のフィッティング
from scipy.optimize import least_squares import numpy as np import matplotlib.pyplot as plt def x2(x, a, b, c): return a*x**2 + b*x + c def loss(coeffs, x, y): return np.mean(abs(y - x2(x, *coeffs))) x = np.linspace(-10, 10, 20) y_true = x2(x, 1, 4, 6) y_noise = 10 * np.random.randn(*x.shape) + y_true y_noise[-5] += 100 * abs(np.random.randn()) y_noise[-8] += 100 * abs(np.random.randn()) plt.scatter(x, y_noise) opt = least_squares(loss_mae, x0=[0, 0, 0], args=(x, y_noise)) plt.plot(x, x2(x, *opt.x), label='mae') opt = least_squares(loss_mse, x0=[0, 0, 0], args=(x, y_noise)) plt.plot(x, x2(x, *opt.x), label='mse') plt.legend() plt.show()
MSEは外れ値に引っ張られやすい。