我想计算每个第三个索引的第二个索引的平均值。
@njit
def mean_some_index(a):
T = a.shape[2]
b = np.zeros((T,T))
for t in range(T):
b[:, t] = a[:,:,t].mean(axis = 1)
return b
我会像这样使用它
a = np.random.randn(5*5*5).reshape((5,5,5))
mean_some_index(a)
没有 Numba 也没关系;然而, Numba 返回一个错误说:
resolving callee type: BoundFunction(array.mean for array(float64, 2d, A))
...
File "C:\Users\Mining-Base\AppData\Local\Temp\ipykernel_565300\1191607406.py", line 7:
def mean_some_index(a):
<source elided>
for t in range(T):
b[:, t] = a[:,:,t].mean(axis = 1)
我不太明白这个错误,并会感谢谁回答我的问题。
回答1
编辑:添加了交换循环版本和更大的数组以进行基准测试。此外,似乎 numba
对于小型数组(数百个元素)非常有帮助,但对于大型数组(数百万个元素)则不太有用。
Edit2:添加了并行代码。
那是因为 numba 不支持这些方法的参数,https://numba.readthedocs.io/en/stable/reference/numpysupported.html#calculation。
尽管如此,我们仍然可以通过计算循环中的平均值来从加速中受益。
import numpy as np
from numba import njit, prange
rng = np.random.default_rng()
# arr = rng.standard_normal(size=(5, 5, 5))
arr = rng.standard_normal(size=(500, 500, 500))
def np_mean(arr):
z_dim = arr.shape[2]
out = np.empty((z_dim, z_dim))
for ax in range(z_dim):
out[:, ax] = arr[:, :, ax].mean(axis=1)
return out
@njit
def nb_mean(arr):
y_dim, z_dim = arr.shape[1], arr.shape[2]
out = np.empty((z_dim, z_dim))
for ax in range(z_dim):
for idx in range(y_dim):
out[idx, ax] = arr[:, :, ax][idx].mean()
return out
@njit
def nb_mean_swapped(arr):
y_dim, z_dim = arr.shape[1], arr.shape[2]
out = np.empty((z_dim, z_dim))
for idx in range(y_dim):
for ax in range(z_dim):
out[idx, ax] = arr[:, :, ax][idx].mean()
return out
@njit(parallel=True)
def nb_mean_swapped_parallel(arr):
y_dim, z_dim = arr.shape[1], arr.shape[2]
out = np.empty((z_dim, z_dim))
for idx in prange(y_dim):
for ax in prange(z_dim):
out[idx, ax] = arr[:, :, ax][idx].mean()
return out
In [27]: %timeit np_mean(arr)
674 ms ± 23.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [28]: %timeit nb_mean(arr)
606 ms ± 28.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [29]: %timeit nb_mean_swapped(arr)
218 ms ± 23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [30]: %timeit nb_mean_swapped_parallel(arr)
64.3 ms ± 2.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)