scipy.linalg.invを使ってみる
引数一覧
引数 | 型 | 意味 |
---|---|---|
a | array_like | 逆行列を計算したい行列を指定する。 |
overwrite_a=False | bool | 引数aを上書きするかどうか。巨大な行列ではTrueにすればパフォーマンスが改善される。 |
check_finite=True | bool | aに有限数のみが含まれていることを確認するかどうか。infかnanでないことが分かっている場合Falseにするとパフォーマンス向上する場合がある |
適当に5*5の行列作って逆行列を求める
import numpy as np from scipy import linalg arr = np.random.randint(1, 10, size=(5, 5)) #array([[1, 8, 7, 1, 1], # [9, 5, 3, 3, 9], # [1, 8, 9, 6, 1], # [1, 6, 5, 7, 2], # [5, 5, 8, 2, 5]]) np.round(arr @ linalg.inv(arr), 10) #array([[ 1., 0., 0., 0., -0.], # [ 0., 1., -0., 0., -0.], # [ 0., 0., 1., 0., -0.], # [ 0., 0., 0., 1., 0.], # [-0., 0., 0., 0., 1.]])
巨大行列の逆行列を求める
コードとしてはsize=
が変わっただけ。
IPythonから%%timeitを利用して実行時間を見てみる。まずデフォルトで。
arr = np.random.randint(1, 10, size=(5000, 5000)) # default %%timeit linalg.inv(arr, overwrite_a=False, check_finite=True) #2.14 s ± 104 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
十分すぎる計算時間!
お次はoverwrite_a
をTrue
にしてみる。
%%timeit linalg.inv(arr, overwrite_a=True, check_finite=True) #2.11 s ± 63.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) arr #array([[3, 4, 6, ..., 5, 7, 1], # [8, 3, 3, ..., 5, 5, 5], # [6, 3, 4, ..., 8, 3, 1], # ..., # [8, 6, 3, ..., 7, 4, 7], # [5, 8, 9, ..., 2, 3, 6], # [7, 6, 5, ..., 8, 4, 6]])
微妙に早い?ばらつきの範囲内な気がする...
それよりarr
が上書きされていない気がするがこれは...
%%timeit linalg.inv(arr, overwrite_a=True, check_finite=False) #2.04 s ± 13.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
こちらは早くなったといっていいと思う。
にしてもこの早さで逆行列を求められるなんてすごいことだ。
行列にinfがある場合
arr2 = np.random.randn(5000, 5000) arr2[-2,-3] = np.inf #array([[ 2.02989041, 0.46380936, 0.91881227, ..., 0.76432204, # -0.803534 , 0.27665392], # [-1.15502314, 0.25701885, -0.38668162, ..., 0.27915569, # 2.26712878, -1.31973077], # [ 0.46494284, 0.19511784, -0.17409262, ..., -0.60674436, # -0.55114893, -1.16376405], # ..., # [-1.02789623, -0.67204559, 0.75562066, ..., 0.73532684, # -0.69068431, 0.66328013], # [ 1.11935754, -2.84095525, -0.91949745, ..., inf, # -1.94359856, 1.62964351], # [ 0.09731474, -1.79696413, 0.00693175, ..., -1.53071478, # 0.13959659, -0.22879148]]) linalg.inv(arr2, overwrite_a=False, check_finite=True) #ValueError: array must not contain infs or NaNs linalg.inv(arr2, overwrite_a=False, check_finite=False) #array([[nan, nan, nan, ..., nan, nan, nan], # [nan, nan, nan, ..., nan, nan, nan], # [nan, nan, nan, ..., nan, nan, nan], # ..., # [nan, nan, nan, ..., nan, nan, nan], # [nan, nan, nan, ..., nan, nan, nan], # [nan, nan, nan, ..., nan, nan, nan]])