python 如何在numba jitted函数中使用类的对象而不jitting整个类?

b1zrtrql  于 3个月前  发布在  Python
关注(0)|答案(1)|浏览(16)

我有一个类,它创建了一个对象,该对象将几个数组(numpy数组)作为其子数组。该类包含构建这些数组的所有复杂逻辑。一个只有一个属性的简单模型看起来像这样:

class System():
    def __init__(self,backend='numpy'):
        if backend == 'numpy':
            self.D = np.ones((2,2))
        else:
            # Mock-up for other backend types
            self.D = [[1,1],[1,1]]

字符串
在真实的代码中,这会为类添加多个属性,这些属性是在运行时初始化类时确定的。如果后端是例如'numpy',那么所有属性都将是numpy数组。为了简单起见,我在这里只添加一个属性D
这个类的目的是用户可以利用他提供的函数中的属性,因为它们很可能包含很多循环,因此使njit大大加快了代码的速度。我希望现在用户能够njit他的函数。再次,一个模拟示例:

### User code
a = System(backend='numpy')

@nb.njit()
def user_provided_function():
    result = a.D * 2
    return result

out = user_provided_function()
print(out)


显然,这是不起作用的,因为numba会抱怨函数不能被njitt,因为a的类型没有定义/ jitable。更明显的是,如果用户避免全局变量的坏习惯,而是使用下面的代码,这将不起作用:

### User code
@nb.njit()
def user_provided_function(a):
    result = a.D * 2
    return result

b = System(backend='numpy')
out = user_provided_function(b)
print(out)


我的主要问题是,我不能从上面的类中创建一个jitclass,因为有一些后端不兼容numba。然而,我喜欢这个类,我只有一段代码,可以提供多个后端,用户可以轻松地更改后端,而无需重构整个代码。
我喜欢在用户代码中输入a.D的美妙之处,并希望为用户保持尽可能干净的界面。

ejk8hzay

ejk8hzay1#

我建议不要将python对象传递给numba函数,而是numpy数组。
考虑一下numba编译函数,它们只是接受数组并返回新数组/修改这些数组,仅此而已。这将使您的程序简单:

import numba as nb
import numpy as np

class System:
    def __init__(self, backend="numpy"):
        if backend == "numpy":
            self.D = np.ones((2, 2), dtype=np.float32)
        else:
            # Mock-up for other backend types
            self.D = [[1, 1], [1, 1]]

@nb.njit("float32[:, :](float32[:, :])")
def user_provided_function(a):
    return a * 2

b = System(backend="numpy")
out = user_provided_function(b.D)  # <-- don't pass whole object, just array from this object
print(out)

字符串
印刷品:

[[2. 2.]
 [2. 2.]]

相关问题