admin管理员组文章数量:1389772
Ordinarily numba jitted functions cannot dynamically access global variables - the value is fixed as of compile time. A pythonic behaviour can be forced by forcing objmode
. (There is a reduction in performance for using objmode
).
from numba import njit, objmode
@njit
def access_global_variable_1():
return global_variable
@njit
def access_global_variable_2():
with objmode(retval='int64'):
retval = globals()['global_variable']
return retval
global_variable = 0
access_global_variable_1() # returns 0
access_global_variable_2() # returns 0
global_variable = 1
access_global_variable_1() # returns 0
access_global_variable_2() # returns 1
However, this breaks when we parallelise functions
from numba import njit, objmode, prange
import numpy as np
@njit
def _add(a):
with objmode():
globals()['acc'] += a
@njit(parallel=True)
def parallel_sum_global(arr):
for i in prange(len(arr)):
_add(arr[i])
@njit(parallel=False)
def sum_global(arr):
for i in prange(len(arr)):
_add(arr[i])
@njit(parallel=True)
def parallel_sum_local(arr):
acc = 0
for i in prange(len(arr)):
acc += arr[i]
return acc
n = 100
print('True answer:', np.arange(n).sum()) # True answer: 4950
acc = 0
parallel_sum_global(np.arange(n))
print('Numba parallel global answer:', acc) # Numba parallel answer: 78
acc = 0
sum_global(np.arange(n))
print('Numba global answer:', acc) # Numba global answer: 4950
acc = parallel_sum_local(np.arange(n))
print('Numba parallel local answer:', acc) # Numba parallel local answer: 4950
It can't be a problem relating to accessing global variables since sum_global
works as intended. It can't be a problem relating to parallelisation as sum_parallel_local
works as intended. It is only a problem that occurs when both accessing global variables and parallelising.
Investigating more closely, parallel_sum_global
correctly handles the first 12 array items but not after 12. I am developing on an 8 core computer.
Has anyone encountered this issue? Are there any workarounds?
(Disclaimer: I came across this issue working on something much more complicated, this is a minimal reproducer on a contrived example of summing arrays)
本文标签: pythonAccessing global variables in numba jitted parallel functionStack Overflow
版权声明:本文标题:python - Accessing global variables in numba jitted parallel function - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1744722947a2621804.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论