我正试图使用numba加速一些代码,但这很难做到。例如,下面的函数不是numba-fy,
@jit(nopython=True)
def returns(Ft, x, delta):
T = len(x)
rets = Ft[0:T - 1] * x[1:T] - delta * np.abs(Ft[1:T] - Ft[0:T - 1])
return np.concatenate([[0], rets])因为numba找不到np.concatenate的签名。对此有规范的修正吗?
发布于 2022-03-03 08:46:53
有点晚了,但我希望还是有用的。既然您要求“规范修复”,我想解释为什么concatenate在处理数组时是个坏主意,特别是如果您表示希望消除瓶颈,从而使用numba。数组是内存中一个连续的字节序列(numpy知道通过创建视图来改变顺序而不复制顺序的一些技巧,但这是另一个主题,参见https://towardsdatascience.com/advanced-numpy-master-stride-tricks-with-25-illustrated-exercises-923a9393ab20)。如果要将值x添加到N元素数组中,则需要使用N+1元素创建一个新数组,将第一个值设置为x并复制其余部分。另外,对于python列表中的前缀项,也存在类似的参数,这也是collections.deque存在的原因。
现在,在您的jit修饰函数中,您可以希望编译器理解您想要做的事情,但是编写始终了解您想要做的事情的编译器几乎是不可能的。因此,当您知道正确的选择时,最好善待编译器,并帮助处理内存布局。因此,IMHO示例代码的“规范修复”如下所示:
@jit(nopython=True)
def returns(Ft, x, delta):
T = len(x)
rets = np.empty_like(x)
rets[0] = 0
rets[1:T] = Ft[0:T - 1] * x[1:T] - delta * np.abs(Ft[1:T] - Ft[0:T - 1])
return rets总的来说,我同意@Aaron的评论,这意味着对于jit修饰函数中调用的任何函数,在输入类型时都应该尽可能明确。在你的例子中,问问自己作为一个编译器“什么是[[0], rets]?”在严格类型中,您会看到一个包含整数列表和浮点(或复杂)数字数组的列表。对于编译器来说,这是一种具有挑战性的混合类型。输出应该成为整数数组还是浮动数组?
https://stackoverflow.com/questions/69702485
复制相似问题