admin管理员组

文章数量:1201180

How do I do a jax get from a masked index?

The code below works without jit.

x = jnp.arange(25).reshape((5,5))
coords = jnp.array([
    [1,2],
    [2,3],
    [1,2],
    [1,2],
])
coords_mask = jnp.array([True, True, False, True])

@jax.jit
def masked_gather(x, coords, coords_mask):
    coords_masked = coords[coords_mask]
    return x.at[coords_masked[:, 0], coords_masked[:, 1]].get()

masked_gather(x, coords, coords_mask)

Fails with NonConcreteBooleanIndexError.

Should return Array([ 7, 13, 7], dtype=int32)

How do I do a jax get from a masked index?

The code below works without jit.

x = jnp.arange(25).reshape((5,5))
coords = jnp.array([
    [1,2],
    [2,3],
    [1,2],
    [1,2],
])
coords_mask = jnp.array([True, True, False, True])

@jax.jit
def masked_gather(x, coords, coords_mask):
    coords_masked = coords[coords_mask]
    return x.at[coords_masked[:, 0], coords_masked[:, 1]].get()

masked_gather(x, coords, coords_mask)

Fails with NonConcreteBooleanIndexError.

Should return Array([ 7, 13, 7], dtype=int32)

Share Improve this question edited Jan 21 at 16:15 jonrsharpe 122k30 gold badges266 silver badges473 bronze badges asked Jan 21 at 16:15 onelooponeloop 1971 silver badge6 bronze badges
Add a comment  | 

1 Answer 1

Reset to default 1

There is no way to execute this function in a JIT-compatible way, because JAX does not support compilation of programs with dynamic shapes. In your case, the size of the returned array depends on the number of True elements in coords_mask, and so the shape is dynamic by definition.

See JAX Sharp Bits: Dynamic Shapes for more information.

Depending on what you are doing with the resulting value, there are a number of available approaches to work around this: for example, if the shape is truly unknown, you could return an array padded with zeros; it might look something like this:

@jax.jit
def masked_gather_padded(x, coords, coords_mask, fill_value=0):
  coords_masked = jnp.where(coords_mask[:, None], coords, max(x.shape))
  order = jnp.argsort(~coords_mask)
  result = x.at[coords_masked[:, 0], coords_masked[:, 1]].get(mode='fill', fill_value=fill_value)
  return result[order]

masked_gather_padded(x, coords, coords_mask)
# Array([ 7, 13,  7,  0], dtype=int32)

Alternatively, if the number of True entries in the mask is known a priori, you could modify the function to accept a static size argument and use that to construct an appropriate output. It might look something like this:

from functools import partial

@partial(jax.jit, static_argnames=['size'])
def masked_gather_with_size(x, coords, coords_mask, *, size):
  coords_masked = jnp.where(coords_mask[:, None], coords, max(x.shape))
  order = jnp.argsort(~coords_mask)
  result = x.at[coords_masked[:, 0], coords_masked[:, 1]].get(mode='drop')
  return result[order[:size]]

masked_gather_with_size(x, coords, coords_mask, size=3)
# Array([ 7, 13,  7], dtype=int32)

The best approach will depend on your application.

本文标签: jaxHow to do jittable masked getStack Overflow