首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将numpy数组的视图倒转两次,使函数运行得更快。

将numpy数组的视图倒转两次,使函数运行得更快。
EN

Stack Overflow用户
提问于 2019-09-21 09:59:28
回答 1查看 300关注 0票数 5

因此,我在测试相同函数的两个版本的速度;一个版本两次反转numpy数组的视图,另一个没有。守则如下:

代码语言:javascript
复制
import numpy as np
from numba import njit

@njit
def min_getter(arr):

    if len(arr) > 1:
        result = np.empty(len(arr), dtype = arr.dtype)
        local_min = arr[0]
        result[0] = local_min

        for i in range(1,len(arr)):
            if arr[i] < local_min:
                local_min = arr[i]
            result[i] = local_min
        return result

    else:
        return arr

@njit
def min_getter_rev1(arr1):

    if len(arr1) > 1:
        arr = arr1[::-1][::-1]
        result = np.empty(len(arr), dtype = arr.dtype)
        local_min = arr[0]
        result[0] = local_min

        for i in range(1,len(arr)):
            if arr[i] < local_min:
                local_min = arr[i]
            result[i] = local_min
        return result

    else:
        return arr1
size = 500000
x = np.arange(size)   
y = np.hstack((x[::-1], x))

y_min = min_getter(y)
yrev_min = min_getter_rev1(y)

令人惊讶的是,有额外操作的那个在多个场合运行得稍微快一些。我在这两个函数上使用了大约10次%timeit;尝试了不同大小的数组,这种差异是明显的(至少在我的计算机中是这样)。min_getter的运行时是这样的:

2.35 ms ± 58.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)(有时为2.33,有时为2.37,但从未低于2时30分)

min_getter_rev1的运行时就是这样的:

2.22 ms ± 23.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)(有时为2.25,有时为2.23,但很少超过2时30分)

对为什么和如何发生这件事有什么想法吗?速度差是4-6%的增长,这可能是一个很大的问题,在一些应用.加速的潜在机制可能有助于加速某些跳码。

Note1:我试过size=5000000,在每个函数上测试了5-10次,两者之间的区别更加明显。在23.2 ms ± 51.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)跑得越快,在24.4 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)跑得越慢

Note2:测试期间numpynumba的版本是1.16.50.45.1;python版本是3.7.4IPython版本是7.8.0;Python使用的是spyder。不同版本的测试结果可能有所不同。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-09-22 17:51:58

TL;DR:这可能只是一个幸运的巧合,第二个代码更快。

检查生成的类型可以发现有一个重要的区别:

在第一个示例中,您的

  • 被键入为array(int32, 1d, C),这是一个C-连续数组.

代码语言:javascript
复制
min_getter.inspect_types()

min_getter (array(int32, 1d, C),)  <--- THIS IS THE IMPORTANT LINE
--------------------------------------------------------------------------------
# File: <>
# --- LINE 4 --- 
# label 0

@njit

# --- LINE 5 --- 

def min_getter(arr):

[...]

在第二个示例中,

  • arr类型为array(int32, 1d, A),如果它是连续的,则为不知情的数组。这是因为,[::-1].

返回一个没有连续信息的数组,一旦它丢失了,就不能由第二个[::-1]恢复。

代码语言:javascript
复制
>>> min_getter_rev1.inspect_types()

[...]

    # --- LINE 18 --- 
    #   arr1 = arg(0, name=arr1)  :: array(int32, 1d, C)
    #   $const0.2 = const(NoneType, None)  :: none
    #   $const0.3 = const(NoneType, None)  :: none
    #   $const0.4 = const(int, -1)  :: Literal[int](-1)
    #   $0.5 = global(slice: <class 'slice'>)  :: Function(<class 'slice'>)
    #   $0.6 = call $0.5($const0.2, $const0.3, $const0.4, func=$0.5, args=(Var($const0.2, <> (18)), Var($const0.3, <> (18)), Var($const0.4, <> (18))), kws=(), vararg=None)  :: (none, none, int64) -> slice<a:b:c>
    #   del $const0.4
    #   del $const0.3
    #   del $const0.2
    #   del $0.5
    #   $0.7 = static_getitem(value=arr1, index=slice(None, None, -1), index_var=$0.6)  :: array(int32, 1d, A)
    #   del arr1
    #   del $0.6
    #   $const0.8 = const(NoneType, None)  :: none
    #   $const0.9 = const(NoneType, None)  :: none
    #   $const0.10 = const(int, -1)  :: Literal[int](-1)
    #   $0.11 = global(slice: <class 'slice'>)  :: Function(<class 'slice'>)
    #   $0.12 = call $0.11($const0.8, $const0.9, $const0.10, func=$0.11, args=(Var($const0.8, <> (18)), Var($const0.9, <> (18)), Var($const0.10, <> (18))), kws=(), vararg=None)  :: (none, none, int64) -> slice<a:b:c>
    #   del $const0.9
    #   del $const0.8
    #   del $const0.10
    #   del $0.11
    #   $0.13 = static_getitem(value=$0.7, index=slice(None, None, -1), index_var=$0.12)  :: array(int32, 1d, A)
    #   del $0.7
    #   del $0.12
    #   arr = $0.13  :: array(int32, 1d, A)  <---- THIS IS THE IMPORTANT LINE
    #   del $0.13

    arr = arr1[::-1][::-1]

[...]

(生成的其余代码几乎相同)

如果已知数组是连续的,索引和迭代应该更快。但这不是我们在这种情况下所观察到的-完全相反。

那么原因是什么呢?

Numba本身使用LLVM来“编译”the代码。因此,这涉及到一个实际的编译器,编译器可以进行优化。尽管inspect_types()检查的代码几乎相同,但实际的LLVM/ASM代码与inspect_llvm()inspect_asm()非常不同。所以编译器(或numba)能够在第二种情况下进行某种优化,在第一种情况下是不可能的。或者,应用于第一种情况的一些优化实际上使代码变得更糟。

然而,这意味着我们只是在第二种情况下“走运”了。它可能不能被控制,因为它取决于:

  • -- numba基于您的源代码创建的类型,
  • ,numba内部使用的源代码,用于对这些类型进行操作的
  • 、从这些类型生成的LLVM和从该LLVM生成的ASM的

这些是太多的运动部件,可以应用优化(或不适用)。

有趣的事实:如果您丢弃外部ifs:

代码语言:javascript
复制
import numpy as np
from numba import njit

@njit
def min_getter(arr):
    result = np.empty(len(arr), dtype = arr.dtype)
    local_min = arr[0]
    result[0] = local_min

    for i in range(1,len(arr)):
        if arr[i] < local_min:
            local_min = arr[i]
        result[i] = local_min
    return result

@njit
def min_getter_rev1(arr1):
    arr = arr1[::-1][::-1]
    result = np.empty(len(arr), dtype = arr.dtype)
    local_min = arr[0]
    result[0] = local_min

    for i in range(1,len(arr)):
        if arr[i] < local_min:
            local_min = arr[i]
        result[i] = local_min
    return result

size = 500000
x = np.arange(size)   
y = np.hstack((x[::-1], x))

y_min = min_getter(y)
yrev_min = min_getter_rev1(y)

%timeit min_getter(y)      # 2.29 ms ± 86.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit min_getter_rev1(y) # 2.37 ms ± 212 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

在这种情况下,没有[::-1][::-1]的那个更快。

因此,如果您想使它可靠地更快:将if len(arr) > 1检查移到函数之外,不要使用[::-1][::-1],因为在大多数情况下,这会使函数运行得更慢(而且可读性更低)!

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58039192

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档