freedom-_-qの勉強履歴

メモ書きが主になるかと思います。勉強強制のために一日一記事目指してます。頭良くないので間違いが多々あるかと思います。

scipy.linalg.invを使ってみる

引数一覧

引数 意味
a array_like 逆行列を計算したい行列を指定する。
overwrite_a=False bool 引数aを上書きするかどうか。巨大な行列ではTrueにすればパフォーマンスが改善される。
check_finite=True bool aに有限数のみが含まれていることを確認するかどうか。infかnanでないことが分かっている場合Falseにするとパフォーマンス向上する場合がある

適当に5*5の行列作って逆行列を求める

import numpy as np
from scipy import linalg

arr = np.random.randint(1, 10, size=(5, 5))
#array([[1, 8, 7, 1, 1],
#       [9, 5, 3, 3, 9],
#       [1, 8, 9, 6, 1],
#       [1, 6, 5, 7, 2],
#       [5, 5, 8, 2, 5]])

np.round(arr @ linalg.inv(arr), 10)
#array([[ 1.,  0.,  0.,  0., -0.],
#       [ 0.,  1., -0.,  0., -0.],
#       [ 0.,  0.,  1.,  0., -0.],
#       [ 0.,  0.,  0.,  1.,  0.],
#       [-0.,  0.,  0.,  0.,  1.]])

巨大行列の逆行列を求める

コードとしてはsize=が変わっただけ。 IPythonから%%timeitを利用して実行時間を見てみる。まずデフォルトで。

arr = np.random.randint(1, 10, size=(5000, 5000))

# default
%%timeit
linalg.inv(arr, overwrite_a=False, check_finite=True)
#2.14 s ± 104 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

十分すぎる計算時間!
お次はoverwrite_aTrueにしてみる。

%%timeit
linalg.inv(arr, overwrite_a=True, check_finite=True)
#2.11 s ± 63.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
arr
#array([[3, 4, 6, ..., 5, 7, 1],
#       [8, 3, 3, ..., 5, 5, 5],
#       [6, 3, 4, ..., 8, 3, 1],
#       ...,
#       [8, 6, 3, ..., 7, 4, 7],
#       [5, 8, 9, ..., 2, 3, 6],
#       [7, 6, 5, ..., 8, 4, 6]])

微妙に早い?ばらつきの範囲内な気がする...
それよりarrが上書きされていない気がするがこれは...

%%timeit
linalg.inv(arr, overwrite_a=True, check_finite=False)
#2.04 s ± 13.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

こちらは早くなったといっていいと思う。
にしてもこの早さで逆行列を求められるなんてすごいことだ。

行列にinfがある場合

arr2 = np.random.randn(5000, 5000)
arr2[-2,-3] = np.inf
#array([[ 2.02989041,  0.46380936,  0.91881227, ...,  0.76432204,
#        -0.803534  ,  0.27665392],
#       [-1.15502314,  0.25701885, -0.38668162, ...,  0.27915569,
#         2.26712878, -1.31973077],
#       [ 0.46494284,  0.19511784, -0.17409262, ..., -0.60674436,
#        -0.55114893, -1.16376405],
#       ...,
#       [-1.02789623, -0.67204559,  0.75562066, ...,  0.73532684,
#        -0.69068431,  0.66328013],
#       [ 1.11935754, -2.84095525, -0.91949745, ...,         inf,
#        -1.94359856,  1.62964351],
#       [ 0.09731474, -1.79696413,  0.00693175, ..., -1.53071478,
#         0.13959659, -0.22879148]])

linalg.inv(arr2, overwrite_a=False, check_finite=True)
#ValueError: array must not contain infs or NaNs

linalg.inv(arr2, overwrite_a=False, check_finite=False)
#array([[nan, nan, nan, ..., nan, nan, nan],
#       [nan, nan, nan, ..., nan, nan, nan],
#       [nan, nan, nan, ..., nan, nan, nan],
#       ...,
#       [nan, nan, nan, ..., nan, nan, nan],
#       [nan, nan, nan, ..., nan, nan, nan],
#       [nan, nan, nan, ..., nan, nan, nan]])