admin管理员组文章数量:1390881
I want to train a simple feed forward neural net that I've built in penzai, but I want to use different learning rates for each parameter group. I store the learning rate scale factor in each parameter's metadata
, e.g. like this:
Parameter(
label='mlp/Affine_0/Linear.weights',
value=<NamedArray float32(| features:784, features_out:128) (wrapping jax.Array)>,
metadata={'learning_rate': 0.0012755102040816326},
)
I use Penzai's StatefulTrainer
for training and declare the optimizer like this:
optax.chain(
optax.scale_by_adam(),
scale_by_metadata_value("learning_rate"),
optax.scale_by_learning_rate(0.01),
)
where I define scale_by_metadata_value
like this:
def scale_by_metadata_value(metadata_field_name: str):
def init_fn(params):
learning_rates = jax.tree.map(
lambda param: param.metadata[metadata_field_name],
params,
is_leaf=(lambda node: isinstance(node, pz.ParameterValue))
)
return {"learning_rates": learning_rates}
def update_fn(updates, state, params):
del params
updates = jax.tree.map(
# This is where the TypeError is thrown:
lambda lr, g: lr * g, state["learning_rates"], updates
)
return updates, state
return optax.GradientTransformation(init_fn, update_fn)
However, when I run a training step, I get
TypeError: unsupported operand type(s) for *: 'jaxlib.xla_extension.ArrayImpl' and 'ParameterValue'
I am especially puzzled because everything works when I remove the scale_by_metadata_value("learning_rate")
line, even though optax.scale_by_learning_rate(0.01) does essentially the same thing as what I do with scale_by_metadata_value
.
What is the correct / best way to implement scale_by_metadata_value
?
Here is a minimal failing example (penzai version 0.2.4):
import penzai.toolshed.basic_training
import penzai
import penzai.pz as pz
import jax
import jax.numpy as jnp
import optax
import numpy as np
model = pz.nn.Linear(
weights=pz.Parameter(
value=pz.nx.wrap(np.ones((8, 4)), "features", "features_out"),
label="linear",
metadata={"learning_rate": 0.5},
),
in_axis_names=("features",),
out_axis_names=("features_out",),
)
def softmax_cross_entropy_loss(
model, rng, state, current_input, current_target: pz.nx.NamedArray
):
del rng, state
logits: pz.nx.NamedArray = model(current_input)
loss = jnp.sum(
optax.losses.softmax_cross_entropy(
logits.unwrap("features_out"),
current_target.unwrap("features_out"),
)
)
return (loss, None, {"softmax_cross_entropy_loss": loss})
def scale_by_metadata_value(metadata_field_name: str):
def init_fn(params):
learning_rates = jax.tree.map(
lambda param: param.metadata[metadata_field_name],
params,
is_leaf=(lambda node: isinstance(node, pz.ParameterValue)),
)
return {"learning_rates": learning_rates}
def update_fn(updates, state, params):
del params
updates = jax.tree.map(lambda lr, g: lr * g, state["learning_rates"], updates)
return updates, state
return optax.GradientTransformation(init_fn, update_fn)
trainer = penzai.toolshed.basic_training.StatefulTrainer.build(
root_rng=jax.random.key(2025),
model=model,
optimizer_def=optax.chain(
optax.scale_by_adam(),
scale_by_metadata_value("learning_rate"),
optax.scale_by_learning_rate(0.01),
),
loss_fn=softmax_cross_entropy_loss,
jit=False,
)
trainer.step(
current_input=pz.nx.wrap(np.zeros(8), "features"),
current_target=pz.nx.wrap(np.ones(4), "features_out"),
)
# TypeError: unsupported operand type(s) for *: 'jaxlib.xla_extension.ArrayImpl' and 'ParameterValue'
I want to train a simple feed forward neural net that I've built in penzai, but I want to use different learning rates for each parameter group. I store the learning rate scale factor in each parameter's metadata
, e.g. like this:
Parameter(
label='mlp/Affine_0/Linear.weights',
value=<NamedArray float32(| features:784, features_out:128) (wrapping jax.Array)>,
metadata={'learning_rate': 0.0012755102040816326},
)
I use Penzai's StatefulTrainer
for training and declare the optimizer like this:
optax.chain(
optax.scale_by_adam(),
scale_by_metadata_value("learning_rate"),
optax.scale_by_learning_rate(0.01),
)
where I define scale_by_metadata_value
like this:
def scale_by_metadata_value(metadata_field_name: str):
def init_fn(params):
learning_rates = jax.tree.map(
lambda param: param.metadata[metadata_field_name],
params,
is_leaf=(lambda node: isinstance(node, pz.ParameterValue))
)
return {"learning_rates": learning_rates}
def update_fn(updates, state, params):
del params
updates = jax.tree.map(
# This is where the TypeError is thrown:
lambda lr, g: lr * g, state["learning_rates"], updates
)
return updates, state
return optax.GradientTransformation(init_fn, update_fn)
However, when I run a training step, I get
TypeError: unsupported operand type(s) for *: 'jaxlib.xla_extension.ArrayImpl' and 'ParameterValue'
I am especially puzzled because everything works when I remove the scale_by_metadata_value("learning_rate")
line, even though optax.scale_by_learning_rate(0.01) does essentially the same thing as what I do with scale_by_metadata_value
.
What is the correct / best way to implement scale_by_metadata_value
?
Here is a minimal failing example (penzai version 0.2.4):
import penzai.toolshed.basic_training
import penzai
import penzai.pz as pz
import jax
import jax.numpy as jnp
import optax
import numpy as np
model = pz.nn.Linear(
weights=pz.Parameter(
value=pz.nx.wrap(np.ones((8, 4)), "features", "features_out"),
label="linear",
metadata={"learning_rate": 0.5},
),
in_axis_names=("features",),
out_axis_names=("features_out",),
)
def softmax_cross_entropy_loss(
model, rng, state, current_input, current_target: pz.nx.NamedArray
):
del rng, state
logits: pz.nx.NamedArray = model(current_input)
loss = jnp.sum(
optax.losses.softmax_cross_entropy(
logits.unwrap("features_out"),
current_target.unwrap("features_out"),
)
)
return (loss, None, {"softmax_cross_entropy_loss": loss})
def scale_by_metadata_value(metadata_field_name: str):
def init_fn(params):
learning_rates = jax.tree.map(
lambda param: param.metadata[metadata_field_name],
params,
is_leaf=(lambda node: isinstance(node, pz.ParameterValue)),
)
return {"learning_rates": learning_rates}
def update_fn(updates, state, params):
del params
updates = jax.tree.map(lambda lr, g: lr * g, state["learning_rates"], updates)
return updates, state
return optax.GradientTransformation(init_fn, update_fn)
trainer = penzai.toolshed.basic_training.StatefulTrainer.build(
root_rng=jax.random.key(2025),
model=model,
optimizer_def=optax.chain(
optax.scale_by_adam(),
scale_by_metadata_value("learning_rate"),
optax.scale_by_learning_rate(0.01),
),
loss_fn=softmax_cross_entropy_loss,
jit=False,
)
trainer.step(
current_input=pz.nx.wrap(np.zeros(8), "features"),
current_target=pz.nx.wrap(np.ones(4), "features_out"),
)
# TypeError: unsupported operand type(s) for *: 'jaxlib.xla_extension.ArrayImpl' and 'ParameterValue'
Share
Improve this question
edited Mar 16 at 16:03
JEM_Mosig
asked Mar 16 at 15:55
JEM_MosigJEM_Mosig
2462 silver badges15 bronze badges
2
|
1 Answer
Reset to default 0The issue is that jax.tree.map
gets trees with different depths: for the learning rates the leafs are the float learning rates, but for the updates, the leaves are the values of the ParameterValue
object(s). To resolve this issue, I define
def pruned_tree_map(fn, parameter_value_tree, *trees):
structure = jax.tree.structure(parameter_value_tree)
return jax.tree.map(
fn,
parameter_value_tree,
*[jax.tree.unflatten(structure, jax.tree.leaves(tree)) for tree in trees],
)
Now I can use this in place of jax.tree.map
in update_fn
:
def update_fn(updates, state, params):
del params
updates = pruned_tree_map(
lambda u, lr: lr * u,
updates,
state["learning_rates"],
)
return updates, state
本文标签: pythonHow to scale learning rates per parameter with Penzai and OptaxStack Overflow
版权声明:本文标题:python - How to scale learning rates per parameter with Penzai and Optax - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1744594373a2614682.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
jax.tree.map
sees thevalue
attribute of theParameterValue
as leaf, but thelearning_rates
tree has floats as leafs.jax.tree.map
goes by the first tree's structure, so it tries to combine the float with the fullParameterValue
, even thoughParameterValue
is not a leaf. I'm still not sure how to solve this in an elegant way, though. – JEM_Mosig Commented Mar 16 at 16:17penzai
andoptax
to this – JEM_Mosig Commented Mar 16 at 16:54