admin管理员组文章数量:1355611
I've been developing models based on this tutorial: .html
However, since now my code can take days to run, I wish to save the model in the middle, so I could load it later.
I tried with JSON, pickle, and several other options GPT suggested, but to no avail.
So before I try saving all the params of the model as strings, and converting them back when I load the model, I wanted to know if there's an easier option.
Here's an sample of my code:
from dynamax.hidden_markov_model import GaussianHMM
import jax.random as jr
hmm = GaussianHMM(5, 3)
param, properties = hmm.initialize(jr.PRNGKey(10))
To save that model I tried, for instance:
import jax.numpy as jnp
import jax
import pickle
def backup_hmm(params, props, filename="hmm_backup_jax.pkl"):
# Extract arrays safely
params_flat, params_tree = jax.tree_util.tree_flatten(params)
props_flat, props_tree = jax.tree_util.tree_flatten(props)
# Convert to plain lists
params_flat = [jnp.array(p) for p in params_flat]
props_flat = [jnp.array(p) for p in props_flat]
# Save to pickle
with open(filename, "wb") as f:
pickle.dump({
"params": params_flat,
"params_tree": params_tree,
"props": props_flat,
"props_tree": props_tree
}, f)
print("Backup completed.")
def restore_hmm(filename="hmm_backup_jax.pkl"):
with open(filename, "rb") as f:
data = pickle.load(f)
# Restore
params_flat = [jnp.array(p) for p in data["params"]]
props_flat = [jnp.array(p) for p in data["props"]]
# Reconstruct original structures
params = jax.tree_util.tree_unflatten(data["params_tree"], params_flat)
props = jax.tree_util.tree_unflatten(data["props_tree"], props_flat)
print("Restoration completed.")
return params, props
本文标签: pythonHow do I save a dynamax GussianHMM modelStack Overflow
版权声明:本文标题:python - How do I save a dynamax GussianHMM model? - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1743984916a2571235.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论