admin管理员组

文章数量:1355611

I've been developing models based on this tutorial: .html

However, since now my code can take days to run, I wish to save the model in the middle, so I could load it later.

I tried with JSON, pickle, and several other options GPT suggested, but to no avail.

So before I try saving all the params of the model as strings, and converting them back when I load the model, I wanted to know if there's an easier option.

Here's an sample of my code:

from dynamax.hidden_markov_model import GaussianHMM
import jax.random as jr

hmm = GaussianHMM(5, 3)
param, properties = hmm.initialize(jr.PRNGKey(10))

To save that model I tried, for instance:

import jax.numpy as jnp
import jax
import pickle

def backup_hmm(params, props, filename="hmm_backup_jax.pkl"):
    # Extract arrays safely
    params_flat, params_tree = jax.tree_util.tree_flatten(params)
    props_flat, props_tree = jax.tree_util.tree_flatten(props)

    # Convert to plain lists
    params_flat = [jnp.array(p) for p in params_flat]
    props_flat = [jnp.array(p) for p in props_flat]

    # Save to pickle
    with open(filename, "wb") as f:
        pickle.dump({
            "params": params_flat,
            "params_tree": params_tree,
            "props": props_flat,
            "props_tree": props_tree
        }, f)

    print("Backup completed.")

def restore_hmm(filename="hmm_backup_jax.pkl"):
    with open(filename, "rb") as f:
        data = pickle.load(f)

    # Restore
    params_flat = [jnp.array(p) for p in data["params"]]
    props_flat = [jnp.array(p) for p in data["props"]]

    # Reconstruct original structures
    params = jax.tree_util.tree_unflatten(data["params_tree"], params_flat)
    props = jax.tree_util.tree_unflatten(data["props_tree"], props_flat)

    print("Restoration completed.")
    return params, props

本文标签: pythonHow do I save a dynamax GussianHMM modelStack Overflow