admin管理员组文章数量:1318572
Is it possible to use vmap for auto-batching if your function isn't jittable?
I have a function that's not jittable:
def testfunc(model, x1, x2, x2_mask):
( ... non-jittable stuff with masks ... )
I'm trying to wrap it in vmap so I can benefit from auto-batching as explained here.
So I do:
testfunc_batched = jax.vmap(testfunc, in_axes=(None, 0, 0, 0))
The intention is that in batched mode, each of x1
, x2
, and x2_mask
will have an additional outter dimension, the batching dimension. The model
shouldn't be treated differently in batched mode hence the None
. Let me know if the syntax isn't right.
I create batches of size one just to test, schematically:
x1s = x1.reshape(1, ...)
x2s = x2.reshape(1, ...)
x2_masks = x2_mask.reshape(1, ...)
testfunc_batched(model, x1s, x2s, x2_masks)
The last line fails with ConcretizationTypeError
.
I've recently learned that stuff with masks makes functions not jittable. But does that mean that I also can't use vmap
? Or am I doing something wrong?
(There is further context in How to JIT code involving masked arrays without NonConcreteBooleanIndexError?, but you don't have to read that question to understand this one.)
Is it possible to use vmap for auto-batching if your function isn't jittable?
I have a function that's not jittable:
def testfunc(model, x1, x2, x2_mask):
( ... non-jittable stuff with masks ... )
I'm trying to wrap it in vmap so I can benefit from auto-batching as explained here.
So I do:
testfunc_batched = jax.vmap(testfunc, in_axes=(None, 0, 0, 0))
The intention is that in batched mode, each of x1
, x2
, and x2_mask
will have an additional outter dimension, the batching dimension. The model
shouldn't be treated differently in batched mode hence the None
. Let me know if the syntax isn't right.
I create batches of size one just to test, schematically:
x1s = x1.reshape(1, ...)
x2s = x2.reshape(1, ...)
x2_masks = x2_mask.reshape(1, ...)
testfunc_batched(model, x1s, x2s, x2_masks)
The last line fails with ConcretizationTypeError
.
I've recently learned that stuff with masks makes functions not jittable. But does that mean that I also can't use vmap
? Or am I doing something wrong?
(There is further context in How to JIT code involving masked arrays without NonConcreteBooleanIndexError?, but you don't have to read that question to understand this one.)
Share Improve this question edited Jan 21 at 11:15 jonrsharpe 122k30 gold badges267 silver badges474 bronze badges asked Jan 21 at 11:08 onelooponeloop 1971 silver badge6 bronze badges1 Answer
Reset to default 1Is it possible to use
jax.vmap
for auto-batching if your function isn't jittable?
No. In general, functions which are incompatible with jit
will also be incompatible with vmap
, because both jit
and vmap
use the same JAX tracing mechanism to transform the program.
本文标签: Is it possible to use jaxvmap for autobatching if your function isn39t jittableStack Overflow
版权声明:本文标题:Is it possible to use jax.vmap for auto-batching if your function isn't jittable? - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1742047328a2417866.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论