python - Numba array.mean的BoundFunction使用无效

我想计算每个第三个索引的第二个索引的平均值。

@njit
def mean_some_index(a):
    T = a.shape[2]
    b = np.zeros((T,T))
    for t in range(T):
        b[:, t] = a[:,:,t].mean(axis = 1)
    return b

我会像这样使用它

a = np.random.randn(5*5*5).reshape((5,5,5))
mean_some_index(a)

没有 Numba 也没关系;然而, Numba 返回一个错误说:

resolving callee type: BoundFunction(array.mean for array(float64, 2d, A))
...
File "C:\Users\Mining-Base\AppData\Local\Temp\ipykernel_565300\1191607406.py", line 7:
def mean_some_index(a):
    <source elided>
    for t in range(T):
        b[:, t] = a[:,:,t].mean(axis = 1)

我不太明白这个错误,并会感谢谁回答我的问题。

回答1

编辑:添加了交换循环版本和更大的数组以进行基准测试。此外,似乎 numba 对于小型数组(数百个元素)非常有帮助,但对于大型数组(数百万个元素)则不太有用。

Edit2:添加了并行代码。

那是因为 numba 不支持这些方法的参数,https://numba.readthedocs.io/en/stable/reference/numpysupported.html#calculation

尽管如此,我们仍然可以通过计算循环中的平均值来从加速中受益。

import numpy as np
from numba import njit, prange

rng = np.random.default_rng()


# arr = rng.standard_normal(size=(5, 5, 5))
arr = rng.standard_normal(size=(500, 500, 500))

def np_mean(arr):
    z_dim = arr.shape[2]
    out = np.empty((z_dim, z_dim))
    for ax in range(z_dim):
        out[:, ax] = arr[:, :, ax].mean(axis=1)
    return out


@njit
def nb_mean(arr):
    y_dim, z_dim = arr.shape[1], arr.shape[2]
    out = np.empty((z_dim, z_dim))
    for ax in range(z_dim):
        for idx in range(y_dim):
            out[idx, ax] = arr[:, :, ax][idx].mean()
    return out

@njit
def nb_mean_swapped(arr):
    y_dim, z_dim = arr.shape[1], arr.shape[2]
    out = np.empty((z_dim, z_dim))
    for idx in range(y_dim):
        for ax in range(z_dim):
            out[idx, ax] = arr[:, :, ax][idx].mean()
    return out

@njit(parallel=True)
def nb_mean_swapped_parallel(arr):
    y_dim, z_dim = arr.shape[1], arr.shape[2]
    out = np.empty((z_dim, z_dim))
    for idx in prange(y_dim):
        for ax in prange(z_dim):
            out[idx, ax] = arr[:, :, ax][idx].mean()
    return out
In [27]: %timeit np_mean(arr)
674 ms ± 23.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [28]: %timeit nb_mean(arr)
606 ms ± 28.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [29]: %timeit nb_mean_swapped(arr)
218 ms ± 23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [30]: %timeit nb_mean_swapped_parallel(arr)
64.3 ms ± 2.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

相似文章

php - 挣扎于数组 php 中的索引

从rapidapi获取json数据并尝试将其放入canvasjs图表中。我相信为什么我的图表不打印的问题是因为我创建的数组没有正确的索引。我的数组在所需的每个元素上都有索引0,它的索引为0-13。<!...

随机推荐

最新文章