freedom-_-qの勉強履歴

メモ書きが主になるかと思います。勉強強制のために一日一記事目指してます。頭良くないので間違いが多々あるかと思います。

least_squaresでフィッティング

はじめに

理解してないから間違ってるかもしれん。

least_squaresの使い方

least_squares(損失関数, x0=係数初期値, args=それ以外の引数)

2次関数のフィッティング

from scipy.optimize import least_squares
import numpy as np
import matplotlib.pyplot as plt

def x2(x, a, b, c):
    return a*x**2 + b*x + c

def loss(coeffs, x, y):
    return np.mean(abs(y - x2(x, *coeffs)))

x = np.linspace(-10, 10, 20)
y_true = x2(x, 1, 4, 6)
y_noise = 10 * np.random.randn(*x.shape) + y_true
y_noise[-5] += 100 * abs(np.random.randn())
y_noise[-8] += 100 * abs(np.random.randn())

plt.scatter(x, y_noise)
opt = least_squares(loss_mae, x0=[0, 0, 0], args=(x, y_noise))
plt.plot(x, x2(x, *opt.x), label='mae')

opt = least_squares(loss_mse, x0=[0, 0, 0], args=(x, y_noise))
plt.plot(x, x2(x, *opt.x), label='mse')

plt.legend()
plt.show()

MSEは外れ値に引っ張られやすい。

f:id:freedom-_-q:20210623232552p:plain