admin管理员组

文章数量:1390881

I want to train a simple feed forward neural net that I've built in penzai, but I want to use different learning rates for each parameter group. I store the learning rate scale factor in each parameter's metadata, e.g. like this:

Parameter(
    label='mlp/Affine_0/Linear.weights',
    value=<NamedArray float32(| features:784, features_out:128) (wrapping jax.Array)>,
    metadata={'learning_rate': 0.0012755102040816326},
)

I use Penzai's StatefulTrainer for training and declare the optimizer like this:

optax.chain(
    optax.scale_by_adam(),
    scale_by_metadata_value("learning_rate"),
    optax.scale_by_learning_rate(0.01),
)

where I define scale_by_metadata_value like this:

def scale_by_metadata_value(metadata_field_name: str):
    def init_fn(params):
        learning_rates = jax.tree.map(
            lambda param: param.metadata[metadata_field_name],
            params,
            is_leaf=(lambda node: isinstance(node, pz.ParameterValue))
        )
        return {"learning_rates": learning_rates}

    def update_fn(updates, state, params):
        del params
        updates = jax.tree.map(
            # This is where the TypeError is thrown:
            lambda lr, g: lr * g, state["learning_rates"], updates
        )
        return updates, state

    return optax.GradientTransformation(init_fn, update_fn)

However, when I run a training step, I get

TypeError: unsupported operand type(s) for *: 'jaxlib.xla_extension.ArrayImpl' and 'ParameterValue'

I am especially puzzled because everything works when I remove the scale_by_metadata_value("learning_rate") line, even though optax.scale_by_learning_rate(0.01) does essentially the same thing as what I do with scale_by_metadata_value.

What is the correct / best way to implement scale_by_metadata_value?

Here is a minimal failing example (penzai version 0.2.4):

import penzai.toolshed.basic_training
import penzai
import penzai.pz as pz
import jax
import jax.numpy as jnp
import optax
import numpy as np


model = pz.nn.Linear(
    weights=pz.Parameter(
        value=pz.nx.wrap(np.ones((8, 4)), "features", "features_out"),
        label="linear",
        metadata={"learning_rate": 0.5},
    ),
    in_axis_names=("features",),
    out_axis_names=("features_out",),
)


def softmax_cross_entropy_loss(
    model, rng, state, current_input, current_target: pz.nx.NamedArray
):
    del rng, state
    logits: pz.nx.NamedArray = model(current_input)
    loss = jnp.sum(
        optax.losses.softmax_cross_entropy(
            logits.unwrap("features_out"),
            current_target.unwrap("features_out"),
        )
    )
    return (loss, None, {"softmax_cross_entropy_loss": loss})


def scale_by_metadata_value(metadata_field_name: str):
    def init_fn(params):
        learning_rates = jax.tree.map(
            lambda param: param.metadata[metadata_field_name],
            params,
            is_leaf=(lambda node: isinstance(node, pz.ParameterValue)),
        )
        return {"learning_rates": learning_rates}

    def update_fn(updates, state, params):
        del params
        updates = jax.tree.map(lambda lr, g: lr * g, state["learning_rates"], updates)
        return updates, state

    return optax.GradientTransformation(init_fn, update_fn)


trainer = penzai.toolshed.basic_training.StatefulTrainer.build(
    root_rng=jax.random.key(2025),
    model=model,
    optimizer_def=optax.chain(
        optax.scale_by_adam(),
        scale_by_metadata_value("learning_rate"),
        optax.scale_by_learning_rate(0.01),
    ),
    loss_fn=softmax_cross_entropy_loss,
    jit=False,
)

trainer.step(
    current_input=pz.nx.wrap(np.zeros(8), "features"),
    current_target=pz.nx.wrap(np.ones(4), "features_out"),
)
# TypeError: unsupported operand type(s) for *: 'jaxlib.xla_extension.ArrayImpl' and 'ParameterValue'

I want to train a simple feed forward neural net that I've built in penzai, but I want to use different learning rates for each parameter group. I store the learning rate scale factor in each parameter's metadata, e.g. like this:

Parameter(
    label='mlp/Affine_0/Linear.weights',
    value=<NamedArray float32(| features:784, features_out:128) (wrapping jax.Array)>,
    metadata={'learning_rate': 0.0012755102040816326},
)

I use Penzai's StatefulTrainer for training and declare the optimizer like this:

optax.chain(
    optax.scale_by_adam(),
    scale_by_metadata_value("learning_rate"),
    optax.scale_by_learning_rate(0.01),
)

where I define scale_by_metadata_value like this:

