我有一个类,它创建了一个对象,该对象将几个数组(numpy数组)作为其子数组。该类包含构建这些数组的所有复杂逻辑。一个只有一个属性的简单模型看起来像这样:
class System():
def __init__(self,backend='numpy'):
if backend == 'numpy':
self.D = np.ones((2,2))
else:
# Mock-up for other backend types
self.D = [[1,1],[1,1]]
字符串
在真实的代码中,这会为类添加多个属性,这些属性是在运行时初始化类时确定的。如果后端是例如'numpy'
,那么所有属性都将是numpy数组。为了简单起见,我在这里只添加一个属性D
。
这个类的目的是用户可以利用他提供的函数中的属性,因为它们很可能包含很多循环,因此使njit大大加快了代码的速度。我希望现在用户能够njit他的函数。再次,一个模拟示例:
### User code
a = System(backend='numpy')
@nb.njit()
def user_provided_function():
result = a.D * 2
return result
out = user_provided_function()
print(out)
型
显然,这是不起作用的,因为numba会抱怨函数不能被njitt,因为a
的类型没有定义/ jitable。更明显的是,如果用户避免全局变量的坏习惯,而是使用下面的代码,这将不起作用:
### User code
@nb.njit()
def user_provided_function(a):
result = a.D * 2
return result
b = System(backend='numpy')
out = user_provided_function(b)
print(out)
型
我的主要问题是,我不能从上面的类中创建一个jitclass,因为有一些后端不兼容numba。然而,我喜欢这个类,我只有一段代码,可以提供多个后端,用户可以轻松地更改后端,而无需重构整个代码。
我喜欢在用户代码中输入a.D
的美妙之处,并希望为用户保持尽可能干净的界面。
1条答案
按热度按时间ejk8hzay1#
我建议不要将python对象传递给numba函数,而是numpy数组。
考虑一下numba编译函数,它们只是接受数组并返回新数组/修改这些数组,仅此而已。这将使您的程序简单:
字符串
印刷品:
型