admin管理员组文章数量:1356236
There is a function in my code base that is "already vmapped", i.e. when fed an array of shape (M, N) it outputs another array of shape (M, N). I would like to take the "row-wise Jacobian" of this function: i.e. a function that returns an array of shape (M, N, N). I've achieved this so far in hacky way by adding a dummy extra dimension and vmapping, as illustrated in the example below, but it really feels like there should be a better way to do this. Does anyone have any ideas?
Example of what I want:
import jax
from jax import numpy as jnp
rng = jax.random.PRNGKey(42)
A = jax.random.normal(rng, shape=(128, 16, 16))
# This is the function I would like to take the row-wise Jacobian of
def already_vmapped(x, A):
vmap_mul = jax.vmap(jnp.matmul)
return vmap_mul(A, x)
# This is how I am doing it now
function = lambda x, A: jax.vmap(jax.jacobian(lambda x, A: already_vmapped(x[None,], A[None,]),
argnums=0))(x, A).squeeze()
x = jnp.ones((128, 16))
print(jnp.allclose(function(x, A), A)) # True
# A slightly cleaner way that is too memory intensive
function_mem = lambda x, A: jnp.diagonal(jax.jacobian(already_vmapped, argnums=0)(x, A),
offset=0, axis1=0, axis2=2).transpose(2, 0, 1)
print(jnp.allclose(function_mem(x, A), A)) # True
I understand that the absolute cleanest way would just to not vmap the original function in the first place, but for whatever reason that is not easy to undo right now given my codebase. Any other suggestions are welcome!
There is a function in my code base that is "already vmapped", i.e. when fed an array of shape (M, N) it outputs another array of shape (M, N). I would like to take the "row-wise Jacobian" of this function: i.e. a function that returns an array of shape (M, N, N). I've achieved this so far in hacky way by adding a dummy extra dimension and vmapping, as illustrated in the example below, but it really feels like there should be a better way to do this. Does anyone have any ideas?
Example of what I want:
import jax
from jax import numpy as jnp
rng = jax.random.PRNGKey(42)
A = jax.random.normal(rng, shape=(128, 16, 16))
# This is the function I would like to take the row-wise Jacobian of
def already_vmapped(x, A):
vmap_mul = jax.vmap(jnp.matmul)
return vmap_mul(A, x)
# This is how I am doing it now
function = lambda x, A: jax.vmap(jax.jacobian(lambda x, A: already_vmapped(x[None,], A[None,]),
argnums=0))(x, A).squeeze()
x = jnp.ones((128, 16))
print(jnp.allclose(function(x, A), A)) # True
# A slightly cleaner way that is too memory intensive
function_mem = lambda x, A: jnp.diagonal(jax.jacobian(already_vmapped, argnums=0)(x, A),
offset=0, axis1=0, axis2=2).transpose(2, 0, 1)
print(jnp.allclose(function_mem(x, A), A)) # True
I understand that the absolute cleanest way would just to not vmap the original function in the first place, but for whatever reason that is not easy to undo right now given my codebase. Any other suggestions are welcome!
Share Improve this question asked Mar 31 at 11:26 Kimon PKimon P 312 bronze badges1 Answer
Reset to default 1I think what you're already doing is more-or-less the best approach. You want to vmap
the jacobian
over the rows, and within each row you want to compute a size-1 batch of your original "already vmapped" function.
For clarity, I'd probably re-express your initial answer this way:
def f_single_batch(x, A):
return already_vmapped(x[None], A[None]).squeeze(0)
result = jax.vmap(jax.jacobian(f_single_batch, 0))(x, A)
A slightly more direct approach to this might look like this:
result = jax.vmap(jax.jacobian(already_vmapped, 0))(x[:, None], A[:, None]).squeeze((1, 3))
But I would lean toward the first version because it's easier to understand, and anybody reading the code (including your future self) will thank you.
本文标签:
版权声明:本文标题:python - Is there a convenient way to take the gradientjacobian of an already vmapped function in JAX? - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1743951178a2567318.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论