admin管理员组文章数量:1201180
How do I do a jax get from a masked index?
The code below works without jit.
x = jnp.arange(25).reshape((5,5))
coords = jnp.array([
[1,2],
[2,3],
[1,2],
[1,2],
])
coords_mask = jnp.array([True, True, False, True])
@jax.jit
def masked_gather(x, coords, coords_mask):
coords_masked = coords[coords_mask]
return x.at[coords_masked[:, 0], coords_masked[:, 1]].get()
masked_gather(x, coords, coords_mask)
Fails with NonConcreteBooleanIndexError
.
Should return Array([ 7, 13, 7], dtype=int32)
How do I do a jax get from a masked index?
The code below works without jit.
x = jnp.arange(25).reshape((5,5))
coords = jnp.array([
[1,2],
[2,3],
[1,2],
[1,2],
])
coords_mask = jnp.array([True, True, False, True])
@jax.jit
def masked_gather(x, coords, coords_mask):
coords_masked = coords[coords_mask]
return x.at[coords_masked[:, 0], coords_masked[:, 1]].get()
masked_gather(x, coords, coords_mask)
Fails with NonConcreteBooleanIndexError
.
Should return Array([ 7, 13, 7], dtype=int32)
1 Answer
Reset to default 1There is no way to execute this function in a JIT-compatible way, because JAX does not support compilation of programs with dynamic shapes. In your case, the size of the returned array depends on the number of True
elements in coords_mask
, and so the shape is dynamic by definition.
See JAX Sharp Bits: Dynamic Shapes for more information.
Depending on what you are doing with the resulting value, there are a number of available approaches to work around this: for example, if the shape is truly unknown, you could return an array padded with zeros; it might look something like this:
@jax.jit
def masked_gather_padded(x, coords, coords_mask, fill_value=0):
coords_masked = jnp.where(coords_mask[:, None], coords, max(x.shape))
order = jnp.argsort(~coords_mask)
result = x.at[coords_masked[:, 0], coords_masked[:, 1]].get(mode='fill', fill_value=fill_value)
return result[order]
masked_gather_padded(x, coords, coords_mask)
# Array([ 7, 13, 7, 0], dtype=int32)
Alternatively, if the number of True
entries in the mask is known a priori, you could modify the function to accept a static size
argument and use that to construct an appropriate output. It might look something like this:
from functools import partial
@partial(jax.jit, static_argnames=['size'])
def masked_gather_with_size(x, coords, coords_mask, *, size):
coords_masked = jnp.where(coords_mask[:, None], coords, max(x.shape))
order = jnp.argsort(~coords_mask)
result = x.at[coords_masked[:, 0], coords_masked[:, 1]].get(mode='drop')
return result[order[:size]]
masked_gather_with_size(x, coords, coords_mask, size=3)
# Array([ 7, 13, 7], dtype=int32)
The best approach will depend on your application.
本文标签: jaxHow to do jittable masked getStack Overflow
版权声明:本文标题:jax - How to do jittable masked get? - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1738618090a2103027.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论