def scale_by_metadata_value(metadata_field_name: str):
    def init_fn(params):
        learning_rates = jax.tree.map(
            lambda param: param.metadata[metadata_field_name],
            params,
            is_leaf=(lambda node: isinstance(node, pz.ParameterValue))
        )
        return {"learning_rates": learning_rates}

    def update_fn(updates, state, params):
        del params
        updates = jax.tree.map(
            # This is where the TypeError is thrown:
            lambda lr, g: lr * g, state["learning_rates"], updates
        )
        return updates, state

    return optax.GradientTransformation(init_fn, update_fn)

However, when I run a training step, I get

TypeError: unsupported operand type(s) for *: 'jaxlib.xla_extension.ArrayImpl' and 'ParameterValue'

I am especially puzzled because everything works when I remove the scale_by_metadata_value("learning_rate") line, even though optax.scale_by_learning_rate(0.01) does essentially the same thing as what I do with scale_by_metadata_value.

What is the correct / best way to implement scale_by_metadata_value?

Here is a minimal failing example (penzai version 0.2.4):

import penzai.toolshed.basic_training
import penzai
import penzai.pz as pz
import jax
import jax.numpy as jnp
import optax
import numpy as np


model = pz.nn.Linear(
    weights=pz.Parameter(
        value=pz.nx.wrap(np.ones((8, 4)), "features", "features_out"),
        label="linear",
        metadata={"learning_rate": 0.5},
    ),
    in_axis_names=("features",),
    out_axis_names=("features_out",),
)


def softmax_cross_entropy_loss(
    model, rng, state, current_input, current_target: pz.nx.NamedArray
):
    del rng, state
    logits: pz.nx.NamedArray = model(current_input)
    loss = jnp.sum(
        optax.losses.softmax_cross_entropy(
            logits.unwrap("features_out"),
            current_target.unwrap("features_out"),
        )
    )
    return (loss, None, {"softmax_cross_entropy_loss": loss})


def scale_by_metadata_value(metadata_field_name: str):
    def init_fn(params):
        learning_rates = jax.tree.map(
            lambda param: param.metadata[metadata_field_name],
            params,
            is_leaf=(lambda node: isinstance(node, pz.ParameterValue)),
        )
        return {"learning_rates": learning_rates}

    def update_fn(updates, state, params):
        del params
        updates = jax.tree.map(lambda lr, g: lr * g, state["learning_rates"], updates)
        return updates, state

    return optax.GradientTransformation(init_fn, update_fn)


trainer = penzai.toolshed.basic_training.StatefulTrainer.build(
    root_rng=jax.random.key(2025),
    model=model,
    optimizer_def=optax.chain(
        optax.scale_by_adam(),
        scale_by_metadata_value("learning_rate"),
        optax.scale_by_learning_rate(0.01),
    ),
    loss_fn=softmax_cross_entropy_loss,
    jit=False,
)

trainer.step(
    current_input=pz.nx.wrap(np.zeros(8), "features"),
    current_target=pz.nx.wrap(np.ones(4), "features_out"),
)
# TypeError: unsupported operand type(s) for *: 'jaxlib.xla_extension.ArrayImpl' and 'ParameterValue'
Share Improve this question edited Mar 16 at 16:03 JEM_Mosig asked Mar 16 at 15:55 JEM_MosigJEM_Mosig 2462 silver badges15 bronze badges 2
  • I just realized that the issue is that jax.tree.map sees the value attribute of the ParameterValue as leaf, but the learning_rates tree has floats as leafs. jax.tree.map goes by the first tree's structure, so it tries to combine the float with the full ParameterValue, even though ParameterValue is not a leaf. I'm still not sure how to solve this in an elegant way, though. – JEM_Mosig Commented Mar 16 at 16:17
  • I'd like to add the tags penzai and optax to this – JEM_Mosig Commented Mar 16 at 16:54
Add a comment  | 

1 Answer 1

Reset to default 0

The issue is that jax.tree.map gets trees with different depths: for the learning rates the leafs are the float learning rates, but for the updates, the leaves are the values of the ParameterValue object(s). To resolve this issue, I define

def pruned_tree_map(fn, parameter_value_tree, *trees):
    structure = jax.tree.structure(parameter_value_tree)
    return jax.tree.map(
        fn,
        parameter_value_tree,
        *[jax.tree.unflatten(structure, jax.tree.leaves(tree)) for tree in trees],
    )

Now I can use this in place of jax.tree.map in update_fn:

def update_fn(updates, state, params):
    del params
    updates = pruned_tree_map(
        lambda u, lr: lr * u,
        updates,
        state["learning_rates"],
    )
    return updates, state

本文标签: pythonHow to scale learning rates per parameter with Penzai and OptaxStack Overflow