admin管理员组文章数量:1123367
Q
is a 3D matrix and could for example have the following shape:
(4000, 25, 25)
I want raise Q
to the power n
for {0, 1, ..., k}
and sum it all.
Basically, I want to calculate
\sum_{i=0}^{k-1}Q^n
I have the following function that works as expected:
def sum_of_powers(Q: np.ndarray, k: int) -> np.ndarray:
Qs = np.sum([
np.linalg.matrix_power(Q, n) for n in range(k)
], axis=0)
return Qs
Is it possible to speed up my function or is there a faster method to obtain the same output?
Q
is a 3D matrix and could for example have the following shape:
(4000, 25, 25)
I want raise Q
to the power n
for {0, 1, ..., k}
and sum it all.
Basically, I want to calculate
\sum_{i=0}^{k-1}Q^n
I have the following function that works as expected:
def sum_of_powers(Q: np.ndarray, k: int) -> np.ndarray:
Qs = np.sum([
np.linalg.matrix_power(Q, n) for n in range(k)
], axis=0)
return Qs
Is it possible to speed up my function or is there a faster method to obtain the same output?
Share Improve this question asked 12 hours ago HJA24HJA24 4682 gold badges15 silver badges40 bronze badges 2 |3 Answers
Reset to default 3We can perform this calculation in O(log k) matrix operations.
Let M(k) represent the k'th power of the input, and S(k) represent the sum of those powers from 0 to k. Let I represent an appropriate identity matrix.
Approach 1
If you expand the product, you'll find that (M(1) - I) * S(k) = M(k+1) - I
. That means we can compute M(k+1)
using a standard matrix power (which takes O(log k) matrix multiplications), and compute S(k)
by using numpy.linalg.solve
to solve the equation (M(1) - I) * S(k) = M(k+1) - I
:
import numpy.linalg
def option1(Q, k):
identity = numpy.eye(Q.shape[-1])
A = Q - identity
B = numpy.linalg.matrix_power(Q, k+1) - identity
return numpy.linalg.solve(A, B)
Approach 2
The standard exponentation by squaring algorithm computes M(2*k)
as M(k)*M(k)
and M(2*k+1)
as M(2*k)*M(1)
.
We can alter the algorithm to track both S(k-1) and M(k), by computing S(2*k-1)
as S(k-1)*M(k) + S(k-1)
and S(2*k)
as S(2*k-1) * M(1) + I
:
import numpy
def option2(Q, k):
identity = numpy.eye(Q.shape[-1])
if k == 0:
res = numpy.empty_like(Q)
res[:] = identity
return res
power = Q
sum_of_powers = identity
# Looping over a string might look dumb, but it's actually the most efficient option,
# as well as the simplest. (It wouldn't be the bottleneck even if it wasn't efficient.)
for bit in bin(k+1)[3:]:
sum_of_powers = (sum_of_powers @ power) + sum_of_powers
power = power @ power
if bit == "1":
sum_of_powers = sum_of_powers @ Q
sum_of_powers += identity
power = power @ Q
return sum_of_powers
Matrix exponentiation by integer powers is the same as chaining matrix multiplications. Since you sum up every power up to a limit, you can simply do regular matrix multiplications.
def sum_of_powers(Q: np.ndarray, k: int) -> np.ndarray:
def gen_powers():
if k < 0:
raise ValueError("negative power not implemented")
# zeroth power
yield np.broadcast_to(np.eye(Q.shape[-1], dtype=Q.dtype), Q.shape)
Qi = Q
for _ in range(1, k):
yield Qi
Qi = Qi @ Q
return sum(gen_powers())
Performance considerations:
- Using regular
sum
instead ofnp.sum
saves a ton of temporary memory. Even more effective versions using+=
are possible (avoiding the temporary allocations that come from repeatedarray + array
) - You may avoid temporary allocations for the matrix multiplication results by ping-ponging the Qi variable between two buffers and using
np.matmul(…, out=…)
Accuracy considerations:
- You will find differences to your straightforward approach since rounding errors are different. You can cache the intermediate results to implement exponentiation by squaring. That doesn't improve performance but it reduces the rounding errors since the n'th power result only requires a chain of
O(log(n))
matrix multiplications instead of a chain of n matrix multiplications np.sum
uses pairwise summing, which also has lower errors but also takes a lot of memory. Since both accuracy improvements need a history of past powers, you can recoup costs using both at the same time if you're smart about the implementation
This more accurate version is implemented here:
def sum_of_powers(Q: np.ndarray, k: int) -> np.ndarray:
Qis = np.empty((k + 1,) + Q.shape, dtype=Q.dtype)
Qis[0] = np.eye(Q.shape[-1], dtype=Q.dtype)
if not k:
return Qis[0]
Qis[1] = Q
for ki in range(2, k):
# exponentiation by squaring
# For even powers Q^k = Q^(k//2) @ Q^(k//2)
# For odd powers Q^k = Q^(k//2) @ Q^(k//2 + 1)
leftk = ki >> 1
rightk = leftk + (ki & 1)
np.matmul(Qis[leftk], Qis[rightk], out=Qis[ki])
return Qis.sum(axis=0)
There is still a difference in accuracy but I'm not sure my version is actually worse. It might be better since normally matrix powers are implemented via eigen decompositions and those are not free of rounding errors, either. But I have not checked numpy's implementation or done any further theoretical considerations.
TL;DR: Yes, the code can be massively speed up using a more efficient code written in Numba taking care of CPU and memory resources and having a lower complexity.
Why the current code is inefficient
The code in the question is very inefficient for multiple reasons.
First of all, the algorithmic complexity of the code is O(k**2)
while this can be performed in O(k)
with successive matrix multiplications as pointed out by @Homer512.
Moreover, the operation is not cache friendly. Indeed, each matrix multiplication operates on a batch of 4000 times 25x25 matrices. This means 2*8*4000*25*25 ≈ 38 MiB
of memory are needed to perform the computation. This does not fit in the LLC cache of most CPUs. As a result, the slow DRAM is used to store data so a significant amount of data need to be fetched and stored back for each iteration (114 MiB/iteration).
Additionally, the operation should be sequential since Numpy parallelize no operation and the BLAS library used to perform the matrix multiplications cannot parallelize the code either because the 25x25 matrix is too small for multiple threads to worth it.
Finally, each matrix multiplication is certainly sub-optimal because of the overhead of Numpy internal iterators and the one of calling BLAS functions for each matrix.
Faster implementation
We can write a Numba code iteratively performing the matrix multiplication on each 25x25 matrix and summing this up. The code can be parallelized trivially so each thread operate on a chunk along the last axis (of size 4000). The operation is thus parallel, cache friendly, without any Numpy/CPython overhead and SIMD-friendly thanks to Numba. Here is the resulting code:
import numba as nb
@nb.njit('(float64[:,:,::1], int32)', parallel=True)
def sum_of_powers_numba(Q: np.ndarray, k: int) -> np.ndarray:
s0, s1, s2 = Q.shape
assert s1 == s2
Qs = np.zeros((s0, s1, s1))
for i in nb.prange(s0):
tmp = np.eye(s1, s1)
Qs[i,:,:] = tmp
for n in range(k-1):
tmp = tmp @ Q[i]
Qs[i,:,:] += tmp
return Qs
The BLAS library seems to be sub-optimal for 25x25 matrices but not so bad either. Writing a faster implementation in Numba is very challenging. It is easier in native languages. I think it does not worth the effort of rewriting a specialized implementation for small matrices. I expect the BLIS library to be a bit faster (close to optimal) than the default BLAS implementation (OpenBLAS), so it is certainly better just to try it (or possibly the highly-optimized Intel MKL).
Besides, if you want a more accurate version, then you should adapt the code so to perform an exponentiation by squaring as shown by @Homer512 while still using Numba.
Benchmark (updated)
Here are results on my i5-9600KF CPU (6 cores):
sum_of_powers_naive: 23.6 sec
sum_of_powers_homer512: 7.8 sec
sum_of_powers_numba: 1.5 sec <-----
user2357112's option2: 0.16 sec
user2357112's option1: 0.13 sec
Interestingly, profiling results show that about ~80% of the time is spent in the BLAS library so the main speed up comes from parallelizing the code with nb.prange
in Numba).
In the end, the approach of @user2357112 is the best since it is the fastest and certainly the most numerically stable. Note it can be faster by parallelizing the computation over the first axis.
本文标签: pythonFast(est) exponentiation of numpy 3D matrixStack Overflow
版权声明:本文标题:python - Fast(est) exponentiation of numpy 3D matrix - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1736563716a1944677.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
np.float64
andk=300
– HJA24 Commented 12 hours ago