admin管理员组

文章数量:1318572

Is it possible to use vmap for auto-batching if your function isn't jittable?

I have a function that's not jittable:

def testfunc(model, x1, x2, x2_mask):
    ( ... non-jittable stuff with masks ... )

I'm trying to wrap it in vmap so I can benefit from auto-batching as explained here.

So I do:

testfunc_batched = jax.vmap(testfunc, in_axes=(None, 0, 0, 0))

The intention is that in batched mode, each of x1, x2, and x2_mask will have an additional outter dimension, the batching dimension. The model shouldn't be treated differently in batched mode hence the None. Let me know if the syntax isn't right.

I create batches of size one just to test, schematically:

x1s = x1.reshape(1, ...)
x2s = x2.reshape(1, ...)
x2_masks = x2_mask.reshape(1, ...)

testfunc_batched(model, x1s, x2s, x2_masks)

The last line fails with ConcretizationTypeError.

I've recently learned that stuff with masks makes functions not jittable. But does that mean that I also can't use vmap? Or am I doing something wrong?

(There is further context in How to JIT code involving masked arrays without NonConcreteBooleanIndexError?, but you don't have to read that question to understand this one.)

Is it possible to use vmap for auto-batching if your function isn't jittable?

I have a function that's not jittable:

def testfunc(model, x1, x2, x2_mask):
    ( ... non-jittable stuff with masks ... )

I'm trying to wrap it in vmap so I can benefit from auto-batching as explained here.

So I do:

testfunc_batched = jax.vmap(testfunc, in_axes=(None, 0, 0, 0))

The intention is that in batched mode, each of x1, x2, and x2_mask will have an additional outter dimension, the batching dimension. The model shouldn't be treated differently in batched mode hence the None. Let me know if the syntax isn't right.

I create batches of size one just to test, schematically:

x1s = x1.reshape(1, ...)
x2s = x2.reshape(1, ...)
x2_masks = x2_mask.reshape(1, ...)

testfunc_batched(model, x1s, x2s, x2_masks)

The last line fails with ConcretizationTypeError.

I've recently learned that stuff with masks makes functions not jittable. But does that mean that I also can't use vmap? Or am I doing something wrong?

(There is further context in How to JIT code involving masked arrays without NonConcreteBooleanIndexError?, but you don't have to read that question to understand this one.)

Share Improve this question edited Jan 21 at 11:15 jonrsharpe 122k30 gold badges267 silver badges474 bronze badges asked Jan 21 at 11:08 onelooponeloop 1971 silver badge6 bronze badges
Add a comment  | 

1 Answer 1

Reset to default 1

Is it possible to use jax.vmap for auto-batching if your function isn't jittable?

No. In general, functions which are incompatible with jit will also be incompatible with vmap, because both jit and vmap use the same JAX tracing mechanism to transform the program.

本文标签: Is it possible to use jaxvmap for autobatching if your function isn39t jittableStack Overflow