admin管理员组文章数量:1296463
I define some named tuple like this:
class checkpoint_t(NamedTuple):
epoch: int
model_state_dict: Dict[str, Any]
optimizer_state_dict: Dict[str, Any]
model_name: str | None = None
However I after save I cannot load this namedtuple via
import torch
from train import checkpoint_t
with torch.serialization.safe_globals([checkpoint_t]):
print("safe globals: ", torch.serialization.get_safe_globals())
checkpoint: checkpoint_t = torch.load(parsed_args.checkpoint, weights_only=True)
it's still saying:
WeightsUnpickler error: Unsupported global: GLOBAL
__main__.checkpoint_t
was not an allowed global by default. Please usetorch.serialization.add_safe_globals([checkpoint_t])
or thetorch.serialization.safe_globals([checkpoint_t])
context manager to allowlist this global if you trust this class/function.
any idea why? and how to fix this?
update about why
The module train.py has an ifmain block, so the module can be executed as python -m package.subpackage.train
. If one run it like this instead of using exposed entry point of console_scripts, the train module name becomes __main__
.
I define some named tuple like this:
class checkpoint_t(NamedTuple):
epoch: int
model_state_dict: Dict[str, Any]
optimizer_state_dict: Dict[str, Any]
model_name: str | None = None
However I after save I cannot load this namedtuple via
import torch
from train import checkpoint_t
with torch.serialization.safe_globals([checkpoint_t]):
print("safe globals: ", torch.serialization.get_safe_globals())
checkpoint: checkpoint_t = torch.load(parsed_args.checkpoint, weights_only=True)
it's still saying:
WeightsUnpickler error: Unsupported global: GLOBAL
__main__.checkpoint_t
was not an allowed global by default. Please usetorch.serialization.add_safe_globals([checkpoint_t])
or thetorch.serialization.safe_globals([checkpoint_t])
context manager to allowlist this global if you trust this class/function.
any idea why? and how to fix this?
update about why
The module train.py has an ifmain block, so the module can be executed as python -m package.subpackage.train
. If one run it like this instead of using exposed entry point of console_scripts, the train module name becomes __main__
.
1 Answer
Reset to default 1The problem is revealed in the error message that you get:
- PyTorch complains that
__main__.checkpoint_t
is not insafe_globals
. - What you actually put into
safe_globals
, however, istrain.checkpoint_t
(as can be seen from your importfrom train import checkpoint_t
).
I guess what happened here is, that at some point, the checkpoint_t
class was moved from the top-level code environment to module train
, while the weights that you are trying to load have been created with an earlier version of your code, with the checkpoint_t
class still in its original place.
A hacky solution
There is no really easy way to fix this, as far as I know (other than putting the definition of checkpoint_t
back to its original place, of course); however, we can make use of the following fact: Every *.pth
file that is not too old is nothing but a bunch of zipped files, into which the actual contents have been serialized using Python's standard pickle
module. Based on this, we can use the following, somewhat hacky approach, drawing some inspiration from this blog post:
- Unzip the
*.pth
file. - Unserialize the content of the contained
data.pkl
file and adjust the module name forcheckpoint_t
instances from__main__
totrain
. Or, to be more precise: when the unpickling process is looking for class__main__.checkpoint_t
into which to deserialize the pickled instance, return classtrain.checkpoint_t
instead. Here, we can follow this answer to a related question and employ our ownUnpickler
subclass to make the adjustments. - Reserialize the content of the
data.pkl
file and rezip everything into a new*.pth
file.
It should then be possible to load the new *.pth
file in the way you tried in the question.
All in all, this could look as follows:
import pickle
import zipfile
# TODO: Provide the paths to be read (original_path) and written (converted_path)
original_path = "test.pth"
converted_path = "converted.pth"
class RenameUnpickler(pickle.Unpickler):
# Following https://stackoverflow/a/53327348/7395592
def find_class(self, module, name):
if module == "__main__" and name == "checkpoint_t":
module = "train"
return super().find_class(module, name)
# Read and adjust pickled data
with zipfile.ZipFile(original_path, "r") as z:
pickle_path = next(n for n in z.namelist() if n.endswith("/data.pkl"))
with z.open(pickle_path) as f:
unpickled_and_renamed = RenameUnpickler(f, encoding="utf-8").load()
# Re-zip adjusted pickled data
with zipfile.ZipFile(converted_path, "w") as conv_z:
with zipfile.ZipFile(original_path, "r") as orig_z:
for item in orig_z.infolist():
if item.filename.endswith("/data.pkl"):
with conv_z.open(item.filename, "w") as f:
pickle.dump(unpickled_and_renamed, f, protocol=2)
else:
conv_z.writestr(item.filename, orig_z.open(item).read())
There are a few caveats: I am not sure with which version of PyTorch your original *.pth
file was written, so the approach might not work directly. Most notably, in my experiments, I had to force the pickle protocol version via pickle.dump(..., protocol=2)
or else I could not load the re-written *.pth
file. I am not sure if this is the case for all versions of PyTorch or if there are other conventions to follow for other versions (I tested with PyTorch 2.5.1).
本文标签: pythontorch cannot add NamedTuple class to safeglobalsStack Overflow
版权声明:本文标题:python - torch cannot add NamedTuple class to safe_globals - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1741639050a2389808.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论